feat: add support for multiple callback urls

This commit is contained in:
Elias Schneider
2024-08-24 00:49:08 +02:00
parent ae7aeb0945
commit 8166e2ead7
20 changed files with 287 additions and 101 deletions

View File

@@ -6,7 +6,6 @@ var (
ErrUsernameTaken = errors.New("username is already taken")
ErrEmailTaken = errors.New("email is already taken")
ErrSetupAlreadyCompleted = errors.New("setup already completed")
ErrInvalidBody = errors.New("invalid request body")
ErrTokenInvalidOrExpired = errors.New("token is invalid or expired")
ErrOidcMissingAuthorization = errors.New("missing authorization")
ErrOidcGrantTypeNotSupported = errors.New("grant type not supported")

View File

@@ -40,39 +40,55 @@ type OidcController struct {
}
func (oc *OidcController) authorizeHandler(c *gin.Context) {
var input dto.AuthorizeOidcClientDto
var input dto.AuthorizeOidcClientRequestDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
code, err := oc.oidcService.Authorize(input, c.GetString("userID"))
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"))
if err != nil {
if errors.Is(err, common.ErrOidcMissingAuthorization) {
utils.CustomControllerError(c, http.StatusForbidden, err.Error())
} else if errors.Is(err, common.ErrOidcInvalidCallbackURL) {
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else {
utils.ControllerError(c, err)
}
return
}
c.JSON(http.StatusOK, gin.H{"code": code})
response := dto.AuthorizeOidcClientResponseDto{
Code: code,
CallbackURL: callbackURL,
}
c.JSON(http.StatusOK, response)
}
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
var input dto.AuthorizeOidcClientDto
var input dto.AuthorizeOidcClientRequestDto
if err := c.ShouldBindJSON(&input); err != nil {
utils.ControllerError(c, err)
return
}
code, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"))
code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"))
if err != nil {
utils.ControllerError(c, err)
if errors.Is(err, common.ErrOidcInvalidCallbackURL) {
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
} else {
utils.ControllerError(c, err)
}
return
}
c.JSON(http.StatusOK, gin.H{"code": code})
response := dto.AuthorizeOidcClientResponseDto{
Code: code,
CallbackURL: callbackURL,
}
c.JSON(http.StatusOK, response)
}
func (oc *OidcController) createIDTokenHandler(c *gin.Context) {

View File

@@ -17,10 +17,16 @@ type OidcClientCreateDto struct {
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
}
type AuthorizeOidcClientDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
Nonce string `json:"nonce"`
type AuthorizeOidcClientRequestDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
}
type AuthorizeOidcClientResponseDto struct {
Code string `json:"code"`
CallbackURL string `json:"callbackURL"`
}
type OidcIdTokenDto struct {

View File

@@ -52,17 +52,14 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
type CallbackURLs []string
func (s *CallbackURLs) Scan(value interface{}) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, s)
case string:
return json.Unmarshal([]byte(v), s)
default:
return errors.New("type assertion to []byte or string failed")
func (cu *CallbackURLs) Scan(value interface{}) error {
if v, ok := value.([]byte); ok {
return json.Unmarshal(v, cu)
} else {
return errors.New("type assertion to []byte failed")
}
}
func (atl CallbackURLs) Value() (driver.Value, error) {
return json.Marshal(atl)
func (cu CallbackURLs) Value() (driver.Value, error) {
return json.Marshal(cu)
}

View File

@@ -11,6 +11,7 @@ import (
"gorm.io/gorm"
"mime/multipart"
"os"
"slices"
"strings"
"time"
)
@@ -27,33 +28,50 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService {
}
}
func (s *OidcService) Authorize(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID string) (string, string, error) {
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID)
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
if userAuthorizedOIDCClient.Scope != req.Scope {
return "", common.ErrOidcMissingAuthorization
if userAuthorizedOIDCClient.Scope != input.Scope {
return "", "", common.ErrOidcMissingAuthorization
}
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
if err != nil {
return "", "", err
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
return code, callbackURL, err
}
func (s *OidcService) AuthorizeNewClient(req dto.AuthorizeOidcClientDto, userID string) (string, error) {
func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto, userID string) (string, string, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", input.ClientID).Error; err != nil {
return "", "", err
}
callbackURL, err := getCallbackURL(client, input.CallbackURL)
if err != nil {
return "", "", err
}
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: req.ClientID,
Scope: req.Scope,
ClientID: input.ClientID,
Scope: input.Scope,
}
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
err = s.db.Model(&userAuthorizedClient).Update("scope", req.Scope).Error
err = s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error
} else {
return "", err
return "", "", err
}
}
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
return code, callbackURL, err
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
@@ -321,3 +339,14 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
return randomString, nil
}
func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
if inputCallbackURL == "" {
return client.CallbackURLs[0], nil
}
if slices.Contains(client.CallbackURLs, inputCallbackURL) {
return inputCallbackURL, nil
}
return "", common.ErrOidcInvalidCallbackURL
}

View File

@@ -58,7 +58,7 @@ func handleValidationError(validationErrors validator.ValidationErrors) string {
default:
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
}
errorMessages = append(errorMessages, errorMessage)
}