refactor: use dtos in controllers

This commit is contained in:
Elias Schneider
2024-08-23 17:04:19 +02:00
parent 9f49e5577e
commit ae7aeb0945
29 changed files with 679 additions and 304 deletions

View File

@@ -27,12 +27,6 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
r := gin.Default() r := gin.Default()
r.Use(gin.Logger()) r.Use(gin.Logger())
// Add middleware
r.Use(
middleware.NewCorsMiddleware().Add(),
middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60),
)
// Initialize services // Initialize services
webauthnService := service.NewWebAuthnService(db, appConfigService) webauthnService := service.NewWebAuthnService(db, appConfigService)
jwtService := service.NewJwtService(appConfigService) jwtService := service.NewJwtService(appConfigService)
@@ -40,8 +34,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
oidcService := service.NewOidcService(db, jwtService) oidcService := service.NewOidcService(db, jwtService)
testService := service.NewTestService(db, appConfigService) testService := service.NewTestService(db, appConfigService)
// Add global middleware
r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))
// Initialize middleware // Initialize middleware
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService) jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
// Set up API routes // Set up API routes
@@ -49,7 +48,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService) controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService) controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService) controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService) controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
// Add test controller in non-production environments // Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" { if common.EnvConfig.AppEnv != "production" {

View File

@@ -13,6 +13,7 @@ var (
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided") ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
ErrOidcClientSecretInvalid = errors.New("invalid client secret") ErrOidcClientSecretInvalid = errors.New("invalid client secret")
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code") ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
ErrFileTypeNotSupported = errors.New("file type not supported") ErrFileTypeNotSupported = errors.New("file type not supported")
ErrInvalidCredentials = errors.New("no user found with provided credentials") ErrInvalidCredentials = errors.New("no user found with provided credentials")
) )

View File

