feat: add PKCE support

This commit is contained in:
Elias Schneider
2024-11-15 15:00:25 +01:00
parent 760c8e83bb
commit 3613ac261c
15 changed files with 188 additions and 86 deletions

View File

@@ -1,6 +1,8 @@
package service
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
@@ -39,16 +41,20 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
if userAuthorizedOIDCClient.Client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
if userAuthorizedOIDCClient.Scope != input.Scope {
return "", "", &common.OidcMissingAuthorizationError{}
}
callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
if err != nil {
return "", "", err
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
@@ -64,7 +70,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return "", "", err
}
callbackURL, err := getCallbackURL(client, input.CallbackURL)
if client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
if err != nil {
return "", "", err
}
@@ -83,7 +93,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
}
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
@@ -93,31 +103,41 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return code, callbackURL, nil
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
}
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
// If the client is public, the code verifier must match the code challenge
if client.IsPublic {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
}
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
@@ -186,6 +206,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.IsPublic = input.IsPublic
if err := s.db.Save(&client).Error; err != nil {
return model.OidcClient{}, err
@@ -358,19 +379,23 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
return claims, nil
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}
codeChallengeMethodSha256 := strings.ToUpper(codeChallengeMethod) == "S256"
oidcAuthorizationCode := model.OidcAuthorizationCode{
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
CodeChallenge: &codeChallenge,
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
}
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
@@ -380,7 +405,23 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
return randomString, nil
}
func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, codeChallengeMethodSha256 bool) bool {
if !codeChallengeMethodSha256 {
return codeVerifier == codeChallenge
}
// Compute SHA-256 hash of the codeVerifier
h := sha256.New()
h.Write([]byte(codeVerifier))
codeVerifierHash := h.Sum(nil)
// Base64 URL encode the verifier hash
encodedVerifierHash := base64.RawURLEncoding.EncodeToString(codeVerifierHash)
return encodedVerifierHash == codeChallenge
}
func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
if inputCallbackURL == "" {
return client.CallbackURLs[0], nil
}