Move token invalidation to auth interceptor

This commit is contained in:
Kirill Kamakin
2022-12-10 09:30:38 +01:00
parent 54d0c895a9
commit 9adf37ae33
5 changed files with 202 additions and 49 deletions

View File

@@ -2,7 +2,6 @@ package gq.kirmanak.mealient.data.network
import gq.kirmanak.mealient.data.add.AddRecipeDataSource import gq.kirmanak.mealient.data.add.AddRecipeDataSource
import gq.kirmanak.mealient.data.add.AddRecipeInfo import gq.kirmanak.mealient.data.add.AddRecipeInfo
import gq.kirmanak.mealient.data.auth.AuthRepo
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
import gq.kirmanak.mealient.data.baseurl.ServerVersion import gq.kirmanak.mealient.data.baseurl.ServerVersion
import gq.kirmanak.mealient.data.recipes.network.FullRecipeInfo import gq.kirmanak.mealient.data.recipes.network.FullRecipeInfo
@@ -10,8 +9,6 @@ import gq.kirmanak.mealient.data.recipes.network.RecipeDataSource
import gq.kirmanak.mealient.data.recipes.network.RecipeSummaryInfo import gq.kirmanak.mealient.data.recipes.network.RecipeSummaryInfo
import gq.kirmanak.mealient.data.share.ParseRecipeDataSource import gq.kirmanak.mealient.data.share.ParseRecipeDataSource
import gq.kirmanak.mealient.data.share.ParseRecipeURLInfo import gq.kirmanak.mealient.data.share.ParseRecipeURLInfo
import gq.kirmanak.mealient.datasource.NetworkError
import gq.kirmanak.mealient.datasource.runCatchingExceptCancel
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0 import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1 import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
import gq.kirmanak.mealient.extensions.toFullRecipeInfo import gq.kirmanak.mealient.extensions.toFullRecipeInfo
@@ -20,17 +17,14 @@ import gq.kirmanak.mealient.extensions.toV0Request
import gq.kirmanak.mealient.extensions.toV1CreateRequest import gq.kirmanak.mealient.extensions.toV1CreateRequest
import gq.kirmanak.mealient.extensions.toV1Request import gq.kirmanak.mealient.extensions.toV1Request
import gq.kirmanak.mealient.extensions.toV1UpdateRequest import gq.kirmanak.mealient.extensions.toV1UpdateRequest
import gq.kirmanak.mealient.logging.Logger
import javax.inject.Inject import javax.inject.Inject
import javax.inject.Singleton import javax.inject.Singleton
@Singleton @Singleton
class MealieDataSourceWrapper @Inject constructor( class MealieDataSourceWrapper @Inject constructor(
private val serverInfoRepo: ServerInfoRepo, private val serverInfoRepo: ServerInfoRepo,
private val authRepo: AuthRepo,
private val v0Source: MealieDataSourceV0, private val v0Source: MealieDataSourceV0,
private val v1Source: MealieDataSourceV1, private val v1Source: MealieDataSourceV1,
private val logger: Logger,
) : AddRecipeDataSource, RecipeDataSource, ParseRecipeDataSource { ) : AddRecipeDataSource, RecipeDataSource, ParseRecipeDataSource {
override suspend fun addRecipe( override suspend fun addRecipe(
@@ -85,20 +79,8 @@ class MealieDataSourceWrapper @Inject constructor(
} }
private suspend inline fun <T> makeCall(block: (String, ServerVersion) -> T): T { private suspend inline fun <T> makeCall(block: (String, ServerVersion) -> T): T {
val authHeader = authRepo.getAuthHeader()
val url = serverInfoRepo.requireUrl() val url = serverInfoRepo.requireUrl()
val version = serverInfoRepo.getVersion() val version = serverInfoRepo.getVersion()
return runCatchingExceptCancel { block(url, version) }.getOrElse { return block(url, version)
if (it is NetworkError.Unauthorized) {
logger.e { "Unauthorized, trying to invalidate token" }
authRepo.invalidateAuthHeader()
// Trying again with new authentication header
val newHeader = authRepo.getAuthHeader()
logger.e { "New token ${if (newHeader == authHeader) "matches" else "doesn't match"} old token" }
if (newHeader == authHeader) throw it else block(url, version)
} else {
throw it
}
}
} }
} }

View File

@@ -3,7 +3,6 @@ package gq.kirmanak.mealient.data.network
import com.google.common.truth.Truth.assertThat import com.google.common.truth.Truth.assertThat
import gq.kirmanak.mealient.data.auth.AuthRepo import gq.kirmanak.mealient.data.auth.AuthRepo
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
import gq.kirmanak.mealient.datasource.NetworkError
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0 import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1 import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
import gq.kirmanak.mealient.test.AuthImplTestData.TEST_AUTH_HEADER import gq.kirmanak.mealient.test.AuthImplTestData.TEST_AUTH_HEADER
@@ -15,7 +14,6 @@ import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_INFO
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_REQUEST_V0 import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_REQUEST_V0
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_CREATE_RECIPE_REQUEST_V1 import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_CREATE_RECIPE_REQUEST_V1
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_FULL_RECIPE_INFO import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_FULL_RECIPE_INFO
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V0
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V1 import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V1
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V0 import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V0
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V1 import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V1
@@ -50,26 +48,7 @@ class MealieDataSourceWrapperTest : BaseUnitTest() {
@Before @Before
override fun setUp() { override fun setUp() {
super.setUp() super.setUp()
subject = MealieDataSourceWrapper(serverInfoRepo, authRepo, v0Source, v1Source, logger) subject = MealieDataSourceWrapper(serverInfoRepo, v0Source, v1Source)
}
@Test
fun `when makeCall fails with Unauthorized expect it to invalidate token`() = runTest {
val slug = "porridge"
coEvery {
v0Source.requestRecipeInfo(any(), any())
} throws NetworkError.Unauthorized(IOException()) andThen PORRIDGE_RECIPE_RESPONSE_V0
coEvery { serverInfoRepo.getVersion() } returns TEST_SERVER_VERSION_V0
coEvery { serverInfoRepo.requireUrl() } returns TEST_BASE_URL
coEvery { authRepo.getAuthHeader() } returns null andThen TEST_AUTH_HEADER
subject.requestRecipeInfo(slug)
coVerifySequence {
authRepo.getAuthHeader()
authRepo.invalidateAuthHeader()
authRepo.getAuthHeader()
}
} }
@Test @Test

View File

@@ -3,4 +3,6 @@ package gq.kirmanak.mealient.datasource
interface AuthenticationProvider { interface AuthenticationProvider {
suspend fun getAuthHeader(): String? suspend fun getAuthHeader(): String?
suspend fun invalidateAuthHeader()
} }

View File

@@ -1,9 +1,11 @@
package gq.kirmanak.mealient.datasource.impl package gq.kirmanak.mealient.datasource.impl
import androidx.annotation.VisibleForTesting
import gq.kirmanak.mealient.datasource.AuthenticationProvider import gq.kirmanak.mealient.datasource.AuthenticationProvider
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import okhttp3.Interceptor import okhttp3.Interceptor
import okhttp3.Response import okhttp3.Response
import retrofit2.HttpException
import javax.inject.Inject import javax.inject.Inject
import javax.inject.Provider import javax.inject.Provider
import javax.inject.Singleton import javax.inject.Singleton
@@ -14,19 +16,43 @@ class AuthInterceptor @Inject constructor(
) : Interceptor { ) : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response { override fun intercept(chain: Interceptor.Chain): Response {
val token = runBlocking { authenticationProvider.get().getAuthHeader() } val token = getAuthHeader()
val request = if (token == null) { return if (token == null) {
chain.request() proceedWithAuthHeader(chain)
} else { } else {
chain.request() try {
.newBuilder() proceedWithAuthHeader(chain, token)
.header(HEADER_NAME, token) } catch (e: HttpException) {
.build() if (e.code() in setOf(401, 403)) {
invalidateAuthHeader()
proceedWithAuthHeader(chain, getAuthHeader())
} else {
throw e
}
}
}
}
private fun proceedWithAuthHeader(chain: Interceptor.Chain, token: String? = null): Response {
val requestBuilder = chain.request().newBuilder()
val request = if (token == null) {
requestBuilder.removeHeader(HEADER_NAME).build()
} else {
requestBuilder.header(HEADER_NAME, token).build()
} }
return chain.proceed(request) return chain.proceed(request)
} }
private fun getAuthHeader() = runBlocking {
authenticationProvider.get().getAuthHeader()
}
private fun invalidateAuthHeader() = runBlocking {
authenticationProvider.get().invalidateAuthHeader()
}
companion object { companion object {
private const val HEADER_NAME = "Authorization" @VisibleForTesting
const val HEADER_NAME = "Authorization"
} }
} }

View File

@@ -0,0 +1,164 @@
package gq.kirmanak.mealient.datasource
import com.google.common.truth.Truth.assertThat
import gq.kirmanak.mealient.datasource.impl.AuthInterceptor
import gq.kirmanak.mealient.test.BaseUnitTest
import io.mockk.CapturingSlot
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.coVerifySequence
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.slot
import io.mockk.verify
import okhttp3.Interceptor
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import okhttp3.ResponseBody.Companion.toResponseBody
import org.junit.Before
import org.junit.Test
import retrofit2.HttpException
import retrofit2.Response as RetrofitResponse
class AuthInterceptorTest : BaseUnitTest() {
private lateinit var subject: AuthInterceptor
@MockK(relaxUnitFun = true)
lateinit var authenticationProvider: AuthenticationProvider
@MockK(relaxUnitFun = true)
lateinit var chain: Interceptor.Chain
@Before
override fun setUp() {
super.setUp()
subject = AuthInterceptor { authenticationProvider }
every { chain.request() } returns Request.Builder().url("http://localhost").build()
}
@Test
fun `when no header then still proceeds`() {
mockProceedCallAndCaptureRequest()
coEvery { authenticationProvider.getAuthHeader() } returns null
subject.intercept(chain)
verify { chain.proceed(any()) }
}
@Test
fun `when has header then adds header`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = mockProceedCallAndCaptureRequest()
subject.intercept(chain)
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
}
@Test(expected = HttpException::class)
fun `when unauthorized but didn't have token then throws`() {
coEvery { authenticationProvider.getAuthHeader() } returns null
mockUnauthorized()
subject.intercept(chain)
}
@Test(expected = HttpException::class)
fun `when unauthorized and had token then invalidates it`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
mockUnauthorized()
subject.intercept(chain)
coVerify { authenticationProvider.invalidateAuthHeader() }
}
@Test(expected = HttpException::class)
fun `when not found and had token then throws exception`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
}
subject.intercept(chain)
}
@Test
fun `when not found and had token then does not invalidate it`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token"
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
}
try {
subject.intercept(chain)
} catch (e: HttpException) {
coVerify(inverse = true) { authenticationProvider.invalidateAuthHeader() }
}
}
@Test
fun `when unauthorized and had token then calls again with new token`() {
coEvery { authenticationProvider.getAuthHeader() } returns "token" andThen "newToken"
val requests = mutableListOf<Request>()
every {
chain.proceed(capture(requests))
} answers {
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
} andThenAnswer {
buildResponse(requests[1])
}
subject.intercept(chain)
coVerifySequence {
authenticationProvider.getAuthHeader()
authenticationProvider.invalidateAuthHeader()
authenticationProvider.getAuthHeader()
}
assertThat(requests[0].header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
assertThat(requests[1].header(AuthInterceptor.HEADER_NAME)).isEqualTo("newToken")
}
@Test
fun `when had token but now does not then removes it`() {
coEvery { authenticationProvider.getAuthHeader() } returns null
val mockRequest = Request.Builder()
.url("http://localhost")
.header(AuthInterceptor.HEADER_NAME, "token")
.build()
every { chain.request() } returns mockRequest
val requestSlot = mockProceedCallAndCaptureRequest()
subject.intercept(chain)
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isNull()
}
private fun mockUnauthorized() {
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
}
}
private fun mockProceedCallAndCaptureRequest(): CapturingSlot<Request> {
val requestSlot = slot<Request>()
every {
chain.proceed(capture(requestSlot))
} answers {
buildResponse(requestSlot.captured)
}
return requestSlot
}
private fun buildResponse(
request: Request,
code: Int = 200,
protocol: Protocol = Protocol.HTTP_2,
message: String = if (code == 200) "OK" else "NOT OK",
) = Response.Builder()
.code(code)
.request(request)
.protocol(protocol)
.message(message)
.build()
}