refactor: use dtos in controllers

This commit is contained in:
Elias Schneider
2024-08-23 17:04:19 +02:00
parent 9f49e5577e
commit ae7aeb0945
29 changed files with 679 additions and 304 deletions

View File

@@ -27,12 +27,6 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
r := gin.Default()
r.Use(gin.Logger())
// Add middleware
r.Use(
middleware.NewCorsMiddleware().Add(),
middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60),
)
// Initialize services
webauthnService := service.NewWebAuthnService(db, appConfigService)
jwtService := service.NewJwtService(appConfigService)
@@ -40,8 +34,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
oidcService := service.NewOidcService(db, jwtService)
testService := service.NewTestService(db, appConfigService)
// Add global middleware
r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))
// Initialize middleware
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService)
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
// Set up API routes
@@ -49,7 +48,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService)
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
// Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" {

View File

@@ -13,6 +13,7 @@ var (
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
ErrFileTypeNotSupported = errors.New("file type not supported")
ErrInvalidCredentials = errors.New("no user found with provided credentials")
)

View File

@@ -5,19 +5,19 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
)
func NewApplicationConfigurationController(
func NewAppConfigController(
group *gin.RouterGroup,
jwtAuthMiddleware *middleware.JwtAuthMiddleware,
appConfigService *service.AppConfigService) {
acc := &ApplicationConfigurationController{
acc := &AppConfigController{
appConfigService: appConfigService,
}
group.GET("/application-configuration", acc.listApplicationConfigurationHandler)
@@ -32,86 +32,104 @@ func NewApplicationConfigurationController(
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
}
type ApplicationConfigurationController struct {
type AppConfigController struct {
appConfigService *service.AppConfigService
}
func (acc *ApplicationConfigurationController) listApplicationConfigurationHandler(c *gin.Context) {
func (acc *AppConfigController) listApplicationConfigurationHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListApplicationConfiguration(false)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(200, configuration)
var configVariablesDto []dto.PublicAppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(200, configVariablesDto)
}
func (acc *ApplicationConfigurationController) listAllApplicationConfigurationHandler(c *gin.Context) {
func (acc *AppConfigController) listAllApplicationConfigurationHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListApplicationConfiguration(true)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(200, configuration)
var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(200, configVariablesDto)
}
func (acc *ApplicationConfigurationController) updateApplicationConfigurationHandler(c *gin.Context) {
var input model.AppConfigUpdateDto
func (acc *AppConfigController) updateApplicationConfigurationHandler(c *gin.Context) {
var input dto.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, savedConfigVariables)
var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, configVariablesDto)
}
func (acc *ApplicationConfigurationController) getLogoHandler(c *gin.Context) {
func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
acc.getImage(c, "logo", imageType)
}
func (acc *ApplicationConfigurationController) getFaviconHandler(c *gin.Context) {
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
acc.getImage(c, "favicon", "ico")
}
func (acc *ApplicationConfigurationController) getBackgroundImageHandler(c *gin.Context) {
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.getImage(c, "background", imageType)
}
func (acc *ApplicationConfigurationController) updateLogoHandler(c *gin.Context) {
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
acc.updateImage(c, "logo", imageType)
}
func (acc *ApplicationConfigurationController) updateFaviconHandler(c *gin.Context) {
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
fileType := utils.GetFileExtension(file.Filename)
if fileType != "ico" {
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico")
utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico")
return
}
acc.updateImage(c, "favicon", "ico")
}
func (acc *ApplicationConfigurationController) updateBackgroundImageHandler(c *gin.Context) {
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.updateImage(c, "background", imageType)
}
func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name string, imageType string) {
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
mimeType := utils.GetImageMimeType(imageType)
@@ -119,19 +137,19 @@ func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name str
c.File(imagePath)
}
func (acc *ApplicationConfigurationController) updateImage(c *gin.Context, imageName string, oldImageType string) {
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}

View File

@@ -4,8 +4,8 @@ import (
"errors"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
@@ -40,18 +40,18 @@ type OidcController struct {
}
func (oc *OidcController) authorizeHandler(c *gin.Context) {
var parsedBody model.AuthorizeRequest
if err := c.ShouldBindJSON(&parsedBody); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
var input dto.AuthorizeOidcClientDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
code, err := oc.oidcService.Authorize(parsedBody, c.GetString("userID"))
code, err := oc.oidcService.Authorize(input, c.GetString("userID"))
if err != nil {
if errors.Is(err, common.ErrOidcMissingAuthorization) {
utils.HandlerError(c, http.StatusForbidden, err.Error())
utils.CustomControllerError(c, http.StatusForbidden, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
@@ -60,15 +60,15 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
}
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
var parsedBody model.AuthorizeNewClientDto
if err := c.ShouldBindJSON(&parsedBody); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
var input dto.AuthorizeOidcClientDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
code, err := oc.oidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
code, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -76,35 +76,35 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
}
func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
var body model.OidcIdTokenDto
var input dto.OidcIdTokenDto
if err := c.ShouldBind(&body); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
if err := c.ShouldBind(&input); err != nil {
utils.ControllerError(c, err)
return
}
clientID := body.ClientID
clientSecret := body.ClientSecret
clientID := input.ClientID
clientSecret := input.ClientSecret
// Client id and secret can also be passed over the Authorization header
if clientID == "" || clientSecret == "" {
var ok bool
clientID, clientSecret, ok = c.Request.BasicAuth()
if !ok {
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided")
utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided")
return
}
}
idToken, accessToken, err := oc.oidcService.CreateTokens(body.Code, body.GrantType, clientID, clientSecret)
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
if err != nil {
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
errors.Is(err, common.ErrOidcClientSecretInvalid) ||
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
@@ -116,14 +116,14 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
if err != nil {
utils.HandlerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
return
}
userID := jwtClaims.Subject
clientId := jwtClaims.Audience[0]
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -134,11 +134,28 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, client)
// Return a different DTO based on the user's role
if c.GetBool("userIsAdmin") {
clientDto := dto.OidcClientDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
} else {
clientDto := dto.PublicOidcClientDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
}
utils.ControllerError(c, err)
}
func (oc *OidcController) listClientsHandler(c *gin.Context) {
@@ -148,36 +165,48 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
var clientsDto []dto.OidcClientDto
if err := dto.MapStructList(clients, &clientsDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": clients,
"data": clientsDto,
"pagination": pagination,
})
}
func (oc *OidcController) createClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto
var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusCreated, client)
var clientDto dto.OidcClientDto
if err := dto.MapStruct(client, &clientDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusCreated, clientDto)
}
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.oidcService.DeleteClient(c.Param("id"))
if err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
utils.ControllerError(c, err)
return
}
@@ -185,25 +214,31 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) {
}
func (oc *OidcController) updateClientHandler(c *gin.Context) {
var input model.OidcClientCreateDto
var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusNoContent, client)
var clientDto dto.OidcClientDto
if err := dto.MapStruct(client, &clientDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, clientDto)
}
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -213,7 +248,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -224,16 +259,16 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
@@ -244,7 +279,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
)
func NewTestController(group *gin.RouterGroup, testService *service.TestService) {
@@ -18,19 +19,19 @@ type TestController struct {
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
if err := tc.TestService.ResetDatabase(); err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
if err := tc.TestService.ResetApplicationImages(); err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
if err := tc.TestService.SeedDatabase(); err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(200, gin.H{"message": "Database reset and seeded"})
c.Status(http.StatusNoContent)
}

View File

@@ -4,8 +4,8 @@ import (
"errors"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/service"
"github.com/stonith404/pocket-id/backend/internal/utils"
"golang.org/x/time/rate"
@@ -43,12 +43,18 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
var usersDto []dto.UserDto
if err := dto.MapStructList(users, &usersDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": users,
"data": usersDto,
"pagination": pagination,
})
}
@@ -56,25 +62,38 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
func (uc *UserController) getUserHandler(c *gin.Context) {
user, err := uc.UserService.GetUser(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, user)
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, userDto)
}
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
user, err := uc.UserService.GetUser(c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, user)
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, userDto)
}
func (uc *UserController) deleteUserHandler(c *gin.Context) {
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -82,22 +101,29 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
}
func (uc *UserController) createUserHandler(c *gin.Context) {
var user model.User
if err := c.ShouldBindJSON(&user); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
if err := uc.UserService.CreateUser(&user); err != nil {
user, err := uc.UserService.CreateUser(input)
if err != nil {
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
utils.HandlerError(c, http.StatusConflict, err.Error())
utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
c.JSON(http.StatusCreated, user)
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusCreated, userDto)
}
func (uc *UserController) updateUserHandler(c *gin.Context) {
@@ -109,15 +135,15 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
}
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
var input model.OneTimeAccessTokenCreateDto
var input dto.OneTimeAccessTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -128,9 +154,9 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
if err != nil {
if errors.Is(err, common.ErrTokenInvalidOrExpired) {
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
@@ -143,21 +169,27 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.UserService.SetupInitialAdmin()
if err != nil {
if errors.Is(err, common.ErrSetupAlreadyCompleted) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user)
c.JSON(http.StatusOK, userDto)
}
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
var updatedUser model.User
if err := c.ShouldBindJSON(&updatedUser); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
@@ -168,15 +200,21 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
userID = c.Param("id")
}
user, err := uc.UserService.UpdateUser(userID, updatedUser, updateOwnUser)
user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser)
if err != nil {
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
utils.HandlerError(c, http.StatusConflict, err.Error())
utils.CustomControllerError(c, http.StatusConflict, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
c.JSON(http.StatusOK, user)
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, userDto)
}

