feat: map allowed groups to OIDC clients (#202)

This commit is contained in:
Elias Schneider
2025-02-03 18:41:15 +01:00
committed by GitHub
parent 430421e98b
commit 13b02a072f
30 changed files with 518 additions and 218 deletions

View File

@@ -176,3 +176,11 @@ func (e *LdapUserGroupUpdateError) Error() string {
return "LDAP user groups can't be updated"
}
func (e *LdapUserGroupUpdateError) HttpStatusCode() int { return http.StatusForbidden }
type OidcAccessDeniedError struct{}
func (e *OidcAccessDeniedError) Error() string {
return "You're not allowed to access this service"
}
func (e *OidcAccessDeniedError) HttpStatusCode() int { return http.StatusForbidden }

View File

@@ -14,7 +14,8 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
oc := &OidcController{oidcService: oidcService, jwtService: jwtService}
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler)
group.POST("/oidc/authorization-required", jwtAuthMiddleware.Add(false), oc.authorizationConfirmationRequiredHandler)
group.POST("/oidc/token", oc.createTokensHandler)
group.GET("/oidc/userinfo", oc.userInfoHandler)
@@ -24,6 +25,7 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
group.PUT("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.updateClientHandler)
group.DELETE("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.deleteClientHandler)
group.PUT("/oidc/clients/:id/allowed-user-groups", jwtAuthMiddleware.Add(true), oc.updateAllowedUserGroupsHandler)
group.POST("/oidc/clients/:id/secret", jwtAuthMiddleware.Add(true), oc.createClientSecretHandler)
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
@@ -57,25 +59,20 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
c.JSON(http.StatusOK, response)
}
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
var input dto.AuthorizeOidcClientRequestDto
func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Context) {
var input dto.AuthorizationRequiredDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}
code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope)
if err != nil {
c.Error(err)
return
}
response := dto.AuthorizeOidcClientResponseDto{
Code: code,
CallbackURL: callbackURL,
}
c.JSON(http.StatusOK, response)
c.JSON(http.StatusOK, gin.H{"authorizationRequired": !hasAuthorizedClient})
}
func (oc *OidcController) createTokensHandler(c *gin.Context) {
@@ -134,7 +131,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
// Return a different DTO based on the user's role
if c.GetBool("userIsAdmin") {
clientDto := dto.OidcClientDto{}
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
@@ -191,7 +188,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
return
}
var clientDto dto.OidcClientDto
var clientDto dto.OidcClientWithAllowedUserGroupsDto
if err := dto.MapStruct(client, &clientDto); err != nil {
c.Error(err)
return
@@ -223,7 +220,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
return
}
var clientDto dto.OidcClientDto
var clientDto dto.OidcClientWithAllowedUserGroupsDto
if err := dto.MapStruct(client, &clientDto); err != nil {
c.Error(err)
return
@@ -278,3 +275,25 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
c.Status(http.StatusNoContent)
}
func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
var input dto.OidcUpdateAllowedUserGroupsDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input)
if err != nil {
c.Error(err)
return
}
var oidcClientDto dto.OidcClientDto
if err := dto.MapStruct(oidcClient, &oidcClientDto); err != nil {
c.Error(err)
return
}
c.JSON(http.StatusOK, oidcClientDto)
}

View File

