diff --git a/app/pom.xml b/app/pom.xml index beb6b42..b39affb 100644 --- a/app/pom.xml +++ b/app/pom.xml @@ -91,6 +91,11 @@ http4k-testing-hamkrest test + + + me.liuwj.ktorm + ktorm-core + diff --git a/app/src/main/kotlin/filters/AuthFilter.kt b/app/src/main/kotlin/filters/AuthFilter.kt index 1a86dec..98f64f0 100644 --- a/app/src/main/kotlin/filters/AuthFilter.kt +++ b/app/src/main/kotlin/filters/AuthFilter.kt @@ -19,25 +19,23 @@ class AuthFilter( private val ctx: RequestContexts, private val source: JwtSource = JwtSource.Cookie, private val redirect: Boolean = true, -) { - operator fun invoke() = Filter { next -> - { - val token = when (source) { - JwtSource.Header -> it.bearerTokenHeader() - JwtSource.Cookie -> it.bearerTokenCookie() +) : Filter { + override fun invoke(next: HttpHandler): HttpHandler = { + val token = when (source) { + JwtSource.Header -> it.bearerTokenHeader() + JwtSource.Cookie -> it.bearerTokenCookie() + } + val jwtPayload = token?.let { token -> extractor(token) } + when { + jwtPayload != null -> { + ctx[it][authKey] = jwtPayload + next(it) } - val jwtPayload = token?.let { token -> extractor(token) } - when { - jwtPayload != null -> { - ctx[it][authKey] = jwtPayload - next(it) - } - authType == AuthType.Required -> { - if (redirect) Response.redirect("/login") - else Response(UNAUTHORIZED) - } - else -> next(it) + authType == AuthType.Required -> { + if (redirect) Response.redirect("/login") + else Response(UNAUTHORIZED) } + else -> next(it) } } } diff --git a/app/src/main/kotlin/filters/ErrorFilter.kt b/app/src/main/kotlin/filters/ErrorFilter.kt index 9fd0b97..21f8ece 100644 --- a/app/src/main/kotlin/filters/ErrorFilter.kt +++ b/app/src/main/kotlin/filters/ErrorFilter.kt @@ -2,32 +2,44 @@ package be.simplenotes.app.filters import be.simplenotes.app.extensions.html import be.simplenotes.app.views.ErrorView +import be.simplenotes.app.views.ErrorView.Type.* import org.http4k.core.* +import org.http4k.core.Status.Companion.INTERNAL_SERVER_ERROR +import org.http4k.core.Status.Companion.NOT_FOUND +import org.http4k.core.Status.Companion.NOT_IMPLEMENTED +import org.http4k.core.Status.Companion.SERVICE_UNAVAILABLE import org.slf4j.LoggerFactory import java.sql.SQLTransientException -class ErrorFilter(private val errorView: ErrorView) { +class ErrorFilter(private val errorView: ErrorView) : Filter { private val logger = LoggerFactory.getLogger(javaClass) - operator fun invoke(): Filter = Filter { next -> - { - try { - val response = next(it) - if (response.status == Status.NOT_FOUND) Response(Status.NOT_FOUND) - .html(errorView.error(ErrorView.Type.NotFound)) - else response - } catch (e: Exception) { - logger.error(e.stackTraceToString()) - if (e is SQLTransientException) - Response(Status.SERVICE_UNAVAILABLE).html(errorView.error(ErrorView.Type.SqlTransientError)) - .noCache() - else - Response(Status.INTERNAL_SERVER_ERROR).html(errorView.error(ErrorView.Type.Other)).noCache() - } catch (e: NotImplementedError) { - logger.error(e.stackTraceToString()) - Response(Status.NOT_IMPLEMENTED).html(errorView.error(ErrorView.Type.Other)).noCache() - } + private fun errorResponse(status: Status): Response { + val type = when (status) { + SERVICE_UNAVAILABLE -> SqlTransientError + NOT_FOUND -> NotFound + NOT_IMPLEMENTED -> Other + else -> Other + } + + return Response(status).html(errorView.error(type)).noCache() + } + + override fun invoke(next: HttpHandler): HttpHandler = { request -> + try { + val response = next(request) + if (response.status == NOT_FOUND) errorResponse(NOT_FOUND) + else response + } catch (e: SQLTransientException) { + logger.error(e.stackTraceToString()) + errorResponse(SERVICE_UNAVAILABLE) + } catch (e: Exception) { + logger.error(e.stackTraceToString()) + errorResponse(INTERNAL_SERVER_ERROR) + } catch (e: NotImplementedError) { + logger.error(e.stackTraceToString()) + errorResponse(NOT_IMPLEMENTED) } } } diff --git a/app/src/main/kotlin/filters/ImmutableFilter.kt b/app/src/main/kotlin/filters/ImmutableFilter.kt index e8a3a33..fd82ab5 100644 --- a/app/src/main/kotlin/filters/ImmutableFilter.kt +++ b/app/src/main/kotlin/filters/ImmutableFilter.kt @@ -2,13 +2,10 @@ package be.simplenotes.app.filters import org.http4k.core.Filter import org.http4k.core.HttpHandler -import org.http4k.core.Method import org.http4k.core.Request -object ImmutableFilter { - operator fun invoke() = Filter { next: HttpHandler -> - { request: Request -> - next(request).header("Cache-Control", "public, max-age=31536000, immutable") - } +object ImmutableFilter : Filter { + override fun invoke(next: HttpHandler) = { request: Request -> + next(request).header("Cache-Control", "public, max-age=31536000, immutable") } } diff --git a/app/src/main/kotlin/filters/SecurityFilter.kt b/app/src/main/kotlin/filters/SecurityFilter.kt index d44d821..dc986e9 100644 --- a/app/src/main/kotlin/filters/SecurityFilter.kt +++ b/app/src/main/kotlin/filters/SecurityFilter.kt @@ -4,17 +4,15 @@ import org.http4k.core.Filter import org.http4k.core.HttpHandler import org.http4k.core.Request -object SecurityFilter { - operator fun invoke() = Filter { next: HttpHandler -> - { request: Request -> - val response = next(request) - .header("X-Content-Type-Options", "nosniff") +object SecurityFilter : Filter { + override fun invoke(next: HttpHandler): HttpHandler = { request: Request -> + val response = next(request) + .header("X-Content-Type-Options", "nosniff") - if (response.header("Content-Type")?.contains("text/html") == true) { - response - .header("Content-Security-Policy", "default-src 'self'") - .header("Referrer-Policy", "no-referrer") - } else response - } + if (response.header("Content-Type")?.contains("text/html") == true) { + response + .header("Content-Security-Policy", "default-src 'self'") + .header("Referrer-Policy", "no-referrer") + } else response } } diff --git a/app/src/main/kotlin/filters/TransactionFilter.kt b/app/src/main/kotlin/filters/TransactionFilter.kt new file mode 100644 index 0000000..952aa84 --- /dev/null +++ b/app/src/main/kotlin/filters/TransactionFilter.kt @@ -0,0 +1,13 @@ +package be.simplenotes.app.filters + +import me.liuwj.ktorm.database.Database +import org.http4k.core.Filter +import org.http4k.core.HttpHandler + +class TransactionFilter(private val db: Database) : Filter { + override fun invoke(next: HttpHandler): HttpHandler = { request -> + db.useTransaction { + next(request) + } + } +} diff --git a/app/src/main/kotlin/modules/ApiModule.kt b/app/src/main/kotlin/modules/ApiModule.kt index 031d30d..1fbf74a 100644 --- a/app/src/main/kotlin/modules/ApiModule.kt +++ b/app/src/main/kotlin/modules/ApiModule.kt @@ -5,19 +5,20 @@ import be.simplenotes.app.api.ApiUserController import be.simplenotes.app.filters.AuthFilter import be.simplenotes.app.filters.AuthType import be.simplenotes.app.filters.JwtSource +import org.http4k.core.Filter import org.koin.core.qualifier.named import org.koin.dsl.module val apiModule = module { single { ApiUserController(get(), get()) } single { ApiNoteController(get(), get()) } - single(named("apiAuthFilter")) { + single(named("apiAuthFilter")) { AuthFilter( extractor = get(), authType = AuthType.Required, ctx = get(), source = JwtSource.Header, redirect = false - )() + ) } } diff --git a/app/src/main/kotlin/modules/ServerModule.kt b/app/src/main/kotlin/modules/ServerModule.kt index e40ca16..dd91d9a 100644 --- a/app/src/main/kotlin/modules/ServerModule.kt +++ b/app/src/main/kotlin/modules/ServerModule.kt @@ -4,12 +4,14 @@ import be.simplenotes.app.Server import be.simplenotes.app.filters.AuthFilter import be.simplenotes.app.filters.AuthType import be.simplenotes.app.filters.ErrorFilter +import be.simplenotes.app.filters.TransactionFilter import be.simplenotes.app.routes.Router import be.simplenotes.app.utils.StaticFileResolver import be.simplenotes.app.utils.StaticFileResolverImpl import be.simplenotes.app.views.ErrorView import be.simplenotes.shared.config.ServerConfig import org.eclipse.jetty.server.ServerConnector +import org.http4k.core.Filter import org.http4k.core.RequestContexts import org.http4k.routing.RoutingHttpHandler import org.http4k.server.ConnectorBuilder @@ -45,14 +47,16 @@ val serverModule = module { get(), requiredAuth = get(AuthType.Required.qualifier), optionalAuth = get(AuthType.Optional.qualifier), - errorFilter = get(named("ErrorFilter")), apiAuth = get(named("apiAuthFilter")), - get() + get(), + get(), + get(), )() } single { RequestContexts() } - single(AuthType.Optional.qualifier) { AuthFilter(get(), AuthType.Optional, get())() } - single(AuthType.Required.qualifier) { AuthFilter(get(), AuthType.Required, get())() } - single(named("ErrorFilter")) { ErrorFilter(get())() } + single(AuthType.Optional.qualifier) { AuthFilter(get(), AuthType.Optional, get()) } + single(AuthType.Required.qualifier) { AuthFilter(get(), AuthType.Required, get()) } + single { ErrorFilter(get()) } + single { TransactionFilter(get()) } single { ErrorView(get()) } } diff --git a/app/src/main/kotlin/routes/Router.kt b/app/src/main/kotlin/routes/Router.kt index cf684e2..4d27e91 100644 --- a/app/src/main/kotlin/routes/Router.kt +++ b/app/src/main/kotlin/routes/Router.kt @@ -6,15 +6,14 @@ import be.simplenotes.app.controllers.BaseController import be.simplenotes.app.controllers.NoteController import be.simplenotes.app.controllers.SettingsController import be.simplenotes.app.controllers.UserController -import be.simplenotes.app.filters.ImmutableFilter -import be.simplenotes.app.filters.SecurityFilter -import be.simplenotes.app.filters.jwtPayload +import be.simplenotes.app.filters.* import be.simplenotes.domain.security.JwtPayload import org.http4k.core.* import org.http4k.core.Method.* -import org.http4k.filter.ResponseFilters +import org.http4k.filter.ResponseFilters.GZip import org.http4k.filter.ServerFilters.InitialiseRequestContext import org.http4k.routing.* +import org.http4k.routing.ResourceLoader.Companion.Classpath class Router( private val baseController: BaseController, @@ -25,24 +24,19 @@ class Router( private val apiNoteController: ApiNoteController, private val requiredAuth: Filter, private val optionalAuth: Filter, - private val errorFilter: Filter, private val apiAuth: Filter, + private val errorFilter: ErrorFilter, + private val transactionFilter: TransactionFilter, private val contexts: RequestContexts, ) { operator fun invoke(): RoutingHttpHandler { - val resourceLoader = ResourceLoader.Classpath(("/static")) - val basicRoutes = routes( - ImmutableFilter().then(static(resourceLoader, "woff2" to ContentType("font/woff2"))), - ) + val basicRoutes = ImmutableFilter.then(static(Classpath("/static"), "woff2" to ContentType("font/woff2"))) - infix fun PathMethod.public(handler: PublicHandler) = this to { handler(it, it.jwtPayload(contexts)) } - infix fun PathMethod.protected(handler: ProtectedHandler) = this to { handler(it, it.jwtPayload(contexts)!!) } - - val publicRoutes: RoutingHttpHandler = routes( + val publicRoutes = routes( "/" bind GET public baseController::index, "/register" bind GET public userController::register, - "/register" bind POST public userController::register, + "/register" bind POST `public transactional` userController::register, "/login" bind GET public userController::login, "/login" bind POST public userController::login, "/logout" bind POST to userController::logout, @@ -51,18 +45,18 @@ class Router( val protectedRoutes = routes( "/settings" bind GET protected settingsController::settings, - "/settings" bind POST protected settingsController::settings, + "/settings" bind POST transactional settingsController::settings, "/export" bind POST protected settingsController::export, "/notes" bind GET protected noteController::list, "/notes" bind POST protected noteController::search, "/notes/new" bind GET protected noteController::new, - "/notes/new" bind POST protected noteController::new, + "/notes/new" bind POST transactional noteController::new, "/notes/trash" bind GET protected noteController::trash, "/notes/{uuid}" bind GET protected noteController::note, - "/notes/{uuid}" bind POST protected noteController::note, + "/notes/{uuid}" bind POST transactional noteController::note, "/notes/{uuid}/edit" bind GET protected noteController::edit, - "/notes/{uuid}/edit" bind POST protected noteController::edit, - "/notes/deleted/{uuid}" bind POST protected noteController::deleted, + "/notes/{uuid}/edit" bind POST transactional noteController::edit, + "/notes/deleted/{uuid}" bind POST transactional noteController::deleted, ) val apiRoutes = routes( @@ -71,10 +65,10 @@ class Router( val protectedApiRoutes = routes( "/api/notes" bind GET protected apiNoteController::notes, - "/api/notes" bind POST protected apiNoteController::createNote, - "/api/notes/search" bind POST protected apiNoteController::search, + "/api/notes" bind POST transactional apiNoteController::createNote, + "/api/notes/search" bind POST transactional apiNoteController::search, "/api/notes/{uuid}" bind GET protected apiNoteController::note, - "/api/notes/{uuid}" bind PUT protected apiNoteController::update, + "/api/notes/{uuid}" bind PUT transactional apiNoteController::update, ) val routes = routes( @@ -87,11 +81,23 @@ class Router( val globalFilters = errorFilter .then(InitialiseRequestContext(contexts)) - .then(SecurityFilter()) - .then(ResponseFilters.GZip()) + .then(SecurityFilter) + .then(GZip()) return globalFilters.then(routes) } + + private inline infix fun PathMethod.public(crossinline handler: PublicHandler) = + this to { handler(it, it.jwtPayload(contexts)) } + + private inline infix fun PathMethod.protected(crossinline handler: ProtectedHandler) = + this to { handler(it, it.jwtPayload(contexts)!!) } + + private inline infix fun PathMethod.transactional(crossinline handler: ProtectedHandler) = + this to transactionFilter.then { handler(it, it.jwtPayload(contexts)!!) } + + private inline infix fun PathMethod.`public transactional`(crossinline handler: PublicHandler) = + this to transactionFilter.then { handler(it, it.jwtPayload(contexts)) } } private typealias PublicHandler = (Request, JwtPayload?) -> Response diff --git a/app/src/test/kotlin/filters/AuthFilterTest.kt b/app/src/test/kotlin/filters/AuthFilterTest.kt index edf5f16..d3ce561 100644 --- a/app/src/test/kotlin/filters/AuthFilterTest.kt +++ b/app/src/test/kotlin/filters/AuthFilterTest.kt @@ -27,8 +27,8 @@ internal class AuthFilterTest { private val simpleJwt = SimpleJwt(jwtConfig) private val extractor = JwtPayloadExtractor(simpleJwt) private val ctx = RequestContexts() - private val requiredAuth = AuthFilter(extractor, AuthType.Required, ctx)() - private val optionalAuth = AuthFilter(extractor, AuthType.Optional, ctx)() + private val requiredAuth = AuthFilter(extractor, AuthType.Required, ctx) + private val optionalAuth = AuthFilter(extractor, AuthType.Optional, ctx) private val echoJwtPayloadHandler = { request: Request -> Response(OK).body(request.jwtPayload(ctx).toString()) } diff --git a/persistance/pom.xml b/persistance/pom.xml index 2de73cd..4d7baa1 100644 --- a/persistance/pom.xml +++ b/persistance/pom.xml @@ -63,12 +63,10 @@ me.liuwj.ktorm ktorm-core - 3.0.0 me.liuwj.ktorm ktorm-support-mysql - 3.0.0 diff --git a/persistance/src/main/kotlin/notes/NoteRepositoryImpl.kt b/persistance/src/main/kotlin/notes/NoteRepositoryImpl.kt index 2b2edc5..8bfab7e 100644 --- a/persistance/src/main/kotlin/notes/NoteRepositoryImpl.kt +++ b/persistance/src/main/kotlin/notes/NoteRepositoryImpl.kt @@ -1,6 +1,9 @@ package be.simplenotes.persistance.notes -import be.simplenotes.domain.model.* +import be.simplenotes.domain.model.ExportedNote +import be.simplenotes.domain.model.Note +import be.simplenotes.domain.model.PersistedNote +import be.simplenotes.domain.model.PersistedNoteMetadata import be.simplenotes.domain.usecases.repositories.NoteRepository import me.liuwj.ktorm.database.Database import me.liuwj.ktorm.dsl.* @@ -59,14 +62,12 @@ internal class NoteRepositoryImpl(private val db: Database) : NoteRepository { val entity = note.toEntity(uuid, userId).apply { this.updatedAt = LocalDateTime.now() } - db.useTransaction { - db.notes.add(entity) - db.batchInsert(Tags) { - note.meta.tags.forEach { tagName -> - item { - it.noteUuid to uuid - it.name to tagName - } + db.notes.add(entity) + db.batchInsert(Tags) { + note.meta.tags.forEach { tagName -> + item { + it.noteUuid to uuid + it.name to tagName } } } @@ -90,64 +91,56 @@ internal class NoteRepositoryImpl(private val db: Database) : NoteRepository { } override fun update(userId: Int, uuid: UUID, note: Note): PersistedNote? { - db.useTransaction { - - val now = LocalDateTime.now() - val count = db.update(Notes) { - it.title to note.meta.title - it.markdown to note.markdown - it.html to note.html - it.updatedAt to now - where { (it.uuid eq uuid) and (it.userId eq userId) and (it.deleted eq false) } - } - - if (count == 0) return null - - // delete all tags - db.delete(Tags) { - it.noteUuid eq uuid - } - - // put new ones - note.meta.tags.forEach { tagName -> - db.insert(Tags) { - it.name to tagName - it.noteUuid to uuid - } - } - - return PersistedNote( - meta = note.meta, - markdown = note.markdown, - html = note.html, - updatedAt = now, - uuid = uuid, - public = false, // TODO - ) + val now = LocalDateTime.now() + val count = db.update(Notes) { + it.title to note.meta.title + it.markdown to note.markdown + it.html to note.html + it.updatedAt to now + where { (it.uuid eq uuid) and (it.userId eq userId) and (it.deleted eq false) } } + + if (count == 0) return null + + // delete all tags + db.delete(Tags) { + it.noteUuid eq uuid + } + + // put new ones + note.meta.tags.forEach { tagName -> + db.insert(Tags) { + it.name to tagName + it.noteUuid to uuid + } + } + + return PersistedNote( + meta = note.meta, + markdown = note.markdown, + html = note.html, + updatedAt = now, + uuid = uuid, + public = false, // TODO + ) } override fun delete(userId: Int, uuid: UUID, permanent: Boolean): Boolean { return if (!permanent) { - db.useTransaction { - db.update(Notes) { - it.deleted to true - it.updatedAt to LocalDateTime.now() - where { it.userId eq userId and (it.uuid eq uuid) } - } + db.update(Notes) { + it.deleted to true + it.updatedAt to LocalDateTime.now() + where { it.userId eq userId and (it.uuid eq uuid) } } == 1 - } else db.useTransaction { + } else db.delete(Notes) { it.uuid eq uuid and (it.userId eq userId) } == 1 - } } override fun restore(userId: Int, uuid: UUID): Boolean { - return db.useTransaction { - db.update(Notes) { - it.deleted to false - where { (it.userId eq userId) and (it.uuid eq uuid) } - } == 1 - } + return db.update(Notes) { + it.deleted to false + where { (it.userId eq userId) and (it.uuid eq uuid) } + } == 1 } override fun getTags(userId: Int): List = diff --git a/persistance/src/main/kotlin/users/UserRepositoryImpl.kt b/persistance/src/main/kotlin/users/UserRepositoryImpl.kt index 189eb59..d53f2ea 100644 --- a/persistance/src/main/kotlin/users/UserRepositoryImpl.kt +++ b/persistance/src/main/kotlin/users/UserRepositoryImpl.kt @@ -3,21 +3,20 @@ package be.simplenotes.persistance.users import be.simplenotes.domain.model.PersistedUser import be.simplenotes.domain.model.User import be.simplenotes.domain.usecases.repositories.UserRepository -import me.liuwj.ktorm.database.* +import me.liuwj.ktorm.database.Database import me.liuwj.ktorm.dsl.* -import me.liuwj.ktorm.entity.* +import me.liuwj.ktorm.entity.any +import me.liuwj.ktorm.entity.find import java.sql.SQLIntegrityConstraintViolationException internal class UserRepositoryImpl(private val db: Database) : UserRepository { override fun create(user: User): PersistedUser? { return try { - db.useTransaction { - val id = db.insertAndGenerateKey(Users) { - it.username to user.username - it.password to user.password - } as Int - PersistedUser(user.username, user.password, id) - } + val id = db.insertAndGenerateKey(Users) { + it.username to user.username + it.password to user.password + } as Int + PersistedUser(user.username, user.password, id) } catch (e: SQLIntegrityConstraintViolationException) { null } @@ -27,6 +26,6 @@ internal class UserRepositoryImpl(private val db: Database) : UserRepository { override fun find(id: Int) = db.users.find { it.id eq id }?.toPersistedUser() override fun exists(username: String) = db.users.any { it.username eq username } override fun exists(id: Int) = db.users.any { it.id eq id } - override fun delete(id: Int) = db.useTransaction { db.delete(Users) { it.id eq id } == 1 } + override fun delete(id: Int) = db.delete(Users) { it.id eq id } == 1 override fun findAll() = db.from(Users).select(Users.id).map { it[Users.id]!! } } diff --git a/pom.xml b/pom.xml index 34fa9e7..9e66c3b 100644 --- a/pom.xml +++ b/pom.xml @@ -170,6 +170,17 @@ 1.7.25 + + me.liuwj.ktorm + ktorm-core + 3.0.0 + + + me.liuwj.ktorm + ktorm-support-mysql + 3.0.0 + + org.junit.jupiter