Refactor jwt

This commit is contained in:
Hubert Van De Walle 2020-11-11 23:19:14 +01:00
parent 8439782430
commit 90701dcdce
10 changed files with 70 additions and 50 deletions

View File

@ -1,13 +1,14 @@
package be.simplenotes.app.filters.auth package be.simplenotes.app.filters.auth
import be.simplenotes.app.filters.auth.JwtSource.Cookie import be.simplenotes.app.filters.auth.JwtSource.Cookie
import be.simplenotes.domain.security.JwtPayloadExtractor import be.simplenotes.domain.security.SimpleJwt
import be.simplenotes.types.LoggedInUser
import org.http4k.core.Filter import org.http4k.core.Filter
import org.http4k.core.HttpHandler import org.http4k.core.HttpHandler
import org.http4k.core.with import org.http4k.core.with
class OptionalAuthFilter( class OptionalAuthFilter(
private val extractor: JwtPayloadExtractor, private val simpleJwt: SimpleJwt<LoggedInUser>,
private val lens: OptionalAuthLens, private val lens: OptionalAuthLens,
private val source: JwtSource = Cookie, private val source: JwtSource = Cookie,
) : Filter { ) : Filter {
@ -17,6 +18,6 @@ class OptionalAuthFilter(
Cookie -> it.bearerTokenCookie() Cookie -> it.bearerTokenCookie()
} }
next(it.with(lens of token?.let { extractor(it) })) next(it.with(lens of token?.let { simpleJwt.extract(it) }))
} }
} }

View File