@@ -5,19 +5,19 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils" "github.com/stonith404/pocket-id/backend/internal/utils"
"net/http" "net/http"
) )
func NewApplicationConfigurationController( func NewAppConfigController(
group *gin.RouterGroup, group *gin.RouterGroup,
jwtAuthMiddleware *middleware.JwtAuthMiddleware, jwtAuthMiddleware *middleware.JwtAuthMiddleware,
appConfigService *service.AppConfigService) { appConfigService *service.AppConfigService) {
acc := &ApplicationConfigurationController{ acc := &AppConfigController{
appConfigService: appConfigService, appConfigService: appConfigService,
} }
group.GET("/application-configuration", acc.listApplicationConfigurationHandler) group.GET("/application-configuration", acc.listApplicationConfigurationHandler)
@@ -32,86 +32,104 @@ func NewApplicationConfigurationController(
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler) group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
} }
type ApplicationConfigurationController struct { type AppConfigController struct {
appConfigService *service.AppConfigService appConfigService *service.AppConfigService
} }
func (acc *ApplicationConfigurationController) listApplicationConfigurationHandler(c *gin.Context) { func (acc *AppConfigController) listApplicationConfigurationHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListApplicationConfiguration(false) configuration, err := acc.appConfigService.ListApplicationConfiguration(false)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(200, configuration) var configVariablesDto []dto.PublicAppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(200, configVariablesDto)
} }
func (acc *ApplicationConfigurationController) listAllApplicationConfigurationHandler(c *gin.Context) { func (acc *AppConfigController) listAllApplicationConfigurationHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListApplicationConfiguration(true) configuration, err := acc.appConfigService.ListApplicationConfiguration(true)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(200, configuration) var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(200, configVariablesDto)
} }
func (acc *ApplicationConfigurationController) updateApplicationConfigurationHandler(c *gin.Context) { func (acc *AppConfigController) updateApplicationConfigurationHandler(c *gin.Context) {
var input model.AppConfigUpdateDto var input dto.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input) savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, savedConfigVariables) var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, configVariablesDto)
} }
func (acc *ApplicationConfigurationController) getLogoHandler(c *gin.Context) { func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.LogoImageType.Value imageType := acc.appConfigService.DbConfig.LogoImageType.Value
acc.getImage(c, "logo", imageType) acc.getImage(c, "logo", imageType)
} }
func (acc *ApplicationConfigurationController) getFaviconHandler(c *gin.Context) { func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
acc.getImage(c, "favicon", "ico") acc.getImage(c, "favicon", "ico")
} }
func (acc *ApplicationConfigurationController) getBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.getImage(c, "background", imageType) acc.getImage(c, "background", imageType)
} }
func (acc *ApplicationConfigurationController) updateLogoHandler(c *gin.Context) { func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.LogoImageType.Value imageType := acc.appConfigService.DbConfig.LogoImageType.Value
acc.updateImage(c, "logo", imageType) acc.updateImage(c, "logo", imageType)
} }
func (acc *ApplicationConfigurationController) updateFaviconHandler(c *gin.Context) { func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
fileType := utils.GetFileExtension(file.Filename) fileType := utils.GetFileExtension(file.Filename)
if fileType != "ico" { if fileType != "ico" {
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico") utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico")
return return
} }
acc.updateImage(c, "favicon", "ico") acc.updateImage(c, "favicon", "ico")
} }
func (acc *ApplicationConfigurationController) updateBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.updateImage(c, "background", imageType) acc.updateImage(c, "background", imageType)
} }
func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name string, imageType string) { func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType) imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
mimeType := utils.GetImageMimeType(imageType) mimeType := utils.GetImageMimeType(imageType)
@@ -119,19 +137,19 @@ func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name str
c.File(imagePath) c.File(imagePath)
} }
func (acc *ApplicationConfigurationController) updateImage(c *gin.Context, imageName string, oldImageType string) { func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
if err != nil { if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) { if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.HandlerError(c, http.StatusBadRequest, err.Error()) utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }

View File

@@ -4,8 +4,8 @@ import (
"errors" "errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils" "github.com/stonith404/pocket-id/backend/internal/utils"
"net/http" "net/http"
@@ -40,18 +40,18 @@ type OidcController struct {
} }
func (oc *OidcController) authorizeHandler(c *gin.Context) { func (oc *OidcController) authorizeHandler(c *gin.Context) {
var parsedBody model.AuthorizeRequest var input dto.AuthorizeOidcClientDto
if err := c.ShouldBindJSON(&parsedBody); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
code, err := oc.oidcService.Authorize(parsedBody, c.GetString("userID")) code, err := oc.oidcService.Authorize(input, c.GetString("userID"))
if err != nil { if err != nil {
if errors.Is(err, common.ErrOidcMissingAuthorization) { if errors.Is(err, common.ErrOidcMissingAuthorization) {
utils.HandlerError(c, http.StatusForbidden, err.Error()) utils.CustomControllerError(c, http.StatusForbidden, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
@@ -60,15 +60,15 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
} }
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
var parsedBody model.AuthorizeNewClientDto var input dto.AuthorizeOidcClientDto
if err := c.ShouldBindJSON(&parsedBody); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
code, err := oc.oidcService.AuthorizeNewClient(parsedBody, c.GetString("userID")) code, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"))
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -76,35 +76,35 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
} }
func (oc *OidcController) createIDTokenHandler(c *gin.Context) { func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
var body model.OidcIdTokenDto var input dto.OidcIdTokenDto
if err := c.ShouldBind(&body); err != nil { if err := c.ShouldBind(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
clientID := body.ClientID clientID := input.ClientID
clientSecret := body.ClientSecret clientSecret := input.ClientSecret
// Client id and secret can also be passed over the Authorization header // Client id and secret can also be passed over the Authorization header
if clientID == "" || clientSecret == "" { if clientID == "" || clientSecret == "" {
var ok bool var ok bool
clientID, clientSecret, ok = c.Request.BasicAuth() clientID, clientSecret, ok = c.Request.BasicAuth()
if !ok { if !ok {
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided") utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided")
return return
} }
} }
idToken, accessToken, err := oc.oidcService.CreateTokens(body.Code, body.GrantType, clientID, clientSecret) idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
if err != nil { if err != nil {
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) || if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
errors.Is(err, common.ErrOidcMissingClientCredentials) || errors.Is(err, common.ErrOidcMissingClientCredentials) ||
errors.Is(err, common.ErrOidcClientSecretInvalid) || errors.Is(err, common.ErrOidcClientSecretInvalid) ||
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) { errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
utils.HandlerError(c, http.StatusBadRequest, err.Error()) utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
@@ -116,14 +116,14 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
token := strings.Split(c.GetHeader("Authorization"), " ")[1] token := strings.Split(c.GetHeader("Authorization"), " ")[1]
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token) jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error()) utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
return return
} }
userID := jwtClaims.Subject userID := jwtClaims.Subject
clientId := jwtClaims.Audience[0] clientId := jwtClaims.Audience[0]
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId) claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -134,11 +134,28 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id") clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId) client, err := oc.oidcService.GetClient(clientId)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, client) // Return a different DTO based on the user's role
if c.GetBool("userIsAdmin") {
clientDto := dto.OidcClientDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
} else {
clientDto := dto.PublicOidcClientDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
}
utils.ControllerError(c, err)
} }
func (oc *OidcController) listClientsHandler(c *gin.Context) { func (oc *OidcController) listClientsHandler(c *gin.Context) {
@@ -148,36 +165,48 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize) clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return
}
var clientsDto []dto.OidcClientDto
if err := dto.MapStructList(clients, &clientsDto); err != nil {
utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": clients, "data": clientsDto,
"pagination": pagination, "pagination": pagination,
}) })
} }
func (oc *OidcController) createClientHandler(c *gin.Context) { func (oc *OidcController) createClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusCreated, client) var clientDto dto.OidcClientDto
if err := dto.MapStruct(client, &clientDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusCreated, clientDto)
} }
func (oc *OidcController) deleteClientHandler(c *gin.Context) { func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.oidcService.DeleteClient(c.Param("id")) err := oc.oidcService.DeleteClient(c.Param("id"))
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found") utils.ControllerError(c, err)
return return
} }
@@ -185,25 +214,31 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) {
} }
func (oc *OidcController) updateClientHandler(c *gin.Context) { func (oc *OidcController) updateClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
client, err := oc.oidcService.UpdateClient(c.Param("id"), input) client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusNoContent, client) var clientDto dto.OidcClientDto
if err := dto.MapStruct(client, &clientDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, clientDto)
} }
func (oc *OidcController) createClientSecretHandler(c *gin.Context) { func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -213,7 +248,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
func (oc *OidcController) getClientLogoHandler(c *gin.Context) { func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -224,16 +259,16 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
if err != nil { if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) { if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.HandlerError(c, http.StatusBadRequest, err.Error()) utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
@@ -244,7 +279,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.oidcService.DeleteClientLogo(c.Param("id")) err := oc.oidcService.DeleteClientLogo(c.Param("id"))
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils" "github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
) )
func NewTestController(group *gin.RouterGroup, testService *service.TestService) { func NewTestController(group *gin.RouterGroup, testService *service.TestService) {
@@ -18,19 +19,19 @@ type TestController struct {
func (tc *TestController) resetAndSeedHandler(c *gin.Context) { func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
if err := tc.TestService.ResetDatabase(); err != nil { if err := tc.TestService.ResetDatabase(); err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
if err := tc.TestService.ResetApplicationImages(); err != nil { if err := tc.TestService.ResetApplicationImages(); err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
if err := tc.TestService.SeedDatabase(); err != nil { if err := tc.TestService.SeedDatabase(); err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(200, gin.H{"message": "Database reset and seeded"}) c.Status(http.StatusNoContent)
} }

View File

@@ -4,8 +4,8 @@ import (
"errors" "errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service" "github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils" "github.com/stonith404/pocket-id/backend/internal/utils"
"golang.org/x/time/rate" "golang.org/x/time/rate"
@@ -43,12 +43,18 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize) users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return
}
var usersDto []dto.UserDto
if err := dto.MapStructList(users, &usersDto); err != nil {
utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": users, "data": usersDto,
"pagination": pagination, "pagination": pagination,
}) })
} }
@@ -56,25 +62,38 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
func (uc *UserController) getUserHandler(c *gin.Context) { func (uc *UserController) getUserHandler(c *gin.Context) {
user, err := uc.UserService.GetUser(c.Param("id")) user, err := uc.UserService.GetUser(c.Param("id"))
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, user) var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, userDto)
} }
func (uc *UserController) getCurrentUserHandler(c *gin.Context) { func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
user, err := uc.UserService.GetUser(c.GetString("userID")) user, err := uc.UserService.GetUser(c.GetString("userID"))
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, user)
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, userDto)
} }
func (uc *UserController) deleteUserHandler(c *gin.Context) { func (uc *UserController) deleteUserHandler(c *gin.Context) {
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil { if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -82,22 +101,29 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
} }
func (uc *UserController) createUserHandler(c *gin.Context) { func (uc *UserController) createUserHandler(c *gin.Context) {
var user model.User var input dto.UserCreateDto
if err := c.ShouldBindJSON(&user); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
if err := uc.UserService.CreateUser(&user); err != nil { user, err := uc.UserService.CreateUser(input)
if err != nil {
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
utils.HandlerError(c, http.StatusConflict, err.Error()) utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
c.JSON(http.StatusCreated, user) var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusCreated, userDto)
} }
func (uc *UserController) updateUserHandler(c *gin.Context) { func (uc *UserController) updateUserHandler(c *gin.Context) {
@@ -109,15 +135,15 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
} }
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
var input model.OneTimeAccessTokenCreateDto var input dto.OneTimeAccessTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -128,9 +154,9 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token")) user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
if err != nil { if err != nil {
if errors.Is(err, common.ErrTokenInvalidOrExpired) { if errors.Is(err, common.ErrTokenInvalidOrExpired) {
utils.HandlerError(c, http.StatusUnauthorized, err.Error()) utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
@@ -143,21 +169,27 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.UserService.SetupInitialAdmin() user, token, err := uc.UserService.SetupInitialAdmin()
if err != nil { if err != nil {
if errors.Is(err, common.ErrSetupAlreadyCompleted) { if errors.Is(err, common.ErrSetupAlreadyCompleted) {
utils.HandlerError(c, http.StatusBadRequest, err.Error()) utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user) c.JSON(http.StatusOK, userDto)
} }
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
var updatedUser model.User var input dto.UserCreateDto
if err := c.ShouldBindJSON(&updatedUser); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
@@ -168,15 +200,21 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
userID = c.Param("id") userID = c.Param("id")
} }
user, err := uc.UserService.UpdateUser(userID, updatedUser, updateOwnUser) user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser)
if err != nil { if err != nil {
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) { if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
utils.HandlerError(c, http.StatusConflict, err.Error()) utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
c.JSON(http.StatusOK, user) var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, userDto)
} }

View File

@@ -3,9 +3,8 @@ package controller
import ( import (
"errors" "errors"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"log"
"net/http" "net/http"
"time" "time"
@@ -40,8 +39,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
options, err := wc.webAuthnService.BeginRegistration(userID) options, err := wc.webAuthnService.BeginRegistration(userID)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
log.Println(err)
return return
} }
@@ -52,24 +50,30 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id") sessionID, err := c.Cookie("session_id")
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
return return
} }
userID := c.GetString("userID") userID := c.GetString("userID")
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, credential) var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credentialDto)
} }
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
options, err := wc.webAuthnService.BeginLogin() options, err := wc.webAuthnService.BeginLogin()
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -80,13 +84,13 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id") sessionID, err := c.Cookie("session_id")
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing") utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
return return
} }
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body) credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
if err != nil { if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
@@ -94,32 +98,44 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData) user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData)
if err != nil { if err != nil {
if errors.Is(err, common.ErrInvalidCredentials) { if errors.Is(err, common.ErrInvalidCredentials) {
utils.HandlerError(c, http.StatusUnauthorized, err.Error()) utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
} else { } else {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
} }
return return
} }
token, err := wc.jwtService.GenerateAccessToken(*user) token, err := wc.jwtService.GenerateAccessToken(user)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return return
} }
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true) c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user) c.JSON(http.StatusOK, userDto)
} }
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
credentials, err := wc.webAuthnService.ListCredentials(userID) credentials, err := wc.webAuthnService.ListCredentials(userID)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.JSON(http.StatusOK, credentials) var credentialDtos []dto.WebauthnCredentialDto
if err := dto.MapStructList(credentials, &credentialDtos); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credentialDtos)
} }
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) { func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
@@ -128,7 +144,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
err := wc.webAuthnService.DeleteCredential(userID, credentialID) err := wc.webAuthnService.DeleteCredential(userID, credentialID)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
@@ -139,19 +155,25 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
credentialID := c.Param("id") credentialID := c.Param("id")
var input model.WebauthnCredentialUpdateDto var input dto.WebauthnCredentialUpdateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error()) utils.ControllerError(c, err)
return return
} }
err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }
c.Status(http.StatusNoContent) var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credentialDto)
} }
func (wc *WebauthnController) logoutHandler(c *gin.Context) { func (wc *WebauthnController) logoutHandler(c *gin.Context) {

View File

@@ -21,7 +21,7 @@ type WellKnownController struct {
func (wkc *WellKnownController) jwksHandler(c *gin.Context) { func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
jwk, err := wkc.jwtService.GetJWK() jwk, err := wkc.jwtService.GetJWK()
if err != nil { if err != nil {
utils.UnknownHandlerError(c, err) utils.ControllerError(c, err)
return return
} }

View File

@@ -0,0 +1,17 @@
package dto
type PublicAppConfigVariableDto struct {
Key string `json:"key"`
Type string `json:"type"`
Value string `json:"value"`
}
type AppConfigVariableDto struct {
PublicAppConfigVariableDto
IsPublic bool `json:"isPublic"`
}
type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required,min=1,max=30"`
SessionDuration string `json:"sessionDuration" binding:"required"`
}

View File

@@ -0,0 +1,79 @@
package dto
import (
"errors"
"reflect"
)
// MapStructList maps a list of source structs to a list of destination structs
func MapStructList[S any, D any](source []S, destination *[]D) error {
for _, item := range source {
var destItem D
if err := MapStruct(item, &destItem); err != nil {
return err
}
*destination = append(*destination, destItem)
}
return nil
}
// MapStruct maps a source struct to a destination struct
func MapStruct[S any, D any](source S, destination *D) error {
// Ensure destination is a non-nil pointer
destValue := reflect.ValueOf(destination)
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
return errors.New("destination must be a non-nil pointer to a struct")
}
// Ensure source is a struct
sourceValue := reflect.ValueOf(source)
if sourceValue.Kind() != reflect.Struct {
return errors.New("source must be a struct")
}
return mapStructInternal(sourceValue, destValue.Elem())
}
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
// Loop through the fields of the destination struct
for i := 0; i < destVal.NumField(); i++ {
destField := destVal.Field(i)
destFieldType := destVal.Type().Field(i)
if destFieldType.Anonymous {
// Recursively handle embedded structs
if err := mapStructInternal(sourceVal, destField); err != nil {
return err
}
continue
}
sourceField := sourceVal.FieldByName(destFieldType.Name)
// If the source field is valid and can be assigned to the destination field
if sourceField.IsValid() && destField.CanSet() {
// Handle direct assignment for simple types
if sourceField.Type() == destField.Type() {
destField.Set(sourceField)
} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
// Handle slices
if sourceField.Type().Elem() == destField.Type().Elem() {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
}
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {
// Recursively map nested structs
if err := mapStructInternal(sourceField, destField); err != nil {
return err
}
}
}
}
return nil
}

View File

@@ -0,0 +1,31 @@
package dto
type PublicOidcClientDto struct {
ID string `json:"id"`
Name string `json:"name"`
}
type OidcClientDto struct {
PublicOidcClientDto
HasLogo bool `json:"hasLogo"`
CallbackURLs []string `json:"callbackURLs"`
CreatedBy UserDto `json:"createdBy"`
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
}
type AuthorizeOidcClientDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
}
type OidcIdTokenDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
}