View File

@@ -3,9 +3,8 @@ package controller
import (
"errors"
"github.com/go-webauthn/webauthn/protocol"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/model"
"log"
"net/http"
"time"
@@ -40,8 +39,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
userID := c.GetString("userID")
options, err := wc.webAuthnService.BeginRegistration(userID)
if err != nil {
utils.UnknownHandlerError(c, err)
log.Println(err)
utils.ControllerError(c, err)
return
}
@@ -52,24 +50,30 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
return
}
userID := c.GetString("userID")
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credential)
var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credentialDto)
}
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
options, err := wc.webAuthnService.BeginLogin()
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -80,13 +84,13 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
sessionID, err := c.Cookie("session_id")
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
return
}
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
if err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
@@ -94,32 +98,44 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData)
if err != nil {
if errors.Is(err, common.ErrInvalidCredentials) {
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
} else {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
}
return
}
token, err := wc.jwtService.GenerateAccessToken(*user)
token, err := wc.jwtService.GenerateAccessToken(user)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
utils.ControllerError(c, err)
return
}
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
c.JSON(http.StatusOK, user)
c.JSON(http.StatusOK, userDto)
}
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
userID := c.GetString("userID")
credentials, err := wc.webAuthnService.ListCredentials(userID)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credentials)
var credentialDtos []dto.WebauthnCredentialDto
if err := dto.MapStructList(credentials, &credentialDtos); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credentialDtos)
}
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
@@ -128,7 +144,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
@@ -139,19 +155,25 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
userID := c.GetString("userID")
credentialID := c.Param("id")
var input model.WebauthnCredentialUpdateDto
var input dto.WebauthnCredentialUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
utils.ControllerError(c, err)
return
}
err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}
c.Status(http.StatusNoContent)
var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil {
utils.ControllerError(c, err)
return
}
c.JSON(http.StatusOK, credentialDto)
}
func (wc *WebauthnController) logoutHandler(c *gin.Context) {

View File

@@ -21,7 +21,7 @@ type WellKnownController struct {
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
jwk, err := wkc.jwtService.GetJWK()
if err != nil {
utils.UnknownHandlerError(c, err)
utils.ControllerError(c, err)
return
}

View File

@@ -0,0 +1,17 @@
package dto
type PublicAppConfigVariableDto struct {
Key string `json:"key"`
Type string `json:"type"`
Value string `json:"value"`
}
type AppConfigVariableDto struct {
PublicAppConfigVariableDto
IsPublic bool `json:"isPublic"`
}
type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required,min=1,max=30"`
SessionDuration string `json:"sessionDuration" binding:"required"`
}

View File

@@ -0,0 +1,79 @@
package dto
import (
"errors"
"reflect"
)
// MapStructList maps a list of source structs to a list of destination structs
func MapStructList[S any, D any](source []S, destination *[]D) error {
for _, item := range source {
var destItem D
if err := MapStruct(item, &destItem); err != nil {
return err
}
*destination = append(*destination, destItem)
}
return nil
}
// MapStruct maps a source struct to a destination struct
func MapStruct[S any, D any](source S, destination *D) error {
// Ensure destination is a non-nil pointer
destValue := reflect.ValueOf(destination)
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
return errors.New("destination must be a non-nil pointer to a struct")
}
// Ensure source is a struct
sourceValue := reflect.ValueOf(source)
if sourceValue.Kind() != reflect.Struct {
return errors.New("source must be a struct")
}
return mapStructInternal(sourceValue, destValue.Elem())
}
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
// Loop through the fields of the destination struct
for i := 0; i < destVal.NumField(); i++ {
destField := destVal.Field(i)
destFieldType := destVal.Type().Field(i)
if destFieldType.Anonymous {
// Recursively handle embedded structs
if err := mapStructInternal(sourceVal, destField); err != nil {
return err
}
continue
}
sourceField := sourceVal.FieldByName(destFieldType.Name)
// If the source field is valid and can be assigned to the destination field
if sourceField.IsValid() && destField.CanSet() {
// Handle direct assignment for simple types
if sourceField.Type() == destField.Type() {
destField.Set(sourceField)
} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
// Handle slices
if sourceField.Type().Elem() == destField.Type().Elem() {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
}
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {
// Recursively map nested structs
if err := mapStructInternal(sourceField, destField); err != nil {
return err
}
}
}
}
return nil
}

