feat: add PKCE support

This commit is contained in:
Elias Schneider
2024-11-15 15:00:25 +01:00
parent 760c8e83bb
commit 3613ac261c
15 changed files with 188 additions and 86 deletions

View File

@@ -102,7 +102,7 @@ func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyR
type ClientIdOrSecretNotProvidedError struct{}
func (e *ClientIdOrSecretNotProvidedError) Error() string {
return "Client id and secret not provided"
return "Client id or secret not provided"
}
func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest }
@@ -146,3 +146,17 @@ func (e *AccountEditNotAllowedError) Error() string {
return "You are not allowed to edit your account"
}
func (e *AccountEditNotAllowedError) HttpStatusCode() int { return http.StatusForbidden }
type OidcInvalidCodeVerifierError struct{}
func (e *OidcInvalidCodeVerifierError) Error() string {
return "Invalid code verifier"
}
func (e *OidcInvalidCodeVerifierError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcMissingCodeChallengeError struct{}
func (e *OidcMissingCodeChallengeError) Error() string {
return "Missing code challenge"
}
func (e *OidcMissingCodeChallengeError) HttpStatusCode() int { return http.StatusBadRequest }

View File

@@ -2,7 +2,6 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/service"
@@ -80,7 +79,10 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
}
func (oc *OidcController) createTokensHandler(c *gin.Context) {
var input dto.OidcIdTokenDto
// Disable cors for this endpoint
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
var input dto.OidcCreateTokensDto
if err := c.ShouldBind(&input); err != nil {
c.Error(err)
@@ -91,16 +93,11 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
clientSecret := input.ClientSecret
// Client id and secret can also be passed over the Authorization header
if clientID == "" || clientSecret == "" {
var ok bool
clientID, clientSecret, ok = c.Request.BasicAuth()
if !ok {
c.Error(&common.ClientIdOrSecretNotProvidedError{})
return
}
if clientID == "" && clientSecret == "" {
clientID, clientSecret, _ = c.Request.BasicAuth()
}
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier)
if err != nil {
c.Error(err)
return

View File

@@ -9,19 +9,23 @@ type PublicOidcClientDto struct {
type OidcClientDto struct {
PublicOidcClientDto
CallbackURLs []string `json:"callbackURLs"`
IsPublic bool `json:"isPublic"`
CreatedBy UserDto `json:"createdBy"`
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
IsPublic bool `json:"isPublic"`
}
type AuthorizeOidcClientRequestDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
type AuthorizeOidcClientResponseDto struct {
@@ -29,9 +33,10 @@ type AuthorizeOidcClientResponseDto struct {
CallbackURL string `json:"callbackURL"`
}
type OidcIdTokenDto struct {
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
}

View File

@@ -1,11 +1,8 @@
package middleware
import (
"github.com/stonith404/pocket-id/backend/internal/common"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
)
type CorsMiddleware struct{}
@@ -15,10 +12,22 @@ func NewCorsMiddleware() *CorsMiddleware {
}
func (m *CorsMiddleware) Add() gin.HandlerFunc {
return cors.New(cors.Config{
AllowOrigins: []string{common.EnvConfig.AppURL},
AllowMethods: []string{"*"},
AllowHeaders: []string{"*"},
MaxAge: 12 * time.Hour,
})
return func(c *gin.Context) {
// Allow all origins for the token endpoint
if c.FullPath() == "/api/oidc/token" {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else {
c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL)
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@@ -20,10 +20,12 @@ type UserAuthorizedOidcClient struct {
type OidcAuthorizationCode struct {
Base
Code string
Scope string
Nonce string
ExpiresAt datatype.DateTime
Code string
Scope string
Nonce string
CodeChallenge *string
CodeChallengeMethodSha256 *bool
ExpiresAt datatype.DateTime
UserID string
User User
@@ -39,6 +41,7 @@ type OidcClient struct {
CallbackURLs CallbackURLs
ImageType *string
HasLogo bool `gorm:"-"`
IsPublic bool
CreatedByID string
CreatedBy User

View File

@@ -1,6 +1,8 @@
package service
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
@@ -39,16 +41,20 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
if userAuthorizedOIDCClient.Client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
if userAuthorizedOIDCClient.Scope != input.Scope {
return "", "", &common.OidcMissingAuthorizationError{}
}
callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
if err != nil {
return "", "", err
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
@@ -64,7 +70,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return "", "", err
}
callbackURL, err := getCallbackURL(client, input.CallbackURL)
if client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
if err != nil {
return "", "", err
}
@@ -83,7 +93,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
}
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
@@ -93,31 +103,41 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return code, callbackURL, nil
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
}
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
// If the client is public, the code verifier must match the code challenge
if client.IsPublic {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
}
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
@@ -186,6 +206,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.IsPublic = input.IsPublic
if err := s.db.Save(&client).Error; err != nil {
return model.OidcClient{}, err
@@ -358,19 +379,23 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
return claims, nil
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}
codeChallengeMethodSha256 := strings.ToUpper(codeChallengeMethod) == "S256"
oidcAuthorizationCode := model.OidcAuthorizationCode{
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
CodeChallenge: &codeChallenge,
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
}
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
@@ -380,7 +405,23 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
return randomString, nil
}
func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, codeChallengeMethodSha256 bool) bool {
if !codeChallengeMethodSha256 {
return codeVerifier == codeChallenge
}
// Compute SHA-256 hash of the codeVerifier
h := sha256.New()
h.Write([]byte(codeVerifier))
codeVerifierHash := h.Sum(nil)
// Base64 URL encode the verifier hash
encodedVerifierHash := base64.RawURLEncoding.EncodeToString(codeVerifierHash)
return encodedVerifierHash == codeChallenge
}
func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
if inputCallbackURL == "" {
return client.CallbackURLs[0], nil
}

View File

@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge;
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge_method_sha256;
ALTER TABLE oidc_clients DROP COLUMN is_public;

View File

@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge TEXT;
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge_method_sha256 NUMERIC;
ALTER TABLE oidc_clients ADD COLUMN is_public BOOLEAN DEFAULT FALSE;