feat: custom claims (#53)

This commit is contained in:
Elias Schneider
2024-10-28 18:11:54 +01:00
committed by GitHub
parent 3350398abc
commit c056089c60
43 changed files with 1071 additions and 281 deletions

View File

@@ -18,18 +18,20 @@ import (
)
type OidcService struct {
db *gorm.DB
jwtService *JwtService
appConfigService *AppConfigService
auditLogService *AuditLogService
db *gorm.DB
jwtService *JwtService
appConfigService *AppConfigService
auditLogService *AuditLogService
customClaimService *CustomClaimService
}
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService) *OidcService {
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
return &OidcService{
db: db,
jwtService: jwtService,
appConfigService: appConfigService,
auditLogService: auditLogService,
db: db,
jwtService: jwtService,
appConfigService: appConfigService,
auditLogService: auditLogService,
customClaimService: customClaimService,
}
}
@@ -38,7 +40,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
if userAuthorizedOIDCClient.Scope != input.Scope {
return "", "", common.ErrOidcMissingAuthorization
return "", "", &common.OidcMissingAuthorizationError{}
}
callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
@@ -93,11 +95,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", common.ErrOidcGrantTypeNotSupported
return "", "", &common.OidcGrantTypeNotSupportedError{}
}
if clientID == "" || clientSecret == "" {
return "", "", common.ErrOidcMissingClientCredentials
return "", "", &common.OidcMissingClientCredentialsError{}
}
var client model.OidcClient
@@ -107,17 +109,17 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", common.ErrOidcClientSecretInvalid
return "", "", &common.OidcClientSecretInvalidError{}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", common.ErrOidcInvalidAuthorizationCode
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", common.ErrOidcInvalidAuthorizationCode
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
@@ -249,7 +251,7 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
fileType := utils.GetFileExtension(file.Filename)
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
return common.ErrFileTypeNotSupported
return &common.FileTypeNotSupportedError{}
}
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
@@ -334,9 +336,20 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
}
if strings.Contains(scope, "profile") {
// Add profile claims
for k, v := range profileClaims {
claims[k] = v
}
// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID)
if err != nil {
return nil, err
}
for _, customClaim := range customClaims {
claims[customClaim.Key] = customClaim.Value
}
}
if strings.Contains(scope, "email") {
claims["email"] = user.Email
@@ -375,5 +388,5 @@ func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackU
return inputCallbackURL, nil
}
return "", common.ErrOidcInvalidCallbackURL
return "", &common.OidcInvalidCallbackURLError{}
}