View File

@@ -0,0 +1,31 @@
package dto
type PublicOidcClientDto struct {
ID string `json:"id"`
Name string `json:"name"`
}
type OidcClientDto struct {
PublicOidcClientDto
HasLogo bool `json:"hasLogo"`
CallbackURLs []string `json:"callbackURLs"`
CreatedBy UserDto `json:"createdBy"`
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
}
type AuthorizeOidcClientDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
}
type OidcIdTokenDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
}

View File

@@ -0,0 +1,30 @@
package dto
import "time"
type UserDto struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email" `
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
IsAdmin bool `json:"isAdmin"`
}
type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=3,max=20"`
Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=3,max=30"`
LastName string `json:"lastName" binding:"required,min=3,max=30"`
IsAdmin bool `json:"isAdmin"`
}
type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}
type LoginUserDto struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}

View File

@@ -0,0 +1,39 @@
package dto
import (
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"log"
"net/url"
"regexp"
)
var validateUrlList validator.Func = func(fl validator.FieldLevel) bool {
urls := fl.Field().Interface().([]string)
for _, u := range urls {
_, err := url.ParseRequestURI(u)
if err != nil {
return false
}
}
return true
}
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
regex := "^[a-z0-9_]*$"
matched, _ := regexp.MatchString(regex, fl.Field().String())
return matched
}
func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("username", validateUsername); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
}

