Move token invalidation to auth interceptor

This commit is contained in:
Kirill Kamakin
2022-12-10 09:30:38 +01:00
parent 54d0c895a9
commit 9adf37ae33
5 changed files with 202 additions and 49 deletions

View File

@@ -2,7 +2,6 @@ package gq.kirmanak.mealient.data.network
import gq.kirmanak.mealient.data.add.AddRecipeDataSource
import gq.kirmanak.mealient.data.add.AddRecipeInfo
import gq.kirmanak.mealient.data.auth.AuthRepo
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
import gq.kirmanak.mealient.data.baseurl.ServerVersion
import gq.kirmanak.mealient.data.recipes.network.FullRecipeInfo
@@ -10,8 +9,6 @@ import gq.kirmanak.mealient.data.recipes.network.RecipeDataSource
import gq.kirmanak.mealient.data.recipes.network.RecipeSummaryInfo
import gq.kirmanak.mealient.data.share.ParseRecipeDataSource
import gq.kirmanak.mealient.data.share.ParseRecipeURLInfo
import gq.kirmanak.mealient.datasource.NetworkError
import gq.kirmanak.mealient.datasource.runCatchingExceptCancel
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
import gq.kirmanak.mealient.extensions.toFullRecipeInfo
@@ -20,17 +17,14 @@ import gq.kirmanak.mealient.extensions.toV0Request
import gq.kirmanak.mealient.extensions.toV1CreateRequest
import gq.kirmanak.mealient.extensions.toV1Request
import gq.kirmanak.mealient.extensions.toV1UpdateRequest
import gq.kirmanak.mealient.logging.Logger
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class MealieDataSourceWrapper @Inject constructor(
private val serverInfoRepo: ServerInfoRepo,
private val authRepo: AuthRepo,
private val v0Source: MealieDataSourceV0,
private val v1Source: MealieDataSourceV1,
private val logger: Logger,
) : AddRecipeDataSource, RecipeDataSource, ParseRecipeDataSource {
override suspend fun addRecipe(
@@ -85,20 +79,8 @@ class MealieDataSourceWrapper @Inject constructor(
}
private suspend inline fun <T> makeCall(block: (String, ServerVersion) -> T): T {
val authHeader = authRepo.getAuthHeader()
val url = serverInfoRepo.requireUrl()
val version = serverInfoRepo.getVersion()
return runCatchingExceptCancel { block(url, version) }.getOrElse {
if (it is NetworkError.Unauthorized) {
logger.e { "Unauthorized, trying to invalidate token" }
authRepo.invalidateAuthHeader()
// Trying again with new authentication header
val newHeader = authRepo.getAuthHeader()
logger.e { "New token ${if (newHeader == authHeader) "matches" else "doesn't match"} old token" }
if (newHeader == authHeader) throw it else block(url, version)
} else {
throw it
}
}
return block(url, version)
}
}

View File

@@ -3,7 +3,6 @@ package gq.kirmanak.mealient.data.network
import com.google.common.truth.Truth.assertThat
import gq.kirmanak.mealient.data.auth.AuthRepo
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
import gq.kirmanak.mealient.datasource.NetworkError
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
import gq.kirmanak.mealient.test.AuthImplTestData.TEST_AUTH_HEADER
@@ -15,7 +14,6 @@ import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_INFO
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_REQUEST_V0
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_CREATE_RECIPE_REQUEST_V1
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_FULL_RECIPE_INFO
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V0
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V1
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V0
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V1
@@ -50,26 +48,7 @@ class MealieDataSourceWrapperTest : BaseUnitTest() {
@Before
override fun setUp() {
super.setUp()
subject = MealieDataSourceWrapper(serverInfoRepo, authRepo, v0Source, v1Source, logger)
}
@Test
fun `when makeCall fails with Unauthorized expect it to invalidate token`() = runTest {
val slug = "porridge"
coEvery {
v0Source.requestRecipeInfo(any(), any())
} throws NetworkError.Unauthorized(IOException()) andThen PORRIDGE_RECIPE_RESPONSE_V0
coEvery { serverInfoRepo.getVersion() } returns TEST_SERVER_VERSION_V0
coEvery { serverInfoRepo.requireUrl() } returns TEST_BASE_URL
coEvery { authRepo.getAuthHeader() } returns null andThen TEST_AUTH_HEADER
subject.requestRecipeInfo(slug)
coVerifySequence {
authRepo.getAuthHeader()
authRepo.invalidateAuthHeader()
authRepo.getAuthHeader()
}
subject = MealieDataSourceWrapper(serverInfoRepo, v0Source, v1Source)
}
@Test

View File

@@ -3,4 +3,6 @@ package gq.kirmanak.mealient.datasource
interface AuthenticationProvider {
suspend fun getAuthHeader(): String?
suspend fun invalidateAuthHeader()
}

View File

@@ -1,9 +1,11 @@
package gq.kirmanak.mealient.datasource.impl
import androidx.annotation.VisibleForTesting
import gq.kirmanak.mealient.datasource.AuthenticationProvider
import kotlinx.coroutines.runBlocking
import okhttp3.Interceptor
import okhttp3.Response
import retrofit2.HttpException
import javax.inject.Inject
import javax.inject.Provider
import javax.inject.Singleton
@@ -14,19 +16,43 @@ class AuthInterceptor @Inject constructor(
) : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response {
val token = runBlocking { authenticationProvider.get().getAuthHeader() }
val request = if (token == null) {
chain.request()
val token = getAuthHeader()
return if (token == null) {
proceedWithAuthHeader(chain)
} else {
chain.request()
.newBuilder()
.header(HEADER_NAME, token)
.build()
try {
proceedWithAuthHeader(chain, token)
} catch (e: HttpException) {
if (e.code() in setOf(401, 403)) {
invalidateAuthHeader()
proceedWithAuthHeader(chain, getAuthHeader())
} else {
throw e
}
}
}
}
private fun proceedWithAuthHeader(chain: Interceptor.Chain, token: String? = null): Response {
val requestBuilder = chain.request().newBuilder()
val request = if (token == null) {
requestBuilder.removeHeader(HEADER_NAME).build()
} else {
requestBuilder.header(HEADER_NAME, token).build()
}
return chain.proceed(request)
}
private fun getAuthHeader() = runBlocking {
authenticationProvider.get().getAuthHeader()
}
private fun invalidateAuthHeader() = runBlocking {
authenticationProvider.get().invalidateAuthHeader()
}
companion object {
private const val HEADER_NAME = "Authorization"
@VisibleForTesting
const val HEADER_NAME = "Authorization"
}
}

View File

@@ -0,0 +1,164 @@
package gq.kirmanak.mealient.datasource
import com.google.common.truth.Truth.assertThat
import gq.kirmanak.mealient.datasource.impl.AuthInterceptor
import gq.kirmanak.mealient.test.BaseUnitTest
import io.mockk.CapturingSlot
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.coVerifySequence
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.slot
import io.mockk.verify
import okhttp3.Interceptor
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import okhttp3.ResponseBody.Companion.toResponseBody
import org.junit.Before
import org.junit.Test
import retrofit2.HttpException
import retrofit2.Response as RetrofitResponse
class AuthInterceptorTest : BaseUnitTest() {
private lateinit var subject: AuthInterceptor
@MockK(relaxUnitFun = true)
lateinit var authenticationProvider: AuthenticationProvider
@MockK(relaxUnitFun = true)
lateinit var chain: Interceptor.Chain
@Before
override fun setUp() {
super.setUp()
subject = AuthInterceptor { authenticationProvider }
every { chain.request() } returns Request.Builder().url("http://localhost").build()
}
@Test
fun `when no header then still proceeds`() {
mockProceedCallAndCaptureRequest()
coEvery { authenticationProvider.getAuthHeader() } returns null
subject.intercept(chain)
verify { chain.proceed(any()) }
}
@Test
fun `when has header then adds header`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = mockProceedCallAndCaptureRequest()
subject.intercept(chain)
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
}
@Test(expected = HttpException::class)
fun `when unauthorized but didn't have token then throws`() {
coEvery { authenticationProvider.getAuthHeader() } returns null
mockUnauthorized()
subject.intercept(chain)
}
@Test(expected = HttpException::class)
fun `when unauthorized and had token then invalidates it`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
mockUnauthorized()
subject.intercept(chain)
coVerify { authenticationProvider.invalidateAuthHeader() }
}
@Test(expected = HttpException::class)
fun `when not found and had token then throws exception`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
}
subject.intercept(chain)
}
@Test
fun `when not found and had token then does not invalidate it`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
}
try {
subject.intercept(chain)
} catch (e: HttpException) {
coVerify(inverse = true) { authenticationProvider.invalidateAuthHeader() }
}
}
@Test
fun `when unauthorized and had token then calls again with new token`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token" andThen "newToken"
val requests = mutableListOf<Request>()
every {
chain.proceed(capture(requests))
} answers {
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
} andThenAnswer {
buildResponse(requests[1])
}
subject.intercept(chain)
coVerifySequence {
authenticationProvider.getAuthHeader()
authenticationProvider.invalidateAuthHeader()
authenticationProvider.getAuthHeader()
}
assertThat(requests[0].header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
assertThat(requests[1].header(AuthInterceptor.HEADER_NAME)).isEqualTo("newToken")
}
@Test
fun `when had token but now does not then removes it`() {
coEvery { authenticationProvider.getAuthHeader() } returns null
val mockRequest = Request.Builder()
.url("http://localhost")
.header(AuthInterceptor.HEADER_NAME, "token")
.build()
every { chain.request() } returns mockRequest
val requestSlot = mockProceedCallAndCaptureRequest()
subject.intercept(chain)
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isNull()
}
private fun mockUnauthorized() {
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
}
}
private fun mockProceedCallAndCaptureRequest(): CapturingSlot<Request> {
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
buildResponse(requestSlot.captured)
}
return requestSlot
}
private fun buildResponse(
request: Request,
code: Int = 200,
protocol: Protocol = Protocol.HTTP_2,
message: String = if (code == 200) "OK" else "NOT OK",
) = Response.Builder()
.code(code)
.request(request)
.protocol(protocol)
.message(message)
.build()
}