View File

@@ -0,0 +1,30 @@
package dto
import "time"
type UserDto struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email" `
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
IsAdmin bool `json:"isAdmin"`
}
type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=3,max=20"`
Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=3,max=30"`
LastName string `json:"lastName" binding:"required,min=3,max=30"`
IsAdmin bool `json:"isAdmin"`
}
type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}
type LoginUserDto struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}

View File

@@ -0,0 +1,39 @@
package dto
import (
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"log"
"net/url"
"regexp"
)
var validateUrlList validator.Func = func(fl validator.FieldLevel) bool {
urls := fl.Field().Interface().([]string)
for _, u := range urls {
_, err := url.ParseRequestURI(u)
if err != nil {
return false
}
}
return true
}
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
regex := "^[a-z0-9_]*$"
matched, _ := regexp.MatchString(regex, fl.Field().String())
return matched
}
func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("username", validateUsername); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
}

View File

@@ -0,0 +1,23 @@
package dto
import (
"github.com/go-webauthn/webauthn/protocol"
"time"
)
type WebauthnCredentialDto struct {
ID string `json:"id"`
Name string `json:"name"`
CredentialID string `json:"credentialID"`
AttestationType string `json:"attestationType"`
Transport []protocol.AuthenticatorTransport `json:"transport"`
BackupEligible bool `json:"backupEligible"`
BackupState bool `json:"backupState"`
CreatedAt time.Time `json:"createdAt"`
}
type WebauthnCredentialUpdateDto struct {
Name string `json:"name" binding:"required,min=1,max=30"`
}

