diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index ac54bd6..d61c3b9 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -27,12 +27,6 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { r := gin.Default() r.Use(gin.Logger()) - // Add middleware - r.Use( - middleware.NewCorsMiddleware().Add(), - middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60), - ) - // Initialize services webauthnService := service.NewWebAuthnService(db, appConfigService) jwtService := service.NewJwtService(appConfigService) @@ -40,8 +34,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { oidcService := service.NewOidcService(db, jwtService) testService := service.NewTestService(db, appConfigService) + // Add global middleware + r.Use(middleware.NewCorsMiddleware().Add()) + r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60)) + r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false)) + // Initialize middleware - jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService) + jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false) fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() // Set up API routes @@ -49,7 +48,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService) controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService) controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService) - controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService) + controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService) // Add test controller in non-production environments if common.EnvConfig.AppEnv != "production" { diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index 7747f97..966a21d 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -13,6 +13,7 @@ var ( ErrOidcMissingClientCredentials = errors.New("client id or secret not provided") ErrOidcClientSecretInvalid = errors.New("invalid client secret") ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code") + ErrOidcInvalidCallbackURL = errors.New("invalid callback URL") ErrFileTypeNotSupported = errors.New("file type not supported") ErrInvalidCredentials = errors.New("no user found with provided credentials") ) diff --git a/backend/internal/controller/app_config_controller.go b/backend/internal/controller/app_config_controller.go index 11b382b..7df077a 100644 --- a/backend/internal/controller/app_config_controller.go +++ b/backend/internal/controller/app_config_controller.go @@ -5,19 +5,19 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" - "github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" ) -func NewApplicationConfigurationController( +func NewAppConfigController( group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, appConfigService *service.AppConfigService) { - acc := &ApplicationConfigurationController{ + acc := &AppConfigController{ appConfigService: appConfigService, } group.GET("/application-configuration", acc.listApplicationConfigurationHandler) @@ -32,86 +32,104 @@ func NewApplicationConfigurationController( group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler) } -type ApplicationConfigurationController struct { +type AppConfigController struct { appConfigService *service.AppConfigService } -func (acc *ApplicationConfigurationController) listApplicationConfigurationHandler(c *gin.Context) { +func (acc *AppConfigController) listApplicationConfigurationHandler(c *gin.Context) { configuration, err := acc.appConfigService.ListApplicationConfiguration(false) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(200, configuration) + var configVariablesDto []dto.PublicAppConfigVariableDto + if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(200, configVariablesDto) } -func (acc *ApplicationConfigurationController) listAllApplicationConfigurationHandler(c *gin.Context) { +func (acc *AppConfigController) listAllApplicationConfigurationHandler(c *gin.Context) { configuration, err := acc.appConfigService.ListApplicationConfiguration(true) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(200, configuration) + var configVariablesDto []dto.AppConfigVariableDto + if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(200, configVariablesDto) } -func (acc *ApplicationConfigurationController) updateApplicationConfigurationHandler(c *gin.Context) { - var input model.AppConfigUpdateDto +func (acc *AppConfigController) updateApplicationConfigurationHandler(c *gin.Context) { + var input dto.AppConfigUpdateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusOK, savedConfigVariables) + var configVariablesDto []dto.AppConfigVariableDto + if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, configVariablesDto) } -func (acc *ApplicationConfigurationController) getLogoHandler(c *gin.Context) { +func (acc *AppConfigController) getLogoHandler(c *gin.Context) { imageType := acc.appConfigService.DbConfig.LogoImageType.Value acc.getImage(c, "logo", imageType) } -func (acc *ApplicationConfigurationController) getFaviconHandler(c *gin.Context) { +func (acc *AppConfigController) getFaviconHandler(c *gin.Context) { acc.getImage(c, "favicon", "ico") } -func (acc *ApplicationConfigurationController) getBackgroundImageHandler(c *gin.Context) { +func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) { imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value acc.getImage(c, "background", imageType) } -func (acc *ApplicationConfigurationController) updateLogoHandler(c *gin.Context) { +func (acc *AppConfigController) updateLogoHandler(c *gin.Context) { imageType := acc.appConfigService.DbConfig.LogoImageType.Value acc.updateImage(c, "logo", imageType) } -func (acc *ApplicationConfigurationController) updateFaviconHandler(c *gin.Context) { +func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) { file, err := c.FormFile("file") if err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } fileType := utils.GetFileExtension(file.Filename) if fileType != "ico" { - utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico") + utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico") return } acc.updateImage(c, "favicon", "ico") } -func (acc *ApplicationConfigurationController) updateBackgroundImageHandler(c *gin.Context) { +func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) { imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value acc.updateImage(c, "background", imageType) } -func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name string, imageType string) { +func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) { imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType) mimeType := utils.GetImageMimeType(imageType) @@ -119,19 +137,19 @@ func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name str c.File(imagePath) } -func (acc *ApplicationConfigurationController) updateImage(c *gin.Context, imageName string, oldImageType string) { +func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) { file, err := c.FormFile("file") if err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) if err != nil { if errors.Is(err, common.ErrFileTypeNotSupported) { - utils.HandlerError(c, http.StatusBadRequest, err.Error()) + utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 6896f1c..f2eaf3b 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -4,8 +4,8 @@ import ( "errors" "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" - "github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/utils" "net/http" @@ -40,18 +40,18 @@ type OidcController struct { } func (oc *OidcController) authorizeHandler(c *gin.Context) { - var parsedBody model.AuthorizeRequest - if err := c.ShouldBindJSON(&parsedBody); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + var input dto.AuthorizeOidcClientDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.ControllerError(c, err) return } - code, err := oc.oidcService.Authorize(parsedBody, c.GetString("userID")) + code, err := oc.oidcService.Authorize(input, c.GetString("userID")) if err != nil { if errors.Is(err, common.ErrOidcMissingAuthorization) { - utils.HandlerError(c, http.StatusForbidden, err.Error()) + utils.CustomControllerError(c, http.StatusForbidden, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } @@ -60,15 +60,15 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) { } func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { - var parsedBody model.AuthorizeNewClientDto - if err := c.ShouldBindJSON(&parsedBody); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + var input dto.AuthorizeOidcClientDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.ControllerError(c, err) return } - code, err := oc.oidcService.AuthorizeNewClient(parsedBody, c.GetString("userID")) + code, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID")) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -76,35 +76,35 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { } func (oc *OidcController) createIDTokenHandler(c *gin.Context) { - var body model.OidcIdTokenDto + var input dto.OidcIdTokenDto - if err := c.ShouldBind(&body); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + if err := c.ShouldBind(&input); err != nil { + utils.ControllerError(c, err) return } - clientID := body.ClientID - clientSecret := body.ClientSecret + clientID := input.ClientID + clientSecret := input.ClientSecret // Client id and secret can also be passed over the Authorization header if clientID == "" || clientSecret == "" { var ok bool clientID, clientSecret, ok = c.Request.BasicAuth() if !ok { - utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided") + utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided") return } } - idToken, accessToken, err := oc.oidcService.CreateTokens(body.Code, body.GrantType, clientID, clientSecret) + idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret) if err != nil { if errors.Is(err, common.ErrOidcGrantTypeNotSupported) || errors.Is(err, common.ErrOidcMissingClientCredentials) || errors.Is(err, common.ErrOidcClientSecretInvalid) || errors.Is(err, common.ErrOidcInvalidAuthorizationCode) { - utils.HandlerError(c, http.StatusBadRequest, err.Error()) + utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } @@ -116,14 +116,14 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) { token := strings.Split(c.GetHeader("Authorization"), " ")[1] jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token) if err != nil { - utils.HandlerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error()) + utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error()) return } userID := jwtClaims.Subject clientId := jwtClaims.Audience[0] claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -134,11 +134,28 @@ func (oc *OidcController) getClientHandler(c *gin.Context) { clientId := c.Param("id") client, err := oc.oidcService.GetClient(clientId) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusOK, client) + // Return a different DTO based on the user's role + if c.GetBool("userIsAdmin") { + clientDto := dto.OidcClientDto{} + err = dto.MapStruct(client, &clientDto) + if err == nil { + c.JSON(http.StatusOK, clientDto) + return + } + } else { + clientDto := dto.PublicOidcClientDto{} + err = dto.MapStruct(client, &clientDto) + if err == nil { + c.JSON(http.StatusOK, clientDto) + return + } + } + + utils.ControllerError(c, err) } func (oc *OidcController) listClientsHandler(c *gin.Context) { @@ -148,36 +165,48 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) { clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) + return + } + + var clientsDto []dto.OidcClientDto + if err := dto.MapStructList(clients, &clientsDto); err != nil { + utils.ControllerError(c, err) return } c.JSON(http.StatusOK, gin.H{ - "data": clients, + "data": clientsDto, "pagination": pagination, }) } func (oc *OidcController) createClientHandler(c *gin.Context) { - var input model.OidcClientCreateDto + var input dto.OidcClientCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusCreated, client) + var clientDto dto.OidcClientDto + if err := dto.MapStruct(client, &clientDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusCreated, clientDto) } func (oc *OidcController) deleteClientHandler(c *gin.Context) { err := oc.oidcService.DeleteClient(c.Param("id")) if err != nil { - utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") + utils.ControllerError(c, err) return } @@ -185,25 +214,31 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) { } func (oc *OidcController) updateClientHandler(c *gin.Context) { - var input model.OidcClientCreateDto + var input dto.OidcClientCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } client, err := oc.oidcService.UpdateClient(c.Param("id"), input) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusNoContent, client) + var clientDto dto.OidcClientDto + if err := dto.MapStruct(client, &clientDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, clientDto) } func (oc *OidcController) createClientSecretHandler(c *gin.Context) { secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -213,7 +248,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) { func (oc *OidcController) getClientLogoHandler(c *gin.Context) { imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -224,16 +259,16 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) { func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { file, err := c.FormFile("file") if err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) if err != nil { if errors.Is(err, common.ErrFileTypeNotSupported) { - utils.HandlerError(c, http.StatusBadRequest, err.Error()) + utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } @@ -244,7 +279,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { err := oc.oidcService.DeleteClientLogo(c.Param("id")) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } diff --git a/backend/internal/controller/test_controller.go b/backend/internal/controller/test_controller.go index 4353cc8..7fb3081 100644 --- a/backend/internal/controller/test_controller.go +++ b/backend/internal/controller/test_controller.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/utils" + "net/http" ) func NewTestController(group *gin.RouterGroup, testService *service.TestService) { @@ -18,19 +19,19 @@ type TestController struct { func (tc *TestController) resetAndSeedHandler(c *gin.Context) { if err := tc.TestService.ResetDatabase(); err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } if err := tc.TestService.ResetApplicationImages(); err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } if err := tc.TestService.SeedDatabase(); err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(200, gin.H{"message": "Database reset and seeded"}) + c.Status(http.StatusNoContent) } diff --git a/backend/internal/controller/user_controller.go b/backend/internal/controller/user_controller.go index 30e29f7..69043c4 100644 --- a/backend/internal/controller/user_controller.go +++ b/backend/internal/controller/user_controller.go @@ -4,8 +4,8 @@ import ( "errors" "github.com/gin-gonic/gin" "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" - "github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/utils" "golang.org/x/time/rate" @@ -43,12 +43,18 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) + return + } + + var usersDto []dto.UserDto + if err := dto.MapStructList(users, &usersDto); err != nil { + utils.ControllerError(c, err) return } c.JSON(http.StatusOK, gin.H{ - "data": users, + "data": usersDto, "pagination": pagination, }) } @@ -56,25 +62,38 @@ func (uc *UserController) listUsersHandler(c *gin.Context) { func (uc *UserController) getUserHandler(c *gin.Context) { user, err := uc.UserService.GetUser(c.Param("id")) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusOK, user) + var userDto dto.UserDto + if err := dto.MapStruct(user, &userDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, userDto) } func (uc *UserController) getCurrentUserHandler(c *gin.Context) { user, err := uc.UserService.GetUser(c.GetString("userID")) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusOK, user) + + var userDto dto.UserDto + if err := dto.MapStruct(user, &userDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, userDto) } func (uc *UserController) deleteUserHandler(c *gin.Context) { if err := uc.UserService.DeleteUser(c.Param("id")); err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -82,22 +101,29 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) { } func (uc *UserController) createUserHandler(c *gin.Context) { - var user model.User - if err := c.ShouldBindJSON(&user); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + var input dto.UserCreateDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.ControllerError(c, err) return } - if err := uc.UserService.CreateUser(&user); err != nil { + user, err := uc.UserService.CreateUser(input) + if err != nil { if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { - utils.HandlerError(c, http.StatusConflict, err.Error()) + utils.CustomControllerError(c, http.StatusConflict, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } - c.JSON(http.StatusCreated, user) + var userDto dto.UserDto + if err := dto.MapStruct(user, &userDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusCreated, userDto) } func (uc *UserController) updateUserHandler(c *gin.Context) { @@ -109,15 +135,15 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { } func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { - var input model.OneTimeAccessTokenCreateDto + var input dto.OneTimeAccessTokenCreateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -128,9 +154,9 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token")) if err != nil { if errors.Is(err, common.ErrTokenInvalidOrExpired) { - utils.HandlerError(c, http.StatusUnauthorized, err.Error()) + utils.CustomControllerError(c, http.StatusUnauthorized, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } @@ -143,21 +169,27 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { user, token, err := uc.UserService.SetupInitialAdmin() if err != nil { if errors.Is(err, common.ErrSetupAlreadyCompleted) { - utils.HandlerError(c, http.StatusBadRequest, err.Error()) + utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } + var userDto dto.UserDto + if err := dto.MapStruct(user, &userDto); err != nil { + utils.ControllerError(c, err) + return + } + c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) - c.JSON(http.StatusOK, user) + c.JSON(http.StatusOK, userDto) } func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { - var updatedUser model.User - if err := c.ShouldBindJSON(&updatedUser); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + var input dto.UserCreateDto + if err := c.ShouldBindJSON(&input); err != nil { + utils.ControllerError(c, err) return } @@ -168,15 +200,21 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { userID = c.Param("id") } - user, err := uc.UserService.UpdateUser(userID, updatedUser, updateOwnUser) + user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser) if err != nil { if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { - utils.HandlerError(c, http.StatusConflict, err.Error()) + utils.CustomControllerError(c, http.StatusConflict, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } - c.JSON(http.StatusOK, user) + var userDto dto.UserDto + if err := dto.MapStruct(user, &userDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, userDto) } diff --git a/backend/internal/controller/webauthn_controller.go b/backend/internal/controller/webauthn_controller.go index 1c7acf7..7617150 100644 --- a/backend/internal/controller/webauthn_controller.go +++ b/backend/internal/controller/webauthn_controller.go @@ -3,9 +3,8 @@ package controller import ( "errors" "github.com/go-webauthn/webauthn/protocol" + "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" - "github.com/stonith404/pocket-id/backend/internal/model" - "log" "net/http" "time" @@ -40,8 +39,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { userID := c.GetString("userID") options, err := wc.webAuthnService.BeginRegistration(userID) if err != nil { - utils.UnknownHandlerError(c, err) - log.Println(err) + utils.ControllerError(c, err) return } @@ -52,24 +50,30 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { sessionID, err := c.Cookie("session_id") if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") + utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing") return } userID := c.GetString("userID") credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusOK, credential) + var credentialDto dto.WebauthnCredentialDto + if err := dto.MapStruct(credential, &credentialDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, credentialDto) } func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { options, err := wc.webAuthnService.BeginLogin() if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -80,13 +84,13 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { sessionID, err := c.Cookie("session_id") if err != nil { - utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") + utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing") return } credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body) if err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } @@ -94,32 +98,44 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData) if err != nil { if errors.Is(err, common.ErrInvalidCredentials) { - utils.HandlerError(c, http.StatusUnauthorized, err.Error()) + utils.CustomControllerError(c, http.StatusUnauthorized, err.Error()) } else { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) } return } - token, err := wc.jwtService.GenerateAccessToken(*user) + token, err := wc.jwtService.GenerateAccessToken(user) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) + return + } + + var userDto dto.UserDto + if err := dto.MapStruct(user, &userDto); err != nil { + utils.ControllerError(c, err) return } c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) - c.JSON(http.StatusOK, user) + c.JSON(http.StatusOK, userDto) } func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { userID := c.GetString("userID") credentials, err := wc.webAuthnService.ListCredentials(userID) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.JSON(http.StatusOK, credentials) + var credentialDtos []dto.WebauthnCredentialDto + if err := dto.MapStructList(credentials, &credentialDtos); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, credentialDtos) } func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) { @@ -128,7 +144,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) { err := wc.webAuthnService.DeleteCredential(userID, credentialID) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } @@ -139,19 +155,25 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) { userID := c.GetString("userID") credentialID := c.Param("id") - var input model.WebauthnCredentialUpdateDto + var input dto.WebauthnCredentialUpdateDto if err := c.ShouldBindJSON(&input); err != nil { - utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) + utils.ControllerError(c, err) return } - err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) + credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } - c.Status(http.StatusNoContent) + var credentialDto dto.WebauthnCredentialDto + if err := dto.MapStruct(credential, &credentialDto); err != nil { + utils.ControllerError(c, err) + return + } + + c.JSON(http.StatusOK, credentialDto) } func (wc *WebauthnController) logoutHandler(c *gin.Context) { diff --git a/backend/internal/controller/well_known_controller.go b/backend/internal/controller/well_known_controller.go index efe1d8f..d2a285f 100644 --- a/backend/internal/controller/well_known_controller.go +++ b/backend/internal/controller/well_known_controller.go @@ -21,7 +21,7 @@ type WellKnownController struct { func (wkc *WellKnownController) jwksHandler(c *gin.Context) { jwk, err := wkc.jwtService.GetJWK() if err != nil { - utils.UnknownHandlerError(c, err) + utils.ControllerError(c, err) return } diff --git a/backend/internal/dto/app_config_dto.go b/backend/internal/dto/app_config_dto.go new file mode 100644 index 0000000..2b2b2bc --- /dev/null +++ b/backend/internal/dto/app_config_dto.go @@ -0,0 +1,17 @@ +package dto + +type PublicAppConfigVariableDto struct { + Key string `json:"key"` + Type string `json:"type"` + Value string `json:"value"` +} + +type AppConfigVariableDto struct { + PublicAppConfigVariableDto + IsPublic bool `json:"isPublic"` +} + +type AppConfigUpdateDto struct { + AppName string `json:"appName" binding:"required,min=1,max=30"` + SessionDuration string `json:"sessionDuration" binding:"required"` +} diff --git a/backend/internal/dto/dto_mapper.go b/backend/internal/dto/dto_mapper.go new file mode 100644 index 0000000..7769451 --- /dev/null +++ b/backend/internal/dto/dto_mapper.go @@ -0,0 +1,79 @@ +package dto + +import ( + "errors" + "reflect" +) + +// MapStructList maps a list of source structs to a list of destination structs +func MapStructList[S any, D any](source []S, destination *[]D) error { + for _, item := range source { + var destItem D + if err := MapStruct(item, &destItem); err != nil { + return err + } + *destination = append(*destination, destItem) + } + return nil +} + +// MapStruct maps a source struct to a destination struct +func MapStruct[S any, D any](source S, destination *D) error { + // Ensure destination is a non-nil pointer + destValue := reflect.ValueOf(destination) + if destValue.Kind() != reflect.Ptr || destValue.IsNil() { + return errors.New("destination must be a non-nil pointer to a struct") + } + + // Ensure source is a struct + sourceValue := reflect.ValueOf(source) + if sourceValue.Kind() != reflect.Struct { + return errors.New("source must be a struct") + } + + return mapStructInternal(sourceValue, destValue.Elem()) +} + +func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error { + // Loop through the fields of the destination struct + for i := 0; i < destVal.NumField(); i++ { + destField := destVal.Field(i) + destFieldType := destVal.Type().Field(i) + + if destFieldType.Anonymous { + // Recursively handle embedded structs + if err := mapStructInternal(sourceVal, destField); err != nil { + return err + } + continue + } + + sourceField := sourceVal.FieldByName(destFieldType.Name) + + // If the source field is valid and can be assigned to the destination field + if sourceField.IsValid() && destField.CanSet() { + // Handle direct assignment for simple types + if sourceField.Type() == destField.Type() { + destField.Set(sourceField) + } else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice { + // Handle slices + if sourceField.Type().Elem() == destField.Type().Elem() { + newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap()) + + for j := 0; j < sourceField.Len(); j++ { + newSlice.Index(j).Set(sourceField.Index(j)) + } + + destField.Set(newSlice) + } + } else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct { + // Recursively map nested structs + if err := mapStructInternal(sourceField, destField); err != nil { + return err + } + } + } + } + + return nil +} diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go new file mode 100644 index 0000000..3f01f62 --- /dev/null +++ b/backend/internal/dto/oidc_dto.go @@ -0,0 +1,31 @@ +package dto + +type PublicOidcClientDto struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type OidcClientDto struct { + PublicOidcClientDto + HasLogo bool `json:"hasLogo"` + CallbackURLs []string `json:"callbackURLs"` + CreatedBy UserDto `json:"createdBy"` +} + +type OidcClientCreateDto struct { + Name string `json:"name" binding:"required,max=50"` + CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"` +} + +type AuthorizeOidcClientDto struct { + ClientID string `json:"clientID" binding:"required"` + Scope string `json:"scope" binding:"required"` + Nonce string `json:"nonce"` +} + +type OidcIdTokenDto struct { + GrantType string `form:"grant_type" binding:"required"` + Code string `form:"code" binding:"required"` + ClientID string `form:"client_id"` + ClientSecret string `form:"client_secret"` +} diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go new file mode 100644 index 0000000..9f3429d --- /dev/null +++ b/backend/internal/dto/user_dto.go @@ -0,0 +1,30 @@ +package dto + +import "time" + +type UserDto struct { + ID string `json:"id"` + Username string `json:"username"` + Email string `json:"email" ` + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + IsAdmin bool `json:"isAdmin"` +} + +type UserCreateDto struct { + Username string `json:"username" binding:"required,username,min=3,max=20"` + Email string `json:"email" binding:"required,email"` + FirstName string `json:"firstName" binding:"required,min=3,max=30"` + LastName string `json:"lastName" binding:"required,min=3,max=30"` + IsAdmin bool `json:"isAdmin"` +} + +type OneTimeAccessTokenCreateDto struct { + UserID string `json:"userId" binding:"required"` + ExpiresAt time.Time `json:"expiresAt" binding:"required"` +} + +type LoginUserDto struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} diff --git a/backend/internal/dto/validations.go b/backend/internal/dto/validations.go new file mode 100644 index 0000000..57d1c75 --- /dev/null +++ b/backend/internal/dto/validations.go @@ -0,0 +1,39 @@ +package dto + +import ( + "github.com/gin-gonic/gin/binding" + "github.com/go-playground/validator/v10" + "log" + "net/url" + "regexp" +) + +var validateUrlList validator.Func = func(fl validator.FieldLevel) bool { + urls := fl.Field().Interface().([]string) + for _, u := range urls { + _, err := url.ParseRequestURI(u) + if err != nil { + return false + } + } + return true +} + +var validateUsername validator.Func = func(fl validator.FieldLevel) bool { + regex := "^[a-z0-9_]*$" + matched, _ := regexp.MatchString(regex, fl.Field().String()) + return matched +} + +func init() { + if v, ok := binding.Validator.Engine().(*validator.Validate); ok { + if err := v.RegisterValidation("urlList", validateUrlList); err != nil { + log.Fatalf("Failed to register custom validation: %v", err) + } + } + if v, ok := binding.Validator.Engine().(*validator.Validate); ok { + if err := v.RegisterValidation("username", validateUsername); err != nil { + log.Fatalf("Failed to register custom validation: %v", err) + } + } +} diff --git a/backend/internal/dto/webauthn_dto.go b/backend/internal/dto/webauthn_dto.go new file mode 100644 index 0000000..ca8c869 --- /dev/null +++ b/backend/internal/dto/webauthn_dto.go @@ -0,0 +1,23 @@ +package dto + +import ( + "github.com/go-webauthn/webauthn/protocol" + "time" +) + +type WebauthnCredentialDto struct { + ID string `json:"id"` + Name string `json:"name"` + CredentialID string `json:"credentialID"` + AttestationType string `json:"attestationType"` + Transport []protocol.AuthenticatorTransport `json:"transport"` + + BackupEligible bool `json:"backupEligible"` + BackupState bool `json:"backupState"` + + CreatedAt time.Time `json:"createdAt"` +} + +type WebauthnCredentialUpdateDto struct { + Name string `json:"name" binding:"required,min=1,max=30"` +} diff --git a/backend/internal/middleware/file_size_limit.go b/backend/internal/middleware/file_size_limit.go index d300f6f..7503acb 100644 --- a/backend/internal/middleware/file_size_limit.go +++ b/backend/internal/middleware/file_size_limit.go @@ -17,7 +17,7 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc { return func(c *gin.Context) { c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize) if err := c.Request.ParseMultipartForm(maxSize); err != nil { - utils.HandlerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize))) + utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize))) c.Abort() return } diff --git a/backend/internal/middleware/jwt_auth.go b/backend/internal/middleware/jwt_auth.go index e7ebd6e..9416d5a 100644 --- a/backend/internal/middleware/jwt_auth.go +++ b/backend/internal/middleware/jwt_auth.go @@ -9,11 +9,12 @@ import ( ) type JwtAuthMiddleware struct { - jwtService *service.JwtService + jwtService *service.JwtService + ignoreUnauthenticated bool } -func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware { - return &JwtAuthMiddleware{jwtService: jwtService} +func NewJwtAuthMiddleware(jwtService *service.JwtService, ignoreUnauthenticated bool) *JwtAuthMiddleware { + return &JwtAuthMiddleware{jwtService: jwtService, ignoreUnauthenticated: ignoreUnauthenticated} } func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { @@ -24,23 +25,29 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ") if len(authorizationHeaderSplitted) == 2 { token = authorizationHeaderSplitted[1] + } else if m.ignoreUnauthenticated { + c.Next() + return } else { - utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in") + utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in") c.Abort() return } } claims, err := m.jwtService.VerifyAccessToken(token) - if err != nil { - utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in") + if err != nil && m.ignoreUnauthenticated { + c.Next() + return + } else if err != nil { + utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in") c.Abort() return } // Check if the user is an admin if adminOnly && !claims.IsAdmin { - utils.HandlerError(c, http.StatusForbidden, "You don't have permission to access this resource") + utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource") c.Abort() return } diff --git a/backend/internal/middleware/rate_limit.go b/backend/internal/middleware/rate_limit.go index 36aba16..494ee06 100644 --- a/backend/internal/middleware/rate_limit.go +++ b/backend/internal/middleware/rate_limit.go @@ -33,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc { limiter := getLimiter(ip, limit, burst) if !limiter.Allow() { - utils.HandlerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.") + utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.") c.Abort() return } diff --git a/backend/internal/model/app_config.go b/backend/internal/model/app_config.go index 6caa47e..6a7583b 100644 --- a/backend/internal/model/app_config.go +++ b/backend/internal/model/app_config.go @@ -1,11 +1,11 @@ package model type AppConfigVariable struct { - Key string `gorm:"primaryKey;not null" json:"key"` - Type string `json:"type"` - IsPublic bool `json:"-"` - IsInternal bool `json:"-"` - Value string `json:"value"` + Key string `gorm:"primaryKey;not null"` + Type string + IsPublic bool + IsInternal bool + Value string } type AppConfig struct { @@ -14,8 +14,3 @@ type AppConfig struct { LogoImageType AppConfigVariable SessionDuration AppConfigVariable } - -type AppConfigUpdateDto struct { - AppName string `json:"appName" binding:"required"` - SessionDuration string `json:"sessionDuration" binding:"required"` -} diff --git a/backend/internal/model/base.go b/backend/internal/model/base.go index dc1d402..68f0da2 100644 --- a/backend/internal/model/base.go +++ b/backend/internal/model/base.go @@ -8,8 +8,8 @@ import ( // Base contains common columns for all tables. type Base struct { - ID string `gorm:"primaryKey;not null" json:"id"` - CreatedAt time.Time `json:"createdAt"` + ID string `gorm:"primaryKey;not null"` + CreatedAt time.Time } func (b *Base) BeforeCreate(_ *gorm.DB) (err error) { diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 96f29bf..8af9756 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -1,38 +1,22 @@ package model import ( + "database/sql/driver" + "encoding/json" + "errors" "gorm.io/gorm" "time" ) type UserAuthorizedOidcClient struct { Scope string - UserID string `json:"userId" gorm:"primary_key;"` + UserID string `gorm:"primary_key;"` User User - ClientID string `json:"clientId" gorm:"primary_key;"` + ClientID string `gorm:"primary_key;"` Client OidcClient } -type OidcClient struct { - Base - - Name string `json:"name"` - Secret string `json:"-"` - CallbackURL string `json:"callbackURL"` - ImageType *string `json:"-"` - HasLogo bool `gorm:"-" json:"hasLogo"` - - CreatedByID string - CreatedBy User -} - -func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { - // Compute HasLogo field - c.HasLogo = c.ImageType != nil && *c.ImageType != "" - return nil -} - type OidcAuthorizationCode struct { Base @@ -47,26 +31,38 @@ type OidcAuthorizationCode struct { ClientID string } -type OidcClientCreateDto struct { - Name string `json:"name" binding:"required"` - CallbackURL string `json:"callbackURL" binding:"required"` +type OidcClient struct { + Base + + Name string + Secret string + CallbackURLs CallbackURLs + ImageType *string + HasLogo bool `gorm:"-"` + + CreatedByID string + CreatedBy User } -type AuthorizeNewClientDto struct { - ClientID string `json:"clientID" binding:"required"` - Scope string `json:"scope" binding:"required"` - Nonce string `json:"nonce"` +func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { + // Compute HasLogo field + c.HasLogo = c.ImageType != nil && *c.ImageType != "" + return nil } -type OidcIdTokenDto struct { - GrantType string `form:"grant_type" binding:"required"` - Code string `form:"code" binding:"required"` - ClientID string `form:"client_id"` - ClientSecret string `form:"client_secret"` +type CallbackURLs []string + +func (s *CallbackURLs) Scan(value interface{}) error { + switch v := value.(type) { + case []byte: + return json.Unmarshal(v, s) + case string: + return json.Unmarshal([]byte(v), s) + default: + return errors.New("type assertion to []byte or string failed") + } } -type AuthorizeRequest struct { - ClientID string `json:"clientID" binding:"required"` - Scope string `json:"scope" binding:"required"` - Nonce string `json:"nonce"` +func (atl CallbackURLs) Value() (driver.Value, error) { + return json.Marshal(atl) } diff --git a/backend/internal/model/user.go b/backend/internal/model/user.go index 36feb8f..2d53783 100644 --- a/backend/internal/model/user.go +++ b/backend/internal/model/user.go @@ -9,13 +9,13 @@ import ( type User struct { Base - Username string `json:"username"` - Email string `json:"email" ` - FirstName string `json:"firstName"` - LastName string `json:"lastName"` - IsAdmin bool `json:"isAdmin"` + Username string + Email string + FirstName string + LastName string + IsAdmin bool - Credentials []WebauthnCredential `json:"-"` + Credentials []WebauthnCredential } func (u User) WebAuthnID() []byte { return []byte(u.ID) } @@ -59,19 +59,9 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential type OneTimeAccessToken struct { Base - Token string `json:"token"` - ExpiresAt time.Time `json:"expiresAt"` + Token string + ExpiresAt time.Time - UserID string `json:"userId"` + UserID string User User } - -type OneTimeAccessTokenCreateDto struct { - UserID string `json:"userId" binding:"required"` - ExpiresAt time.Time `json:"expiresAt" binding:"required"` -} - -type LoginUserDto struct { - Username string `json:"username" binding:"required"` - Password string `json:"password" binding:"required"` -} diff --git a/backend/internal/model/webauthn.go b/backend/internal/model/webauthn.go index 379077e..b785643 100644 --- a/backend/internal/model/webauthn.go +++ b/backend/internal/model/webauthn.go @@ -19,11 +19,11 @@ type WebauthnSession struct { type WebauthnCredential struct { Base - Name string `json:"name"` - CredentialID string `json:"credentialID"` - PublicKey []byte `json:"-"` - AttestationType string `json:"attestationType"` - Transport AuthenticatorTransportList `json:"-"` + Name string + CredentialID string + PublicKey []byte + AttestationType string + Transport AuthenticatorTransportList BackupEligible bool `json:"backupEligible"` BackupState bool `json:"backupState"` @@ -32,15 +32,15 @@ type WebauthnCredential struct { } type PublicKeyCredentialCreationOptions struct { - Response protocol.PublicKeyCredentialCreationOptions `json:"response"` - SessionID string `json:"session_id"` - Timeout time.Duration `json:"timeout"` + Response protocol.PublicKeyCredentialCreationOptions + SessionID string + Timeout time.Duration } type PublicKeyCredentialRequestOptions struct { - Response protocol.PublicKeyCredentialRequestOptions `json:"response"` - SessionID string `json:"session_id"` - Timeout time.Duration `json:"timeout"` + Response protocol.PublicKeyCredentialRequestOptions + SessionID string + Timeout time.Duration } type AuthenticatorTransportList []protocol.AuthenticatorTransport @@ -58,7 +58,3 @@ func (atl *AuthenticatorTransportList) Scan(value interface{}) error { func (atl AuthenticatorTransportList) Value() (driver.Value, error) { return json.Marshal(atl) } - -type WebauthnCredentialUpdateDto struct { - Name string `json:"name"` -} diff --git a/backend/internal/service/app_config_service.go b/backend/internal/service/app_config_service.go index 8ce3434..7d77eea 100644 --- a/backend/internal/service/app_config_service.go +++ b/backend/internal/service/app_config_service.go @@ -3,6 +3,7 @@ package service import ( "fmt" "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/utils" "gorm.io/gorm" @@ -54,7 +55,7 @@ var defaultDbConfig = model.AppConfig{ }, } -func (s *AppConfigService) UpdateApplicationConfiguration(input model.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { +func (s *AppConfigService) UpdateApplicationConfiguration(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { var savedConfigVariables []model.AppConfigVariable tx := s.db.Begin() diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 7a02ce3..f34677d 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/utils" "golang.org/x/crypto/bcrypt" @@ -26,7 +27,7 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService { } } -func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (string, error) { +func (s *OidcService) Authorize(req dto.AuthorizeOidcClientDto, userID string) (string, error) { var userAuthorizedOIDCClient model.UserAuthorizedOidcClient s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID) @@ -37,7 +38,7 @@ func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (stri return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce) } -func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID string) (string, error) { +func (s *OidcService) AuthorizeNewClient(req dto.AuthorizeOidcClientDto, userID string) (string, error) { userAuthorizedClient := model.UserAuthorizedOidcClient{ UserID: userID, ClientID: req.ClientID, @@ -101,18 +102,18 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin return idToken, accessToken, nil } -func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) { +func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) { var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { - return nil, err + if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil { + return model.OidcClient{}, err } - return &client, nil + return client, nil } func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) { var clients []model.OidcClient - query := s.db.Model(&model.OidcClient{}) + query := s.db.Preload("CreatedBy").Model(&model.OidcClient{}) if searchTerm != "" { searchPattern := "%" + searchTerm + "%" query = query.Where("name LIKE ?", searchPattern) @@ -126,34 +127,34 @@ func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([] return clients, pagination, nil } -func (s *OidcService) CreateClient(input model.OidcClientCreateDto, userID string) (*model.OidcClient, error) { +func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { client := model.OidcClient{ - Name: input.Name, - CallbackURL: input.CallbackURL, - CreatedByID: userID, + Name: input.Name, + CallbackURLs: input.CallbackURLs, + CreatedByID: userID, } if err := s.db.Create(&client).Error; err != nil { - return nil, err + return model.OidcClient{}, err } - return &client, nil + return client, nil } -func (s *OidcService) UpdateClient(clientID string, input model.OidcClientCreateDto) (*model.OidcClient, error) { +func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) { var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { - return nil, err + if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil { + return model.OidcClient{}, err } client.Name = input.Name - client.CallbackURL = input.CallbackURL + client.CallbackURLs = input.CallbackURLs if err := s.db.Save(&client).Error; err != nil { - return nil, err + return model.OidcClient{}, err } - return &client, nil + return client, nil } func (s *OidcService) DeleteClient(clientID string) error { diff --git a/backend/internal/service/test_service.go b/backend/internal/service/test_service.go index fe71be4..4e43a9c 100644 --- a/backend/internal/service/test_service.go +++ b/backend/internal/service/test_service.go @@ -61,20 +61,20 @@ func (s *TestService) SeedDatabase() error { Base: model.Base{ ID: "3654a746-35d4-4321-ac61-0bdcff2b4055", }, - Name: "Nextcloud", - Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY - CallbackURL: "http://nextcloud/auth/callback", - ImageType: utils.StringPointer("png"), - CreatedByID: users[0].ID, + Name: "Nextcloud", + Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY + CallbackURLs: model.CallbackURLs{"http://nextcloud/auth/callback"}, + ImageType: utils.StringPointer("png"), + CreatedByID: users[0].ID, }, { Base: model.Base{ ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018", }, - Name: "Immich", - Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x - CallbackURL: "http://immich/auth/callback", - CreatedByID: users[0].ID, + Name: "Immich", + Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x + CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"}, + CreatedByID: users[0].ID, }, } for _, client := range oidcClients { diff --git a/backend/internal/service/user_sevice.go b/backend/internal/service/user_sevice.go index 0c5c5a2..e4be4e2 100644 --- a/backend/internal/service/user_sevice.go +++ b/backend/internal/service/user_sevice.go @@ -3,6 +3,7 @@ package service import ( "errors" "github.com/stonith404/pocket-id/backend/internal/common" + "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/utils" "gorm.io/gorm" @@ -46,17 +47,24 @@ func (s *UserService) DeleteUser(userID string) error { return s.db.Delete(&user).Error } -func (s *UserService) CreateUser(user *model.User) error { +func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) { + user := &model.User{ + FirstName: input.FirstName, + LastName: input.LastName, + Email: input.Email, + Username: input.Username, + IsAdmin: input.IsAdmin, + } if err := s.db.Create(user).Error; err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return s.checkDuplicatedFields(*user) + return model.User{}, s.checkDuplicatedFields(*user) } - return err + return model.User{}, err } - return nil + return *user, nil } -func (s *UserService) UpdateUser(userID string, updatedUser model.User, updateOwnUser bool) (model.User, error) { +func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool) (model.User, error) { var user model.User if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { return model.User{}, err diff --git a/backend/internal/service/webauthn_service.go b/backend/internal/service/webauthn_service.go index fdf4dbc..68b51f8 100644 --- a/backend/internal/service/webauthn_service.go +++ b/backend/internal/service/webauthn_service.go @@ -67,10 +67,10 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred }, nil } -func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (*model.WebauthnCredential, error) { +func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) { var storedSession model.WebauthnSession if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { - return nil, err + return model.WebauthnCredential{}, err } session := webauthn.SessionData{ @@ -81,12 +81,12 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R var user model.User if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { - return nil, err + return model.WebauthnCredential{}, err } credential, err := s.webAuthn.FinishRegistration(&user, session, r) if err != nil { - return nil, err + return model.WebauthnCredential{}, err } credentialToStore := model.WebauthnCredential{ @@ -100,10 +100,10 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R BackupState: credential.Flags.BackupState, } if err := s.db.Create(&credentialToStore).Error; err != nil { - return nil, err + return model.WebauthnCredential{}, err } - return &credentialToStore, nil + return credentialToStore, nil } func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) { @@ -129,10 +129,10 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions }, nil } -func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (*model.User, error) { +func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (model.User, error) { var storedSession model.WebauthnSession if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { - return nil, err + return model.User{}, err } session := webauthn.SessionData{ @@ -149,14 +149,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssert }, session, credentialAssertionData) if err != nil { - return nil, common.ErrInvalidCredentials + return model.User{}, err } if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { - return nil, err + return model.User{}, err } - return user, nil + return *user, nil } func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) { @@ -180,17 +180,17 @@ func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error { return nil } -func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) error { +func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) { var credential model.WebauthnCredential if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil { - return err + return credential, err } credential.Name = name if err := s.db.Save(&credential).Error; err != nil { - return err + return credential, err } - return nil + return credential, nil } diff --git a/backend/internal/utils/controller_error_util.go b/backend/internal/utils/controller_error_util.go new file mode 100644 index 0000000..db9eb3d --- /dev/null +++ b/backend/internal/utils/controller_error_util.go @@ -0,0 +1,75 @@ +package utils + +import ( + "errors" + "github.com/gin-gonic/gin" + "github.com/go-playground/validator/v10" + "gorm.io/gorm" + "log" + "net/http" + "strings" +) + +import ( + "fmt" +) + +func ControllerError(c *gin.Context, err error) { + // Check for record not found errors + if errors.Is(err, gorm.ErrRecordNotFound) { + CustomControllerError(c, http.StatusNotFound, "Record not found") + return + } + + // Check for validation errors + var validationErrors validator.ValidationErrors + if errors.As(err, &validationErrors) { + message := handleValidationError(validationErrors) + CustomControllerError(c, http.StatusBadRequest, message) + return + + } + + log.Println(err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) +} + +func handleValidationError(validationErrors validator.ValidationErrors) string { + var errorMessages []string + + for _, ve := range validationErrors { + fieldName := ve.Field() + var errorMessage string + switch ve.Tag() { + case "required": + errorMessage = fmt.Sprintf("%s is required", fieldName) + case "email": + errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName) + case "username": + errorMessage = fmt.Sprintf("%s must contain only lowercase letters, numbers, and underscores", fieldName) + case "url": + errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName) + case "min": + errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param()) + case "max": + errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param()) + case "urlList": + errorMessage = fmt.Sprintf("%s must be a list of valid URLs", fieldName) + default: + errorMessage = fmt.Sprintf("%s is invalid", fieldName) + } + + errorMessages = append(errorMessages, errorMessage) + } + + // Join all the error messages into a single string + combinedErrors := strings.Join(errorMessages, ", ") + + return combinedErrors +} + +func CustomControllerError(c *gin.Context, statusCode int, message string) { + // Capitalize the first letter of the message + message = strings.ToUpper(message[:1]) + message[1:] + c.JSON(statusCode, gin.H{"error": message}) +} diff --git a/backend/internal/utils/handler_error_util.go b/backend/internal/utils/handler_error_util.go deleted file mode 100644 index e644501..0000000 --- a/backend/internal/utils/handler_error_util.go +++ /dev/null @@ -1,27 +0,0 @@ -package utils - -import ( - "errors" - "github.com/gin-gonic/gin" - "gorm.io/gorm" - "log" - "net/http" - "strings" -) - -func UnknownHandlerError(c *gin.Context, err error) { - if errors.Is(err, gorm.ErrRecordNotFound) { - HandlerError(c, http.StatusNotFound, "Record not found") - return - } else { - log.Println(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"}) - } - -} - -func HandlerError(c *gin.Context, statusCode int, message string) { - // Capitalize the first letter of the message - message = strings.ToUpper(message[:1]) + message[1:] - c.JSON(statusCode, gin.H{"error": message}) -}