@@ -11,7 +11,14 @@ type OidcClientDto struct {
CallbackURLs []string `json:"callbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
CreatedBy UserDto `json:"createdBy"`
}
type OidcClientWithAllowedUserGroupsDto struct {
PublicOidcClientDto
CallbackURLs []string `json:"callbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"`
}
type OidcClientCreateDto struct {
@@ -35,6 +42,11 @@ type AuthorizeOidcClientResponseDto struct {
CallbackURL string `json:"callbackURL"`
}
type AuthorizationRequiredDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
}
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
@@ -42,3 +54,7 @@ type OidcCreateTokensDto struct {
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
}
type OidcUpdateAllowedUserGroupsDto struct {
UserGroupIDs []string `json:"userGroupIds" binding:"required"`
}

View File

@@ -33,7 +33,3 @@ type UserGroupCreateDto struct {
type UserGroupUpdateUsersDto struct {
UserIDs []string `json:"userIds" binding:"required"`
}
type AssignUserToGroupDto struct {
UserID string `json:"userId" binding:"required"`
}

View File

@@ -44,8 +44,9 @@ type OidcClient struct {
IsPublic bool
PkceEnabled bool
CreatedByID string
CreatedBy User
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
CreatedByID string
CreatedBy User
}
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {

View File

@@ -38,71 +38,111 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppCo
}
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
if userAuthorizedOIDCClient.Client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
if userAuthorizedOIDCClient.Scope != input.Scope {
return "", "", &common.OidcMissingAuthorizationError{}
}
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
if err != nil {
return "", "", err
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": userAuthorizedOIDCClient.Client.Name})
return code, callbackURL, nil
}
func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", input.ClientID).Error; err != nil {
if err := s.db.Preload("AllowedUserGroups").First(&client, "id = ?", input.ClientID).Error; err != nil {
return "", "", err
}
// If the client is not public, the code challenge must be provided
if client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
// Get the callback URL of the client. Return an error if the provided callback URL is not allowed
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
if err != nil {
return "", "", err
}
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: input.ClientID,
Scope: input.Scope,
// Check if the user group is allowed to authorize the client
var user model.User
if err := s.db.Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
return "", "", err
}
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
err = s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error
} else {
return "", "", err
if !s.IsUserGroupAllowedToAuthorize(user, client) {
return "", "", &common.OidcAccessDeniedError{}
}
// Check if the user has already authorized the client with the given scope
hasAuthorizedClient, err := s.HasAuthorizedClient(input.ClientID, userID, input.Scope)
if err != nil {
return "", "", err
}
// If the user has not authorized the client, create a new authorization in the database
if !hasAuthorizedClient {
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: input.ClientID,
Scope: input.Scope,
}
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
// The client has already been authorized but with a different scope so we need to update the scope
if err := s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
return "", "", err
}
} else {
return "", "", err
}
}
}
// Create the authorization code
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
// Log the authorization event
if hasAuthorizedClient {
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
} else {
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
}
return code, callbackURL, nil
}
// HasAuthorizedClient checks if the user has already authorized the client with the given scope
func (s *OidcService) HasAuthorizedClient(clientID, userID, scope string) (bool, error) {
var userAuthorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}
if userAuthorizedOidcClient.Scope != scope {
return false, nil
}
return true, nil
}
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client model.OidcClient) bool {
if len(client.AllowedUserGroups) == 0 {
return true
}
isAllowedToAuthorize := false
for _, userGroup := range client.AllowedUserGroups {
for _, userGroupUser := range user.UserGroups {
if userGroup.ID == userGroupUser.ID {
isAllowedToAuthorize = true
break
}
}
}
return isAllowedToAuthorize
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
@@ -161,7 +201,7 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, code
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
var client model.OidcClient
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
if err := s.db.Preload("CreatedBy").Preload("AllowedUserGroups").First(&client, "id = ?", clientID).Error; err != nil {
return model.OidcClient{}, err
}
return client, nil
@@ -382,6 +422,33 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
return claims, nil
}
func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
client, err = s.GetClient(id)
if err != nil {
return model.OidcClient{}, err
}
// Fetch the user groups based on UserGroupIDs in input
var groups []model.UserGroup
if len(input.UserGroupIDs) > 0 {
if err := s.db.Where("id IN (?)", input.UserGroupIDs).Find(&groups).Error; err != nil {
return model.OidcClient{}, err
}
}
// Replace the current user groups with the new set of user groups
if err := s.db.Model(&client).Association("AllowedUserGroups").Replace(groups); err != nil {
return model.OidcClient{}, err
}
// Save the updated client
if err := s.db.Save(&client).Error; err != nil {
return model.OidcClient{}, err
}
return client, nil
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {

View File

@@ -124,7 +124,10 @@ func (s *TestService) SeedDatabase() error {
Name: "Immich",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"},
CreatedByID: users[0].ID,
CreatedByID: users[1].ID,
AllowedUserGroups: []model.UserGroup{
userGroups[1],
},
},
}
for _, client := range oidcClients {
@@ -163,27 +166,31 @@ func (s *TestService) SeedDatabase() error {
return err
}
publicKey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==")
// To generate a new key pair, run the following command:
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
if err != nil {
return err
}
webauthnCredentials := []model.WebauthnCredential{
{
Name: "Passkey 1",
CredentialID: []byte("test-credential-1"),
PublicKey: publicKey1,
CredentialID: []byte("test-credential-tim"),
PublicKey: publicKeyPasskey1,
AttestationType: "none",
Transport: model.AuthenticatorTransportList{protocol.Internal},
UserID: users[0].ID,
},
{
Name: "Passkey 2",
CredentialID: []byte("test-credential-2"),
PublicKey: publicKey2,
CredentialID: []byte("test-credential-craig"),
PublicKey: publicKeyPasskey2,
AttestationType: "none",
Transport: model.AuthenticatorTransportList{protocol.Internal},
UserID: users[0].ID,
UserID: users[1].ID,
},
}
for _, credential := range webauthnCredentials {

View File

@@ -0,0 +1 @@
DROP TABLE oidc_clients_allowed_user_groups;

View File

@@ -0,0 +1,8 @@
CREATE TABLE oidc_clients_allowed_user_groups
(
user_group_id UUID NOT NULL REFERENCES user_groups ON DELETE CASCADE,
oidc_client_id UUID NOT NULL REFERENCES oidc_clients ON DELETE CASCADE,
PRIMARY KEY (oidc_client_id, user_group_id)
);

View File

@@ -0,0 +1 @@
DROP TABLE oidc_clients_allowed_user_groups;

View File

@@ -0,0 +1,8 @@
CREATE TABLE oidc_clients_allowed_user_groups
(
user_group_id TEXT NOT NULL,
oidc_client_id TEXT NOT NULL,
PRIMARY KEY (oidc_client_id, user_group_id),
FOREIGN KEY (oidc_client_id) REFERENCES oidc_clients (id) ON DELETE CASCADE,
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE
);