View File

@@ -17,7 +17,7 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize) c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
if err := c.Request.ParseMultipartForm(maxSize); err != nil { if err := c.Request.ParseMultipartForm(maxSize); err != nil {
utils.HandlerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize))) utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
c.Abort() c.Abort()
return return
} }

View File

@@ -9,11 +9,12 @@ import (
) )
type JwtAuthMiddleware struct { type JwtAuthMiddleware struct {
jwtService *service.JwtService jwtService *service.JwtService
ignoreUnauthenticated bool
} }
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware { func NewJwtAuthMiddleware(jwtService *service.JwtService, ignoreUnauthenticated bool) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService} return &JwtAuthMiddleware{jwtService: jwtService, ignoreUnauthenticated: ignoreUnauthenticated}
} }
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc { func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
@@ -24,23 +25,29 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ") authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
if len(authorizationHeaderSplitted) == 2 { if len(authorizationHeaderSplitted) == 2 {
token = authorizationHeaderSplitted[1] token = authorizationHeaderSplitted[1]
} else if m.ignoreUnauthenticated {
c.Next()
return
} else { } else {
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in") utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
c.Abort() c.Abort()
return return
} }
} }
claims, err := m.jwtService.VerifyAccessToken(token) claims, err := m.jwtService.VerifyAccessToken(token)
if err != nil { if err != nil && m.ignoreUnauthenticated {
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in") c.Next()
return
} else if err != nil {
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
c.Abort() c.Abort()
return return
} }
// Check if the user is an admin // Check if the user is an admin
if adminOnly && !claims.IsAdmin { if adminOnly && !claims.IsAdmin {
utils.HandlerError(c, http.StatusForbidden, "You don't have permission to access this resource") utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource")
c.Abort() c.Abort()
return return
} }

