Refactor jwt
This commit is contained in:
parent
8439782430
commit
90701dcdce
@ -1,13 +1,14 @@
|
||||
package be.simplenotes.app.filters.auth
|
||||
|
||||
import be.simplenotes.app.filters.auth.JwtSource.Cookie
|
||||
import be.simplenotes.domain.security.JwtPayloadExtractor
|
||||
import be.simplenotes.domain.security.SimpleJwt
|
||||
import be.simplenotes.types.LoggedInUser
|
||||
import org.http4k.core.Filter
|
||||
import org.http4k.core.HttpHandler
|
||||
import org.http4k.core.with
|
||||
|
||||
class OptionalAuthFilter(
|
||||
private val extractor: JwtPayloadExtractor,
|
||||
private val simpleJwt: SimpleJwt<LoggedInUser>,
|
||||
private val lens: OptionalAuthLens,
|
||||
private val source: JwtSource = Cookie,
|
||||
) : Filter {
|
||||
@ -17,6 +18,6 @@ class OptionalAuthFilter(
|
||||
Cookie -> it.bearerTokenCookie()
|
||||
}
|
||||
|
||||
next(it.with(lens of token?.let { extractor(it) }))
|
||||
next(it.with(lens of token?.let { simpleJwt.extract(it) }))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
package be.simplenotes.app.filters.auth
|
||||
|
||||
import be.simplenotes.app.extensions.redirect
|
||||
import be.simplenotes.domain.security.JwtPayloadExtractor
|
||||
import be.simplenotes.domain.security.SimpleJwt
|
||||
import be.simplenotes.types.LoggedInUser
|
||||
import org.http4k.core.Filter
|
||||
import org.http4k.core.HttpHandler
|
||||
import org.http4k.core.Response
|
||||
@ -9,7 +10,7 @@ import org.http4k.core.Status.Companion.UNAUTHORIZED
|
||||
import org.http4k.core.with
|
||||
|
||||
class RequiredAuthFilter(
|
||||
private val extractor: JwtPayloadExtractor,
|
||||
private val simpleJwt: SimpleJwt<LoggedInUser>,
|
||||
private val lens: RequiredAuthLens,
|
||||
private val source: JwtSource = JwtSource.Cookie,
|
||||
private val redirect: Boolean = true,
|
||||
@ -19,7 +20,7 @@ class RequiredAuthFilter(
|
||||
JwtSource.Header -> it.bearerTokenHeader()
|
||||
JwtSource.Cookie -> it.bearerTokenCookie()
|
||||
}
|
||||
val jwtPayload = token?.let { extractor(token) }
|
||||
val jwtPayload = token?.let { simpleJwt.extract(token) }
|
||||
|
||||
if (jwtPayload != null) next(it.with(lens of jwtPayload))
|
||||
else {
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
package be.simplenotes.app.modules
|
||||
|
||||
import be.simplenotes.app.filters.auth.*
|
||||
import be.simplenotes.domain.security.JwtPayloadExtractor
|
||||
import be.simplenotes.domain.security.SimpleJwt
|
||||
import be.simplenotes.types.LoggedInUser
|
||||
import io.micronaut.context.annotation.Factory
|
||||
import io.micronaut.context.annotation.Primary
|
||||
import org.http4k.core.RequestContexts
|
||||
@ -21,21 +22,21 @@ class AuthModule {
|
||||
fun requiredAuthLens(ctx: RequestContexts): RequiredAuthLens = RequestContextKey.required(ctx)
|
||||
|
||||
@Singleton
|
||||
fun optionalAuth(extractor: JwtPayloadExtractor, @Named("optional") lens: OptionalAuthLens) =
|
||||
OptionalAuthFilter(extractor, lens)
|
||||
fun optionalAuth(simpleJwt: SimpleJwt<LoggedInUser>, @Named("optional") lens: OptionalAuthLens) =
|
||||
OptionalAuthFilter(simpleJwt, lens)
|
||||
|
||||
@Primary
|
||||
@Singleton
|
||||
fun requiredAuth(extractor: JwtPayloadExtractor, @Named("required") lens: RequiredAuthLens) =
|
||||
RequiredAuthFilter(extractor, lens)
|
||||
fun requiredAuth(simpleJwt: SimpleJwt<LoggedInUser>, @Named("required") lens: RequiredAuthLens) =
|
||||
RequiredAuthFilter(simpleJwt, lens)
|
||||
|
||||
@Singleton
|
||||
@Named("api")
|
||||
internal fun apiAuthFilter(
|
||||
jwtPayloadExtractor: JwtPayloadExtractor,
|
||||
simpleJwt: SimpleJwt<LoggedInUser>,
|
||||
@Named("required") lens: RequiredAuthLens,
|
||||
) = RequiredAuthFilter(
|
||||
extractor = jwtPayloadExtractor,
|
||||
simpleJwt = simpleJwt,
|
||||
lens = lens,
|
||||
source = JwtSource.Header,
|
||||
redirect = false
|
||||
|
||||
@ -6,6 +6,7 @@ import be.simplenotes.app.filters.auth.RequiredAuthFilter
|
||||
import be.simplenotes.app.filters.auth.RequiredAuthLens
|
||||
import be.simplenotes.config.JwtConfig
|
||||
import be.simplenotes.domain.security.SimpleJwt
|
||||
import be.simplenotes.domain.security.UserJwtMapper
|
||||
import be.simplenotes.types.LoggedInUser
|
||||
import com.natpryce.hamkrest.assertion.assertThat
|
||||
import io.micronaut.context.BeanContext
|
||||
@ -32,7 +33,7 @@ internal class RequiredAuthFilterTest {
|
||||
|
||||
// region setup
|
||||
private val jwtConfig = JwtConfig("secret", 1, TimeUnit.HOURS)
|
||||
private val simpleJwt = SimpleJwt(jwtConfig)
|
||||
private val simpleJwt = SimpleJwt(jwtConfig, UserJwtMapper())
|
||||
|
||||
private val beanCtx = BeanContext.build()
|
||||
.registerSingleton(jwtConfig)
|
||||
|
||||
30
domain/src/security/JwtMapper.kt
Normal file
30
domain/src/security/JwtMapper.kt
Normal file
@ -0,0 +1,30 @@
|
||||
package be.simplenotes.domain.security
|
||||
|
||||
import be.simplenotes.types.LoggedInUser
|
||||
import com.auth0.jwt.JWTCreator
|
||||
import com.auth0.jwt.interfaces.DecodedJWT
|
||||
import javax.inject.Singleton
|
||||
|
||||
interface JwtMapper<T> {
|
||||
fun extract(decodedJWT: DecodedJWT): T?
|
||||
fun build(builder: JWTCreator.Builder, value: T)
|
||||
}
|
||||
|
||||
@Singleton
|
||||
class UserJwtMapper : JwtMapper<LoggedInUser> {
|
||||
private val userIdField = "i"
|
||||
private val usernameField = "u"
|
||||
|
||||
override fun extract(decodedJWT: DecodedJWT): LoggedInUser? {
|
||||
val id = decodedJWT.getClaim(userIdField).asInt() ?: null
|
||||
val username = decodedJWT.getClaim(usernameField).asString() ?: null
|
||||
return if (id != null && username != null)
|
||||
LoggedInUser(id, username)
|
||||
else null
|
||||
}
|
||||
|
||||
override fun build(builder: JWTCreator.Builder, value: LoggedInUser) {
|
||||
builder.withClaim(userIdField, value.userId)
|
||||
.withClaim(usernameField, value.username)
|
||||
}
|
||||
}
|
||||
@ -1,19 +0,0 @@
|
||||
package be.simplenotes.domain.security
|
||||
|
||||
import be.simplenotes.types.LoggedInUser
|
||||
import com.auth0.jwt.exceptions.JWTVerificationException
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class JwtPayloadExtractor(private val jwt: SimpleJwt) {
|
||||
operator fun invoke(token: String): LoggedInUser? = try {
|
||||
val decodedJWT = jwt.verifier.verify(token)
|
||||
val id = decodedJWT.getClaim(userIdField).asInt() ?: null
|
||||
val username = decodedJWT.getClaim(usernameField).asString() ?: null
|
||||
id?.let { username?.let { LoggedInUser(id, username) } }
|
||||
} catch (e: JWTVerificationException) {
|
||||
null
|
||||
} catch (e: IllegalArgumentException) {
|
||||
null
|
||||
}
|
||||
}
|
||||
@ -1,28 +1,33 @@
|
||||
package be.simplenotes.domain.security
|
||||
|
||||
import be.simplenotes.config.JwtConfig
|
||||
import be.simplenotes.types.LoggedInUser
|
||||
import com.auth0.jwt.JWT
|
||||
import com.auth0.jwt.JWTVerifier
|
||||
import com.auth0.jwt.algorithms.Algorithm
|
||||
import com.auth0.jwt.exceptions.JWTVerificationException
|
||||
import java.util.*
|
||||
import java.util.concurrent.TimeUnit
|
||||
import javax.inject.Singleton
|
||||
|
||||
internal const val userIdField = "i"
|
||||
internal const val usernameField = "u"
|
||||
|
||||
@Singleton
|
||||
class SimpleJwt(jwtConfig: JwtConfig) {
|
||||
class SimpleJwt<T>(jwtConfig: JwtConfig, private val mapper: JwtMapper<T>) {
|
||||
private val validityInMs = TimeUnit.MILLISECONDS.convert(jwtConfig.validity, jwtConfig.timeUnit)
|
||||
private val algorithm = Algorithm.HMAC256(jwtConfig.secret)
|
||||
private val verifier: JWTVerifier = JWT.require(algorithm).build()
|
||||
|
||||
val verifier: JWTVerifier = JWT.require(algorithm).build()
|
||||
fun sign(loggedInUser: LoggedInUser): String = JWT.create()
|
||||
.withClaim(userIdField, loggedInUser.userId)
|
||||
.withClaim(usernameField, loggedInUser.username)
|
||||
fun sign(value: T): String = JWT.create()
|
||||
.apply { mapper.build(this, value) }
|
||||
.withExpiresAt(getExpiration())
|
||||
.sign(algorithm)
|
||||
|
||||
fun extract(token: String): T? = try {
|
||||
val decodedJWT = verifier.verify(token)
|
||||
mapper.extract(decodedJWT)
|
||||
} catch (e: JWTVerificationException) {
|
||||
null
|
||||
} catch (e: IllegalArgumentException) {
|
||||
null
|
||||
}
|
||||
|
||||
private fun getExpiration() = Date(System.currentTimeMillis() + validityInMs)
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@ import javax.inject.Singleton
|
||||
internal class LoginUseCaseImpl(
|
||||
private val userRepository: UserRepository,
|
||||
private val passwordHash: PasswordHash,
|
||||
private val jwt: SimpleJwt
|
||||
private val jwt: SimpleJwt<LoggedInUser>
|
||||
) : LoginUseCase {
|
||||
override fun login(form: LoginForm) = either.eager<LoginError, Token> {
|
||||
val user = !UserValidations.validateLogin(form)
|
||||
|
||||
@ -16,14 +16,13 @@ import java.util.stream.Stream
|
||||
|
||||
internal class LoggedInUserExtractorTest {
|
||||
private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS)
|
||||
private val simpleJwt = SimpleJwt(jwtConfig)
|
||||
private val jwtPayloadExtractor = JwtPayloadExtractor(simpleJwt)
|
||||
private val mapper = UserJwtMapper()
|
||||
private val simpleJwt = SimpleJwt(jwtConfig, mapper)
|
||||
|
||||
private fun createToken(username: String? = null, id: Int? = null, secret: String = jwtConfig.secret): Token {
|
||||
val algo = Algorithm.HMAC256(secret)
|
||||
return JWT.create().apply {
|
||||
username?.let { withClaim(usernameField, it) }
|
||||
id?.let { withClaim(userIdField, it) }
|
||||
if (username != null && id != null) mapper.build(this, LoggedInUser(id, username))
|
||||
}.sign(algo)
|
||||
}
|
||||
|
||||
@ -40,12 +39,12 @@ internal class LoggedInUserExtractorTest {
|
||||
@ParameterizedTest(name = "[{index}] token `{0}` should be invalid")
|
||||
@MethodSource("invalidTokens")
|
||||
fun `parse invalid tokens`(token: String) {
|
||||
assertThat(jwtPayloadExtractor(token), absent())
|
||||
assertThat(simpleJwt.extract(token), absent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `parse valid token`() {
|
||||
val token = createToken(username = "someone", id = 1)
|
||||
assertThat(jwtPayloadExtractor(token), equalTo(LoggedInUser(1, "someone")))
|
||||
assertThat(simpleJwt.extract(token), equalTo(LoggedInUser(1, "someone")))
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ package be.simplenotes.domain.usecases.users.login
|
||||
import be.simplenotes.config.JwtConfig
|
||||
import be.simplenotes.domain.security.BcryptPasswordHash
|
||||
import be.simplenotes.domain.security.SimpleJwt
|
||||
import be.simplenotes.domain.security.UserJwtMapper
|
||||
import be.simplenotes.domain.testutils.isLeftOfType
|
||||
import be.simplenotes.domain.testutils.isRight
|
||||
import be.simplenotes.persistance.repositories.UserRepository
|
||||
@ -18,7 +19,7 @@ internal class LoginUseCaseImplTest {
|
||||
private val mockUserRepository = mockk<UserRepository>()
|
||||
private val passwordHash = BcryptPasswordHash(test = true)
|
||||
private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS)
|
||||
private val simpleJwt = SimpleJwt(jwtConfig)
|
||||
private val simpleJwt = SimpleJwt(jwtConfig, UserJwtMapper())
|
||||
private val loginUseCase = LoginUseCaseImpl(mockUserRepository, passwordHash, simpleJwt)
|
||||
|
||||
@BeforeEach
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user