mirror of
https://github.com/nikdoof/pocket-id.git
synced 2025-12-24 06:49:22 +00:00
feat: add end session endpoint (#232)
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"os"
|
||||
@@ -28,8 +27,8 @@ const (
|
||||
)
|
||||
|
||||
type JwtService struct {
|
||||
publicKey *rsa.PublicKey
|
||||
privateKey *rsa.PrivateKey
|
||||
PublicKey *rsa.PublicKey
|
||||
PrivateKey *rsa.PrivateKey
|
||||
appConfigService *AppConfigService
|
||||
}
|
||||
|
||||
@@ -72,7 +71,7 @@ func (s *JwtService) loadOrGenerateKeys() error {
|
||||
if err != nil {
|
||||
return errors.New("can't read jwt private key: " + err.Error())
|
||||
}
|
||||
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
if err != nil {
|
||||
return errors.New("can't parse jwt private key: " + err.Error())
|
||||
}
|
||||
@@ -81,7 +80,7 @@ func (s *JwtService) loadOrGenerateKeys() error {
|
||||
if err != nil {
|
||||
return errors.New("can't read jwt public key: " + err.Error())
|
||||
}
|
||||
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
return errors.New("can't parse jwt public key: " + err.Error())
|
||||
}
|
||||
@@ -101,7 +100,7 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
IsAdmin: user.IsAdmin,
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.publicKey)
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||
}
|
||||
@@ -109,12 +108,12 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
token.Header["kid"] = kid
|
||||
|
||||
return token.SignedString(s.privateKey)
|
||||
return token.SignedString(s.PrivateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.publicKey, nil
|
||||
return s.PublicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
@@ -147,7 +146,7 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.publicKey)
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||
}
|
||||
@@ -155,7 +154,7 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = kid
|
||||
|
||||
return token.SignedString(s.privateKey)
|
||||
return token.SignedString(s.PrivateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
@@ -167,7 +166,7 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
||||
Issuer: common.EnvConfig.AppURL,
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.publicKey)
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||
}
|
||||
@@ -175,12 +174,12 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
token.Header["kid"] = kid
|
||||
|
||||
return token.SignedString(s.privateKey)
|
||||
return token.SignedString(s.PrivateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.publicKey, nil
|
||||
return s.PublicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
@@ -194,13 +193,30 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.Registered
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.PublicKey, nil
|
||||
}, jwt.WithIssuer(common.EnvConfig.AppURL))
|
||||
|
||||
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
||||
func (s *JwtService) GetJWK() (JWK, error) {
|
||||
if s.publicKey == nil {
|
||||
if s.PublicKey == nil {
|
||||
return JWK{}, errors.New("public key is not initialized")
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.publicKey)
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return JWK{}, err
|
||||
}
|
||||
@@ -210,8 +226,8 @@ func (s *JwtService) GetJWK() (JWK, error) {
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()),
|
||||
N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()),
|
||||
}
|
||||
|
||||
return jwk, nil
|
||||
@@ -246,14 +262,14 @@ func (s *JwtService) generateKeys() error {
|
||||
if err != nil {
|
||||
return errors.New("failed to generate private key: " + err.Error())
|
||||
}
|
||||
s.privateKey = privateKey
|
||||
s.PrivateKey = privateKey
|
||||
|
||||
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
s.publicKey = publicKey
|
||||
s.PublicKey = publicKey
|
||||
|
||||
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
|
||||
return err
|
||||
@@ -281,32 +297,3 @@ func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) er
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadKeys loads RSA keys from the given paths.
|
||||
func (s *JwtService) loadKeys() error {
|
||||
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||
if err := s.generateKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't read jwt private key: %w", err)
|
||||
}
|
||||
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't parse jwt private key: %w", err)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't read jwt public key: %w", err)
|
||||
}
|
||||
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't parse jwt public key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
|
||||
}
|
||||
|
||||
// Get the callback URL of the client. Return an error if the provided callback URL is not allowed
|
||||
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
|
||||
callbackURL, err := s.getCallbackURL(client.CallbackURLs, input.CallbackURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -228,11 +228,12 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti
|
||||
|
||||
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
CreatedByID: userID,
|
||||
IsPublic: input.IsPublic,
|
||||
PkceEnabled: input.IsPublic || input.PkceEnabled,
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
LogoutCallbackURLs: input.LogoutCallbackURLs,
|
||||
CreatedByID: userID,
|
||||
IsPublic: input.IsPublic,
|
||||
PkceEnabled: input.IsPublic || input.PkceEnabled,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&client).Error; err != nil {
|
||||
@@ -250,6 +251,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
|
||||
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
|
||||
@@ -460,6 +462,46 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
|
||||
func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) {
|
||||
// If no ID token hint is provided, return an error
|
||||
if input.IdTokenHint == "" {
|
||||
return "", &common.TokenInvalidError{}
|
||||
}
|
||||
|
||||
// If the ID token hint is provided, verify the ID token
|
||||
claims, err := s.jwtService.VerifyIdToken(input.IdTokenHint)
|
||||
if err != nil {
|
||||
return "", &common.TokenInvalidError{}
|
||||
}
|
||||
|
||||
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request
|
||||
if input.ClientId != "" && claims.Audience[0] != input.ClientId {
|
||||
return "", &common.OidcClientIdNotMatchingError{}
|
||||
}
|
||||
|
||||
clientId := claims.Audience[0]
|
||||
|
||||
// Check if the user has authorized the client before
|
||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientId, userID).Error; err != nil {
|
||||
return "", &common.OidcMissingAuthorizationError{}
|
||||
}
|
||||
|
||||
// If the client has no logout callback URLs, return an error
|
||||
if len(userAuthorizedOIDCClient.Client.LogoutCallbackURLs) == 0 {
|
||||
return "", &common.OidcNoCallbackURLError{}
|
||||
}
|
||||
|
||||
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client.LogoutCallbackURLs, input.PostLogoutRedirectUri)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return callbackURL, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
@@ -506,12 +548,12 @@ func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, c
|
||||
return encodedVerifierHash == codeChallenge
|
||||
}
|
||||
|
||||
func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
|
||||
func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (callbackURL string, err error) {
|
||||
if inputCallbackURL == "" {
|
||||
return client.CallbackURLs[0], nil
|
||||
return urls[0], nil
|
||||
}
|
||||
|
||||
for _, callbackPattern := range client.CallbackURLs {
|
||||
for _, callbackPattern := range urls {
|
||||
regexPattern := strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$"
|
||||
matched, err := regexp.MatchString(regexPattern, inputCallbackURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -23,11 +24,12 @@ import (
|
||||
|
||||
type TestService struct {
|
||||
db *gorm.DB
|
||||
jwtService *JwtService
|
||||
appConfigService *AppConfigService
|
||||
}
|
||||
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService) *TestService {
|
||||
return &TestService{db: db, appConfigService: appConfigService}
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService) *TestService {
|
||||
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService}
|
||||
}
|
||||
|
||||
func (s *TestService) SeedDatabase() error {
|
||||
@@ -112,11 +114,12 @@ func (s *TestService) SeedDatabase() error {
|
||||
Base: model.Base{
|
||||
ID: "3654a746-35d4-4321-ac61-0bdcff2b4055",
|
||||
},
|
||||
Name: "Nextcloud",
|
||||
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
|
||||
CallbackURLs: model.CallbackURLs{"http://nextcloud/auth/callback"},
|
||||
ImageType: utils.StringPointer("png"),
|
||||
CreatedByID: users[0].ID,
|
||||
Name: "Nextcloud",
|
||||
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
|
||||
CallbackURLs: model.UrlList{"http://nextcloud/auth/callback"},
|
||||
LogoutCallbackURLs: model.UrlList{"http://nextcloud/auth/logout/callback"},
|
||||
ImageType: utils.StringPointer("png"),
|
||||
CreatedByID: users[0].ID,
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
@@ -124,7 +127,7 @@ func (s *TestService) SeedDatabase() error {
|
||||
},
|
||||
Name: "Immich",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"},
|
||||
CallbackURLs: model.UrlList{"http://immich/auth/callback"},
|
||||
CreatedByID: users[1].ID,
|
||||
AllowedUserGroups: []model.UserGroup{
|
||||
userGroups[1],
|
||||
@@ -288,6 +291,43 @@ func (s *TestService) ResetAppConfig() error {
|
||||
return s.appConfigService.LoadDbConfigFromDb()
|
||||
}
|
||||
|
||||
func (s *TestService) SetJWTKeys() {
|
||||
privateKeyString := `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpQIBAAKCAQEAyaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B
|
||||
83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC+585UXacoJ0c
|
||||
hUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl/4EDDTO8HwawTjwkPo
|
||||
QlRzeByhlvGPVvwgB3Fn93B8QJ/cZhXKxJvjjrC/8Pk76heC/ntEMru71Ix77BoC
|
||||
3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeO
|
||||
Zl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJwIDAQABAoIBAQCa8wNZJ08+9y6b
|
||||
RzSIQcTaBuq1XY0oyYvCuX0ToruDyVNX3lJ48udb9vDIw9XsQans9CTeXXsjldGE
|
||||
WPN7sapOcUg6ArMyJqc+zuO/YQu0EwYrTE48BOC7WIZvvTFnq9y+4R9HJjd0nTOv
|
||||
iOlR1W5fAqbH2srgh1mfZ0UIp+9K6ymoinPXVGEXUAuuoMuTEZW/tnA2HT9WEllT
|
||||
2FyMbmXrFzutAQqk9GRmnQh2OQZLxnQWyShVqJEhYBtm6JUUH1YJbyTVzMLgdBM8
|
||||
ukgjTVtRDHaW51ubRSVdGBVT2m1RRtTsYAiZCpM5bwt88aSUS9yDOUiVH+irDg/3
|
||||
IHEuL7IxAoGBAP2MpXPXtOwinajUQ9hKLDAtpq4axGvY+aGP5dNEMsuPo5ggOfUP
|
||||
b4sqr73kaNFO3EbxQOQVoFjehhi4dQxt1/kAala9HZ5N7s26G2+eUWFF8jy7gWSN
|
||||
qusNqGrG4g8D3WOyqZFb/x/m6SE0Jcg7zvIYbnAOq1Fexeik0Fc/DNzLAoGBAMua
|
||||
d4XIfu4ydtU5AIaf1ZNXywgLg+LWxK8ELNqH/Y2vLAeIiTrOVp+hw9z+zHPD5cnu
|
||||
6mix783PCOYNLTylrwtAz3fxSz14lsDFQM3ntzVF/6BniTTkKddctcPyqnTvamah
|
||||
0hD2dzXBS/0mTBYIIMYTNbs0Yj87FTdJZw/+qa2VAoGBAKbzQkp54W6PCIMPabD0
|
||||
fg4nMRZ5F5bv4seIKcunn068QPs9VQxQ4qCfNeLykDYqGA86cgD9YHzD4UZLxv6t
|
||||
IUWbCWod0m/XXwPlpIUlmO5VEUD+MiAUzFNDxf6xAE7ku5UXImJNUjseX6l2Xd5v
|
||||
yz9L6QQuFI5aujQKugiIwp5rAoGATtUVGCCkPNgfOLmkYXu7dxxUCV5kB01+xAEK
|
||||
2OY0n0pG8vfDophH4/D/ZC7nvJ8J9uDhs/3JStexq1lIvaWtG99RNTChIEDzpdn6
|
||||
GH9yaVcb/eB4uJjrNm64FhF8PGCCwxA+xMCZMaARKwhMB2/IOMkxUbWboL3gnhJ2
|
||||
rDO/QO0CgYEA2Grt6uXHm61ji3xSdkBWNtUnj19vS1+7rFJp5SoYztVQVThf/W52
|
||||
BAiXKBdYZDRVoItC/VS2NvAOjeJjhYO/xQ/q3hK7MdtuXfEPpLnyXKkmWo3lrJ26
|
||||
wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI=
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`
|
||||
|
||||
block, _ := pem.Decode([]byte(privateKeyString))
|
||||
privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
|
||||
s.jwtService.PrivateKey = privateKey
|
||||
s.jwtService.PublicKey = &privateKey.PublicKey
|
||||
}
|
||||
|
||||
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
||||
func (s *TestService) getCborPublicKey(base64PublicKey string) ([]byte, error) {
|
||||
decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey)
|
||||
|
||||
Reference in New Issue
Block a user