View File

@@ -33,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
limiter := getLimiter(ip, limit, burst) limiter := getLimiter(ip, limit, burst)
if !limiter.Allow() { if !limiter.Allow() {
utils.HandlerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.") utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
c.Abort() c.Abort()
return return
} }

View File

@@ -1,11 +1,11 @@
package model package model
type AppConfigVariable struct { type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null" json:"key"` Key string `gorm:"primaryKey;not null"`
Type string `json:"type"` Type string
IsPublic bool `json:"-"` IsPublic bool
IsInternal bool `json:"-"` IsInternal bool
Value string `json:"value"` Value string
} }
type AppConfig struct { type AppConfig struct {
@@ -14,8 +14,3 @@ type AppConfig struct {
LogoImageType AppConfigVariable LogoImageType AppConfigVariable
SessionDuration AppConfigVariable SessionDuration AppConfigVariable
} }
type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required"`
SessionDuration string `json:"sessionDuration" binding:"required"`
}

View File

@@ -8,8 +8,8 @@ import (
// Base contains common columns for all tables. // Base contains common columns for all tables.
type Base struct { type Base struct {
ID string `gorm:"primaryKey;not null" json:"id"` ID string `gorm:"primaryKey;not null"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time
} }
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) { func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {

View File

@@ -1,38 +1,22 @@
package model package model
import ( import (
"database/sql/driver"
"encoding/json"
"errors"
"gorm.io/gorm" "gorm.io/gorm"
"time" "time"
) )
type UserAuthorizedOidcClient struct { type UserAuthorizedOidcClient struct {
Scope string Scope string
UserID string `json:"userId" gorm:"primary_key;"` UserID string `gorm:"primary_key;"`
User User User User
ClientID string `json:"clientId" gorm:"primary_key;"` ClientID string `gorm:"primary_key;"`
Client OidcClient Client OidcClient
} }
type OidcClient struct {
Base
Name string `json:"name"`
Secret string `json:"-"`
CallbackURL string `json:"callbackURL"`
ImageType *string `json:"-"`
HasLogo bool `gorm:"-" json:"hasLogo"`
CreatedByID string
CreatedBy User
}
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
// Compute HasLogo field
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
return nil
}
type OidcAuthorizationCode struct { type OidcAuthorizationCode struct {
Base Base
@@ -47,26 +31,38 @@ type OidcAuthorizationCode struct {
ClientID string ClientID string
} }
type OidcClientCreateDto struct { type OidcClient struct {
Name string `json:"name" binding:"required"` Base
CallbackURL string `json:"callbackURL" binding:"required"`
Name string
Secret string
CallbackURLs CallbackURLs
ImageType *string
HasLogo bool `gorm:"-"`
CreatedByID string
CreatedBy User
} }
type AuthorizeNewClientDto struct { func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
ClientID string `json:"clientID" binding:"required"` // Compute HasLogo field
Scope string `json:"scope" binding:"required"` c.HasLogo = c.ImageType != nil && *c.ImageType != ""
Nonce string `json:"nonce"` return nil
} }
type OidcIdTokenDto struct { type CallbackURLs []string
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"` func (s *CallbackURLs) Scan(value interface{}) error {
ClientID string `form:"client_id"` switch v := value.(type) {
ClientSecret string `form:"client_secret"` case []byte:
return json.Unmarshal(v, s)
case string:
return json.Unmarshal([]byte(v), s)
default:
return errors.New("type assertion to []byte or string failed")
}
} }
type AuthorizeRequest struct { func (atl CallbackURLs) Value() (driver.Value, error) {
ClientID string `json:"clientID" binding:"required"` return json.Marshal(atl)
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
} }

View File

@@ -9,13 +9,13 @@ import (
type User struct { type User struct {
Base Base
Username string `json:"username"` Username string
Email string `json:"email" ` Email string
FirstName string `json:"firstName"` FirstName string
LastName string `json:"lastName"` LastName string
IsAdmin bool `json:"isAdmin"` IsAdmin bool
Credentials []WebauthnCredential `json:"-"` Credentials []WebauthnCredential
} }
func (u User) WebAuthnID() []byte { return []byte(u.ID) } func (u User) WebAuthnID() []byte { return []byte(u.ID) }
@@ -59,19 +59,9 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
type OneTimeAccessToken struct { type OneTimeAccessToken struct {
Base Base
Token string `json:"token"` Token string
ExpiresAt time.Time `json:"expiresAt"` ExpiresAt time.Time
UserID string `json:"userId"` UserID string
User User User User
} }
type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}
type LoginUserDto struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}

View File

@@ -19,11 +19,11 @@ type WebauthnSession struct {
type WebauthnCredential struct { type WebauthnCredential struct {
Base Base
Name string `json:"name"` Name string
CredentialID string `json:"credentialID"` CredentialID string
PublicKey []byte `json:"-"` PublicKey []byte
AttestationType string `json:"attestationType"` AttestationType string
Transport AuthenticatorTransportList `json:"-"` Transport AuthenticatorTransportList
BackupEligible bool `json:"backupEligible"` BackupEligible bool `json:"backupEligible"`
BackupState bool `json:"backupState"` BackupState bool `json:"backupState"`
@@ -32,15 +32,15 @@ type WebauthnCredential struct {
} }
type PublicKeyCredentialCreationOptions struct { type PublicKeyCredentialCreationOptions struct {
Response protocol.PublicKeyCredentialCreationOptions `json:"response"` Response protocol.PublicKeyCredentialCreationOptions
SessionID string `json:"session_id"` SessionID string
Timeout time.Duration `json:"timeout"` Timeout time.Duration
} }
type PublicKeyCredentialRequestOptions struct { type PublicKeyCredentialRequestOptions struct {
Response protocol.PublicKeyCredentialRequestOptions `json:"response"` Response protocol.PublicKeyCredentialRequestOptions
SessionID string `json:"session_id"` SessionID string
Timeout time.Duration `json:"timeout"` Timeout time.Duration
} }
type AuthenticatorTransportList []protocol.AuthenticatorTransport type AuthenticatorTransportList []protocol.AuthenticatorTransport
@@ -58,7 +58,3 @@ func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
func (atl AuthenticatorTransportList) Value() (driver.Value, error) { func (atl AuthenticatorTransportList) Value() (driver.Value, error) {
return json.Marshal(atl) return json.Marshal(atl)
} }
type WebauthnCredentialUpdateDto struct {
Name string `json:"name"`
}

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"fmt" "fmt"
"github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils" "github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm" "gorm.io/gorm"
@@ -54,7 +55,7 @@ var defaultDbConfig = model.AppConfig{
}, },
} }
func (s *AppConfigService) UpdateApplicationConfiguration(input model.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { func (s *AppConfigService) UpdateApplicationConfiguration(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
var savedConfigVariables []model.AppConfigVariable var savedConfigVariables []model.AppConfigVariable
tx := s.db.Begin() tx := s.db.Begin()

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils" "github.com/stonith404/pocket-id/backend/internal/utils"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -26,7 +27,7 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService {
} }
} }
func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (string, error) { func (s *OidcService) Authorize(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID) s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID)
@@ -37,7 +38,7 @@ func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (stri
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce) return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
} }
func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID string) (string, error) { func (s *OidcService) AuthorizeNewClient(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
userAuthorizedClient := model.UserAuthorizedOidcClient{ userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID, UserID: userID,
ClientID: req.ClientID, ClientID: req.ClientID,
@@ -101,18 +102,18 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin
return idToken, accessToken, nil return idToken, accessToken, nil
} }
func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) { func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
return nil, err return model.OidcClient{}, err
} }
return &client, nil return client, nil
} }
func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) { func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) {
var clients []model.OidcClient var clients []model.OidcClient
query := s.db.Model(&model.OidcClient{}) query := s.db.Preload("CreatedBy").Model(&model.OidcClient{})
if searchTerm != "" { if searchTerm != "" {
searchPattern := "%" + searchTerm + "%" searchPattern := "%" + searchTerm + "%"
query = query.Where("name LIKE ?", searchPattern) query = query.Where("name LIKE ?", searchPattern)
@@ -126,34 +127,34 @@ func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]
return clients, pagination, nil return clients, pagination, nil
} }
func (s *OidcService) CreateClient(input model.OidcClientCreateDto, userID string) (*model.OidcClient, error) { func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{ client := model.OidcClient{
Name: input.Name, Name: input.Name,
CallbackURL: input.CallbackURL, CallbackURLs: input.CallbackURLs,
CreatedByID: userID, CreatedByID: userID,
} }
if err := s.db.Create(&client).Error; err != nil { if err := s.db.Create(&client).Error; err != nil {
return nil, err return model.OidcClient{}, err
} }
return &client, nil return client, nil
} }
func (s *OidcService) UpdateClient(clientID string, input model.OidcClientCreateDto) (*model.OidcClient, error) { func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
return nil, err return model.OidcClient{}, err
} }
client.Name = input.Name client.Name = input.Name
client.CallbackURL = input.CallbackURL client.CallbackURLs = input.CallbackURLs
if err := s.db.Save(&client).Error; err != nil { if err := s.db.Save(&client).Error; err != nil {
return nil, err return model.OidcClient{}, err
} }
return &client, nil return client, nil
} }
func (s *OidcService) DeleteClient(clientID string) error { func (s *OidcService) DeleteClient(clientID string) error {

View File

@@ -61,20 +61,20 @@ func (s *TestService) SeedDatabase() error {
Base: model.Base{ Base: model.Base{
ID: "3654a746-35d4-4321-ac61-0bdcff2b4055", ID: "3654a746-35d4-4321-ac61-0bdcff2b4055",
}, },
Name: "Nextcloud", Name: "Nextcloud",
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
CallbackURL: "http://nextcloud/auth/callback", CallbackURLs: model.CallbackURLs{"http://nextcloud/auth/callback"},
ImageType: utils.StringPointer("png"), ImageType: utils.StringPointer("png"),
CreatedByID: users[0].ID, CreatedByID: users[0].ID,
}, },
{ {
Base: model.Base{ Base: model.Base{
ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018", ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
}, },
Name: "Immich", Name: "Immich",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURL: "http://immich/auth/callback", CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"},
CreatedByID: users[0].ID, CreatedByID: users[0].ID,
}, },
} }
for _, client := range oidcClients { for _, client := range oidcClients {

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"errors" "errors"
"github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model" "github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils" "github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm" "gorm.io/gorm"
@@ -46,17 +47,24 @@ func (s *UserService) DeleteUser(userID string) error {
return s.db.Delete(&user).Error return s.db.Delete(&user).Error
} }
func (s *UserService) CreateUser(user *model.User) error { func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
user := &model.User{
FirstName: input.FirstName,
LastName: input.LastName,
Email: input.Email,
Username: input.Username,
IsAdmin: input.IsAdmin,
}
if err := s.db.Create(user).Error; err != nil { if err := s.db.Create(user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) { if errors.Is(err, gorm.ErrDuplicatedKey) {
return s.checkDuplicatedFields(*user) return model.User{}, s.checkDuplicatedFields(*user)
} }
return err return model.User{}, err
} }
return nil return *user, nil
} }
func (s *UserService) UpdateUser(userID string, updatedUser model.User, updateOwnUser bool) (model.User, error) { func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool) (model.User, error) {
var user model.User var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
return model.User{}, err return model.User{}, err

View File

@@ -67,10 +67,10 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
}, nil }, nil
} }
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (*model.WebauthnCredential, error) { func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
var storedSession model.WebauthnSession var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
return nil, err return model.WebauthnCredential{}, err
} }
session := webauthn.SessionData{ session := webauthn.SessionData{
@@ -81,12 +81,12 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
var user model.User var user model.User
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
return nil, err return model.WebauthnCredential{}, err
} }
credential, err := s.webAuthn.FinishRegistration(&user, session, r) credential, err := s.webAuthn.FinishRegistration(&user, session, r)
if err != nil { if err != nil {
return nil, err return model.WebauthnCredential{}, err
} }
credentialToStore := model.WebauthnCredential{ credentialToStore := model.WebauthnCredential{
@@ -100,10 +100,10 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
BackupState: credential.Flags.BackupState, BackupState: credential.Flags.BackupState,
} }
if err := s.db.Create(&credentialToStore).Error; err != nil { if err := s.db.Create(&credentialToStore).Error; err != nil {
return nil, err return model.WebauthnCredential{}, err
} }
return &credentialToStore, nil return credentialToStore, nil
} }
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) { func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
@@ -129,10 +129,10 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}, nil }, nil
} }
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (*model.User, error) { func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (model.User, error) {
var storedSession model.WebauthnSession var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
return nil, err return model.User{}, err
} }
session := webauthn.SessionData{ session := webauthn.SessionData{
@@ -149,14 +149,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssert
}, session, credentialAssertionData) }, session, credentialAssertionData)
if err != nil { if err != nil {
return nil, common.ErrInvalidCredentials return model.User{}, err
} }
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
return nil, err return model.User{}, err
} }
return user, nil return *user, nil
} }
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) { func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
@@ -180,17 +180,17 @@ func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
return nil return nil
} }
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) error { func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
var credential model.WebauthnCredential var credential model.WebauthnCredential
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil { if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
return err return credential, err
} }
credential.Name = name credential.Name = name
if err := s.db.Save(&credential).Error; err != nil { if err := s.db.Save(&credential).Error; err != nil {
return err return credential, err
} }
return nil return credential, nil
} }

