mirror of
https://github.com/nikdoof/pocket-id.git
synced 2025-12-14 15:22:18 +00:00
refactor: use dtos in controllers
This commit is contained in:
@@ -27,12 +27,6 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
|||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.Use(gin.Logger())
|
r.Use(gin.Logger())
|
||||||
|
|
||||||
// Add middleware
|
|
||||||
r.Use(
|
|
||||||
middleware.NewCorsMiddleware().Add(),
|
|
||||||
middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Initialize services
|
// Initialize services
|
||||||
webauthnService := service.NewWebAuthnService(db, appConfigService)
|
webauthnService := service.NewWebAuthnService(db, appConfigService)
|
||||||
jwtService := service.NewJwtService(appConfigService)
|
jwtService := service.NewJwtService(appConfigService)
|
||||||
@@ -40,8 +34,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
|||||||
oidcService := service.NewOidcService(db, jwtService)
|
oidcService := service.NewOidcService(db, jwtService)
|
||||||
testService := service.NewTestService(db, appConfigService)
|
testService := service.NewTestService(db, appConfigService)
|
||||||
|
|
||||||
|
// Add global middleware
|
||||||
|
r.Use(middleware.NewCorsMiddleware().Add())
|
||||||
|
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
|
||||||
|
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))
|
||||||
|
|
||||||
// Initialize middleware
|
// Initialize middleware
|
||||||
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService)
|
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
|
||||||
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
|
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
|
||||||
|
|
||||||
// Set up API routes
|
// Set up API routes
|
||||||
@@ -49,7 +48,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
|||||||
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
|
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
|
||||||
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
|
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
|
||||||
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
|
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
|
||||||
controller.NewApplicationConfigurationController(apiGroup, jwtAuthMiddleware, appConfigService)
|
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
|
||||||
|
|
||||||
// Add test controller in non-production environments
|
// Add test controller in non-production environments
|
||||||
if common.EnvConfig.AppEnv != "production" {
|
if common.EnvConfig.AppEnv != "production" {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ var (
|
|||||||
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
|
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
|
||||||
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
|
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
|
||||||
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
|
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
|
||||||
|
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
|
||||||
ErrFileTypeNotSupported = errors.New("file type not supported")
|
ErrFileTypeNotSupported = errors.New("file type not supported")
|
||||||
ErrInvalidCredentials = errors.New("no user found with provided credentials")
|
ErrInvalidCredentials = errors.New("no user found with provided credentials")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,19 +5,19 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
|
||||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewApplicationConfigurationController(
|
func NewAppConfigController(
|
||||||
group *gin.RouterGroup,
|
group *gin.RouterGroup,
|
||||||
jwtAuthMiddleware *middleware.JwtAuthMiddleware,
|
jwtAuthMiddleware *middleware.JwtAuthMiddleware,
|
||||||
appConfigService *service.AppConfigService) {
|
appConfigService *service.AppConfigService) {
|
||||||
|
|
||||||
acc := &ApplicationConfigurationController{
|
acc := &AppConfigController{
|
||||||
appConfigService: appConfigService,
|
appConfigService: appConfigService,
|
||||||
}
|
}
|
||||||
group.GET("/application-configuration", acc.listApplicationConfigurationHandler)
|
group.GET("/application-configuration", acc.listApplicationConfigurationHandler)
|
||||||
@@ -32,86 +32,104 @@ func NewApplicationConfigurationController(
|
|||||||
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
|
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApplicationConfigurationController struct {
|
type AppConfigController struct {
|
||||||
appConfigService *service.AppConfigService
|
appConfigService *service.AppConfigService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) listApplicationConfigurationHandler(c *gin.Context) {
|
func (acc *AppConfigController) listApplicationConfigurationHandler(c *gin.Context) {
|
||||||
configuration, err := acc.appConfigService.ListApplicationConfiguration(false)
|
configuration, err := acc.appConfigService.ListApplicationConfiguration(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, configuration)
|
var configVariablesDto []dto.PublicAppConfigVariableDto
|
||||||
|
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, configVariablesDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) listAllApplicationConfigurationHandler(c *gin.Context) {
|
func (acc *AppConfigController) listAllApplicationConfigurationHandler(c *gin.Context) {
|
||||||
configuration, err := acc.appConfigService.ListApplicationConfiguration(true)
|
configuration, err := acc.appConfigService.ListApplicationConfiguration(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, configuration)
|
var configVariablesDto []dto.AppConfigVariableDto
|
||||||
|
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, configVariablesDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) updateApplicationConfigurationHandler(c *gin.Context) {
|
func (acc *AppConfigController) updateApplicationConfigurationHandler(c *gin.Context) {
|
||||||
var input model.AppConfigUpdateDto
|
var input dto.AppConfigUpdateDto
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input)
|
savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, savedConfigVariables)
|
var configVariablesDto []dto.AppConfigVariableDto
|
||||||
|
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, configVariablesDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) getLogoHandler(c *gin.Context) {
|
func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
|
||||||
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
||||||
acc.getImage(c, "logo", imageType)
|
acc.getImage(c, "logo", imageType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) getFaviconHandler(c *gin.Context) {
|
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
|
||||||
acc.getImage(c, "favicon", "ico")
|
acc.getImage(c, "favicon", "ico")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) getBackgroundImageHandler(c *gin.Context) {
|
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
|
||||||
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
||||||
acc.getImage(c, "background", imageType)
|
acc.getImage(c, "background", imageType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) updateLogoHandler(c *gin.Context) {
|
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
||||||
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
||||||
acc.updateImage(c, "logo", imageType)
|
acc.updateImage(c, "logo", imageType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) updateFaviconHandler(c *gin.Context) {
|
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
|
||||||
file, err := c.FormFile("file")
|
file, err := c.FormFile("file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fileType := utils.GetFileExtension(file.Filename)
|
fileType := utils.GetFileExtension(file.Filename)
|
||||||
if fileType != "ico" {
|
if fileType != "ico" {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico")
|
utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc.updateImage(c, "favicon", "ico")
|
acc.updateImage(c, "favicon", "ico")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) updateBackgroundImageHandler(c *gin.Context) {
|
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
|
||||||
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
||||||
acc.updateImage(c, "background", imageType)
|
acc.updateImage(c, "background", imageType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name string, imageType string) {
|
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
|
||||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
|
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
|
||||||
mimeType := utils.GetImageMimeType(imageType)
|
mimeType := utils.GetImageMimeType(imageType)
|
||||||
|
|
||||||
@@ -119,19 +137,19 @@ func (acc *ApplicationConfigurationController) getImage(c *gin.Context, name str
|
|||||||
c.File(imagePath)
|
c.File(imagePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acc *ApplicationConfigurationController) updateImage(c *gin.Context, imageName string, oldImageType string) {
|
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
|
||||||
file, err := c.FormFile("file")
|
file, err := c.FormFile("file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
|
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
|
||||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -40,18 +40,18 @@ type OidcController struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||||
var parsedBody model.AuthorizeRequest
|
var input dto.AuthorizeOidcClientDto
|
||||||
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code, err := oc.oidcService.Authorize(parsedBody, c.GetString("userID"))
|
code, err := oc.oidcService.Authorize(input, c.GetString("userID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrOidcMissingAuthorization) {
|
if errors.Is(err, common.ErrOidcMissingAuthorization) {
|
||||||
utils.HandlerError(c, http.StatusForbidden, err.Error())
|
utils.CustomControllerError(c, http.StatusForbidden, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -60,15 +60,15 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
||||||
var parsedBody model.AuthorizeNewClientDto
|
var input dto.AuthorizeOidcClientDto
|
||||||
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code, err := oc.oidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
|
code, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,35 +76,35 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
|
func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
|
||||||
var body model.OidcIdTokenDto
|
var input dto.OidcIdTokenDto
|
||||||
|
|
||||||
if err := c.ShouldBind(&body); err != nil {
|
if err := c.ShouldBind(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
clientID := body.ClientID
|
clientID := input.ClientID
|
||||||
clientSecret := body.ClientSecret
|
clientSecret := input.ClientSecret
|
||||||
|
|
||||||
// Client id and secret can also be passed over the Authorization header
|
// Client id and secret can also be passed over the Authorization header
|
||||||
if clientID == "" || clientSecret == "" {
|
if clientID == "" || clientSecret == "" {
|
||||||
var ok bool
|
var ok bool
|
||||||
clientID, clientSecret, ok = c.Request.BasicAuth()
|
clientID, clientSecret, ok = c.Request.BasicAuth()
|
||||||
if !ok {
|
if !ok {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided")
|
utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, accessToken, err := oc.oidcService.CreateTokens(body.Code, body.GrantType, clientID, clientSecret)
|
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
|
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
|
||||||
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
|
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
|
||||||
errors.Is(err, common.ErrOidcClientSecretInvalid) ||
|
errors.Is(err, common.ErrOidcClientSecretInvalid) ||
|
||||||
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
|
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -116,14 +116,14 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
|||||||
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
|
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
|
||||||
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
|
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
|
utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userID := jwtClaims.Subject
|
userID := jwtClaims.Subject
|
||||||
clientId := jwtClaims.Audience[0]
|
clientId := jwtClaims.Audience[0]
|
||||||
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
|
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,11 +134,28 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
|||||||
clientId := c.Param("id")
|
clientId := c.Param("id")
|
||||||
client, err := oc.oidcService.GetClient(clientId)
|
client, err := oc.oidcService.GetClient(clientId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, client)
|
// Return a different DTO based on the user's role
|
||||||
|
if c.GetBool("userIsAdmin") {
|
||||||
|
clientDto := dto.OidcClientDto{}
|
||||||
|
err = dto.MapStruct(client, &clientDto)
|
||||||
|
if err == nil {
|
||||||
|
c.JSON(http.StatusOK, clientDto)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
clientDto := dto.PublicOidcClientDto{}
|
||||||
|
err = dto.MapStruct(client, &clientDto)
|
||||||
|
if err == nil {
|
||||||
|
c.JSON(http.StatusOK, clientDto)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||||
@@ -148,36 +165,48 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
|||||||
|
|
||||||
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
|
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var clientsDto []dto.OidcClientDto
|
||||||
|
if err := dto.MapStructList(clients, &clientsDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": clients,
|
"data": clientsDto,
|
||||||
"pagination": pagination,
|
"pagination": pagination,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) createClientHandler(c *gin.Context) {
|
func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||||
var input model.OidcClientCreateDto
|
var input dto.OidcClientCreateDto
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
|
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, client)
|
var clientDto dto.OidcClientDto
|
||||||
|
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, clientDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||||
err := oc.oidcService.DeleteClient(c.Param("id"))
|
err := oc.oidcService.DeleteClient(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,25 +214,31 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||||
var input model.OidcClientCreateDto
|
var input dto.OidcClientCreateDto
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
|
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusNoContent, client)
|
var clientDto dto.OidcClientDto
|
||||||
|
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, clientDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||||
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
|
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,7 +248,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
|||||||
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||||
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
|
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,16 +259,16 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
|||||||
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||||
file, err := c.FormFile("file")
|
file, err := c.FormFile("file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
|
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -244,7 +279,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
|||||||
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||||
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
|
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewTestController(group *gin.RouterGroup, testService *service.TestService) {
|
func NewTestController(group *gin.RouterGroup, testService *service.TestService) {
|
||||||
@@ -18,19 +19,19 @@ type TestController struct {
|
|||||||
|
|
||||||
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||||
if err := tc.TestService.ResetDatabase(); err != nil {
|
if err := tc.TestService.ResetDatabase(); err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tc.TestService.ResetApplicationImages(); err != nil {
|
if err := tc.TestService.ResetApplicationImages(); err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tc.TestService.SeedDatabase(); err != nil {
|
if err := tc.TestService.SeedDatabase(); err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{"message": "Database reset and seeded"})
|
c.Status(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
|
||||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
@@ -43,12 +43,18 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
|
|||||||
|
|
||||||
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
|
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var usersDto []dto.UserDto
|
||||||
|
if err := dto.MapStructList(users, &usersDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": users,
|
"data": usersDto,
|
||||||
"pagination": pagination,
|
"pagination": pagination,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -56,25 +62,38 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
|
|||||||
func (uc *UserController) getUserHandler(c *gin.Context) {
|
func (uc *UserController) getUserHandler(c *gin.Context) {
|
||||||
user, err := uc.UserService.GetUser(c.Param("id"))
|
user, err := uc.UserService.GetUser(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, user)
|
var userDto dto.UserDto
|
||||||
|
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, userDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
||||||
user, err := uc.UserService.GetUser(c.GetString("userID"))
|
user, err := uc.UserService.GetUser(c.GetString("userID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, user)
|
|
||||||
|
var userDto dto.UserDto
|
||||||
|
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, userDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
||||||
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
|
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,22 +101,29 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) createUserHandler(c *gin.Context) {
|
func (uc *UserController) createUserHandler(c *gin.Context) {
|
||||||
var user model.User
|
var input dto.UserCreateDto
|
||||||
if err := c.ShouldBindJSON(&user); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := uc.UserService.CreateUser(&user); err != nil {
|
user, err := uc.UserService.CreateUser(input)
|
||||||
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||||
utils.HandlerError(c, http.StatusConflict, err.Error())
|
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, user)
|
var userDto dto.UserDto
|
||||||
|
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, userDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) updateUserHandler(c *gin.Context) {
|
func (uc *UserController) updateUserHandler(c *gin.Context) {
|
||||||
@@ -109,15 +135,15 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
|
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
|
||||||
var input model.OneTimeAccessTokenCreateDto
|
var input dto.OneTimeAccessTokenCreateDto
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
|
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,9 +154,9 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
|||||||
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
|
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrTokenInvalidOrExpired) {
|
if errors.Is(err, common.ErrTokenInvalidOrExpired) {
|
||||||
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
|
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -143,21 +169,27 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
|||||||
user, token, err := uc.UserService.SetupInitialAdmin()
|
user, token, err := uc.UserService.SetupInitialAdmin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrSetupAlreadyCompleted) {
|
if errors.Is(err, common.ErrSetupAlreadyCompleted) {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var userDto dto.UserDto
|
||||||
|
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||||
c.JSON(http.StatusOK, user)
|
c.JSON(http.StatusOK, userDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||||
var updatedUser model.User
|
var input dto.UserCreateDto
|
||||||
if err := c.ShouldBindJSON(&updatedUser); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,15 +200,21 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
|||||||
userID = c.Param("id")
|
userID = c.Param("id")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := uc.UserService.UpdateUser(userID, updatedUser, updateOwnUser)
|
user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||||
utils.HandlerError(c, http.StatusConflict, err.Error())
|
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, user)
|
var userDto dto.UserDto
|
||||||
|
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, userDto)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/go-webauthn/webauthn/protocol"
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -40,8 +39,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
|
|||||||
userID := c.GetString("userID")
|
userID := c.GetString("userID")
|
||||||
options, err := wc.webAuthnService.BeginRegistration(userID)
|
options, err := wc.webAuthnService.BeginRegistration(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
log.Println(err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,24 +50,30 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
|
|||||||
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
||||||
sessionID, err := c.Cookie("session_id")
|
sessionID, err := c.Cookie("session_id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := c.GetString("userID")
|
userID := c.GetString("userID")
|
||||||
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
|
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, credential)
|
var credentialDto dto.WebauthnCredentialDto
|
||||||
|
if err := dto.MapStruct(credential, &credentialDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, credentialDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
||||||
options, err := wc.webAuthnService.BeginLogin()
|
options, err := wc.webAuthnService.BeginLogin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,13 +84,13 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
|||||||
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
||||||
sessionID, err := c.Cookie("session_id")
|
sessionID, err := c.Cookie("session_id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
|
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,32 +98,44 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
|||||||
user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData)
|
user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, common.ErrInvalidCredentials) {
|
if errors.Is(err, common.ErrInvalidCredentials) {
|
||||||
utils.HandlerError(c, http.StatusUnauthorized, err.Error())
|
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := wc.jwtService.GenerateAccessToken(*user)
|
token, err := wc.jwtService.GenerateAccessToken(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var userDto dto.UserDto
|
||||||
|
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||||
c.JSON(http.StatusOK, user)
|
c.JSON(http.StatusOK, userDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
|
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
|
||||||
userID := c.GetString("userID")
|
userID := c.GetString("userID")
|
||||||
credentials, err := wc.webAuthnService.ListCredentials(userID)
|
credentials, err := wc.webAuthnService.ListCredentials(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, credentials)
|
var credentialDtos []dto.WebauthnCredentialDto
|
||||||
|
if err := dto.MapStructList(credentials, &credentialDtos); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, credentialDtos)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
|
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
|
||||||
@@ -128,7 +144,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
|
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,19 +155,25 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
|
|||||||
userID := c.GetString("userID")
|
userID := c.GetString("userID")
|
||||||
credentialID := c.Param("id")
|
credentialID := c.Param("id")
|
||||||
|
|
||||||
var input model.WebauthnCredentialUpdateDto
|
var input dto.WebauthnCredentialUpdateDto
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
utils.HandlerError(c, http.StatusBadRequest, common.ErrInvalidBody.Error())
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
|
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
var credentialDto dto.WebauthnCredentialDto
|
||||||
|
if err := dto.MapStruct(credential, &credentialDto); err != nil {
|
||||||
|
utils.ControllerError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, credentialDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wc *WebauthnController) logoutHandler(c *gin.Context) {
|
func (wc *WebauthnController) logoutHandler(c *gin.Context) {
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type WellKnownController struct {
|
|||||||
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
|
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
|
||||||
jwk, err := wkc.jwtService.GetJWK()
|
jwk, err := wkc.jwtService.GetJWK()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.UnknownHandlerError(c, err)
|
utils.ControllerError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
17
backend/internal/dto/app_config_dto.go
Normal file
17
backend/internal/dto/app_config_dto.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type PublicAppConfigVariableDto struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppConfigVariableDto struct {
|
||||||
|
PublicAppConfigVariableDto
|
||||||
|
IsPublic bool `json:"isPublic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppConfigUpdateDto struct {
|
||||||
|
AppName string `json:"appName" binding:"required,min=1,max=30"`
|
||||||
|
SessionDuration string `json:"sessionDuration" binding:"required"`
|
||||||
|
}
|
||||||
79
backend/internal/dto/dto_mapper.go
Normal file
79
backend/internal/dto/dto_mapper.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MapStructList maps a list of source structs to a list of destination structs
|
||||||
|
func MapStructList[S any, D any](source []S, destination *[]D) error {
|
||||||
|
for _, item := range source {
|
||||||
|
var destItem D
|
||||||
|
if err := MapStruct(item, &destItem); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*destination = append(*destination, destItem)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapStruct maps a source struct to a destination struct
|
||||||
|
func MapStruct[S any, D any](source S, destination *D) error {
|
||||||
|
// Ensure destination is a non-nil pointer
|
||||||
|
destValue := reflect.ValueOf(destination)
|
||||||
|
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
|
||||||
|
return errors.New("destination must be a non-nil pointer to a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure source is a struct
|
||||||
|
sourceValue := reflect.ValueOf(source)
|
||||||
|
if sourceValue.Kind() != reflect.Struct {
|
||||||
|
return errors.New("source must be a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapStructInternal(sourceValue, destValue.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
|
||||||
|
// Loop through the fields of the destination struct
|
||||||
|
for i := 0; i < destVal.NumField(); i++ {
|
||||||
|
destField := destVal.Field(i)
|
||||||
|
destFieldType := destVal.Type().Field(i)
|
||||||
|
|
||||||
|
if destFieldType.Anonymous {
|
||||||
|
// Recursively handle embedded structs
|
||||||
|
if err := mapStructInternal(sourceVal, destField); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceField := sourceVal.FieldByName(destFieldType.Name)
|
||||||
|
|
||||||
|
// If the source field is valid and can be assigned to the destination field
|
||||||
|
if sourceField.IsValid() && destField.CanSet() {
|
||||||
|
// Handle direct assignment for simple types
|
||||||
|
if sourceField.Type() == destField.Type() {
|
||||||
|
destField.Set(sourceField)
|
||||||
|
} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
|
||||||
|
// Handle slices
|
||||||
|
if sourceField.Type().Elem() == destField.Type().Elem() {
|
||||||
|
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
|
||||||
|
|
||||||
|
for j := 0; j < sourceField.Len(); j++ {
|
||||||
|
newSlice.Index(j).Set(sourceField.Index(j))
|
||||||
|
}
|
||||||
|
|
||||||
|
destField.Set(newSlice)
|
||||||
|
}
|
||||||
|
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {
|
||||||
|
// Recursively map nested structs
|
||||||
|
if err := mapStructInternal(sourceField, destField); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
31
backend/internal/dto/oidc_dto.go
Normal file
31
backend/internal/dto/oidc_dto.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type PublicOidcClientDto struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcClientDto struct {
|
||||||
|
PublicOidcClientDto
|
||||||
|
HasLogo bool `json:"hasLogo"`
|
||||||
|
CallbackURLs []string `json:"callbackURLs"`
|
||||||
|
CreatedBy UserDto `json:"createdBy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcClientCreateDto struct {
|
||||||
|
Name string `json:"name" binding:"required,max=50"`
|
||||||
|
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthorizeOidcClientDto struct {
|
||||||
|
ClientID string `json:"clientID" binding:"required"`
|
||||||
|
Scope string `json:"scope" binding:"required"`
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcIdTokenDto struct {
|
||||||
|
GrantType string `form:"grant_type" binding:"required"`
|
||||||
|
Code string `form:"code" binding:"required"`
|
||||||
|
ClientID string `form:"client_id"`
|
||||||
|
ClientSecret string `form:"client_secret"`
|
||||||
|
}
|
||||||
30
backend/internal/dto/user_dto.go
Normal file
30
backend/internal/dto/user_dto.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type UserDto struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email" `
|
||||||
|
FirstName string `json:"firstName"`
|
||||||
|
LastName string `json:"lastName"`
|
||||||
|
IsAdmin bool `json:"isAdmin"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserCreateDto struct {
|
||||||
|
Username string `json:"username" binding:"required,username,min=3,max=20"`
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
FirstName string `json:"firstName" binding:"required,min=3,max=30"`
|
||||||
|
LastName string `json:"lastName" binding:"required,min=3,max=30"`
|
||||||
|
IsAdmin bool `json:"isAdmin"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OneTimeAccessTokenCreateDto struct {
|
||||||
|
UserID string `json:"userId" binding:"required"`
|
||||||
|
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginUserDto struct {
|
||||||
|
Username string `json:"username" binding:"required"`
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
|
}
|
||||||
39
backend/internal/dto/validations.go
Normal file
39
backend/internal/dto/validations.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin/binding"
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"log"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var validateUrlList validator.Func = func(fl validator.FieldLevel) bool {
|
||||||
|
urls := fl.Field().Interface().([]string)
|
||||||
|
for _, u := range urls {
|
||||||
|
_, err := url.ParseRequestURI(u)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
||||||
|
regex := "^[a-z0-9_]*$"
|
||||||
|
matched, _ := regexp.MatchString(regex, fl.Field().String())
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||||
|
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
|
||||||
|
log.Fatalf("Failed to register custom validation: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||||
|
if err := v.RegisterValidation("username", validateUsername); err != nil {
|
||||||
|
log.Fatalf("Failed to register custom validation: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
23
backend/internal/dto/webauthn_dto.go
Normal file
23
backend/internal/dto/webauthn_dto.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WebauthnCredentialDto struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
CredentialID string `json:"credentialID"`
|
||||||
|
AttestationType string `json:"attestationType"`
|
||||||
|
Transport []protocol.AuthenticatorTransport `json:"transport"`
|
||||||
|
|
||||||
|
BackupEligible bool `json:"backupEligible"`
|
||||||
|
BackupState bool `json:"backupState"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebauthnCredentialUpdateDto struct {
|
||||||
|
Name string `json:"name" binding:"required,min=1,max=30"`
|
||||||
|
}
|
||||||
@@ -17,7 +17,7 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
|
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
|
||||||
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
||||||
utils.HandlerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
|
utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type JwtAuthMiddleware struct {
|
type JwtAuthMiddleware struct {
|
||||||
jwtService *service.JwtService
|
jwtService *service.JwtService
|
||||||
|
ignoreUnauthenticated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware {
|
func NewJwtAuthMiddleware(jwtService *service.JwtService, ignoreUnauthenticated bool) *JwtAuthMiddleware {
|
||||||
return &JwtAuthMiddleware{jwtService: jwtService}
|
return &JwtAuthMiddleware{jwtService: jwtService, ignoreUnauthenticated: ignoreUnauthenticated}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
|
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
|
||||||
@@ -24,23 +25,29 @@ func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
|
|||||||
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
|
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
|
||||||
if len(authorizationHeaderSplitted) == 2 {
|
if len(authorizationHeaderSplitted) == 2 {
|
||||||
token = authorizationHeaderSplitted[1]
|
token = authorizationHeaderSplitted[1]
|
||||||
|
} else if m.ignoreUnauthenticated {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
} else {
|
} else {
|
||||||
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := m.jwtService.VerifyAccessToken(token)
|
claims, err := m.jwtService.VerifyAccessToken(token)
|
||||||
if err != nil {
|
if err != nil && m.ignoreUnauthenticated {
|
||||||
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
c.Next()
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the user is an admin
|
// Check if the user is an admin
|
||||||
if adminOnly && !claims.IsAdmin {
|
if adminOnly && !claims.IsAdmin {
|
||||||
utils.HandlerError(c, http.StatusForbidden, "You don't have permission to access this resource")
|
utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
|
|||||||
|
|
||||||
limiter := getLimiter(ip, limit, burst)
|
limiter := getLimiter(ip, limit, burst)
|
||||||
if !limiter.Allow() {
|
if !limiter.Allow() {
|
||||||
utils.HandlerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
|
utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
type AppConfigVariable struct {
|
type AppConfigVariable struct {
|
||||||
Key string `gorm:"primaryKey;not null" json:"key"`
|
Key string `gorm:"primaryKey;not null"`
|
||||||
Type string `json:"type"`
|
Type string
|
||||||
IsPublic bool `json:"-"`
|
IsPublic bool
|
||||||
IsInternal bool `json:"-"`
|
IsInternal bool
|
||||||
Value string `json:"value"`
|
Value string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
@@ -14,8 +14,3 @@ type AppConfig struct {
|
|||||||
LogoImageType AppConfigVariable
|
LogoImageType AppConfigVariable
|
||||||
SessionDuration AppConfigVariable
|
SessionDuration AppConfigVariable
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppConfigUpdateDto struct {
|
|
||||||
AppName string `json:"appName" binding:"required"`
|
|
||||||
SessionDuration string `json:"sessionDuration" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
|
|
||||||
// Base contains common columns for all tables.
|
// Base contains common columns for all tables.
|
||||||
type Base struct {
|
type Base struct {
|
||||||
ID string `gorm:"primaryKey;not null" json:"id"`
|
ID string `gorm:"primaryKey;not null"`
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {
|
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {
|
||||||
|
|||||||
@@ -1,38 +1,22 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserAuthorizedOidcClient struct {
|
type UserAuthorizedOidcClient struct {
|
||||||
Scope string
|
Scope string
|
||||||
UserID string `json:"userId" gorm:"primary_key;"`
|
UserID string `gorm:"primary_key;"`
|
||||||
User User
|
User User
|
||||||
|
|
||||||
ClientID string `json:"clientId" gorm:"primary_key;"`
|
ClientID string `gorm:"primary_key;"`
|
||||||
Client OidcClient
|
Client OidcClient
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcClient struct {
|
|
||||||
Base
|
|
||||||
|
|
||||||
Name string `json:"name"`
|
|
||||||
Secret string `json:"-"`
|
|
||||||
CallbackURL string `json:"callbackURL"`
|
|
||||||
ImageType *string `json:"-"`
|
|
||||||
HasLogo bool `gorm:"-" json:"hasLogo"`
|
|
||||||
|
|
||||||
CreatedByID string
|
|
||||||
CreatedBy User
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
|
||||||
// Compute HasLogo field
|
|
||||||
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type OidcAuthorizationCode struct {
|
type OidcAuthorizationCode struct {
|
||||||
Base
|
Base
|
||||||
|
|
||||||
@@ -47,26 +31,38 @@ type OidcAuthorizationCode struct {
|
|||||||
ClientID string
|
ClientID string
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcClientCreateDto struct {
|
type OidcClient struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Base
|
||||||
CallbackURL string `json:"callbackURL" binding:"required"`
|
|
||||||
|
Name string
|
||||||
|
Secret string
|
||||||
|
CallbackURLs CallbackURLs
|
||||||
|
ImageType *string
|
||||||
|
HasLogo bool `gorm:"-"`
|
||||||
|
|
||||||
|
CreatedByID string
|
||||||
|
CreatedBy User
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeNewClientDto struct {
|
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||||
ClientID string `json:"clientID" binding:"required"`
|
// Compute HasLogo field
|
||||||
Scope string `json:"scope" binding:"required"`
|
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
||||||
Nonce string `json:"nonce"`
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcIdTokenDto struct {
|
type CallbackURLs []string
|
||||||
GrantType string `form:"grant_type" binding:"required"`
|
|
||||||
Code string `form:"code" binding:"required"`
|
func (s *CallbackURLs) Scan(value interface{}) error {
|
||||||
ClientID string `form:"client_id"`
|
switch v := value.(type) {
|
||||||
ClientSecret string `form:"client_secret"`
|
case []byte:
|
||||||
|
return json.Unmarshal(v, s)
|
||||||
|
case string:
|
||||||
|
return json.Unmarshal([]byte(v), s)
|
||||||
|
default:
|
||||||
|
return errors.New("type assertion to []byte or string failed")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeRequest struct {
|
func (atl CallbackURLs) Value() (driver.Value, error) {
|
||||||
ClientID string `json:"clientID" binding:"required"`
|
return json.Marshal(atl)
|
||||||
Scope string `json:"scope" binding:"required"`
|
|
||||||
Nonce string `json:"nonce"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ import (
|
|||||||
type User struct {
|
type User struct {
|
||||||
Base
|
Base
|
||||||
|
|
||||||
Username string `json:"username"`
|
Username string
|
||||||
Email string `json:"email" `
|
Email string
|
||||||
FirstName string `json:"firstName"`
|
FirstName string
|
||||||
LastName string `json:"lastName"`
|
LastName string
|
||||||
IsAdmin bool `json:"isAdmin"`
|
IsAdmin bool
|
||||||
|
|
||||||
Credentials []WebauthnCredential `json:"-"`
|
Credentials []WebauthnCredential
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u User) WebAuthnID() []byte { return []byte(u.ID) }
|
func (u User) WebAuthnID() []byte { return []byte(u.ID) }
|
||||||
@@ -59,19 +59,9 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
|
|||||||
|
|
||||||
type OneTimeAccessToken struct {
|
type OneTimeAccessToken struct {
|
||||||
Base
|
Base
|
||||||
Token string `json:"token"`
|
Token string
|
||||||
ExpiresAt time.Time `json:"expiresAt"`
|
ExpiresAt time.Time
|
||||||
|
|
||||||
UserID string `json:"userId"`
|
UserID string
|
||||||
User User
|
User User
|
||||||
}
|
}
|
||||||
|
|
||||||
type OneTimeAccessTokenCreateDto struct {
|
|
||||||
UserID string `json:"userId" binding:"required"`
|
|
||||||
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type LoginUserDto struct {
|
|
||||||
Username string `json:"username" binding:"required"`
|
|
||||||
Password string `json:"password" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ type WebauthnSession struct {
|
|||||||
type WebauthnCredential struct {
|
type WebauthnCredential struct {
|
||||||
Base
|
Base
|
||||||
|
|
||||||
Name string `json:"name"`
|
Name string
|
||||||
CredentialID string `json:"credentialID"`
|
CredentialID string
|
||||||
PublicKey []byte `json:"-"`
|
PublicKey []byte
|
||||||
AttestationType string `json:"attestationType"`
|
AttestationType string
|
||||||
Transport AuthenticatorTransportList `json:"-"`
|
Transport AuthenticatorTransportList
|
||||||
|
|
||||||
BackupEligible bool `json:"backupEligible"`
|
BackupEligible bool `json:"backupEligible"`
|
||||||
BackupState bool `json:"backupState"`
|
BackupState bool `json:"backupState"`
|
||||||
@@ -32,15 +32,15 @@ type WebauthnCredential struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PublicKeyCredentialCreationOptions struct {
|
type PublicKeyCredentialCreationOptions struct {
|
||||||
Response protocol.PublicKeyCredentialCreationOptions `json:"response"`
|
Response protocol.PublicKeyCredentialCreationOptions
|
||||||
SessionID string `json:"session_id"`
|
SessionID string
|
||||||
Timeout time.Duration `json:"timeout"`
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type PublicKeyCredentialRequestOptions struct {
|
type PublicKeyCredentialRequestOptions struct {
|
||||||
Response protocol.PublicKeyCredentialRequestOptions `json:"response"`
|
Response protocol.PublicKeyCredentialRequestOptions
|
||||||
SessionID string `json:"session_id"`
|
SessionID string
|
||||||
Timeout time.Duration `json:"timeout"`
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthenticatorTransportList []protocol.AuthenticatorTransport
|
type AuthenticatorTransportList []protocol.AuthenticatorTransport
|
||||||
@@ -58,7 +58,3 @@ func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
|
|||||||
func (atl AuthenticatorTransportList) Value() (driver.Value, error) {
|
func (atl AuthenticatorTransportList) Value() (driver.Value, error) {
|
||||||
return json.Marshal(atl)
|
return json.Marshal(atl)
|
||||||
}
|
}
|
||||||
|
|
||||||
type WebauthnCredentialUpdateDto struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -54,7 +55,7 @@ var defaultDbConfig = model.AppConfig{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AppConfigService) UpdateApplicationConfiguration(input model.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
|
func (s *AppConfigService) UpdateApplicationConfiguration(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
|
||||||
var savedConfigVariables []model.AppConfigVariable
|
var savedConfigVariables []model.AppConfigVariable
|
||||||
|
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@@ -26,7 +27,7 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (string, error) {
|
func (s *OidcService) Authorize(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
|
||||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||||
s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID)
|
s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID)
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ func (s *OidcService) Authorize(req model.AuthorizeRequest, userID string) (stri
|
|||||||
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
|
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID string) (string, error) {
|
func (s *OidcService) AuthorizeNewClient(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
|
||||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
ClientID: req.ClientID,
|
ClientID: req.ClientID,
|
||||||
@@ -101,18 +102,18 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin
|
|||||||
return idToken, accessToken, nil
|
return idToken, accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) {
|
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
|
||||||
var client model.OidcClient
|
var client model.OidcClient
|
||||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
return nil, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
return &client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) {
|
func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) {
|
||||||
var clients []model.OidcClient
|
var clients []model.OidcClient
|
||||||
|
|
||||||
query := s.db.Model(&model.OidcClient{})
|
query := s.db.Preload("CreatedBy").Model(&model.OidcClient{})
|
||||||
if searchTerm != "" {
|
if searchTerm != "" {
|
||||||
searchPattern := "%" + searchTerm + "%"
|
searchPattern := "%" + searchTerm + "%"
|
||||||
query = query.Where("name LIKE ?", searchPattern)
|
query = query.Where("name LIKE ?", searchPattern)
|
||||||
@@ -126,34 +127,34 @@ func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]
|
|||||||
return clients, pagination, nil
|
return clients, pagination, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) CreateClient(input model.OidcClientCreateDto, userID string) (*model.OidcClient, error) {
|
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||||
client := model.OidcClient{
|
client := model.OidcClient{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
CallbackURL: input.CallbackURL,
|
CallbackURLs: input.CallbackURLs,
|
||||||
CreatedByID: userID,
|
CreatedByID: userID,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.db.Create(&client).Error; err != nil {
|
if err := s.db.Create(&client).Error; err != nil {
|
||||||
return nil, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) UpdateClient(clientID string, input model.OidcClientCreateDto) (*model.OidcClient, error) {
|
func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
|
||||||
var client model.OidcClient
|
var client model.OidcClient
|
||||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
return nil, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
client.Name = input.Name
|
client.Name = input.Name
|
||||||
client.CallbackURL = input.CallbackURL
|
client.CallbackURLs = input.CallbackURLs
|
||||||
|
|
||||||
if err := s.db.Save(&client).Error; err != nil {
|
if err := s.db.Save(&client).Error; err != nil {
|
||||||
return nil, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) DeleteClient(clientID string) error {
|
func (s *OidcService) DeleteClient(clientID string) error {
|
||||||
|
|||||||
@@ -61,20 +61,20 @@ func (s *TestService) SeedDatabase() error {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "3654a746-35d4-4321-ac61-0bdcff2b4055",
|
ID: "3654a746-35d4-4321-ac61-0bdcff2b4055",
|
||||||
},
|
},
|
||||||
Name: "Nextcloud",
|
Name: "Nextcloud",
|
||||||
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
|
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
|
||||||
CallbackURL: "http://nextcloud/auth/callback",
|
CallbackURLs: model.CallbackURLs{"http://nextcloud/auth/callback"},
|
||||||
ImageType: utils.StringPointer("png"),
|
ImageType: utils.StringPointer("png"),
|
||||||
CreatedByID: users[0].ID,
|
CreatedByID: users[0].ID,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
|
ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
|
||||||
},
|
},
|
||||||
Name: "Immich",
|
Name: "Immich",
|
||||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||||
CallbackURL: "http://immich/auth/callback",
|
CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"},
|
||||||
CreatedByID: users[0].ID,
|
CreatedByID: users[0].ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, client := range oidcClients {
|
for _, client := range oidcClients {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||||
|
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -46,17 +47,24 @@ func (s *UserService) DeleteUser(userID string) error {
|
|||||||
return s.db.Delete(&user).Error
|
return s.db.Delete(&user).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) CreateUser(user *model.User) error {
|
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
|
||||||
|
user := &model.User{
|
||||||
|
FirstName: input.FirstName,
|
||||||
|
LastName: input.LastName,
|
||||||
|
Email: input.Email,
|
||||||
|
Username: input.Username,
|
||||||
|
IsAdmin: input.IsAdmin,
|
||||||
|
}
|
||||||
if err := s.db.Create(user).Error; err != nil {
|
if err := s.db.Create(user).Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||||
return s.checkDuplicatedFields(*user)
|
return model.User{}, s.checkDuplicatedFields(*user)
|
||||||
}
|
}
|
||||||
return err
|
return model.User{}, err
|
||||||
}
|
}
|
||||||
return nil
|
return *user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) UpdateUser(userID string, updatedUser model.User, updateOwnUser bool) (model.User, error) {
|
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool) (model.User, error) {
|
||||||
var user model.User
|
var user model.User
|
||||||
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||||
return model.User{}, err
|
return model.User{}, err
|
||||||
|
|||||||
@@ -67,10 +67,10 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (*model.WebauthnCredential, error) {
|
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
|
||||||
var storedSession model.WebauthnSession
|
var storedSession model.WebauthnSession
|
||||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||||
return nil, err
|
return model.WebauthnCredential{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
session := webauthn.SessionData{
|
session := webauthn.SessionData{
|
||||||
@@ -81,12 +81,12 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
|
|||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||||
return nil, err
|
return model.WebauthnCredential{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
credential, err := s.webAuthn.FinishRegistration(&user, session, r)
|
credential, err := s.webAuthn.FinishRegistration(&user, session, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return model.WebauthnCredential{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
credentialToStore := model.WebauthnCredential{
|
credentialToStore := model.WebauthnCredential{
|
||||||
@@ -100,10 +100,10 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
|
|||||||
BackupState: credential.Flags.BackupState,
|
BackupState: credential.Flags.BackupState,
|
||||||
}
|
}
|
||||||
if err := s.db.Create(&credentialToStore).Error; err != nil {
|
if err := s.db.Create(&credentialToStore).Error; err != nil {
|
||||||
return nil, err
|
return model.WebauthnCredential{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &credentialToStore, nil
|
return credentialToStore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
|
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
|
||||||
@@ -129,10 +129,10 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (*model.User, error) {
|
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (model.User, error) {
|
||||||
var storedSession model.WebauthnSession
|
var storedSession model.WebauthnSession
|
||||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||||
return nil, err
|
return model.User{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
session := webauthn.SessionData{
|
session := webauthn.SessionData{
|
||||||
@@ -149,14 +149,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssert
|
|||||||
}, session, credentialAssertionData)
|
}, session, credentialAssertionData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, common.ErrInvalidCredentials
|
return model.User{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||||
return nil, err
|
return model.User{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return *user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
|
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
|
||||||
@@ -180,17 +180,17 @@ func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) error {
|
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
|
||||||
var credential model.WebauthnCredential
|
var credential model.WebauthnCredential
|
||||||
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
|
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
|
||||||
return err
|
return credential, err
|
||||||
}
|
}
|
||||||
|
|
||||||
credential.Name = name
|
credential.Name = name
|
||||||
|
|
||||||
if err := s.db.Save(&credential).Error; err != nil {
|
if err := s.db.Save(&credential).Error; err != nil {
|
||||||
return err
|
return credential, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return credential, nil
|
||||||
}
|
}
|
||||||
|
|||||||
75
backend/internal/utils/controller_error_util.go
Normal file
75
backend/internal/utils/controller_error_util.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ControllerError(c *gin.Context, err error) {
|
||||||
|
// Check for record not found errors
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
CustomControllerError(c, http.StatusNotFound, "Record not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for validation errors
|
||||||
|
var validationErrors validator.ValidationErrors
|
||||||
|
if errors.As(err, &validationErrors) {
|
||||||
|
message := handleValidationError(validationErrors)
|
||||||
|
CustomControllerError(c, http.StatusBadRequest, message)
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleValidationError(validationErrors validator.ValidationErrors) string {
|
||||||
|
var errorMessages []string
|
||||||
|
|
||||||
|
for _, ve := range validationErrors {
|
||||||
|
fieldName := ve.Field()
|
||||||
|
var errorMessage string
|
||||||
|
switch ve.Tag() {
|
||||||
|
case "required":
|
||||||
|
errorMessage = fmt.Sprintf("%s is required", fieldName)
|
||||||
|
case "email":
|
||||||
|
errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName)
|
||||||
|
case "username":
|
||||||
|
errorMessage = fmt.Sprintf("%s must contain only lowercase letters, numbers, and underscores", fieldName)
|
||||||
|
case "url":
|
||||||
|
errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName)
|
||||||
|
case "min":
|
||||||
|
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
|
||||||
|
case "max":
|
||||||
|
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
|
||||||
|
case "urlList":
|
||||||
|
errorMessage = fmt.Sprintf("%s must be a list of valid URLs", fieldName)
|
||||||
|
default:
|
||||||
|
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
errorMessages = append(errorMessages, errorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join all the error messages into a single string
|
||||||
|
combinedErrors := strings.Join(errorMessages, ", ")
|
||||||
|
|
||||||
|
return combinedErrors
|
||||||
|
}
|
||||||
|
|
||||||
|
func CustomControllerError(c *gin.Context, statusCode int, message string) {
|
||||||
|
// Capitalize the first letter of the message
|
||||||
|
message = strings.ToUpper(message[:1]) + message[1:]
|
||||||
|
c.JSON(statusCode, gin.H{"error": message})
|
||||||
|
}
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func UnknownHandlerError(c *gin.Context, err error) {
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
HandlerError(c, http.StatusNotFound, "Record not found")
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
log.Println(err)
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func HandlerError(c *gin.Context, statusCode int, message string) {
|
|
||||||
// Capitalize the first letter of the message
|
|
||||||
message = strings.ToUpper(message[:1]) + message[1:]
|
|
||||||
c.JSON(statusCode, gin.H{"error": message})
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user