Move token invalidation to auth interceptor

This commit is contained in:
Kirill Kamakin
2022-12-10 09:30:38 +01:00
parent 54d0c895a9
commit 9adf37ae33
5 changed files with 202 additions and 49 deletions

View File

@@ -3,4 +3,6 @@ package gq.kirmanak.mealient.datasource
interface AuthenticationProvider {
suspend fun getAuthHeader(): String?
suspend fun invalidateAuthHeader()
}

View File

@@ -1,9 +1,11 @@
package gq.kirmanak.mealient.datasource.impl
import androidx.annotation.VisibleForTesting
import gq.kirmanak.mealient.datasource.AuthenticationProvider
import kotlinx.coroutines.runBlocking
import okhttp3.Interceptor
import okhttp3.Response
import retrofit2.HttpException
import javax.inject.Inject
import javax.inject.Provider
import javax.inject.Singleton
@@ -14,19 +16,43 @@ class AuthInterceptor @Inject constructor(
) : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response {
val token = runBlocking { authenticationProvider.get().getAuthHeader() }
val request = if (token == null) {
chain.request()
val token = getAuthHeader()
return if (token == null) {
proceedWithAuthHeader(chain)
} else {
chain.request()
.newBuilder()
.header(HEADER_NAME, token)
.build()
try {
proceedWithAuthHeader(chain, token)
} catch (e: HttpException) {
if (e.code() in setOf(401, 403)) {
invalidateAuthHeader()
proceedWithAuthHeader(chain, getAuthHeader())
} else {
throw e
}
}
}
}
private fun proceedWithAuthHeader(chain: Interceptor.Chain, token: String? = null): Response {
val requestBuilder = chain.request().newBuilder()
val request = if (token == null) {
requestBuilder.removeHeader(HEADER_NAME).build()
} else {
requestBuilder.header(HEADER_NAME, token).build()
}
return chain.proceed(request)
}
private fun getAuthHeader() = runBlocking {
authenticationProvider.get().getAuthHeader()
}
private fun invalidateAuthHeader() = runBlocking {
authenticationProvider.get().invalidateAuthHeader()
}
companion object {
private const val HEADER_NAME = "Authorization"
@VisibleForTesting
const val HEADER_NAME = "Authorization"
}
}

View File

@@ -0,0 +1,164 @@
package gq.kirmanak.mealient.datasource
import com.google.common.truth.Truth.assertThat
import gq.kirmanak.mealient.datasource.impl.AuthInterceptor
import gq.kirmanak.mealient.test.BaseUnitTest
import io.mockk.CapturingSlot
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.coVerifySequence
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.slot
import io.mockk.verify
import okhttp3.Interceptor
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import okhttp3.ResponseBody.Companion.toResponseBody
import org.junit.Before
import org.junit.Test
import retrofit2.HttpException
import retrofit2.Response as RetrofitResponse
class AuthInterceptorTest : BaseUnitTest() {
private lateinit var subject: AuthInterceptor
@MockK(relaxUnitFun = true)
lateinit var authenticationProvider: AuthenticationProvider
@MockK(relaxUnitFun = true)
lateinit var chain: Interceptor.Chain
@Before
override fun setUp() {
super.setUp()
subject = AuthInterceptor { authenticationProvider }
every { chain.request() } returns Request.Builder().url("http://localhost").build()
}
@Test
fun `when no header then still proceeds`() {
mockProceedCallAndCaptureRequest()
coEvery { authenticationProvider.getAuthHeader() } returns null
subject.intercept(chain)
verify { chain.proceed(any()) }
}
@Test
fun `when has header then adds header`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = mockProceedCallAndCaptureRequest()
subject.intercept(chain)
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
}
@Test(expected = HttpException::class)
fun `when unauthorized but didn't have token then throws`() {
coEvery { authenticationProvider.getAuthHeader() } returns null
mockUnauthorized()
subject.intercept(chain)
}
@Test(expected = HttpException::class)
fun `when unauthorized and had token then invalidates it`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
mockUnauthorized()
subject.intercept(chain)
coVerify { authenticationProvider.invalidateAuthHeader() }
}
@Test(expected = HttpException::class)
fun `when not found and had token then throws exception`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
}
subject.intercept(chain)
}
@Test
fun `when not found and had token then does not invalidate it`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
}
try {
subject.intercept(chain)
} catch (e: HttpException) {
coVerify(inverse = true) { authenticationProvider.invalidateAuthHeader() }
}
}
@Test
fun `when unauthorized and had token then calls again with new token`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token" andThen "newToken"
val requests = mutableListOf<Request>()
every {
chain.proceed(capture(requests))
} answers {
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
} andThenAnswer {
buildResponse(requests[1])
}
subject.intercept(chain)
coVerifySequence {
authenticationProvider.getAuthHeader()
authenticationProvider.invalidateAuthHeader()
authenticationProvider.getAuthHeader()
}
assertThat(requests[0].header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
assertThat(requests[1].header(AuthInterceptor.HEADER_NAME)).isEqualTo("newToken")
}
@Test
fun `when had token but now does not then removes it`() {
coEvery { authenticationProvider.getAuthHeader() } returns null
val mockRequest = Request.Builder()
.url("http://localhost")
.header(AuthInterceptor.HEADER_NAME, "token")
.build()
every { chain.request() } returns mockRequest
val requestSlot = mockProceedCallAndCaptureRequest()
subject.intercept(chain)
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isNull()
}
private fun mockUnauthorized() {
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
}
}
private fun mockProceedCallAndCaptureRequest(): CapturingSlot<Request> {
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
buildResponse(requestSlot.captured)
}
return requestSlot
}
private fun buildResponse(
request: Request,
code: Int = 200,
protocol: Protocol = Protocol.HTTP_2,
message: String = if (code == 200) "OK" else "NOT OK",
) = Response.Builder()
.code(code)
.request(request)
.protocol(protocol)
.message(message)
.build()
}