View File

@@ -0,0 +1,75 @@
package utils
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"gorm.io/gorm"
"log"
"net/http"
"strings"
)
import (
"fmt"
)
func ControllerError(c *gin.Context, err error) {
// Check for record not found errors
if errors.Is(err, gorm.ErrRecordNotFound) {
CustomControllerError(c, http.StatusNotFound, "Record not found")
return
}
// Check for validation errors
var validationErrors validator.ValidationErrors
if errors.As(err, &validationErrors) {
message := handleValidationError(validationErrors)
CustomControllerError(c, http.StatusBadRequest, message)
return
}
log.Println(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
}
func handleValidationError(validationErrors validator.ValidationErrors) string {
var errorMessages []string
for _, ve := range validationErrors {
fieldName := ve.Field()
var errorMessage string
switch ve.Tag() {
case "required":
errorMessage = fmt.Sprintf("%s is required", fieldName)
case "email":
errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName)
case "username":
errorMessage = fmt.Sprintf("%s must contain only lowercase letters, numbers, and underscores", fieldName)
case "url":
errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName)
case "min":
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
case "max":
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
case "urlList":
errorMessage = fmt.Sprintf("%s must be a list of valid URLs", fieldName)
default:
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
}
errorMessages = append(errorMessages, errorMessage)
}
// Join all the error messages into a single string
combinedErrors := strings.Join(errorMessages, ", ")
return combinedErrors
}
func CustomControllerError(c *gin.Context, statusCode int, message string) {
// Capitalize the first letter of the message
message = strings.ToUpper(message[:1]) + message[1:]
c.JSON(statusCode, gin.H{"error": message})
}

View File

@@ -1,27 +0,0 @@
package utils
import (
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"log"
"net/http"
"strings"
)
func UnknownHandlerError(c *gin.Context, err error) {
if errors.Is(err, gorm.ErrRecordNotFound) {
HandlerError(c, http.StatusNotFound, "Record not found")
return
} else {
log.Println(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
}
}
func HandlerError(c *gin.Context, statusCode int, message string) {
// Capitalize the first letter of the message
message = strings.ToUpper(message[:1]) + message[1:]
c.JSON(statusCode, gin.H{"error": message})
}