View File

@@ -0,0 +1,23 @@
package dto
import (
"github.com/go-webauthn/webauthn/protocol"
"time"
)
type WebauthnCredentialDto struct {
ID string `json:"id"`
Name string `json:"name"`
CredentialID string `json:"credentialID"`
AttestationType string `json:"attestationType"`
Transport []protocol.AuthenticatorTransport `json:"transport"`
BackupEligible bool `json:"backupEligible"`
BackupState bool `json:"backupState"`
CreatedAt time.Time `json:"createdAt"`
}
type WebauthnCredentialUpdateDto struct {
Name string `json:"name" binding:"required,min=1,max=30"`
}

View File

@@ -17,7 +17,7 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
utils.HandlerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
c.Abort()
return
}

View File

@@ -9,11 +9,12 @@ import (
)
type JwtAuthMiddleware struct {
jwtService *service.JwtService
jwtService *service.JwtService
ignoreUnauthenticated bool
}
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService}
func NewJwtAuthMiddleware(jwtService *service.JwtService, ignoreUnauthenticated bool) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService, ignoreUnauthenticated: ignoreUnauthenticated}
}
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
@@ -24,23 +25,29 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
if len(authorizationHeaderSplitted) == 2 {
token = authorizationHeaderSplitted[1]
} else if m.ignoreUnauthenticated {
c.Next()
return
} else {
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
c.Abort()
return
}
}
claims, err := m.jwtService.VerifyAccessToken(token)
if err != nil {
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
if err != nil && m.ignoreUnauthenticated {
c.Next()
return
} else if err != nil {
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
c.Abort()
return
}
// Check if the user is an admin
if adminOnly && !claims.IsAdmin {
utils.HandlerError(c, http.StatusForbidden, "You don't have permission to access this resource")
utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource")
c.Abort()
return
}

View File

@@ -33,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
limiter := getLimiter(ip, limit, burst)
if !limiter.Allow() {
utils.HandlerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
c.Abort()
return
}

View File

@@ -1,11 +1,11 @@
package model
type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null" json:"key"`
Type string `json:"type"`
IsPublic bool `json:"-"`
IsInternal bool `json:"-"`
Value string `json:"value"`
Key string `gorm:"primaryKey;not null"`
Type string
IsPublic bool
IsInternal bool
Value string
}
type AppConfig struct {
@@ -14,8 +14,3 @@ type AppConfig struct {
LogoImageType AppConfigVariable
SessionDuration AppConfigVariable
}
type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required"`
SessionDuration string `json:"sessionDuration" binding:"required"`
}

