diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index ba47459..c4dfeb0 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -54,7 +54,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { // Set up API routes apiGroup := r.Group("/api") - controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService) + controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, appConfigService) controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService) controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService, appConfigService) controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService, emailService) diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index 4f59c43..6c47312 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -14,8 +14,8 @@ import ( func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService, appConfigService *service.AppConfigService) { uc := UserController{ - UserService: userService, - AppConfigService: appConfigService, + userService: userService, + appConfigService: appConfigService, } group.GET("/users", jwtAuthMiddleware.Add(true), uc.listUsersHandler) @@ -32,8 +32,8 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt } type UserController struct { - UserService *service.UserService - AppConfigService *service.AppConfigService + userService *service.UserService + appConfigService *service.AppConfigService } func (uc *UserController) listUsersHandler(c *gin.Context) { @@ -44,7 +44,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { return } - users, pagination, err := uc.UserService.ListUsers(searchTerm, sortedPaginationRequest) + users, pagination, err := uc.userService.ListUsers(searchTerm, sortedPaginationRequest) if err != nil { c.Error(err) return @@ -63,7 +63,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { } func (uc *UserController) getUserHandler(c *gin.Context) { - user, err := uc.UserService.GetUser(c.Param("id")) + user, err := uc.userService.GetUser(c.Param("id")) if err != nil { c.Error(err) return @@ -79,7 +79,7 @@ func (uc *UserController) getUserHandler(c *gin.Context) { } func (uc *UserController) getCurrentUserHandler(c *gin.Context) { - user, err := uc.UserService.GetUser(c.GetString("userID")) + user, err := uc.userService.GetUser(c.GetString("userID")) if err != nil { c.Error(err) return @@ -95,7 +95,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) { } func (uc *UserController) deleteUserHandler(c *gin.Context) { - if err := uc.UserService.DeleteUser(c.Param("id")); err != nil { + if err := uc.userService.DeleteUser(c.Param("id")); err != nil { c.Error(err) return } @@ -110,7 +110,7 @@ func (uc *UserController) createUserHandler(c *gin.Context) { return } - user, err := uc.UserService.CreateUser(input) + user, err := uc.userService.CreateUser(input) if err != nil { c.Error(err) return @@ -130,7 +130,7 @@ func (uc *UserController) updateUserHandler(c *gin.Context) { } func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { - if uc.AppConfigService.DbConfig.AllowOwnAccountEdit.Value != "true" { + if uc.appConfigService.DbConfig.AllowOwnAccountEdit.Value != "true" { c.Error(&common.AccountEditNotAllowedError{}) return } @@ -144,7 +144,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { return } - token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt, c.ClientIP(), c.Request.UserAgent()) + token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt, c.ClientIP(), c.Request.UserAgent()) if err != nil { c.Error(err) return @@ -154,7 +154,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { } func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { - user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token")) + user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token")) if err != nil { c.Error(err) return @@ -166,12 +166,12 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { return } - c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", true, true) + utils.AddAccessTokenCookie(c, uc.appConfigService.DbConfig.SessionDuration.Value, token) c.JSON(http.StatusOK, userDto) } func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { - user, token, err := uc.UserService.SetupInitialAdmin() + user, token, err := uc.userService.SetupInitialAdmin() if err != nil { c.Error(err) return @@ -183,7 +183,7 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { return } - c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", true, true) + utils.AddAccessTokenCookie(c, uc.appConfigService.DbConfig.SessionDuration.Value, token) c.JSON(http.StatusOK, userDto) } @@ -201,7 +201,7 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { userID = c.Param("id") } - user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser) + user, err := uc.userService.UpdateUser(userID, input, updateOwnUser) if err != nil { c.Error(err) return diff --git a/backend/internal/controller/webauthn_controller.go b/backend/internal/controller/webauthn_controller.go index 894a316..fbf5ce5 100644 --- a/backend/internal/controller/webauthn_controller.go +++ b/backend/internal/controller/webauthn_controller.go @@ -5,6 +5,7 @@ import ( "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" + "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" "time" @@ -13,8 +14,8 @@ import ( "golang.org/x/time/rate" ) -func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService) { - wc := &WebauthnController{webAuthnService: webauthnService} +func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService, appConfigService *service.AppConfigService) { + wc := &WebauthnController{webAuthnService: webauthnService, appConfigService: appConfigService} group.GET("/webauthn/register/start", jwtAuthMiddleware.Add(false), wc.beginRegistrationHandler) group.POST("/webauthn/register/finish", jwtAuthMiddleware.Add(false), wc.verifyRegistrationHandler) @@ -29,7 +30,8 @@ func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware } type WebauthnController struct { - webAuthnService *service.WebAuthnService + webAuthnService *service.WebAuthnService + appConfigService *service.AppConfigService } func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { @@ -103,7 +105,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { return } - c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", true, true) + utils.AddAccessTokenCookie(c, wc.appConfigService.DbConfig.SessionDuration.Value, token) c.JSON(http.StatusOK, userDto) } @@ -163,6 +165,6 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) { } func (wc *WebauthnController) logoutHandler(c *gin.Context) { - c.SetCookie("access_token", "", 0, "/", "", true, true) + utils.AddAccessTokenCookie(c, "0", "") c.Status(http.StatusNoContent) } diff --git a/backend/internal/utils/cookie_util.go b/backend/internal/utils/cookie_util.go new file mode 100644 index 0000000..b045a98 --- /dev/null +++ b/backend/internal/utils/cookie_util.go @@ -0,0 +1,12 @@ +package utils + +import ( + "github.com/gin-gonic/gin" + "strconv" +) + +func AddAccessTokenCookie(c *gin.Context, sessionDurationInMinutes string, token string) { + sessionDurationInMinutesParsed, _ := strconv.Atoi(sessionDurationInMinutes) + maxAge := sessionDurationInMinutesParsed * 60 + c.SetCookie("access_token", token, maxAge, "/", "", true, true) +}