Move token invalidation to auth interceptor
This commit is contained in:
@@ -2,7 +2,6 @@ package gq.kirmanak.mealient.data.network
|
|||||||
|
|
||||||
import gq.kirmanak.mealient.data.add.AddRecipeDataSource
|
import gq.kirmanak.mealient.data.add.AddRecipeDataSource
|
||||||
import gq.kirmanak.mealient.data.add.AddRecipeInfo
|
import gq.kirmanak.mealient.data.add.AddRecipeInfo
|
||||||
import gq.kirmanak.mealient.data.auth.AuthRepo
|
|
||||||
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
|
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
|
||||||
import gq.kirmanak.mealient.data.baseurl.ServerVersion
|
import gq.kirmanak.mealient.data.baseurl.ServerVersion
|
||||||
import gq.kirmanak.mealient.data.recipes.network.FullRecipeInfo
|
import gq.kirmanak.mealient.data.recipes.network.FullRecipeInfo
|
||||||
@@ -10,8 +9,6 @@ import gq.kirmanak.mealient.data.recipes.network.RecipeDataSource
|
|||||||
import gq.kirmanak.mealient.data.recipes.network.RecipeSummaryInfo
|
import gq.kirmanak.mealient.data.recipes.network.RecipeSummaryInfo
|
||||||
import gq.kirmanak.mealient.data.share.ParseRecipeDataSource
|
import gq.kirmanak.mealient.data.share.ParseRecipeDataSource
|
||||||
import gq.kirmanak.mealient.data.share.ParseRecipeURLInfo
|
import gq.kirmanak.mealient.data.share.ParseRecipeURLInfo
|
||||||
import gq.kirmanak.mealient.datasource.NetworkError
|
|
||||||
import gq.kirmanak.mealient.datasource.runCatchingExceptCancel
|
|
||||||
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
|
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
|
||||||
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
|
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
|
||||||
import gq.kirmanak.mealient.extensions.toFullRecipeInfo
|
import gq.kirmanak.mealient.extensions.toFullRecipeInfo
|
||||||
@@ -20,17 +17,14 @@ import gq.kirmanak.mealient.extensions.toV0Request
|
|||||||
import gq.kirmanak.mealient.extensions.toV1CreateRequest
|
import gq.kirmanak.mealient.extensions.toV1CreateRequest
|
||||||
import gq.kirmanak.mealient.extensions.toV1Request
|
import gq.kirmanak.mealient.extensions.toV1Request
|
||||||
import gq.kirmanak.mealient.extensions.toV1UpdateRequest
|
import gq.kirmanak.mealient.extensions.toV1UpdateRequest
|
||||||
import gq.kirmanak.mealient.logging.Logger
|
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
class MealieDataSourceWrapper @Inject constructor(
|
class MealieDataSourceWrapper @Inject constructor(
|
||||||
private val serverInfoRepo: ServerInfoRepo,
|
private val serverInfoRepo: ServerInfoRepo,
|
||||||
private val authRepo: AuthRepo,
|
|
||||||
private val v0Source: MealieDataSourceV0,
|
private val v0Source: MealieDataSourceV0,
|
||||||
private val v1Source: MealieDataSourceV1,
|
private val v1Source: MealieDataSourceV1,
|
||||||
private val logger: Logger,
|
|
||||||
) : AddRecipeDataSource, RecipeDataSource, ParseRecipeDataSource {
|
) : AddRecipeDataSource, RecipeDataSource, ParseRecipeDataSource {
|
||||||
|
|
||||||
override suspend fun addRecipe(
|
override suspend fun addRecipe(
|
||||||
@@ -85,20 +79,8 @@ class MealieDataSourceWrapper @Inject constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
private suspend inline fun <T> makeCall(block: (String, ServerVersion) -> T): T {
|
private suspend inline fun <T> makeCall(block: (String, ServerVersion) -> T): T {
|
||||||
val authHeader = authRepo.getAuthHeader()
|
|
||||||
val url = serverInfoRepo.requireUrl()
|
val url = serverInfoRepo.requireUrl()
|
||||||
val version = serverInfoRepo.getVersion()
|
val version = serverInfoRepo.getVersion()
|
||||||
return runCatchingExceptCancel { block(url, version) }.getOrElse {
|
return block(url, version)
|
||||||
if (it is NetworkError.Unauthorized) {
|
|
||||||
logger.e { "Unauthorized, trying to invalidate token" }
|
|
||||||
authRepo.invalidateAuthHeader()
|
|
||||||
// Trying again with new authentication header
|
|
||||||
val newHeader = authRepo.getAuthHeader()
|
|
||||||
logger.e { "New token ${if (newHeader == authHeader) "matches" else "doesn't match"} old token" }
|
|
||||||
if (newHeader == authHeader) throw it else block(url, version)
|
|
||||||
} else {
|
|
||||||
throw it
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3,7 +3,6 @@ package gq.kirmanak.mealient.data.network
|
|||||||
import com.google.common.truth.Truth.assertThat
|
import com.google.common.truth.Truth.assertThat
|
||||||
import gq.kirmanak.mealient.data.auth.AuthRepo
|
import gq.kirmanak.mealient.data.auth.AuthRepo
|
||||||
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
|
import gq.kirmanak.mealient.data.baseurl.ServerInfoRepo
|
||||||
import gq.kirmanak.mealient.datasource.NetworkError
|
|
||||||
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
|
import gq.kirmanak.mealient.datasource.v0.MealieDataSourceV0
|
||||||
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
|
import gq.kirmanak.mealient.datasource.v1.MealieDataSourceV1
|
||||||
import gq.kirmanak.mealient.test.AuthImplTestData.TEST_AUTH_HEADER
|
import gq.kirmanak.mealient.test.AuthImplTestData.TEST_AUTH_HEADER
|
||||||
@@ -15,7 +14,6 @@ import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_INFO
|
|||||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_REQUEST_V0
|
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_ADD_RECIPE_REQUEST_V0
|
||||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_CREATE_RECIPE_REQUEST_V1
|
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_CREATE_RECIPE_REQUEST_V1
|
||||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_FULL_RECIPE_INFO
|
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_FULL_RECIPE_INFO
|
||||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V0
|
|
||||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V1
|
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_RESPONSE_V1
|
||||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V0
|
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V0
|
||||||
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V1
|
import gq.kirmanak.mealient.test.RecipeImplTestData.PORRIDGE_RECIPE_SUMMARY_RESPONSE_V1
|
||||||
@@ -50,26 +48,7 @@ class MealieDataSourceWrapperTest : BaseUnitTest() {
|
|||||||
@Before
|
@Before
|
||||||
override fun setUp() {
|
override fun setUp() {
|
||||||
super.setUp()
|
super.setUp()
|
||||||
subject = MealieDataSourceWrapper(serverInfoRepo, authRepo, v0Source, v1Source, logger)
|
subject = MealieDataSourceWrapper(serverInfoRepo, v0Source, v1Source)
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun `when makeCall fails with Unauthorized expect it to invalidate token`() = runTest {
|
|
||||||
val slug = "porridge"
|
|
||||||
coEvery {
|
|
||||||
v0Source.requestRecipeInfo(any(), any())
|
|
||||||
} throws NetworkError.Unauthorized(IOException()) andThen PORRIDGE_RECIPE_RESPONSE_V0
|
|
||||||
coEvery { serverInfoRepo.getVersion() } returns TEST_SERVER_VERSION_V0
|
|
||||||
coEvery { serverInfoRepo.requireUrl() } returns TEST_BASE_URL
|
|
||||||
coEvery { authRepo.getAuthHeader() } returns null andThen TEST_AUTH_HEADER
|
|
||||||
|
|
||||||
subject.requestRecipeInfo(slug)
|
|
||||||
|
|
||||||
coVerifySequence {
|
|
||||||
authRepo.getAuthHeader()
|
|
||||||
authRepo.invalidateAuthHeader()
|
|
||||||
authRepo.getAuthHeader()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|||||||
@@ -3,4 +3,6 @@ package gq.kirmanak.mealient.datasource
|
|||||||
interface AuthenticationProvider {
|
interface AuthenticationProvider {
|
||||||
|
|
||||||
suspend fun getAuthHeader(): String?
|
suspend fun getAuthHeader(): String?
|
||||||
|
|
||||||
|
suspend fun invalidateAuthHeader()
|
||||||
}
|
}
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
package gq.kirmanak.mealient.datasource.impl
|
package gq.kirmanak.mealient.datasource.impl
|
||||||
|
|
||||||
|
import androidx.annotation.VisibleForTesting
|
||||||
import gq.kirmanak.mealient.datasource.AuthenticationProvider
|
import gq.kirmanak.mealient.datasource.AuthenticationProvider
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import okhttp3.Interceptor
|
import okhttp3.Interceptor
|
||||||
import okhttp3.Response
|
import okhttp3.Response
|
||||||
|
import retrofit2.HttpException
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
import javax.inject.Provider
|
import javax.inject.Provider
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
@@ -14,19 +16,43 @@ class AuthInterceptor @Inject constructor(
|
|||||||
) : Interceptor {
|
) : Interceptor {
|
||||||
|
|
||||||
override fun intercept(chain: Interceptor.Chain): Response {
|
override fun intercept(chain: Interceptor.Chain): Response {
|
||||||
val token = runBlocking { authenticationProvider.get().getAuthHeader() }
|
val token = getAuthHeader()
|
||||||
val request = if (token == null) {
|
return if (token == null) {
|
||||||
chain.request()
|
proceedWithAuthHeader(chain)
|
||||||
} else {
|
} else {
|
||||||
chain.request()
|
try {
|
||||||
.newBuilder()
|
proceedWithAuthHeader(chain, token)
|
||||||
.header(HEADER_NAME, token)
|
} catch (e: HttpException) {
|
||||||
.build()
|
if (e.code() in setOf(401, 403)) {
|
||||||
|
invalidateAuthHeader()
|
||||||
|
proceedWithAuthHeader(chain, getAuthHeader())
|
||||||
|
} else {
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun proceedWithAuthHeader(chain: Interceptor.Chain, token: String? = null): Response {
|
||||||
|
val requestBuilder = chain.request().newBuilder()
|
||||||
|
val request = if (token == null) {
|
||||||
|
requestBuilder.removeHeader(HEADER_NAME).build()
|
||||||
|
} else {
|
||||||
|
requestBuilder.header(HEADER_NAME, token).build()
|
||||||
}
|
}
|
||||||
return chain.proceed(request)
|
return chain.proceed(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun getAuthHeader() = runBlocking {
|
||||||
|
authenticationProvider.get().getAuthHeader()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun invalidateAuthHeader() = runBlocking {
|
||||||
|
authenticationProvider.get().invalidateAuthHeader()
|
||||||
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val HEADER_NAME = "Authorization"
|
@VisibleForTesting
|
||||||
|
const val HEADER_NAME = "Authorization"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,164 @@
|
|||||||
|
package gq.kirmanak.mealient.datasource
|
||||||
|
|
||||||
|
import com.google.common.truth.Truth.assertThat
|
||||||
|
import gq.kirmanak.mealient.datasource.impl.AuthInterceptor
|
||||||
|
import gq.kirmanak.mealient.test.BaseUnitTest
|
||||||
|
import io.mockk.CapturingSlot
|
||||||
|
import io.mockk.coEvery
|
||||||
|
import io.mockk.coVerify
|
||||||
|
import io.mockk.coVerifySequence
|
||||||
|
import io.mockk.every
|
||||||
|
import io.mockk.impl.annotations.MockK
|
||||||
|
import io.mockk.slot
|
||||||
|
import io.mockk.verify
|
||||||
|
import okhttp3.Interceptor
|
||||||
|
import okhttp3.Protocol
|
||||||
|
import okhttp3.Request
|
||||||
|
import okhttp3.Response
|
||||||
|
import okhttp3.ResponseBody.Companion.toResponseBody
|
||||||
|
import org.junit.Before
|
||||||
|
import org.junit.Test
|
||||||
|
import retrofit2.HttpException
|
||||||
|
import retrofit2.Response as RetrofitResponse
|
||||||
|
|
||||||
|
class AuthInterceptorTest : BaseUnitTest() {
|
||||||
|
|
||||||
|
private lateinit var subject: AuthInterceptor
|
||||||
|
|
||||||
|
@MockK(relaxUnitFun = true)
|
||||||
|
lateinit var authenticationProvider: AuthenticationProvider
|
||||||
|
|
||||||
|
@MockK(relaxUnitFun = true)
|
||||||
|
lateinit var chain: Interceptor.Chain
|
||||||
|
|
||||||
|
@Before
|
||||||
|
override fun setUp() {
|
||||||
|
super.setUp()
|
||||||
|
subject = AuthInterceptor { authenticationProvider }
|
||||||
|
every { chain.request() } returns Request.Builder().url("http://localhost").build()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when no header then still proceeds`() {
|
||||||
|
mockProceedCallAndCaptureRequest()
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns null
|
||||||
|
subject.intercept(chain)
|
||||||
|
verify { chain.proceed(any()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when has header then adds header`() {
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||||
|
val requestSlot = mockProceedCallAndCaptureRequest()
|
||||||
|
subject.intercept(chain)
|
||||||
|
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = HttpException::class)
|
||||||
|
fun `when unauthorized but didn't have token then throws`() {
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns null
|
||||||
|
mockUnauthorized()
|
||||||
|
subject.intercept(chain)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = HttpException::class)
|
||||||
|
fun `when unauthorized and had token then invalidates it`() {
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||||
|
mockUnauthorized()
|
||||||
|
subject.intercept(chain)
|
||||||
|
coVerify { authenticationProvider.invalidateAuthHeader() }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = HttpException::class)
|
||||||
|
fun `when not found and had token then throws exception`() {
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||||
|
val requestSlot = slot<Request>()
|
||||||
|
every {
|
||||||
|
chain.proceed(capture(requestSlot))
|
||||||
|
} answers {
|
||||||
|
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
|
||||||
|
}
|
||||||
|
subject.intercept(chain)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when not found and had token then does not invalidate it`() {
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns "token"
|
||||||
|
val requestSlot = slot<Request>()
|
||||||
|
every {
|
||||||
|
chain.proceed(capture(requestSlot))
|
||||||
|
} answers {
|
||||||
|
throw HttpException(RetrofitResponse.error<String>(404, "".toResponseBody()))
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
subject.intercept(chain)
|
||||||
|
} catch (e: HttpException) {
|
||||||
|
coVerify(inverse = true) { authenticationProvider.invalidateAuthHeader() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when unauthorized and had token then calls again with new token`() {
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns "token" andThen "newToken"
|
||||||
|
val requests = mutableListOf<Request>()
|
||||||
|
every {
|
||||||
|
chain.proceed(capture(requests))
|
||||||
|
} answers {
|
||||||
|
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
|
||||||
|
} andThenAnswer {
|
||||||
|
buildResponse(requests[1])
|
||||||
|
}
|
||||||
|
subject.intercept(chain)
|
||||||
|
coVerifySequence {
|
||||||
|
authenticationProvider.getAuthHeader()
|
||||||
|
authenticationProvider.invalidateAuthHeader()
|
||||||
|
authenticationProvider.getAuthHeader()
|
||||||
|
}
|
||||||
|
assertThat(requests[0].header(AuthInterceptor.HEADER_NAME)).isEqualTo("token")
|
||||||
|
assertThat(requests[1].header(AuthInterceptor.HEADER_NAME)).isEqualTo("newToken")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `when had token but now does not then removes it`() {
|
||||||
|
coEvery { authenticationProvider.getAuthHeader() } returns null
|
||||||
|
val mockRequest = Request.Builder()
|
||||||
|
.url("http://localhost")
|
||||||
|
.header(AuthInterceptor.HEADER_NAME, "token")
|
||||||
|
.build()
|
||||||
|
every { chain.request() } returns mockRequest
|
||||||
|
val requestSlot = mockProceedCallAndCaptureRequest()
|
||||||
|
subject.intercept(chain)
|
||||||
|
assertThat(requestSlot.captured.header(AuthInterceptor.HEADER_NAME)).isNull()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun mockUnauthorized() {
|
||||||
|
val requestSlot = slot<Request>()
|
||||||
|
every {
|
||||||
|
chain.proceed(capture(requestSlot))
|
||||||
|
} answers {
|
||||||
|
throw HttpException(RetrofitResponse.error<String>(401, "".toResponseBody()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun mockProceedCallAndCaptureRequest(): CapturingSlot<Request> {
|
||||||
|
val requestSlot = slot<Request>()
|
||||||
|
every {
|
||||||
|
chain.proceed(capture(requestSlot))
|
||||||
|
} answers {
|
||||||
|
buildResponse(requestSlot.captured)
|
||||||
|
}
|
||||||
|
return requestSlot
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun buildResponse(
|
||||||
|
request: Request,
|
||||||
|
code: Int = 200,
|
||||||
|
protocol: Protocol = Protocol.HTTP_2,
|
||||||
|
message: String = if (code == 200) "OK" else "NOT OK",
|
||||||
|
) = Response.Builder()
|
||||||
|
.code(code)
|
||||||
|
.request(request)
|
||||||
|
.protocol(protocol)
|
||||||
|
.message(message)
|
||||||
|
.build()
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user