@ -1,7 +1,8 @@
package be.simplenotes.app.filters.auth package be.simplenotes.app.filters.auth
import be.simplenotes.app.extensions.redirect import be.simplenotes.app.extensions.redirect
import be.simplenotes.domain.security.JwtPayloadExtractor import be.simplenotes.domain.security.SimpleJwt
import be.simplenotes.types.LoggedInUser
import org.http4k.core.Filter import org.http4k.core.Filter
import org.http4k.core.HttpHandler import org.http4k.core.HttpHandler
import org.http4k.core.Response import org.http4k.core.Response
@ -9,7 +10,7 @@ import org.http4k.core.Status.Companion.UNAUTHORIZED
import org.http4k.core.with import org.http4k.core.with
class RequiredAuthFilter( class RequiredAuthFilter(
private val extractor: JwtPayloadExtractor, private val simpleJwt: SimpleJwt<LoggedInUser>,
private val lens: RequiredAuthLens, private val lens: RequiredAuthLens,
private val source: JwtSource = JwtSource.Cookie, private val source: JwtSource = JwtSource.Cookie,
private val redirect: Boolean = true, private val redirect: Boolean = true,
@ -19,7 +20,7 @@ class RequiredAuthFilter(
JwtSource.Header -> it.bearerTokenHeader() JwtSource.Header -> it.bearerTokenHeader()
JwtSource.Cookie -> it.bearerTokenCookie() JwtSource.Cookie -> it.bearerTokenCookie()
} }
val jwtPayload = token?.let { extractor(token) } val jwtPayload = token?.let { simpleJwt.extract(token) }
if (jwtPayload != null) next(it.with(lens of jwtPayload)) if (jwtPayload != null) next(it.with(lens of jwtPayload))
else { else {

View File

@ -1,7 +1,8 @@
package be.simplenotes.app.modules package be.simplenotes.app.modules
import be.simplenotes.app.filters.auth.* import be.simplenotes.app.filters.auth.*
import be.simplenotes.domain.security.JwtPayloadExtractor import be.simplenotes.domain.security.SimpleJwt
import be.simplenotes.types.LoggedInUser
import io.micronaut.context.annotation.Factory import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Primary import io.micronaut.context.annotation.Primary
import org.http4k.core.RequestContexts import org.http4k.core.RequestContexts
@ -21,21 +22,21 @@ class AuthModule {
fun requiredAuthLens(ctx: RequestContexts): RequiredAuthLens = RequestContextKey.required(ctx) fun requiredAuthLens(ctx: RequestContexts): RequiredAuthLens = RequestContextKey.required(ctx)
@Singleton @Singleton
fun optionalAuth(extractor: JwtPayloadExtractor, @Named("optional") lens: OptionalAuthLens) = fun optionalAuth(simpleJwt: SimpleJwt<LoggedInUser>, @Named("optional") lens: OptionalAuthLens) =
OptionalAuthFilter(extractor, lens) OptionalAuthFilter(simpleJwt, lens)
@Primary @Primary
@Singleton @Singleton
fun requiredAuth(extractor: JwtPayloadExtractor, @Named("required") lens: RequiredAuthLens) = fun requiredAuth(simpleJwt: SimpleJwt<LoggedInUser>, @Named("required") lens: RequiredAuthLens) =
RequiredAuthFilter(extractor, lens) RequiredAuthFilter(simpleJwt, lens)
@Singleton @Singleton
@Named("api") @Named("api")
internal fun apiAuthFilter( internal fun apiAuthFilter(
jwtPayloadExtractor: JwtPayloadExtractor, simpleJwt: SimpleJwt<LoggedInUser>,
@Named("required") lens: RequiredAuthLens, @Named("required") lens: RequiredAuthLens,
) = RequiredAuthFilter( ) = RequiredAuthFilter(
extractor = jwtPayloadExtractor, simpleJwt = simpleJwt,
lens = lens, lens = lens,
source = JwtSource.Header, source = JwtSource.Header,
redirect = false redirect = false

View File

@ -6,6 +6,7 @@ import be.simplenotes.app.filters.auth.RequiredAuthFilter
import be.simplenotes.app.filters.auth.RequiredAuthLens import be.simplenotes.app.filters.auth.RequiredAuthLens
import be.simplenotes.config.JwtConfig import be.simplenotes.config.JwtConfig
import be.simplenotes.domain.security.SimpleJwt import be.simplenotes.domain.security.SimpleJwt
import be.simplenotes.domain.security.UserJwtMapper
import be.simplenotes.types.LoggedInUser import be.simplenotes.types.LoggedInUser
import com.natpryce.hamkrest.assertion.assertThat import com.natpryce.hamkrest.assertion.assertThat
import io.micronaut.context.BeanContext import io.micronaut.context.BeanContext
@ -32,7 +33,7 @@ internal class RequiredAuthFilterTest {
// region setup // region setup
private val jwtConfig = JwtConfig("secret", 1, TimeUnit.HOURS) private val jwtConfig = JwtConfig("secret", 1, TimeUnit.HOURS)
private val simpleJwt = SimpleJwt(jwtConfig) private val simpleJwt = SimpleJwt(jwtConfig, UserJwtMapper())
private val beanCtx = BeanContext.build() private val beanCtx = BeanContext.build()
.registerSingleton(jwtConfig) .registerSingleton(jwtConfig)

View File

@ -0,0 +1,30 @@
package be.simplenotes.domain.security
import be.simplenotes.types.LoggedInUser
import com.auth0.jwt.JWTCreator
import com.auth0.jwt.interfaces.DecodedJWT
import javax.inject.Singleton
interface JwtMapper<T> {
fun extract(decodedJWT: DecodedJWT): T?
fun build(builder: JWTCreator.Builder, value: T)
}
@Singleton
class UserJwtMapper : JwtMapper<LoggedInUser> {
private val userIdField = "i"
private val usernameField = "u"
override fun extract(decodedJWT: DecodedJWT): LoggedInUser? {
val id = decodedJWT.getClaim(userIdField).asInt() ?: null
val username = decodedJWT.getClaim(usernameField).asString() ?: null
return if (id != null && username != null)
LoggedInUser(id, username)
else null
}
override fun build(builder: JWTCreator.Builder, value: LoggedInUser) {
builder.withClaim(userIdField, value.userId)
.withClaim(usernameField, value.username)
}
}

View File

@ -1,19 +0,0 @@
package be.simplenotes.domain.security
import be.simplenotes.types.LoggedInUser
import com.auth0.jwt.exceptions.JWTVerificationException
import javax.inject.Singleton
@Singleton
class JwtPayloadExtractor(private val jwt: SimpleJwt) {
operator fun invoke(token: String): LoggedInUser? = try {
val decodedJWT = jwt.verifier.verify(token)
val id = decodedJWT.getClaim(userIdField).asInt() ?: null
val username = decodedJWT.getClaim(usernameField).asString() ?: null
id?.let { username?.let { LoggedInUser(id, username) } }
} catch (e: JWTVerificationException) {
null
} catch (e: IllegalArgumentException) {
null
}
}

View File

@ -1,28 +1,33 @@
package be.simplenotes.domain.security package be.simplenotes.domain.security
import be.simplenotes.config.JwtConfig import be.simplenotes.config.JwtConfig
import be.simplenotes.types.LoggedInUser
import com.auth0.jwt.JWT import com.auth0.jwt.JWT
import com.auth0.jwt.JWTVerifier import com.auth0.jwt.JWTVerifier
import com.auth0.jwt.algorithms.Algorithm import com.auth0.jwt.algorithms.Algorithm
import com.auth0.jwt.exceptions.JWTVerificationException
import java.util.* import java.util.*
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import javax.inject.Singleton import javax.inject.Singleton
internal const val userIdField = "i"
internal const val usernameField = "u"
@Singleton @Singleton
class SimpleJwt(jwtConfig: JwtConfig) { class SimpleJwt<T>(jwtConfig: JwtConfig, private val mapper: JwtMapper<T>) {
private val validityInMs = TimeUnit.MILLISECONDS.convert(jwtConfig.validity, jwtConfig.timeUnit) private val validityInMs = TimeUnit.MILLISECONDS.convert(jwtConfig.validity, jwtConfig.timeUnit)
private val algorithm = Algorithm.HMAC256(jwtConfig.secret) private val algorithm = Algorithm.HMAC256(jwtConfig.secret)
private val verifier: JWTVerifier = JWT.require(algorithm).build()
val verifier: JWTVerifier = JWT.require(algorithm).build() fun sign(value: T): String = JWT.create()
fun sign(loggedInUser: LoggedInUser): String = JWT.create() .apply { mapper.build(this, value) }
.withClaim(userIdField, loggedInUser.userId)
.withClaim(usernameField, loggedInUser.username)
.withExpiresAt(getExpiration()) .withExpiresAt(getExpiration())
.sign(algorithm) .sign(algorithm)
fun extract(token: String): T? = try {
val decodedJWT = verifier.verify(token)
mapper.extract(decodedJWT)
} catch (e: JWTVerificationException) {
null
} catch (e: IllegalArgumentException) {
null
}
private fun getExpiration() = Date(System.currentTimeMillis() + validityInMs) private fun getExpiration() = Date(System.currentTimeMillis() + validityInMs)
} }

View File

@ -16,7 +16,7 @@ import javax.inject.Singleton
internal class LoginUseCaseImpl( internal class LoginUseCaseImpl(
private val userRepository: UserRepository, private val userRepository: UserRepository,
private val passwordHash: PasswordHash, private val passwordHash: PasswordHash,
private val jwt: SimpleJwt private val jwt: SimpleJwt<LoggedInUser>
) : LoginUseCase { ) : LoginUseCase {
override fun login(form: LoginForm) = either.eager<LoginError, Token> { override fun login(form: LoginForm) = either.eager<LoginError, Token> {
val user = !UserValidations.validateLogin(form) val user = !UserValidations.validateLogin(form)

View File

@ -16,14 +16,13 @@ import java.util.stream.Stream
internal class LoggedInUserExtractorTest { internal class LoggedInUserExtractorTest {
private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS) private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS)
private val simpleJwt = SimpleJwt(jwtConfig) private val mapper = UserJwtMapper()
private val jwtPayloadExtractor = JwtPayloadExtractor(simpleJwt) private val simpleJwt = SimpleJwt(jwtConfig, mapper)
private fun createToken(username: String? = null, id: Int? = null, secret: String = jwtConfig.secret): Token { private fun createToken(username: String? = null, id: Int? = null, secret: String = jwtConfig.secret): Token {
val algo = Algorithm.HMAC256(secret) val algo = Algorithm.HMAC256(secret)
return JWT.create().apply { return JWT.create().apply {
username?.let { withClaim(usernameField, it) } if (username != null && id != null) mapper.build(this, LoggedInUser(id, username))
id?.let { withClaim(userIdField, it) }
}.sign(algo) }.sign(algo)
} }
@ -40,12 +39,12 @@ internal class LoggedInUserExtractorTest {
@ParameterizedTest(name = "[{index}] token `{0}` should be invalid") @ParameterizedTest(name = "[{index}] token `{0}` should be invalid")
@MethodSource("invalidTokens") @MethodSource("invalidTokens")
fun `parse invalid tokens`(token: String) { fun `parse invalid tokens`(token: String) {
assertThat(jwtPayloadExtractor(token), absent()) assertThat(simpleJwt.extract(token), absent())
} }
@Test @Test
fun `parse valid token`() { fun `parse valid token`() {
val token = createToken(username = "someone", id = 1) val token = createToken(username = "someone", id = 1)
assertThat(jwtPayloadExtractor(token), equalTo(LoggedInUser(1, "someone"))) assertThat(simpleJwt.extract(token), equalTo(LoggedInUser(1, "someone")))
} }
} }

View File

@ -3,6 +3,7 @@ package be.simplenotes.domain.usecases.users.login
import be.simplenotes.config.JwtConfig import be.simplenotes.config.JwtConfig
import be.simplenotes.domain.security.BcryptPasswordHash import be.simplenotes.domain.security.BcryptPasswordHash
import be.simplenotes.domain.security.SimpleJwt import be.simplenotes.domain.security.SimpleJwt
import be.simplenotes.domain.security.UserJwtMapper
import be.simplenotes.domain.testutils.isLeftOfType import be.simplenotes.domain.testutils.isLeftOfType
import be.simplenotes.domain.testutils.isRight import be.simplenotes.domain.testutils.isRight
import be.simplenotes.persistance.repositories.UserRepository import be.simplenotes.persistance.repositories.UserRepository
@ -18,7 +19,7 @@ internal class LoginUseCaseImplTest {
private val mockUserRepository = mockk<UserRepository>() private val mockUserRepository = mockk<UserRepository>()
private val passwordHash = BcryptPasswordHash(test = true) private val passwordHash = BcryptPasswordHash(test = true)
private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS) private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS)
private val simpleJwt = SimpleJwt(jwtConfig) private val simpleJwt = SimpleJwt(jwtConfig, UserJwtMapper())
private val loginUseCase = LoginUseCaseImpl(mockUserRepository, passwordHash, simpleJwt) private val loginUseCase = LoginUseCaseImpl(mockUserRepository, passwordHash, simpleJwt)
@BeforeEach @BeforeEach