Refactor jwt

This commit is contained in:
2020-11-11 23:19:14 +01:00
parent 8439782430
commit 90701dcdce
10 changed files with 70 additions and 50 deletions
+30
View File
@@ -0,0 +1,30 @@
package be.simplenotes.domain.security
import be.simplenotes.types.LoggedInUser
import com.auth0.jwt.JWTCreator
import com.auth0.jwt.interfaces.DecodedJWT
import javax.inject.Singleton
interface JwtMapper<T> {
fun extract(decodedJWT: DecodedJWT): T?
fun build(builder: JWTCreator.Builder, value: T)
}
@Singleton
class UserJwtMapper : JwtMapper<LoggedInUser> {
private val userIdField = "i"
private val usernameField = "u"
override fun extract(decodedJWT: DecodedJWT): LoggedInUser? {
val id = decodedJWT.getClaim(userIdField).asInt() ?: null
val username = decodedJWT.getClaim(usernameField).asString() ?: null
return if (id != null && username != null)
LoggedInUser(id, username)
else null
}
override fun build(builder: JWTCreator.Builder, value: LoggedInUser) {
builder.withClaim(userIdField, value.userId)
.withClaim(usernameField, value.username)
}
}
@@ -1,19 +0,0 @@
package be.simplenotes.domain.security
import be.simplenotes.types.LoggedInUser
import com.auth0.jwt.exceptions.JWTVerificationException
import javax.inject.Singleton
@Singleton
class JwtPayloadExtractor(private val jwt: SimpleJwt) {
operator fun invoke(token: String): LoggedInUser? = try {
val decodedJWT = jwt.verifier.verify(token)
val id = decodedJWT.getClaim(userIdField).asInt() ?: null
val username = decodedJWT.getClaim(usernameField).asString() ?: null
id?.let { username?.let { LoggedInUser(id, username) } }
} catch (e: JWTVerificationException) {
null
} catch (e: IllegalArgumentException) {
null
}
}
+14 -9
View File
@@ -1,28 +1,33 @@
package be.simplenotes.domain.security
import be.simplenotes.config.JwtConfig
import be.simplenotes.types.LoggedInUser
import com.auth0.jwt.JWT
import com.auth0.jwt.JWTVerifier
import com.auth0.jwt.algorithms.Algorithm
import com.auth0.jwt.exceptions.JWTVerificationException
import java.util.*
import java.util.concurrent.TimeUnit
import javax.inject.Singleton
internal const val userIdField = "i"
internal const val usernameField = "u"
@Singleton
class SimpleJwt(jwtConfig: JwtConfig) {
class SimpleJwt<T>(jwtConfig: JwtConfig, private val mapper: JwtMapper<T>) {
private val validityInMs = TimeUnit.MILLISECONDS.convert(jwtConfig.validity, jwtConfig.timeUnit)
private val algorithm = Algorithm.HMAC256(jwtConfig.secret)
private val verifier: JWTVerifier = JWT.require(algorithm).build()
val verifier: JWTVerifier = JWT.require(algorithm).build()
fun sign(loggedInUser: LoggedInUser): String = JWT.create()
.withClaim(userIdField, loggedInUser.userId)
.withClaim(usernameField, loggedInUser.username)
fun sign(value: T): String = JWT.create()
.apply { mapper.build(this, value) }
.withExpiresAt(getExpiration())
.sign(algorithm)
fun extract(token: String): T? = try {
val decodedJWT = verifier.verify(token)
mapper.extract(decodedJWT)
} catch (e: JWTVerificationException) {
null
} catch (e: IllegalArgumentException) {
null
}
private fun getExpiration() = Date(System.currentTimeMillis() + validityInMs)
}
@@ -16,7 +16,7 @@ import javax.inject.Singleton
internal class LoginUseCaseImpl(
private val userRepository: UserRepository,
private val passwordHash: PasswordHash,
private val jwt: SimpleJwt
private val jwt: SimpleJwt<LoggedInUser>
) : LoginUseCase {
override fun login(form: LoginForm) = either.eager<LoginError, Token> {
val user = !UserValidations.validateLogin(form)
@@ -16,14 +16,13 @@ import java.util.stream.Stream
internal class LoggedInUserExtractorTest {
private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS)
private val simpleJwt = SimpleJwt(jwtConfig)
private val jwtPayloadExtractor = JwtPayloadExtractor(simpleJwt)
private val mapper = UserJwtMapper()
private val simpleJwt = SimpleJwt(jwtConfig, mapper)
private fun createToken(username: String? = null, id: Int? = null, secret: String = jwtConfig.secret): Token {
val algo = Algorithm.HMAC256(secret)
return JWT.create().apply {
username?.let { withClaim(usernameField, it) }
id?.let { withClaim(userIdField, it) }
if (username != null && id != null) mapper.build(this, LoggedInUser(id, username))
}.sign(algo)
}
@@ -40,12 +39,12 @@ internal class LoggedInUserExtractorTest {
@ParameterizedTest(name = "[{index}] token `{0}` should be invalid")
@MethodSource("invalidTokens")
fun `parse invalid tokens`(token: String) {
assertThat(jwtPayloadExtractor(token), absent())
assertThat(simpleJwt.extract(token), absent())
}
@Test
fun `parse valid token`() {
val token = createToken(username = "someone", id = 1)
assertThat(jwtPayloadExtractor(token), equalTo(LoggedInUser(1, "someone")))
assertThat(simpleJwt.extract(token), equalTo(LoggedInUser(1, "someone")))
}
}
@@ -3,6 +3,7 @@ package be.simplenotes.domain.usecases.users.login
import be.simplenotes.config.JwtConfig
import be.simplenotes.domain.security.BcryptPasswordHash
import be.simplenotes.domain.security.SimpleJwt
import be.simplenotes.domain.security.UserJwtMapper
import be.simplenotes.domain.testutils.isLeftOfType
import be.simplenotes.domain.testutils.isRight
import be.simplenotes.persistance.repositories.UserRepository
@@ -18,7 +19,7 @@ internal class LoginUseCaseImplTest {
private val mockUserRepository = mockk<UserRepository>()
private val passwordHash = BcryptPasswordHash(test = true)
private val jwtConfig = JwtConfig("a secret", 1, TimeUnit.HOURS)
private val simpleJwt = SimpleJwt(jwtConfig)
private val simpleJwt = SimpleJwt(jwtConfig, UserJwtMapper())
private val loginUseCase = LoginUseCaseImpl(mockUserRepository, passwordHash, simpleJwt)
@BeforeEach