mirror of
https://github.com/nikdoof/pocket-id.git
synced 2025-12-23 14:29:23 +00:00
feat: add end session endpoint (#232)
This commit is contained in:
@@ -51,7 +51,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
|
||||
}
|
||||
|
||||
// Get the callback URL of the client. Return an error if the provided callback URL is not allowed
|
||||
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
|
||||
callbackURL, err := s.getCallbackURL(client.CallbackURLs, input.CallbackURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -228,11 +228,12 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti
|
||||
|
||||
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
CreatedByID: userID,
|
||||
IsPublic: input.IsPublic,
|
||||
PkceEnabled: input.IsPublic || input.PkceEnabled,
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
LogoutCallbackURLs: input.LogoutCallbackURLs,
|
||||
CreatedByID: userID,
|
||||
IsPublic: input.IsPublic,
|
||||
PkceEnabled: input.IsPublic || input.PkceEnabled,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&client).Error; err != nil {
|
||||
@@ -250,6 +251,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
|
||||
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
|
||||
@@ -460,6 +462,46 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
|
||||
func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) {
|
||||
// If no ID token hint is provided, return an error
|
||||
if input.IdTokenHint == "" {
|
||||
return "", &common.TokenInvalidError{}
|
||||
}
|
||||
|
||||
// If the ID token hint is provided, verify the ID token
|
||||
claims, err := s.jwtService.VerifyIdToken(input.IdTokenHint)
|
||||
if err != nil {
|
||||
return "", &common.TokenInvalidError{}
|
||||
}
|
||||
|
||||
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request
|
||||
if input.ClientId != "" && claims.Audience[0] != input.ClientId {
|
||||
return "", &common.OidcClientIdNotMatchingError{}
|
||||
}
|
||||
|
||||
clientId := claims.Audience[0]
|
||||
|
||||
// Check if the user has authorized the client before
|
||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientId, userID).Error; err != nil {
|
||||
return "", &common.OidcMissingAuthorizationError{}
|
||||
}
|
||||
|
||||
// If the client has no logout callback URLs, return an error
|
||||
if len(userAuthorizedOIDCClient.Client.LogoutCallbackURLs) == 0 {
|
||||
return "", &common.OidcNoCallbackURLError{}
|
||||
}
|
||||
|
||||
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client.LogoutCallbackURLs, input.PostLogoutRedirectUri)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return callbackURL, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
@@ -506,12 +548,12 @@ func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, c
|
||||
return encodedVerifierHash == codeChallenge
|
||||
}
|
||||
|
||||
func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
|
||||
func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (callbackURL string, err error) {
|
||||
if inputCallbackURL == "" {
|
||||
return client.CallbackURLs[0], nil
|
||||
return urls[0], nil
|
||||
}
|
||||
|
||||
for _, callbackPattern := range client.CallbackURLs {
|
||||
for _, callbackPattern := range urls {
|
||||
regexPattern := strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$"
|
||||
matched, err := regexp.MatchString(regexPattern, inputCallbackURL)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user