Move token invalidation to auth interceptor
This commit is contained in:
@@ -2,7 +2,6 @@ package gq.kirmanak.mealient.data.network
|
||||
|
||||
import gq.kirmanak.mealient.data.add.AddRecipeDataSource
|
||||
import gq.kirmanak.mealient.data.add.AddRecipeInfo
|
||||
import gq.kirmanak.mealient.data.auth.AuthRepo
|
||||
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
|
||||
import gq.kirmanak.mealient.data.baseurl.ServerVersion
|
||||
import gq.kirmanak.mealient.data.recipes.network.FullRecipeInfo
|
||||
@@ -10,8 +9,6 @@ import gq.kirmanak.mealient.data.recipes.network.RecipeDataSource
|
||||
import gq.kirmanak.mealient.data.recipes.network.RecipeSummaryInfo
|
||||
import gq.kirmanak.mealient.data.share.ParseRecipeDataSource
|
||||
import gq.kirmanak.mealient.data.share.ParseRecipeURLInfo
|
||||
import gq.kirmanak.mealient.datasource.NetworkError
|
||||
import gq.kirmanak.mealient.datasource.runCatchingExceptCancel
|
||||
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
|
||||
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
|
||||
import gq.kirmanak.mealient.extensions.toFullRecipeInfo
|
||||
@@ -20,17 +17,14 @@ import gq.kirmanak.mealient.extensions.toV0Request
|
||||
import gq.kirmanak.mealient.extensions.toV1CreateRequest
|
||||
import gq.kirmanak.mealient.extensions.toV1Request
|
||||
import gq.kirmanak.mealient.extensions.toV1UpdateRequest
|
||||
import gq.kirmanak.mealient.logging.Logger
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class MealieDataSourceWrapper @Inject constructor(
|
||||
private val serverInfoRepo: ServerInfoRepo,
|
||||
private val authRepo: AuthRepo,
|
||||
private val v0Source: MealieDataSourceV0,
|
||||
private val v1Source: MealieDataSourceV1,
|
||||
private val logger: Logger,
|
||||
) : AddRecipeDataSource, RecipeDataSource, ParseRecipeDataSource {
|
||||
|
||||
override suspend fun addRecipe(
|
||||
@@ -85,20 +79,8 @@ class MealieDataSourceWrapper @Inject constructor(
|
||||
}
|
||||
|
||||
private suspend inline fun <T> makeCall(block: (String, ServerVersion) -> T): T {
|
||||
val authHeader = authRepo.getAuthHeader()
|
||||
val url = serverInfoRepo.requireUrl()
|
||||
val version = serverInfoRepo.getVersion()
|
||||
return runCatchingExceptCancel { block(url, version) }.getOrElse {
|
||||
if (it is NetworkError.Unauthorized) {
|
||||
logger.e { "Unauthorized, trying to invalidate token" }
|
||||
authRepo.invalidateAuthHeader()
|
||||
// Trying again with new authentication header
|
||||
val newHeader = authRepo.getAuthHeader()
|
||||
logger.e { "New token ${if (newHeader == authHeader) "matches" else "doesn't match"} old token" }
|
||||
if (newHeader == authHeader) throw it else block(url, version)
|
||||
} else {
|
||||
throw it
|
||||
}
|
||||
}
|
||||
return block(url, version)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package gq.kirmanak.mealient.data.network
|
||||
import com.google.common.truth.Truth.assertThat
|
||||
import gq.kirmanak.mealient.data.auth.AuthRepo
|
||||
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
|
||||
import gq.kirmanak.mealient.datasource.NetworkError
|
||||
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
|
||||
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
|
||||
import gq.kirmanak.mealient.test.AuthImplTestData.TEST_AUTH_HEADER
|
||||
@@ -15,7 +14,6 @@ import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_INFO
|
||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_REQUEST_V0
|
||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_CREATE_RECIPE_REQUEST_V1
|
||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_FULL_RECIPE_INFO
|
||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V0
|
||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V1
|
||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V0
|
||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V1
|
||||
@@ -50,26 +48,7 @@ class MealieDataSourceWrapperTest : BaseUnitTest() {
|
||||
@Before
|
||||
override fun setUp() {
|
||||
super.setUp()
|
||||
subject = MealieDataSourceWrapper(serverInfoRepo, authRepo, v0Source, v1Source, logger)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when makeCall fails with Unauthorized expect it to invalidate token`() = runTest {
|
||||
val slug = "porridge"
|
||||
coEvery {
|
||||
v0Source.requestRecipeInfo(any(), any())
|
||||
} throws NetworkError.Unauthorized(IOException()) andThen PORRIDGE_RECIPE_RESPONSE_V0
|
||||
coEvery { serverInfoRepo.getVersion() } returns TEST_SERVER_VERSION_V0
|
||||
coEvery { serverInfoRepo.requireUrl() } returns TEST_BASE_URL
|
||||
coEvery { authRepo.getAuthHeader() } returns null andThen TEST_AUTH_HEADER
|
||||
|
||||
subject.requestRecipeInfo(slug)
|
||||
|
||||
coVerifySequence {
|
||||
authRepo.getAuthHeader()
|
||||
authRepo.invalidateAuthHeader()
|
||||
authRepo.getAuthHeader()
|
||||
}
|
||||
subject = MealieDataSourceWrapper(serverInfoRepo, v0Source, v1Source)
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -3,4 +3,6 @@ package gq.kirmanak.mealient.datasource
|
||||
interface AuthenticationProvider {
|
||||
|
||||
suspend fun getAuthHeader(): String?
|
||||
|
||||
suspend fun invalidateAuthHeader()
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package gq.kirmanak.mealient.datasource.impl
|
||||
|
||||
import androidx.annotation.VisibleForTesting
|
||||
import gq.kirmanak.mealient.datasource.AuthenticationProvider
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.Response
|
||||
import retrofit2.HttpException
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Provider
|
||||
import javax.inject.Singleton
|
||||
@@ -14,19 +16,43 @@ class AuthInterceptor @Inject constructor(
|
||||
) : Interceptor {
|
||||
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
val token = runBlocking { authenticationProvider.get().getAuthHeader() }
|
||||
val request = if (token == null) {
|
||||
chain.request()
|
||||
val token = getAuthHeader()
|
||||
return if (token == null) {
|
||||
proceedWithAuthHeader(chain)
|
||||
} else {
|
||||
chain.request()
|
||||
.newBuilder()
|
||||
.header(HEADER_NAME, token)
|
||||
.build()
|
||||
try {
|
||||
proceedWithAuthHeader(chain, token)
|
||||
} catch (e: HttpException) {
|
||||
if (e.code() in setOf(401, 403)) {
|
||||
invalidateAuthHeader()
|
||||
proceedWithAuthHeader(chain, getAuthHeader())
|
||||
} else {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun proceedWithAuthHeader(chain: Interceptor.Chain, token: String? = null): Response {
|
||||
val requestBuilder = chain.request().newBuilder()
|
||||
val request = if (token == null) {
|
||||
requestBuilder.removeHeader(HEADER_NAME).build()
|
||||
} else {
|
||||
requestBuilder.header(HEADER_NAME, token).build()
|
||||
}
|
||||
return chain.proceed(request)
|
||||
}
|
||||
|
||||
private fun getAuthHeader() = runBlocking {
|
||||
authenticationProvider.get().getAuthHeader()
|
||||
}
|
||||
|
||||
private fun invalidateAuthHeader() = runBlocking {
|
||||
authenticationProvider.get().invalidateAuthHeader()
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val HEADER_NAME = "Authorization"
|
||||
@VisibleForTesting
|
||||
const val HEADER_NAME = "Authorization"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package gq.kirmanak.mealient.datasource
|
||||
|
||||
import com.google.common.truth.Truth.assertThat
|
||||
import gq.kirmanak.mealient.datasource.impl.AuthInterceptor
|
||||
import gq.kirmanak.mealient.test.BaseUnitTest
|
||||
import io.mockk.CapturingSlot
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.coVerifySequence
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import io.mockk.slot
|
||||
import io.mockk.verify
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.Protocol
|
||||
import okhttp3.Request
|
||||
import okhttp3.Response
|
||||
import okhttp3.ResponseBody.Companion.toResponseBody
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import retrofit2.HttpException
|
||||
import retrofit2.Response as RetrofitResponse
|
||||
|
||||
class AuthInterceptorTest : BaseUnitTest() {
|
||||
|
||||
private lateinit var subject: AuthInterceptor
|
||||
|
||||
@MockK(relaxUnitFun = true)
|
||||
lateinit var authenticationProvider: AuthenticationProvider
|
||||
|
||||
@MockK(relaxUnitFun = true)
|
||||
lateinit var chain: Interceptor.Chain
|
||||
|
||||
@Before
|
||||
override fun setUp() {
|
||||
super.setUp()
|
||||
subject = AuthInterceptor { authenticationProvider }
|
||||
every { chain.request() } returns Request.Builder().url("http://localhost").build()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when no header then still proceeds`() {
|
||||
mockProceedCallAndCaptureRequest()
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns null
|
||||
subject.intercept(chain)
|
||||
verify { chain.proceed(any()) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when has header then adds header`() {
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||
val requestSlot = mockProceedCallAndCaptureRequest()
|
||||
subject.intercept(chain)
|
||||
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
|
||||
}
|
||||
|
||||
@Test(expected = HttpException::class)
|
||||
fun `when unauthorized but didn't have token then throws`() {
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns null
|
||||
mockUnauthorized()
|
||||
subject.intercept(chain)
|
||||
}
|
||||
|
||||
@Test(expected = HttpException::class)
|
||||
fun `when unauthorized and had token then invalidates it`() {
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||
mockUnauthorized()
|
||||
subject.intercept(chain)
|
||||
coVerify { authenticationProvider.invalidateAuthHeader() }
|
||||
}
|
||||
|
||||
@Test(expected = HttpException::class)
|
||||
fun `when not found and had token then throws exception`() {
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||
val requestSlot = slot<Request>()
|
||||
every {
|
||||
chain.proceed(capture(requestSlot))
|
||||
} answers {
|
||||
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
|
||||
}
|
||||
subject.intercept(chain)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when not found and had token then does not invalidate it`() {
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||
val requestSlot = slot<Request>()
|
||||
every {
|
||||
chain.proceed(capture(requestSlot))
|
||||
} answers {
|
||||
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
|
||||
}
|
||||
try {
|
||||
subject.intercept(chain)
|
||||
} catch (e: HttpException) {
|
||||
coVerify(inverse = true) { authenticationProvider.invalidateAuthHeader() }
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when unauthorized and had token then calls again with new token`() {
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns "token" andThen "newToken"
|
||||
val requests = mutableListOf<Request>()
|
||||
every {
|
||||
chain.proceed(capture(requests))
|
||||
} answers {
|
||||
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
|
||||
} andThenAnswer {
|
||||
buildResponse(requests[1])
|
||||
}
|
||||
subject.intercept(chain)
|
||||
coVerifySequence {
|
||||
authenticationProvider.getAuthHeader()
|
||||
authenticationProvider.invalidateAuthHeader()
|
||||
authenticationProvider.getAuthHeader()
|
||||
}
|
||||
assertThat(requests[0].header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
|
||||
assertThat(requests[1].header(AuthInterceptor.HEADER_NAME)).isEqualTo("newToken")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when had token but now does not then removes it`() {
|
||||
coEvery { authenticationProvider.getAuthHeader() } returns null
|
||||
val mockRequest = Request.Builder()
|
||||
.url("http://localhost")
|
||||
.header(AuthInterceptor.HEADER_NAME, "token")
|
||||
.build()
|
||||
every { chain.request() } returns mockRequest
|
||||
val requestSlot = mockProceedCallAndCaptureRequest()
|
||||
subject.intercept(chain)
|
||||
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isNull()
|
||||
}
|
||||
|
||||
private fun mockUnauthorized() {
|
||||
val requestSlot = slot<Request>()
|
||||
every {
|
||||
chain.proceed(capture(requestSlot))
|
||||
} answers {
|
||||
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
|
||||
}
|
||||
}
|
||||
|
||||
private fun mockProceedCallAndCaptureRequest(): CapturingSlot<Request> {
|
||||
val requestSlot = slot<Request>()
|
||||
every {
|
||||
chain.proceed(capture(requestSlot))
|
||||
} answers {
|
||||
buildResponse(requestSlot.captured)
|
||||
}
|
||||
return requestSlot
|
||||
}
|
||||
|
||||
private fun buildResponse(
|
||||
request: Request,
|
||||
code: Int = 200,
|
||||
protocol: Protocol = Protocol.HTTP_2,
|
||||
message: String = if (code == 200) "OK" else "NOT OK",
|
||||
) = Response.Builder()
|
||||
.code(code)
|
||||
.request(request)
|
||||
.protocol(protocol)
|
||||
.message(message)
|
||||
.build()
|
||||
}
|
||||
Reference in New Issue
Block a user