View File

@@ -8,8 +8,8 @@ import (
// Base contains common columns for all tables.
type Base struct {
ID string `gorm:"primaryKey;not null" json:"id"`
CreatedAt time.Time `json:"createdAt"`
ID string `gorm:"primaryKey;not null"`
CreatedAt time.Time
}
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {

View File

@@ -1,38 +1,22 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"gorm.io/gorm"
"time"
)
type UserAuthorizedOidcClient struct {
Scope string
UserID string `json:"userId" gorm:"primary_key;"`
UserID string `gorm:"primary_key;"`
User User
ClientID string `json:"clientId" gorm:"primary_key;"`
ClientID string `gorm:"primary_key;"`
Client OidcClient
}
type OidcClient struct {
Base
Name string `json:"name"`
Secret string `json:"-"`
CallbackURL string `json:"callbackURL"`
ImageType *string `json:"-"`
HasLogo bool `gorm:"-" json:"hasLogo"`
CreatedByID string
CreatedBy User
}
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
// Compute HasLogo field
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
return nil
}
type OidcAuthorizationCode struct {
Base
@@ -47,26 +31,38 @@ type OidcAuthorizationCode struct {
ClientID string
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required"`
CallbackURL string `json:"callbackURL" binding:"required"`
type OidcClient struct {
Base
Name string
Secret string
CallbackURLs CallbackURLs
ImageType *string
HasLogo bool `gorm:"-"`
CreatedByID string
CreatedBy User
}
type AuthorizeNewClientDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
// Compute HasLogo field
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
return nil
}
type OidcIdTokenDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
type CallbackURLs []string
func (s *CallbackURLs) Scan(value interface{}) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, s)
case string:
return json.Unmarshal([]byte(v), s)
default:
return errors.New("type assertion to []byte or string failed")
}
}
type AuthorizeRequest struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
func (atl CallbackURLs) Value() (driver.Value, error) {
return json.Marshal(atl)
}

View File

@@ -9,13 +9,13 @@ import (
type User struct {
Base
Username string `json:"username"`
Email string `json:"email" `
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
IsAdmin bool `json:"isAdmin"`
Username string
Email string
FirstName string
LastName string
IsAdmin bool
Credentials []WebauthnCredential `json:"-"`
Credentials []WebauthnCredential
}
func (u User) WebAuthnID() []byte { return []byte(u.ID) }
@@ -59,19 +59,9 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
type OneTimeAccessToken struct {
Base
Token string `json:"token"`
ExpiresAt time.Time `json:"expiresAt"`
Token string
ExpiresAt time.Time
UserID string `json:"userId"`
UserID string
User User
}
type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}
type LoginUserDto struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}

View File

