diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 6229b19..35b4dc1 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -41,7 +41,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { userService := service.NewUserService(db, jwtService, auditLogService, emailService, appConfigService) customClaimService := service.NewCustomClaimService(db) oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService) - testService := service.NewTestService(db, appConfigService) + testService := service.NewTestService(db, appConfigService, jwtService) userGroupService := service.NewUserGroupService(db, appConfigService) ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService) diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index bcf6641..4a0153f 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -31,6 +31,13 @@ type TokenInvalidOrExpiredError struct{} func (e *TokenInvalidOrExpiredError) Error() string { return "token is invalid or expired" } func (e *TokenInvalidOrExpiredError) HttpStatusCode() int { return 400 } +type TokenInvalidError struct{} + +func (e *TokenInvalidError) Error() string { + return "Token is invalid" +} +func (e *TokenInvalidError) HttpStatusCode() int { return 400 } + type OidcMissingAuthorizationError struct{} func (e *OidcMissingAuthorizationError) Error() string { return "missing authorization" } @@ -182,9 +189,22 @@ type OidcAccessDeniedError struct{} func (e *OidcAccessDeniedError) Error() string { return "You're not allowed to access this service" } - func (e *OidcAccessDeniedError) HttpStatusCode() int { return http.StatusForbidden } +type OidcClientIdNotMatchingError struct{} + +func (e *OidcClientIdNotMatchingError) Error() string { + return "Client id in request doesn't match client id in token" +} +func (e *OidcClientIdNotMatchingError) HttpStatusCode() int { return http.StatusBadRequest } + +type OidcNoCallbackURLError struct{} + +func (e *OidcNoCallbackURLError) Error() string { + return "No callback URL provided" +} +func (e *OidcNoCallbackURLError) HttpStatusCode() int { return http.StatusBadRequest } + type UiConfigDisabledError struct{} func (e *UiConfigDisabledError) Error() string { diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index c2a1abb..c95d6fa 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -1,7 +1,11 @@ package controller import ( + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/utils/cookie" + "log" "net/http" + "net/url" "strings" "github.com/gin-gonic/gin" @@ -19,6 +23,8 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt group.POST("/oidc/token", oc.createTokensHandler) group.GET("/oidc/userinfo", oc.userInfoHandler) + group.POST("/oidc/end-session", oc.EndSessionHandler) + group.GET("/oidc/end-session", oc.EndSessionHandler) group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler) group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler) @@ -122,6 +128,44 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) { c.JSON(http.StatusOK, claims) } +func (oc *OidcController) EndSessionHandler(c *gin.Context) { + var input dto.OidcLogoutDto + + // Bind query parameters to the struct + if c.Request.Method == http.MethodGet { + if err := c.ShouldBindQuery(&input); err != nil { + c.Error(err) + return + } + } else if c.Request.Method == http.MethodPost { + // Bind form parameters to the struct + if err := c.ShouldBind(&input); err != nil { + c.Error(err) + return + } + } + + callbackURL, err := oc.oidcService.ValidateEndSession(input, c.GetString("userID")) + if err != nil { + // If the validation fails, the user has to confirm the logout manually and doesn't get redirected + log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err) + c.Redirect(http.StatusFound, common.EnvConfig.AppURL+"/logout") + return + } + + // The validation was successful, so we can log out and redirect the user to the callback URL without confirmation + cookie.AddAccessTokenCookie(c, 0, "") + + logoutCallbackURL, _ := url.Parse(callbackURL) + if input.State != "" { + q := logoutCallbackURL.Query() + q.Set("state", input.State) + logoutCallbackURL.RawQuery = q.Encode() + } + + c.Redirect(http.StatusFound, logoutCallbackURL.String()) +} + func (oc *OidcController) getClientHandler(c *gin.Context) { clientId := c.Param("id") client, err := oc.oidcService.GetClient(clientId) diff --git a/backend/internal/controller/test_controller.go b/backend/internal/controller/test_controller.go index 096034f..f1c6ad3 100644 --- a/backend/internal/controller/test_controller.go +++ b/backend/internal/controller/test_controller.go @@ -38,5 +38,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) { return } + tc.TestService.SetJWTKeys() + c.Status(http.StatusNoContent) } diff --git a/backend/internal/controller/well_known_controller.go b/backend/internal/controller/well_known_controller.go index 379b37f..48aeba7 100644 --- a/backend/internal/controller/well_known_controller.go +++ b/backend/internal/controller/well_known_controller.go @@ -35,6 +35,7 @@ func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) { "authorization_endpoint": appUrl + "/authorize", "token_endpoint": appUrl + "/api/oidc/token", "userinfo_endpoint": appUrl + "/api/oidc/userinfo", + "end_session_endpoint": appUrl + "/api/oidc/end-session", "jwks_uri": appUrl + "/.well-known/jwks.json", "scopes_supported": []string{"openid", "profile", "email"}, "claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username"}, diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index a904be2..78a1fcb 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -8,24 +8,27 @@ type PublicOidcClientDto struct { type OidcClientDto struct { PublicOidcClientDto - CallbackURLs []string `json:"callbackURLs"` - IsPublic bool `json:"isPublic"` - PkceEnabled bool `json:"pkceEnabled"` + CallbackURLs []string `json:"callbackURLs"` + LogoutCallbackURLs []string `json:"logoutCallbackURLs"` + IsPublic bool `json:"isPublic"` + PkceEnabled bool `json:"pkceEnabled"` } type OidcClientWithAllowedUserGroupsDto struct { PublicOidcClientDto - CallbackURLs []string `json:"callbackURLs"` - IsPublic bool `json:"isPublic"` - PkceEnabled bool `json:"pkceEnabled"` - AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"` + CallbackURLs []string `json:"callbackURLs"` + LogoutCallbackURLs []string `json:"logoutCallbackURLs"` + IsPublic bool `json:"isPublic"` + PkceEnabled bool `json:"pkceEnabled"` + AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"` } type OidcClientCreateDto struct { - Name string `json:"name" binding:"required,max=50"` - CallbackURLs []string `json:"callbackURLs" binding:"required"` - IsPublic bool `json:"isPublic"` - PkceEnabled bool `json:"pkceEnabled"` + Name string `json:"name" binding:"required,max=50"` + CallbackURLs []string `json:"callbackURLs" binding:"required"` + LogoutCallbackURLs []string `json:"logoutCallbackURLs"` + IsPublic bool `json:"isPublic"` + PkceEnabled bool `json:"pkceEnabled"` } type AuthorizeOidcClientRequestDto struct { @@ -58,3 +61,10 @@ type OidcCreateTokensDto struct { type OidcUpdateAllowedUserGroupsDto struct { UserGroupIDs []string `json:"userGroupIds" binding:"required"` } + +type OidcLogoutDto struct { + IdTokenHint string `form:"id_token_hint"` + ClientId string `form:"client_id"` + PostLogoutRedirectUri string `form:"post_logout_redirect_uri"` + State string `form:"state"` +} diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index a4d2efc..714eb7a 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -37,13 +37,14 @@ type OidcAuthorizationCode struct { type OidcClient struct { Base - Name string `sortable:"true"` - Secret string - CallbackURLs CallbackURLs - ImageType *string - HasLogo bool `gorm:"-"` - IsPublic bool - PkceEnabled bool + Name string `sortable:"true"` + Secret string + CallbackURLs UrlList + LogoutCallbackURLs UrlList + ImageType *string + HasLogo bool `gorm:"-"` + IsPublic bool + PkceEnabled bool AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"` CreatedByID string @@ -56,9 +57,9 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { return nil } -type CallbackURLs []string +type UrlList []string -func (cu *CallbackURLs) Scan(value interface{}) error { +func (cu *UrlList) Scan(value interface{}) error { if v, ok := value.([]byte); ok { return json.Unmarshal(v, cu) } else { @@ -66,6 +67,6 @@ func (cu *CallbackURLs) Scan(value interface{}) error { } } -func (cu CallbackURLs) Value() (driver.Value, error) { +func (cu UrlList) Value() (driver.Value, error) { return json.Marshal(cu) } diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index aa340a9..8a832e8 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -8,7 +8,6 @@ import ( "encoding/base64" "encoding/pem" "errors" - "fmt" "log" "math/big" "os" @@ -28,8 +27,8 @@ const ( ) type JwtService struct { - publicKey *rsa.PublicKey - privateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey + PrivateKey *rsa.PrivateKey appConfigService *AppConfigService } @@ -72,7 +71,7 @@ func (s *JwtService) loadOrGenerateKeys() error { if err != nil { return errors.New("can't read jwt private key: " + err.Error()) } - s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) + s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) if err != nil { return errors.New("can't parse jwt private key: " + err.Error()) } @@ -81,7 +80,7 @@ func (s *JwtService) loadOrGenerateKeys() error { if err != nil { return errors.New("can't read jwt public key: " + err.Error()) } - s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) + s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) if err != nil { return errors.New("can't parse jwt public key: " + err.Error()) } @@ -101,7 +100,7 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { IsAdmin: user.IsAdmin, } - kid, err := s.generateKeyID(s.publicKey) + kid, err := s.generateKeyID(s.PublicKey) if err != nil { return "", errors.New("failed to generate key ID: " + err.Error()) } @@ -109,12 +108,12 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) token.Header["kid"] = kid - return token.SignedString(s.privateKey) + return token.SignedString(s.PrivateKey) } func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) { - return s.publicKey, nil + return s.PublicKey, nil }) if err != nil || !token.Valid { return nil, errors.New("couldn't handle this token") @@ -147,7 +146,7 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID claims["nonce"] = nonce } - kid, err := s.generateKeyID(s.publicKey) + kid, err := s.generateKeyID(s.PublicKey) if err != nil { return "", errors.New("failed to generate key ID: " + err.Error()) } @@ -155,7 +154,7 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = kid - return token.SignedString(s.privateKey) + return token.SignedString(s.PrivateKey) } func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { @@ -167,7 +166,7 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) Issuer: common.EnvConfig.AppURL, } - kid, err := s.generateKeyID(s.publicKey) + kid, err := s.generateKeyID(s.PublicKey) if err != nil { return "", errors.New("failed to generate key ID: " + err.Error()) } @@ -175,12 +174,12 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) token.Header["kid"] = kid - return token.SignedString(s.privateKey) + return token.SignedString(s.PrivateKey) } func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { - return s.publicKey, nil + return s.PublicKey, nil }) if err != nil || !token.Valid { return nil, errors.New("couldn't handle this token") @@ -194,13 +193,30 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.Registered return claims, nil } +func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + return s.PublicKey, nil + }, jwt.WithIssuer(common.EnvConfig.AppURL)) + + if err != nil && !errors.Is(err, jwt.ErrTokenExpired) { + return nil, errors.New("couldn't handle this token") + } + + claims, isValid := token.Claims.(*jwt.RegisteredClaims) + if !isValid { + return nil, errors.New("can't parse claims") + } + + return claims, nil +} + // GetJWK returns the JSON Web Key (JWK) for the public key. func (s *JwtService) GetJWK() (JWK, error) { - if s.publicKey == nil { + if s.PublicKey == nil { return JWK{}, errors.New("public key is not initialized") } - kid, err := s.generateKeyID(s.publicKey) + kid, err := s.generateKeyID(s.PublicKey) if err != nil { return JWK{}, err } @@ -210,8 +226,8 @@ func (s *JwtService) GetJWK() (JWK, error) { Kty: "RSA", Use: "sig", Alg: "RS256", - N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()), - E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()), + N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()), } return jwk, nil @@ -246,14 +262,14 @@ func (s *JwtService) generateKeys() error { if err != nil { return errors.New("failed to generate private key: " + err.Error()) } - s.privateKey = privateKey + s.PrivateKey = privateKey if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil { return err } publicKey := &privateKey.PublicKey - s.publicKey = publicKey + s.PublicKey = publicKey if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil { return err @@ -281,32 +297,3 @@ func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) er return nil } - -// loadKeys loads RSA keys from the given paths. -func (s *JwtService) loadKeys() error { - if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) { - if err := s.generateKeys(); err != nil { - return err - } - } - - privateKeyBytes, err := os.ReadFile(privateKeyPath) - if err != nil { - return fmt.Errorf("can't read jwt private key: %w", err) - } - s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) - if err != nil { - return fmt.Errorf("can't parse jwt private key: %w", err) - } - - publicKeyBytes, err := os.ReadFile(publicKeyPath) - if err != nil { - return fmt.Errorf("can't read jwt public key: %w", err) - } - s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) - if err != nil { - return fmt.Errorf("can't parse jwt public key: %w", err) - } - - return nil -} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index db2b02a..91b3c27 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -51,7 +51,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, } // Get the callback URL of the client. Return an error if the provided callback URL is not allowed - callbackURL, err := s.getCallbackURL(client, input.CallbackURL) + callbackURL, err := s.getCallbackURL(client.CallbackURLs, input.CallbackURL) if err != nil { return "", "", err } @@ -228,11 +228,12 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { client := model.OidcClient{ - Name: input.Name, - CallbackURLs: input.CallbackURLs, - CreatedByID: userID, - IsPublic: input.IsPublic, - PkceEnabled: input.IsPublic || input.PkceEnabled, + Name: input.Name, + CallbackURLs: input.CallbackURLs, + LogoutCallbackURLs: input.LogoutCallbackURLs, + CreatedByID: userID, + IsPublic: input.IsPublic, + PkceEnabled: input.IsPublic || input.PkceEnabled, } if err := s.db.Create(&client).Error; err != nil { @@ -250,6 +251,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt client.Name = input.Name client.CallbackURLs = input.CallbackURLs + client.LogoutCallbackURLs = input.LogoutCallbackURLs client.IsPublic = input.IsPublic client.PkceEnabled = input.IsPublic || input.PkceEnabled @@ -460,6 +462,46 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll return client, nil } +// ValidateEndSession returns the logout callback URL for the client if all the validations pass +func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) { + // If no ID token hint is provided, return an error + if input.IdTokenHint == "" { + return "", &common.TokenInvalidError{} + } + + // If the ID token hint is provided, verify the ID token + claims, err := s.jwtService.VerifyIdToken(input.IdTokenHint) + if err != nil { + return "", &common.TokenInvalidError{} + } + + // If the client ID is provided check if the client ID in the ID token matches the client ID in the request + if input.ClientId != "" && claims.Audience[0] != input.ClientId { + return "", &common.OidcClientIdNotMatchingError{} + } + + clientId := claims.Audience[0] + + // Check if the user has authorized the client before + var userAuthorizedOIDCClient model.UserAuthorizedOidcClient + if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientId, userID).Error; err != nil { + return "", &common.OidcMissingAuthorizationError{} + } + + // If the client has no logout callback URLs, return an error + if len(userAuthorizedOIDCClient.Client.LogoutCallbackURLs) == 0 { + return "", &common.OidcNoCallbackURLError{} + } + + callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client.LogoutCallbackURLs, input.PostLogoutRedirectUri) + if err != nil { + return "", err + } + + return callbackURL, nil + +} + func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) { randomString, err := utils.GenerateRandomAlphanumericString(32) if err != nil { @@ -506,12 +548,12 @@ func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, c return encodedVerifierHash == codeChallenge } -func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) { +func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (callbackURL string, err error) { if inputCallbackURL == "" { - return client.CallbackURLs[0], nil + return urls[0], nil } - for _, callbackPattern := range client.CallbackURLs { + for _, callbackPattern := range urls { regexPattern := strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$" matched, err := regexp.MatchString(regexPattern, inputCallbackURL) if err != nil { diff --git a/backend/internal/service/test_service.go b/backend/internal/service/test_service.go index f6f4816..89d208f 100644 --- a/backend/internal/service/test_service.go +++ b/backend/internal/service/test_service.go @@ -4,6 +4,7 @@ import ( "crypto/ecdsa" "crypto/x509" "encoding/base64" + "encoding/pem" "fmt" "log" "os" @@ -23,11 +24,12 @@ import ( type TestService struct { db *gorm.DB + jwtService *JwtService appConfigService *AppConfigService } -func NewTestService(db *gorm.DB, appConfigService *AppConfigService) *TestService { - return &TestService{db: db, appConfigService: appConfigService} +func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService) *TestService { + return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService} } func (s *TestService) SeedDatabase() error { @@ -112,11 +114,12 @@ func (s *TestService) SeedDatabase() error { Base: model.Base{ ID: "3654a746-35d4-4321-ac61-0bdcff2b4055", }, - Name: "Nextcloud", - Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY - CallbackURLs: model.CallbackURLs{"http://nextcloud/auth/callback"}, - ImageType: utils.StringPointer("png"), - CreatedByID: users[0].ID, + Name: "Nextcloud", + Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY + CallbackURLs: model.UrlList{"http://nextcloud/auth/callback"}, + LogoutCallbackURLs: model.UrlList{"http://nextcloud/auth/logout/callback"}, + ImageType: utils.StringPointer("png"), + CreatedByID: users[0].ID, }, { Base: model.Base{ @@ -124,7 +127,7 @@ func (s *TestService) SeedDatabase() error { }, Name: "Immich", Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x - CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"}, + CallbackURLs: model.UrlList{"http://immich/auth/callback"}, CreatedByID: users[1].ID, AllowedUserGroups: []model.UserGroup{ userGroups[1], @@ -288,6 +291,43 @@ func (s *TestService) ResetAppConfig() error { return s.appConfigService.LoadDbConfigFromDb() } +func (s *TestService) SetJWTKeys() { + privateKeyString := `-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAyaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B +83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC+585UXacoJ0c +hUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl/4EDDTO8HwawTjwkPo +QlRzeByhlvGPVvwgB3Fn93B8QJ/cZhXKxJvjjrC/8Pk76heC/ntEMru71Ix77BoC +3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeO +Zl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJwIDAQABAoIBAQCa8wNZJ08+9y6b +RzSIQcTaBuq1XY0oyYvCuX0ToruDyVNX3lJ48udb9vDIw9XsQans9CTeXXsjldGE +WPN7sapOcUg6ArMyJqc+zuO/YQu0EwYrTE48BOC7WIZvvTFnq9y+4R9HJjd0nTOv +iOlR1W5fAqbH2srgh1mfZ0UIp+9K6ymoinPXVGEXUAuuoMuTEZW/tnA2HT9WEllT +2FyMbmXrFzutAQqk9GRmnQh2OQZLxnQWyShVqJEhYBtm6JUUH1YJbyTVzMLgdBM8 +ukgjTVtRDHaW51ubRSVdGBVT2m1RRtTsYAiZCpM5bwt88aSUS9yDOUiVH+irDg/3 +IHEuL7IxAoGBAP2MpXPXtOwinajUQ9hKLDAtpq4axGvY+aGP5dNEMsuPo5ggOfUP +b4sqr73kaNFO3EbxQOQVoFjehhi4dQxt1/kAala9HZ5N7s26G2+eUWFF8jy7gWSN +qusNqGrG4g8D3WOyqZFb/x/m6SE0Jcg7zvIYbnAOq1Fexeik0Fc/DNzLAoGBAMua +d4XIfu4ydtU5AIaf1ZNXywgLg+LWxK8ELNqH/Y2vLAeIiTrOVp+hw9z+zHPD5cnu +6mix783PCOYNLTylrwtAz3fxSz14lsDFQM3ntzVF/6BniTTkKddctcPyqnTvamah +0hD2dzXBS/0mTBYIIMYTNbs0Yj87FTdJZw/+qa2VAoGBAKbzQkp54W6PCIMPabD0 +fg4nMRZ5F5bv4seIKcunn068QPs9VQxQ4qCfNeLykDYqGA86cgD9YHzD4UZLxv6t +IUWbCWod0m/XXwPlpIUlmO5VEUD+MiAUzFNDxf6xAE7ku5UXImJNUjseX6l2Xd5v +yz9L6QQuFI5aujQKugiIwp5rAoGATtUVGCCkPNgfOLmkYXu7dxxUCV5kB01+xAEK +2OY0n0pG8vfDophH4/D/ZC7nvJ8J9uDhs/3JStexq1lIvaWtG99RNTChIEDzpdn6 +GH9yaVcb/eB4uJjrNm64FhF8PGCCwxA+xMCZMaARKwhMB2/IOMkxUbWboL3gnhJ2 +rDO/QO0CgYEA2Grt6uXHm61ji3xSdkBWNtUnj19vS1+7rFJp5SoYztVQVThf/W52 +BAiXKBdYZDRVoItC/VS2NvAOjeJjhYO/xQ/q3hK7MdtuXfEPpLnyXKkmWo3lrJ26 +wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI= +-----END RSA PRIVATE KEY----- +` + + block, _ := pem.Decode([]byte(privateKeyString)) + privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes) + + s.jwtService.PrivateKey = privateKey + s.jwtService.PublicKey = &privateKey.PublicKey +} + // getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key func (s *TestService) getCborPublicKey(base64PublicKey string) ([]byte, error) { decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey) diff --git a/backend/resources/migrations/postgres/20250210152631_post_logout_url.down.sql b/backend/resources/migrations/postgres/20250210152631_post_logout_url.down.sql new file mode 100644 index 0000000..b92c022 --- /dev/null +++ b/backend/resources/migrations/postgres/20250210152631_post_logout_url.down.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients DROP COLUMN logout_callback_urls; \ No newline at end of file diff --git a/backend/resources/migrations/postgres/20250210152631_post_logout_url.up.sql b/backend/resources/migrations/postgres/20250210152631_post_logout_url.up.sql new file mode 100644 index 0000000..34bb4cf --- /dev/null +++ b/backend/resources/migrations/postgres/20250210152631_post_logout_url.up.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients ADD COLUMN logout_callback_urls JSONB; \ No newline at end of file diff --git a/backend/resources/migrations/sqlite/20250210152631_post_logout_url.down.sql b/backend/resources/migrations/sqlite/20250210152631_post_logout_url.down.sql new file mode 100644 index 0000000..b92c022 --- /dev/null +++ b/backend/resources/migrations/sqlite/20250210152631_post_logout_url.down.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients DROP COLUMN logout_callback_urls; \ No newline at end of file diff --git a/backend/resources/migrations/sqlite/20250210152631_post_logout_url.up.sql b/backend/resources/migrations/sqlite/20250210152631_post_logout_url.up.sql new file mode 100644 index 0000000..3c5be97 --- /dev/null +++ b/backend/resources/migrations/sqlite/20250210152631_post_logout_url.up.sql @@ -0,0 +1 @@ +ALTER TABLE oidc_clients ADD COLUMN logout_callback_urls BLOB; \ No newline at end of file diff --git a/frontend/src/hooks.server.ts b/frontend/src/hooks.server.ts index 562a7f1..c927de2 100644 --- a/frontend/src/hooks.server.ts +++ b/frontend/src/hooks.server.ts @@ -12,7 +12,11 @@ process.env.INTERNAL_BACKEND_URL = env.INTERNAL_BACKEND_URL ?? 'http://localhost export const handle: Handle = async ({ event, resolve }) => { const { isSignedIn, isAdmin } = verifyJwt(event.cookies.get(ACCESS_TOKEN_COOKIE_NAME)); - if (event.url.pathname.startsWith('/settings') && !event.url.pathname.startsWith('/login')) { + const isUnauthenticatedOnlyPath = event.url.pathname.startsWith('/login'); + const isPublicPath = ['/authorize', '/health'].includes(event.url.pathname); + const isAdminPath = event.url.pathname.startsWith('/settings/admin'); + + if (!isUnauthenticatedOnlyPath && !isPublicPath) { if (!isSignedIn) { return new Response(null, { status: 302, @@ -21,14 +25,14 @@ export const handle: Handle = async ({ event, resolve }) => { } } - if (event.url.pathname.startsWith('/login') && isSignedIn) { + if (isUnauthenticatedOnlyPath && isSignedIn) { return new Response(null, { status: 302, headers: { location: '/settings' } }); } - if (event.url.pathname.startsWith('/settings/admin') && !isAdmin) { + if (isAdminPath && !isAdmin) { return new Response(null, { status: 302, headers: { location: '/settings' } diff --git a/frontend/src/lib/components/header/header.svelte b/frontend/src/lib/components/header/header.svelte index 7e998c5..d5f65de 100644 --- a/frontend/src/lib/components/header/header.svelte +++ b/frontend/src/lib/components/header/header.svelte @@ -5,10 +5,9 @@ import Logo from '../logo.svelte'; import HeaderAvatar from './header-avatar.svelte'; - let isAuthPage = $derived( - !$page.error && - ($page.url.pathname.startsWith('/authorize') || $page.url.pathname.startsWith('/login')) - ); + const authUrls = ['/authorize', '/login', '/logout']; + let isAuthPage = $derived(!$page.error && authUrls.includes($page.url.pathname)); +
diff --git a/frontend/src/lib/types/oidc.type.ts b/frontend/src/lib/types/oidc.type.ts index 27bd40f..85da1cb 100644 --- a/frontend/src/lib/types/oidc.type.ts +++ b/frontend/src/lib/types/oidc.type.ts @@ -5,6 +5,7 @@ export type OidcClient = { name: string; logoURL: string; callbackURLs: [string, ...string[]]; + logoutCallbackURLs: string[]; hasLogo: boolean; isPublic: boolean; pkceEnabled: boolean; diff --git a/frontend/src/routes/authorize/+page.svelte b/frontend/src/routes/authorize/+page.svelte index 854ed03..1cf681f 100644 --- a/frontend/src/routes/authorize/+page.svelte +++ b/frontend/src/routes/authorize/+page.svelte @@ -8,7 +8,6 @@ import userStore from '$lib/stores/user-store'; import { getWebauthnErrorMessage } from '$lib/utils/error-util'; import { startAuthentication } from '@simplewebauthn/browser'; - import { AxiosError } from 'axios'; import { LucideMail, LucideUser, LucideUsers } from 'lucide-svelte'; import { onMount } from 'svelte'; import { slide } from 'svelte/transition'; @@ -60,11 +59,7 @@ onSuccess(code, callbackURL); }); } catch (e) { - if (e instanceof AxiosError && e.response?.data.error === 'Missing authorization') { - authorizationRequired = true; - } else { - errorMessage = getWebauthnErrorMessage(e); - } + errorMessage = getWebauthnErrorMessage(e); isLoading = false; } } diff --git a/frontend/src/routes/logout/+page.svelte b/frontend/src/routes/logout/+page.svelte new file mode 100644 index 0000000..8d2ba8b --- /dev/null +++ b/frontend/src/routes/logout/+page.svelte @@ -0,0 +1,43 @@ + + + + Logout + + + +
+
+ +
+
+

Sign out

+ +

+ Do you want to sign out of Pocket ID with the account {$userStore?.username}? +

+
+ + +
+
diff --git a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte index 7ac854e..6756392 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte @@ -1,7 +1,6 @@
- +
{#each callbackURLs as _, i}
- {#if callbackURLs.length > 1} + {#if callbackURLs.length > 1 || allowEmpty} {/if}
diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte index a26685d..4f1f45d 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte @@ -10,7 +10,7 @@ OidcClientCreateWithLogo } from '$lib/types/oidc.type'; import { createForm } from '$lib/utils/form-util'; - import { set, z } from 'zod'; + import { z } from 'zod'; import OidcCallbackUrlInput from './oidc-callback-url-input.svelte'; let { @@ -30,6 +30,7 @@ const client: OidcClientCreate = { name: existingClient?.name || '', callbackURLs: existingClient?.callbackURLs || [''], + logoutCallbackURLs: existingClient?.logoutCallbackURLs || [], isPublic: existingClient?.isPublic || false, pkceEnabled: existingClient?.isPublic == true || existingClient?.pkceEnabled || false }; @@ -37,6 +38,7 @@ const formSchema = z.object({ name: z.string().min(2).max(50), callbackURLs: z.array(z.string()).nonempty(), + logoutCallbackURLs: z.array(z.string()), isPublic: z.boolean(), pkceEnabled: z.boolean() }); @@ -78,11 +80,20 @@
+
+ Logo
{#if logoDataURL} -
+
{ await page.goto(`/settings/admin/oidc-clients/${oidcClient.id}`); await page.getByLabel('Name').fill('Nextcloud updated'); - await page.getByTestId('callback-url-1').fill('http://nextcloud-updated/auth/callback'); + await page.getByTestId('callback-url-1').first().fill('http://nextcloud-updated/auth/callback'); await page.getByLabel('logo').setInputFiles('tests/assets/nextcloud-logo.png'); await page.getByRole('button', { name: 'Save' }).click(); diff --git a/frontend/tests/oidc.spec.ts b/frontend/tests/oidc.spec.ts index 19e3c23..30a7b20 100644 --- a/frontend/tests/oidc.spec.ts +++ b/frontend/tests/oidc.spec.ts @@ -89,10 +89,11 @@ test('Authorize new client fails with user group not allowed', async ({ page }) await page.getByRole('button', { name: 'Sign in' }).click(); - await expect(page.getByRole('paragraph').first()).toHaveText("You're not allowed to access this service."); + await expect(page.getByRole('paragraph').first()).toHaveText( + "You're not allowed to access this service." + ); }); - function createUrlParams(oidcClient: { id: string; callbackUrl: string }) { return new URLSearchParams({ client_id: oidcClient.id, @@ -103,3 +104,33 @@ function createUrlParams(oidcClient: { id: string; callbackUrl: string }) { nonce: 'P1gN3PtpKHJgKUVcLpLjm' }); } + +test('End session without id token hint shows confirmation page', async ({ page }) => { + await page.goto('/api/oidc/end-session'); + + await expect(page).toHaveURL('/logout'); + await page.getByRole('button', { name: 'Sign out' }).click(); + + await expect(page).toHaveURL('/login'); +}); + +test('End session with id token hint redirects to callback URL', async ({ page }) => { + const client = oidcClients.nextcloud; + const idToken = + 'eyJhbGciOiJSUzI1NiIsImtpZCI6Ijh1SER3M002cmY4IiwidHlwIjoiSldUIn0.eyJhdWQiOiIzNjU0YTc0Ni0zNWQ0LTQzMjEtYWM2MS0wYmRjZmYyYjQwNTUiLCJlbWFpbCI6InRpbS5jb29rQHRlc3QuY29tIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsImV4cCI6MTY5MDAwMDAwMSwiZmFtaWx5X25hbWUiOiJUaW0iLCJnaXZlbl9uYW1lIjoiQ29vayIsImlhdCI6MTY5MDAwMDAwMCwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdCIsIm5hbWUiOiJUaW0gQ29vayIsIm5vbmNlIjoib1cxQTFPNzhHUTE1RDczT3NIRXg3V1FLajdacXZITFp1XzM3bWRYSXFBUSIsInN1YiI6IjRiODlkYzItNjJmYi00NmJmLTlmNWYtYzM0ZjRlYWZlOTNlIn0.ruYCyjA2BNjROpmLGPNHrhgUNLnpJMEuncvjDYVuv1dAZwvOPfG-Rn-OseAgJDJbV7wJ0qf6ZmBkGWiifwc_B9h--fgd4Vby9fefj0MiHbSDgQyaU5UmpvJU8OlvM-TueD6ICJL0NeT3DwoW5xpIWaHtt3JqJIdP__Q-lTONL2Zokq50kWm0IO-bIw2QrQviSfHNpv8A5rk1RTzpXCPXYNB-eJbm3oBqYQWzerD9HaNrSvrKA7mKG8Te1mI9aMirPpG9FvcAU-I3lY8ky1hJZDu42jHpVEUdWPAmUZPZafoX8iYtlPfkoklDnHj_cdg4aZBGN5bfjM6xf1Oe_rLDWg'; + + let redirectedCorrectly = false; + await page + .goto( + `/api/oidc/end-session?id_token_hint=${idToken}&post_logout_redirect_uri=${client.logoutCallbackUrl}` + ) + .catch((e) => { + if (e.message.includes('net::ERR_NAME_NOT_RESOLVED')) { + redirectedCorrectly = true; + } else { + throw e; + } + }); + + expect(redirectedCorrectly).toBeTruthy(); +});