mirror of
https://github.com/nikdoof/pocket-id.git
synced 2025-12-14 07:12:19 +00:00
feat: map allowed groups to OIDC clients (#202)
This commit is contained in:
@@ -176,3 +176,11 @@ func (e *LdapUserGroupUpdateError) Error() string {
|
||||
return "LDAP user groups can't be updated"
|
||||
}
|
||||
func (e *LdapUserGroupUpdateError) HttpStatusCode() int { return http.StatusForbidden }
|
||||
|
||||
type OidcAccessDeniedError struct{}
|
||||
|
||||
func (e *OidcAccessDeniedError) Error() string {
|
||||
return "You're not allowed to access this service"
|
||||
}
|
||||
|
||||
func (e *OidcAccessDeniedError) HttpStatusCode() int { return http.StatusForbidden }
|
||||
|
||||
@@ -14,7 +14,8 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
|
||||
oc := &OidcController{oidcService: oidcService, jwtService: jwtService}
|
||||
|
||||
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
|
||||
group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler)
|
||||
group.POST("/oidc/authorization-required", jwtAuthMiddleware.Add(false), oc.authorizationConfirmationRequiredHandler)
|
||||
|
||||
group.POST("/oidc/token", oc.createTokensHandler)
|
||||
group.GET("/oidc/userinfo", oc.userInfoHandler)
|
||||
|
||||
@@ -24,6 +25,7 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
|
||||
group.PUT("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.updateClientHandler)
|
||||
group.DELETE("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.deleteClientHandler)
|
||||
|
||||
group.PUT("/oidc/clients/:id/allowed-user-groups", jwtAuthMiddleware.Add(true), oc.updateAllowedUserGroupsHandler)
|
||||
group.POST("/oidc/clients/:id/secret", jwtAuthMiddleware.Add(true), oc.createClientSecretHandler)
|
||||
|
||||
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
|
||||
@@ -57,25 +59,20 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
||||
var input dto.AuthorizeOidcClientRequestDto
|
||||
func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Context) {
|
||||
var input dto.AuthorizationRequiredDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
|
||||
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope)
|
||||
if err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
response := dto.AuthorizeOidcClientResponseDto{
|
||||
Code: code,
|
||||
CallbackURL: callbackURL,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
c.JSON(http.StatusOK, gin.H{"authorizationRequired": !hasAuthorizedClient})
|
||||
}
|
||||
|
||||
func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
@@ -134,7 +131,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
|
||||
// Return a different DTO based on the user's role
|
||||
if c.GetBool("userIsAdmin") {
|
||||
clientDto := dto.OidcClientDto{}
|
||||
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
|
||||
err = dto.MapStruct(client, &clientDto)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
@@ -191,7 +188,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var clientDto dto.OidcClientDto
|
||||
var clientDto dto.OidcClientWithAllowedUserGroupsDto
|
||||
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
@@ -223,7 +220,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var clientDto dto.OidcClientDto
|
||||
var clientDto dto.OidcClientWithAllowedUserGroupsDto
|
||||
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
@@ -278,3 +275,25 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
|
||||
var input dto.OidcUpdateAllowedUserGroupsDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input)
|
||||
if err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var oidcClientDto dto.OidcClientDto
|
||||
if err := dto.MapStruct(oidcClient, &oidcClientDto); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, oidcClientDto)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,14 @@ type OidcClientDto struct {
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
CreatedBy UserDto `json:"createdBy"`
|
||||
}
|
||||
|
||||
type OidcClientWithAllowedUserGroupsDto struct {
|
||||
PublicOidcClientDto
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"`
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
@@ -35,6 +42,11 @@ type AuthorizeOidcClientResponseDto struct {
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
}
|
||||
|
||||
type AuthorizationRequiredDto struct {
|
||||
ClientID string `json:"clientID" binding:"required"`
|
||||
Scope string `json:"scope" binding:"required"`
|
||||
}
|
||||
|
||||
type OidcCreateTokensDto struct {
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code" binding:"required"`
|
||||
@@ -42,3 +54,7 @@ type OidcCreateTokensDto struct {
|
||||
ClientSecret string `form:"client_secret"`
|
||||
CodeVerifier string `form:"code_verifier"`
|
||||
}
|
||||
|
||||
type OidcUpdateAllowedUserGroupsDto struct {
|
||||
UserGroupIDs []string `json:"userGroupIds" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -33,7 +33,3 @@ type UserGroupCreateDto struct {
|
||||
type UserGroupUpdateUsersDto struct {
|
||||
UserIDs []string `json:"userIds" binding:"required"`
|
||||
}
|
||||
|
||||
type AssignUserToGroupDto struct {
|
||||
UserID string `json:"userId" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -44,8 +44,9 @@ type OidcClient struct {
|
||||
IsPublic bool
|
||||
PkceEnabled bool
|
||||
|
||||
CreatedByID string
|
||||
CreatedBy User
|
||||
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||
CreatedByID string
|
||||
CreatedBy User
|
||||
}
|
||||
|
||||
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||
|
||||
@@ -38,71 +38,111 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppCo
|
||||
}
|
||||
|
||||
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
|
||||
|
||||
if userAuthorizedOIDCClient.Client.IsPublic && input.CodeChallenge == "" {
|
||||
return "", "", &common.OidcMissingCodeChallengeError{}
|
||||
}
|
||||
|
||||
if userAuthorizedOIDCClient.Scope != input.Scope {
|
||||
return "", "", &common.OidcMissingAuthorizationError{}
|
||||
}
|
||||
|
||||
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": userAuthorizedOIDCClient.Client.Name})
|
||||
|
||||
return code, callbackURL, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", input.ClientID).Error; err != nil {
|
||||
if err := s.db.Preload("AllowedUserGroups").First(&client, "id = ?", input.ClientID).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// If the client is not public, the code challenge must be provided
|
||||
if client.IsPublic && input.CodeChallenge == "" {
|
||||
return "", "", &common.OidcMissingCodeChallengeError{}
|
||||
}
|
||||
|
||||
// Get the callback URL of the client. Return an error if the provided callback URL is not allowed
|
||||
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: input.ClientID,
|
||||
Scope: input.Scope,
|
||||
// Check if the user group is allowed to authorize the client
|
||||
var user model.User
|
||||
if err := s.db.Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
err = s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error
|
||||
} else {
|
||||
return "", "", err
|
||||
if !s.IsUserGroupAllowedToAuthorize(user, client) {
|
||||
return "", "", &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
// Check if the user has already authorized the client with the given scope
|
||||
hasAuthorizedClient, err := s.HasAuthorizedClient(input.ClientID, userID, input.Scope)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// If the user has not authorized the client, create a new authorization in the database
|
||||
if !hasAuthorizedClient {
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: input.ClientID,
|
||||
Scope: input.Scope,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
// The client has already been authorized but with a different scope so we need to update the scope
|
||||
if err := s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
} else {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the authorization code
|
||||
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
|
||||
// Log the authorization event
|
||||
if hasAuthorizedClient {
|
||||
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
|
||||
} else {
|
||||
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name})
|
||||
|
||||
}
|
||||
|
||||
return code, callbackURL, nil
|
||||
}
|
||||
|
||||
// HasAuthorizedClient checks if the user has already authorized the client with the given scope
|
||||
func (s *OidcService) HasAuthorizedClient(clientID, userID, scope string) (bool, error) {
|
||||
var userAuthorizedOidcClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
if userAuthorizedOidcClient.Scope != scope {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
|
||||
func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client model.OidcClient) bool {
|
||||
if len(client.AllowedUserGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
isAllowedToAuthorize := false
|
||||
for _, userGroup := range client.AllowedUserGroups {
|
||||
for _, userGroupUser := range user.UserGroups {
|
||||
if userGroup.ID == userGroupUser.ID {
|
||||
isAllowedToAuthorize = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isAllowedToAuthorize
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
|
||||
if grantType != "authorization_code" {
|
||||
return "", "", &common.OidcGrantTypeNotSupportedError{}
|
||||
@@ -161,7 +201,7 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, code
|
||||
|
||||
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
|
||||
if err := s.db.Preload("CreatedBy").Preload("AllowedUserGroups").First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
return client, nil
|
||||
@@ -382,6 +422,33 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
|
||||
client, err = s.GetClient(id)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
// Fetch the user groups based on UserGroupIDs in input
|
||||
var groups []model.UserGroup
|
||||
if len(input.UserGroupIDs) > 0 {
|
||||
if err := s.db.Where("id IN (?)", input.UserGroupIDs).Find(&groups).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the current user groups with the new set of user groups
|
||||
if err := s.db.Model(&client).Association("AllowedUserGroups").Replace(groups); err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
// Save the updated client
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
|
||||
@@ -124,7 +124,10 @@ func (s *TestService) SeedDatabase() error {
|
||||
Name: "Immich",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"},
|
||||
CreatedByID: users[0].ID,
|
||||
CreatedByID: users[1].ID,
|
||||
AllowedUserGroups: []model.UserGroup{
|
||||
userGroups[1],
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, client := range oidcClients {
|
||||
@@ -163,27 +166,31 @@ func (s *TestService) SeedDatabase() error {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==")
|
||||
// To generate a new key pair, run the following command:
|
||||
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
|
||||
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
|
||||
|
||||
publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
webauthnCredentials := []model.WebauthnCredential{
|
||||
{
|
||||
Name: "Passkey 1",
|
||||
CredentialID: []byte("test-credential-1"),
|
||||
PublicKey: publicKey1,
|
||||
CredentialID: []byte("test-credential-tim"),
|
||||
PublicKey: publicKeyPasskey1,
|
||||
AttestationType: "none",
|
||||
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
||||
UserID: users[0].ID,
|
||||
},
|
||||
{
|
||||
Name: "Passkey 2",
|
||||
CredentialID: []byte("test-credential-2"),
|
||||
PublicKey: publicKey2,
|
||||
CredentialID: []byte("test-credential-craig"),
|
||||
PublicKey: publicKeyPasskey2,
|
||||
AttestationType: "none",
|
||||
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
||||
UserID: users[0].ID,
|
||||
UserID: users[1].ID,
|
||||
},
|
||||
}
|
||||
for _, credential := range webauthnCredentials {
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE oidc_clients_allowed_user_groups;
|
||||
@@ -0,0 +1,8 @@
|
||||
CREATE TABLE oidc_clients_allowed_user_groups
|
||||
(
|
||||
user_group_id UUID NOT NULL REFERENCES user_groups ON DELETE CASCADE,
|
||||
oidc_client_id UUID NOT NULL REFERENCES oidc_clients ON DELETE CASCADE,
|
||||
PRIMARY KEY (oidc_client_id, user_group_id)
|
||||
);
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE oidc_clients_allowed_user_groups;
|
||||
@@ -0,0 +1,8 @@
|
||||
CREATE TABLE oidc_clients_allowed_user_groups
|
||||
(
|
||||
user_group_id TEXT NOT NULL,
|
||||
oidc_client_id TEXT NOT NULL,
|
||||
PRIMARY KEY (oidc_client_id, user_group_id),
|
||||
FOREIGN KEY (oidc_client_id) REFERENCES oidc_clients (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE
|
||||
);
|
||||
Reference in New Issue
Block a user