@@ -19,11 +19,11 @@ type WebauthnSession struct {
type WebauthnCredential struct {
Base
Name string `json:"name"`
CredentialID string `json:"credentialID"`
PublicKey []byte `json:"-"`
AttestationType string `json:"attestationType"`
Transport AuthenticatorTransportList `json:"-"`
Name string
CredentialID string
PublicKey []byte
AttestationType string
Transport AuthenticatorTransportList
BackupEligible bool `json:"backupEligible"`
BackupState bool `json:"backupState"`
@@ -32,15 +32,15 @@ type WebauthnCredential struct {
}
type PublicKeyCredentialCreationOptions struct {
Response protocol.PublicKeyCredentialCreationOptions `json:"response"`
SessionID string `json:"session_id"`
Timeout time.Duration `json:"timeout"`
Response protocol.PublicKeyCredentialCreationOptions
SessionID string
Timeout time.Duration
}
type PublicKeyCredentialRequestOptions struct {
Response protocol.PublicKeyCredentialRequestOptions `json:"response"`
SessionID string `json:"session_id"`
Timeout time.Duration `json:"timeout"`
Response protocol.PublicKeyCredentialRequestOptions
SessionID string
Timeout time.Duration
}
type AuthenticatorTransportList []protocol.AuthenticatorTransport
@@ -58,7 +58,3 @@ func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
func (atl AuthenticatorTransportList) Value() (driver.Value, error) {
return json.Marshal(atl)
}
type WebauthnCredentialUpdateDto struct {
Name string `json:"name"`
}

View File

@@ -3,6 +3,7 @@ package service
import (
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
@@ -54,7 +55,7 @@ var defaultDbConfig = model.AppConfig{
},
}
func (s *AppConfigService) UpdateApplicationConfiguration(input model.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
func (s *AppConfigService) UpdateApplicationConfiguration(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
var savedConfigVariables []model.AppConfigVariable
tx := s.db.Begin()

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"golang.org/x/crypto/bcrypt"
@@ -26,7 +27,7 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService {
}
}
func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (string, error) {
func (s *OidcService) Authorize(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID)
@@ -37,7 +38,7 @@ func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (stri
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
}
func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID string) (string, error) {
func (s *OidcService) AuthorizeNewClient(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: req.ClientID,
@@ -101,18 +102,18 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin
return idToken, accessToken, nil
}
func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) {
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return nil, err
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
return model.OidcClient{}, err
}
return &client, nil
return client, nil
}
func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) {
var clients []model.OidcClient
query := s.db.Model(&model.OidcClient{})
query := s.db.Preload("CreatedBy").Model(&model.OidcClient{})
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("name LIKE ?", searchPattern)
@@ -126,34 +127,34 @@ func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]
return clients, pagination, nil
}
func (s *OidcService) CreateClient(input model.OidcClientCreateDto, userID string) (*model.OidcClient, error) {
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{
Name: input.Name,
CallbackURL: input.CallbackURL,
CreatedByID: userID,
Name: input.Name,
CallbackURLs: input.CallbackURLs,
CreatedByID: userID,
}
if err := s.db.Create(&client).Error; err != nil {
return nil, err
return model.OidcClient{}, err
}
return &client, nil
return client, nil
}
func (s *OidcService) UpdateClient(clientID string, input model.OidcClientCreateDto) (*model.OidcClient, error) {
func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return nil, err
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
return model.OidcClient{}, err
}
client.Name = input.Name
client.CallbackURL = input.CallbackURL
client.CallbackURLs = input.CallbackURLs
if err := s.db.Save(&client).Error; err != nil {
return nil, err
return model.OidcClient{}, err
}
return &client, nil
return client, nil
}
func (s *OidcService) DeleteClient(clientID string) error {

View File

@@ -61,20 +61,20 @@ func (s *TestService) SeedDatabase() error {
Base: model.Base{
ID: "3654a746-35d4-4321-ac61-0bdcff2b4055",
},
Name: "Nextcloud",
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
CallbackURL: "http://nextcloud/auth/callback",
ImageType: utils.StringPointer("png"),
CreatedByID: users[0].ID,
Name: "Nextcloud",
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
CallbackURLs: model.CallbackURLs{"http://nextcloud/auth/callback"},
ImageType: utils.StringPointer("png"),
CreatedByID: users[0].ID,
},
{
Base: model.Base{
ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
},
Name: "Immich",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURL: "http://immich/auth/callback",
CreatedByID: users[0].ID,
Name: "Immich",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"},
CreatedByID: users[0].ID,
},
}
for _, client := range oidcClients {

View File

@@ -3,6 +3,7 @@ package service
import (
"errors"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/utils"
"gorm.io/gorm"
@@ -46,17 +47,24 @@ func (s *UserService) DeleteUser(userID string) error {
return s.db.Delete(&user).Error
}
func (s *UserService) CreateUser(user *model.User) error {
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
user := &model.User{
FirstName: input.FirstName,
LastName: input.LastName,
Email: input.Email,
Username: input.Username,
IsAdmin: input.IsAdmin,
}
if err := s.db.Create(user).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return s.checkDuplicatedFields(*user)
return model.User{}, s.checkDuplicatedFields(*user)
}
return err
return model.User{}, err
}
return nil
return *user, nil
}
func (s *UserService) UpdateUser(userID string, updatedUser model.User, updateOwnUser bool) (model.User, error) {
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool) (model.User, error) {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
return model.User{}, err

View File

@@ -67,10 +67,10 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
}, nil
}
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (*model.WebauthnCredential, error) {
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
return nil, err
return model.WebauthnCredential{}, err
}
session := webauthn.SessionData{
@@ -81,12 +81,12 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
var user model.User
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
return nil, err
return model.WebauthnCredential{}, err
}
credential, err := s.webAuthn.FinishRegistration(&user, session, r)
if err != nil {
return nil, err
return model.WebauthnCredential{}, err
}
credentialToStore := model.WebauthnCredential{
@@ -100,10 +100,10 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
BackupState: credential.Flags.BackupState,
}
if err := s.db.Create(&credentialToStore).Error; err != nil {
return nil, err
return model.WebauthnCredential{}, err
}
return &credentialToStore, nil
return credentialToStore, nil
}
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
@@ -129,10 +129,10 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}, nil
}
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (*model.User, error) {
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (model.User, error) {
var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
return nil, err
return model.User{}, err
}
session := webauthn.SessionData{
@@ -149,14 +149,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssert
}, session, credentialAssertionData)
if err != nil {
return nil, common.ErrInvalidCredentials
return model.User{}, err
}
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
return nil, err
return model.User{}, err
}
return user, nil
return *user, nil
}
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
@@ -180,17 +180,17 @@ func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
return nil
}
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) error {
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
var credential model.WebauthnCredential
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
return err
return credential, err
}
credential.Name = name
if err := s.db.Save(&credential).Error; err != nil {
return err
return credential, err
}
return nil
return credential, nil
}

