diff --git a/app/src/filters/auth/OptionalAuthFilter.kt b/app/src/filters/auth/OptionalAuthFilter.kt index c6c2ced..8c17c6f 100644 --- a/app/src/filters/auth/OptionalAuthFilter.kt +++ b/app/src/filters/auth/OptionalAuthFilter.kt @@ -1,13 +1,14 @@ package be.simplenotes.app.filters.auth import be.simplenotes.app.filters.auth.JwtSource.Cookie -import be.simplenotes.domain.security.JwtPayloadExtractor +import be.simplenotes.domain.security.SimpleJwt +import be.simplenotes.types.LoggedInUser import org.http4k.core.Filter import org.http4k.core.HttpHandler import org.http4k.core.with class OptionalAuthFilter( - private val extractor: JwtPayloadExtractor, + private val simpleJwt: SimpleJwt, private val lens: OptionalAuthLens, private val source: JwtSource = Cookie, ) : Filter { @@ -17,6 +18,6 @@ class OptionalAuthFilter( Cookie -> it.bearerTokenCookie() } - next(it.with(lens of token?.let { extractor(it) })) + next(it.with(lens of token?.let { simpleJwt.extract(it) })) } } diff --git a/app/src/filters/auth/RequiredAuthFilter.kt b/app/src/filters/auth/RequiredAuthFilter.kt index d70c7c6..4dbf5ef 100644 --- a/app/src/filters/auth/RequiredAuthFilter.kt +++ b/app/src/filters/auth/RequiredAuthFilter.kt @@ -1,7 +1,8 @@ package be.simplenotes.app.filters.auth import be.simplenotes.app.extensions.redirect -import be.simplenotes.domain.security.JwtPayloadExtractor +import be.simplenotes.domain.security.SimpleJwt +import be.simplenotes.types.LoggedInUser import org.http4k.core.Filter import org.http4k.core.HttpHandler import org.http4k.core.Response @@ -9,7 +10,7 @@ import org.http4k.core.Status.Companion.UNAUTHORIZED import org.http4k.core.with class RequiredAuthFilter( - private val extractor: JwtPayloadExtractor, + private val simpleJwt: SimpleJwt, private val lens: RequiredAuthLens, private val source: JwtSource = JwtSource.Cookie, private val redirect: Boolean = true, @@ -19,7 +20,7 @@ class RequiredAuthFilter( JwtSource.Header -> it.bearerTokenHeader() JwtSource.Cookie -> it.bearerTokenCookie() } - val jwtPayload = token?.let { extractor(token) } + val jwtPayload = token?.let { simpleJwt.extract(token) } if (jwtPayload != null) next(it.with(lens of jwtPayload)) else { diff --git a/app/src/modules/AuthModule.kt b/app/src/modules/AuthModule.kt index 42c856d..a9b8aed 100644 --- a/app/src/modules/AuthModule.kt +++ b/app/src/modules/AuthModule.kt @@ -1,7 +1,8 @@ package be.simplenotes.app.modules import be.simplenotes.app.filters.auth.* -import be.simplenotes.domain.security.JwtPayloadExtractor +import be.simplenotes.domain.security.SimpleJwt +import be.simplenotes.types.LoggedInUser import io.micronaut.context.annotation.Factory import io.micronaut.context.annotation.Primary import org.http4k.core.RequestContexts @@ -21,21 +22,21 @@ class AuthModule { fun requiredAuthLens(ctx: RequestContexts): RequiredAuthLens = RequestContextKey.required(ctx) @Singleton - fun optionalAuth(extractor: JwtPayloadExtractor, @Named("optional") lens: OptionalAuthLens) = - OptionalAuthFilter(extractor, lens) + fun optionalAuth(simpleJwt: SimpleJwt, @Named("optional") lens: OptionalAuthLens) = + OptionalAuthFilter(simpleJwt, lens) @Primary @Singleton - fun requiredAuth(extractor: JwtPayloadExtractor, @Named("required") lens: RequiredAuthLens) = - RequiredAuthFilter(extractor, lens) + fun requiredAuth(simpleJwt: SimpleJwt, @Named("required") lens: RequiredAuthLens) = + RequiredAuthFilter(simpleJwt, lens) @Singleton @Named("api") internal fun apiAuthFilter( - jwtPayloadExtractor: JwtPayloadExtractor, + simpleJwt: SimpleJwt, @Named("required") lens: RequiredAuthLens, ) = RequiredAuthFilter( - extractor = jwtPayloadExtractor, + simpleJwt = simpleJwt, lens = lens, source = JwtSource.Header, redirect = false diff --git a/app/test/filters/RequiredAuthFilterTest.kt b/app/test/filters/RequiredAuthFilterTest.kt index 93331fb..01db86c 100644 --- a/app/test/filters/RequiredAuthFilterTest.kt +++ b/app/test/filters/RequiredAuthFilterTest.kt @@ -6,6 +6,7 @@ import be.simplenotes.app.filters.auth.RequiredAuthFilter import be.simplenotes.app.filters.auth.RequiredAuthLens import be.simplenotes.config.JwtConfig import be.simplenotes.domain.security.SimpleJwt +import be.simplenotes.domain.security.UserJwtMapper import be.simplenotes.types.LoggedInUser import com.natpryce.hamkrest.assertion.assertThat import io.micronaut.context.BeanContext @@ -32,7 +33,7 @@ internal class RequiredAuthFilterTest { // region setup private val jwtConfig = JwtConfig("secret", 1, TimeUnit.HOURS) - private val simpleJwt = SimpleJwt(jwtConfig) + private val simpleJwt = SimpleJwt(jwtConfig, UserJwtMapper()) private val beanCtx = BeanContext.build() .registerSingleton(jwtConfig) diff --git a/domain/src/security/JwtMapper.kt b/domain/src/security/JwtMapper.kt new file mode 100644 index 0000000..3a992e0 --- /dev/null +++ b/domain/src/security/JwtMapper.kt @@ -0,0 +1,30 @@ +package be.simplenotes.domain.security + +import be.simplenotes.types.LoggedInUser +import com.auth0.jwt.JWTCreator +import com.auth0.jwt.interfaces.DecodedJWT +import javax.inject.Singleton + +interface JwtMapper { + fun extract(decodedJWT: DecodedJWT): T? + fun build(builder: JWTCreator.Builder, value: T) +} + +@Singleton +class UserJwtMapper : JwtMapper { + private val userIdField = "i" + private val usernameField = "u" + + override fun extract(decodedJWT: DecodedJWT): LoggedInUser? { + val id = decodedJWT.getClaim(userIdField).asInt() ?: null + val username = decodedJWT.getClaim(usernameField).asString() ?: null + return if (id != null && username != null) + LoggedInUser(id, username) + else null + } + + override fun build(builder: JWTCreator.Builder, value: LoggedInUser) { + builder.withClaim(userIdField, value.userId) + .withClaim(usernameField, value.username) + } +} diff --git a/domain/src/security/JwtPayloadExtractor.kt b/domain/src/security/JwtPayloadExtractor.kt deleted file mode 100644 index f46d935..0000000 --- a/domain/src/security/JwtPayloadExtractor.kt +++ /dev/null @@ -1,19 +0,0 @@ -package be.simplenotes.domain.security - -import be.simplenotes.types.LoggedInUser -import com.auth0.jwt.exceptions.JWTVerificationException -import javax.inject.Singleton - -@Singleton -class JwtPayloadExtractor(private val jwt: SimpleJwt) { - operator fun invoke(token: String): LoggedInUser? = try { - val decodedJWT = jwt.verifier.verify(token) - val id = decodedJWT.getClaim(userIdField).asInt() ?: null - val username = decodedJWT.getClaim(usernameField).asString() ?: null - id?.let { username?.let { LoggedInUser(id, username) } } - } catch (e: JWTVerificationException) { - null - } catch (e: IllegalArgumentException) { - null - } -} diff --git a/domain/src/security/SimpleJwt.kt b/domain/src/security/SimpleJwt.kt index 50e12b7..b10810e 100644 --- a/domain/src/security/SimpleJwt.kt +++ b/domain/src/security/SimpleJwt.kt @@ -1,28 +1,33 @@ package be.simplenotes.domain.security import be.simplenotes.config.JwtConfig -import be.simplenotes.types.LoggedInUser import com.auth0.jwt.JWT import com.auth0.jwt.JWTVerifier import com.auth0.jwt.algorithms.Algorithm +import com.auth0.jwt.exceptions.JWTVerificationException import java.util.* import java.util.concurrent.TimeUnit import javax.inject.Singleton -internal const val userIdField = "i" -internal const val usernameField = "u" - @Singleton -class SimpleJwt(jwtConfig: JwtConfig) { +class SimpleJwt(jwtConfig: JwtConfig, private val mapper: JwtMapper) { private val validityInMs = TimeUnit.MILLISECONDS.convert(jwtConfig.validity, jwtConfig.timeUnit) private val algorithm = Algorithm.HMAC256(jwtConfig.secret) + private val verifier: JWTVerifier = JWT.require(algorithm).build() - val verifier: JWTVerifier = JWT.require(algorithm).build() - fun sign(loggedInUser: LoggedInUser): String = JWT.create() - .withClaim(userIdField, loggedInUser.userId) - .withClaim(usernameField, loggedInUser.username) + fun sign(value: T): String = JWT.create() + .apply { mapper.build(this, value) } .withExpiresAt(getExpiration()) .sign(algorithm) + fun extract(token: String): T? = try { + val decodedJWT = verifier.verify(token) + mapper.extract(decodedJWT) + } catch (e: JWTVerificationException) { + null + } catch (e: IllegalArgumentException) { + null + } + private fun getExpiration() = Date(System.currentTimeMillis() + validityInMs) } diff --git a/domain/src/usecases/users/login/LoginUseCaseImpl.kt b/domain/src/usecases/users/login/LoginUseCaseImpl.kt index 6cc88c6..17326cd 100644 --- a/domain/src/usecases/users/login/LoginUseCaseImpl.kt +++ b/domain/src/usecases/users/login/LoginUseCaseImpl.kt @@ -16,7 +16,7 @@ import javax.inject.Singleton internal class LoginUseCaseImpl( private val userRepository: UserRepository, private val passwordHash: PasswordHash, - private val jwt: SimpleJwt + private val jwt: SimpleJwt ) : LoginUseCase { override fun login(form: LoginForm) = either.eager { val user = !UserValidations.validateLogin(form) diff --git a/domain/test/security/LoggedInUserExtractorTest.kt b/domain/test/security/LoggedInUserExtractorTest.kt index 8c4255a..600255e 100644 --- a/domain/test/security/LoggedInUserExtractorTest.kt +++ b/domain/test/security/LoggedInUserExtractorTest.kt @@ -16,14 +16,13 @@ import java.util.stream.Stream internal class LoggedInUserExtractorTest { private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS) - private val simpleJwt = SimpleJwt(jwtConfig) - private val jwtPayloadExtractor = JwtPayloadExtractor(simpleJwt) + private val mapper = UserJwtMapper() + private val simpleJwt = SimpleJwt(jwtConfig, mapper) private fun createToken(username: String? = null, id: Int? = null, secret: String = jwtConfig.secret): Token { val algo = Algorithm.HMAC256(secret) return JWT.create().apply { - username?.let { withClaim(usernameField, it) } - id?.let { withClaim(userIdField, it) } + if (username != null && id != null) mapper.build(this, LoggedInUser(id, username)) }.sign(algo) } @@ -40,12 +39,12 @@ internal class LoggedInUserExtractorTest { @ParameterizedTest(name = "[{index}] token `{0}` should be invalid") @MethodSource("invalidTokens") fun `parse invalid tokens`(token: String) { - assertThat(jwtPayloadExtractor(token), absent()) + assertThat(simpleJwt.extract(token), absent()) } @Test fun `parse valid token`() { val token = createToken(username = "someone", id = 1) - assertThat(jwtPayloadExtractor(token), equalTo(LoggedInUser(1, "someone"))) + assertThat(simpleJwt.extract(token), equalTo(LoggedInUser(1, "someone"))) } } diff --git a/domain/test/usecases/users/login/LoginUseCaseImplTest.kt b/domain/test/usecases/users/login/LoginUseCaseImplTest.kt index 27f4ee4..180c5c4 100644 --- a/domain/test/usecases/users/login/LoginUseCaseImplTest.kt +++ b/domain/test/usecases/users/login/LoginUseCaseImplTest.kt @@ -3,6 +3,7 @@ package be.simplenotes.domain.usecases.users.login import be.simplenotes.config.JwtConfig import be.simplenotes.domain.security.BcryptPasswordHash import be.simplenotes.domain.security.SimpleJwt +import be.simplenotes.domain.security.UserJwtMapper import be.simplenotes.domain.testutils.isLeftOfType import be.simplenotes.domain.testutils.isRight import be.simplenotes.persistance.repositories.UserRepository @@ -18,7 +19,7 @@ internal class LoginUseCaseImplTest { private val mockUserRepository = mockk() private val passwordHash = BcryptPasswordHash(test = true) private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS) - private val simpleJwt = SimpleJwt(jwtConfig) + private val simpleJwt = SimpleJwt(jwtConfig, UserJwtMapper()) private val loginUseCase = LoginUseCaseImpl(mockUserRepository, passwordHash, simpleJwt) @BeforeEach