View File

@@ -0,0 +1,75 @@
package utils
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"gorm.io/gorm"
"log"
"net/http"
"strings"
)
import (
"fmt"
)
func ControllerError(c *gin.Context, err error) {
// Check for record not found errors
if errors.Is(err, gorm.ErrRecordNotFound) {
CustomControllerError(c, http.StatusNotFound, "Record not found")
return
}
// Check for validation errors
var validationErrors validator.ValidationErrors
if errors.As(err, &validationErrors) {
message := handleValidationError(validationErrors)
CustomControllerError(c, http.StatusBadRequest, message)
return
}
log.Println(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
}
func handleValidationError(validationErrors validator.ValidationErrors) string {
var errorMessages []string
for _, ve := range validationErrors {
fieldName := ve.Field()
var errorMessage string
switch ve.Tag() {
case "required":
errorMessage = fmt.Sprintf("%s is required", fieldName)
case "email":
errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName)
case "username":
errorMessage = fmt.Sprintf("%s must contain only lowercase letters, numbers, and underscores", fieldName)
case "url":
errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName)
case "min":
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
case "max":
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
case "urlList":
errorMessage = fmt.Sprintf("%s must be a list of valid URLs", fieldName)
default:
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
}
errorMessages = append(errorMessages, errorMessage)
}
// Join all the error messages into a single string
combinedErrors := strings.Join(errorMessages, ", ")
return combinedErrors
}
func CustomControllerError(c *gin.Context, statusCode int, message string) {
// Capitalize the first letter of the message
message = strings.ToUpper(message[:1]) + message[1:]
c.JSON(statusCode, gin.H{"error": message})
}

View File

@@ -1,27 +0,0 @@
package utils
import (
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"log"
"net/http"
"strings"
)
func UnknownHandlerError(c *gin.Context, err error) {
if errors.Is(err, gorm.ErrRecordNotFound) {
HandlerError(c, http.StatusNotFound, "Record not found")
return
} else {
log.Println(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
}
}
func HandlerError(c *gin.Context, statusCode int, message string) {
// Capitalize the first letter of the message
message = strings.ToUpper(message[:1]) + message[1:]
c.JSON(statusCode, gin.H{"error": message})
}