diff --git a/README.md b/README.md index 76ac43d..848f184 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Additionally, what makes Pocket ID special is that it only supports [passkey](ht ## Setup > [!WARNING] -> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue. For example PKCE is not yet implemented. +> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue. ### Before you start diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index 4aad22f..e9d2a9f 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -102,7 +102,7 @@ func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyR type ClientIdOrSecretNotProvidedError struct{} func (e *ClientIdOrSecretNotProvidedError) Error() string { - return "Client id and secret not provided" + return "Client id or secret not provided" } func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest } @@ -146,3 +146,17 @@ func (e *AccountEditNotAllowedError) Error() string { return "You are not allowed to edit your account" } func (e *AccountEditNotAllowedError) HttpStatusCode() int { return http.StatusForbidden } + +type OidcInvalidCodeVerifierError struct{} + +func (e *OidcInvalidCodeVerifierError) Error() string { + return "Invalid code verifier" +} +func (e *OidcInvalidCodeVerifierError) HttpStatusCode() int { return http.StatusBadRequest } + +type OidcMissingCodeChallengeError struct{} + +func (e *OidcMissingCodeChallengeError) Error() string { + return "Missing code challenge" +} +func (e *OidcMissingCodeChallengeError) HttpStatusCode() int { return http.StatusBadRequest } diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 49934cb..e3950be 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -2,7 +2,6 @@ package controller import ( "github.com/gin-gonic/gin" - "github.com/stonith404/pocket-id/backend/internal/common" "github.com/stonith404/pocket-id/backend/internal/dto" "github.com/stonith404/pocket-id/backend/internal/middleware" "github.com/stonith404/pocket-id/backend/internal/service" @@ -80,7 +79,10 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { } func (oc *OidcController) createTokensHandler(c *gin.Context) { - var input dto.OidcIdTokenDto + // Disable cors for this endpoint + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + + var input dto.OidcCreateTokensDto if err := c.ShouldBind(&input); err != nil { c.Error(err) @@ -91,16 +93,11 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) { clientSecret := input.ClientSecret // Client id and secret can also be passed over the Authorization header - if clientID == "" || clientSecret == "" { - var ok bool - clientID, clientSecret, ok = c.Request.BasicAuth() - if !ok { - c.Error(&common.ClientIdOrSecretNotProvidedError{}) - return - } + if clientID == "" && clientSecret == "" { + clientID, clientSecret, _ = c.Request.BasicAuth() } - idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret) + idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier) if err != nil { c.Error(err) return diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index 00c53bf..e2e8a97 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -9,19 +9,23 @@ type PublicOidcClientDto struct { type OidcClientDto struct { PublicOidcClientDto CallbackURLs []string `json:"callbackURLs"` + IsPublic bool `json:"isPublic"` CreatedBy UserDto `json:"createdBy"` } type OidcClientCreateDto struct { Name string `json:"name" binding:"required,max=50"` CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"` + IsPublic bool `json:"isPublic"` } type AuthorizeOidcClientRequestDto struct { - ClientID string `json:"clientID" binding:"required"` - Scope string `json:"scope" binding:"required"` - CallbackURL string `json:"callbackURL"` - Nonce string `json:"nonce"` + ClientID string `json:"clientID" binding:"required"` + Scope string `json:"scope" binding:"required"` + CallbackURL string `json:"callbackURL"` + Nonce string `json:"nonce"` + CodeChallenge string `json:"codeChallenge"` + CodeChallengeMethod string `json:"codeChallengeMethod"` } type AuthorizeOidcClientResponseDto struct { @@ -29,9 +33,10 @@ type AuthorizeOidcClientResponseDto struct { CallbackURL string `json:"callbackURL"` } -type OidcIdTokenDto struct { +type OidcCreateTokensDto struct { GrantType string `form:"grant_type" binding:"required"` Code string `form:"code" binding:"required"` ClientID string `form:"client_id"` ClientSecret string `form:"client_secret"` + CodeVerifier string `form:"code_verifier"` } diff --git a/backend/internal/middleware/cors.go b/backend/internal/middleware/cors.go index 1cb09ce..91403f3 100644 --- a/backend/internal/middleware/cors.go +++ b/backend/internal/middleware/cors.go @@ -1,11 +1,8 @@ package middleware import ( - "github.com/stonith404/pocket-id/backend/internal/common" - "time" - - "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "github.com/stonith404/pocket-id/backend/internal/common" ) type CorsMiddleware struct{} @@ -15,10 +12,22 @@ func NewCorsMiddleware() *CorsMiddleware { } func (m *CorsMiddleware) Add() gin.HandlerFunc { - return cors.New(cors.Config{ - AllowOrigins: []string{common.EnvConfig.AppURL}, - AllowMethods: []string{"*"}, - AllowHeaders: []string{"*"}, - MaxAge: 12 * time.Hour, - }) + return func(c *gin.Context) { + // Allow all origins for the token endpoint + if c.FullPath() == "/api/oidc/token" { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + } else { + c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL) + } + + c.Writer.Header().Set("Access-Control-Allow-Headers", "*") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + } } diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 7b0dacc..bf0904f 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -20,10 +20,12 @@ type UserAuthorizedOidcClient struct { type OidcAuthorizationCode struct { Base - Code string - Scope string - Nonce string - ExpiresAt datatype.DateTime + Code string + Scope string + Nonce string + CodeChallenge *string + CodeChallengeMethodSha256 *bool + ExpiresAt datatype.DateTime UserID string User User @@ -39,6 +41,7 @@ type OidcClient struct { CallbackURLs CallbackURLs ImageType *string HasLogo bool `gorm:"-"` + IsPublic bool CreatedByID string CreatedBy User diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 9036406..c838fed 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -1,6 +1,8 @@ package service import ( + "crypto/sha256" + "encoding/base64" "errors" "fmt" "github.com/stonith404/pocket-id/backend/internal/common" @@ -39,16 +41,20 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, var userAuthorizedOIDCClient model.UserAuthorizedOidcClient s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID) + if userAuthorizedOIDCClient.Client.IsPublic && input.CodeChallenge == "" { + return "", "", &common.OidcMissingCodeChallengeError{} + } + if userAuthorizedOIDCClient.Scope != input.Scope { return "", "", &common.OidcMissingAuthorizationError{} } - callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL) + callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL) if err != nil { return "", "", err } - code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce) + code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod) if err != nil { return "", "", err } @@ -64,7 +70,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto return "", "", err } - callbackURL, err := getCallbackURL(client, input.CallbackURL) + if client.IsPublic && input.CodeChallenge == "" { + return "", "", &common.OidcMissingCodeChallengeError{} + } + + callbackURL, err := s.getCallbackURL(client, input.CallbackURL) if err != nil { return "", "", err } @@ -83,7 +93,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto } } - code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce) + code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod) if err != nil { return "", "", err } @@ -93,31 +103,41 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto return code, callbackURL, nil } -func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) { +func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) { if grantType != "authorization_code" { return "", "", &common.OidcGrantTypeNotSupportedError{} } - if clientID == "" || clientSecret == "" { - return "", "", &common.OidcMissingClientCredentialsError{} - } - var client model.OidcClient if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { return "", "", err } - err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) - if err != nil { - return "", "", &common.OidcClientSecretInvalidError{} + // Verify the client secret if the client is not public + if !client.IsPublic { + if clientID == "" || clientSecret == "" { + return "", "", &common.OidcMissingClientCredentialsError{} + } + + err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) + if err != nil { + return "", "", &common.OidcClientSecretInvalidError{} + } } var authorizationCodeMetaData model.OidcAuthorizationCode - err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error + err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error if err != nil { return "", "", &common.OidcInvalidAuthorizationCodeError{} } + // If the client is public, the code verifier must match the code challenge + if client.IsPublic { + if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) { + return "", "", &common.OidcInvalidCodeVerifierError{} + } + } + if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { return "", "", &common.OidcInvalidAuthorizationCodeError{} } @@ -186,6 +206,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt client.Name = input.Name client.CallbackURLs = input.CallbackURLs + client.IsPublic = input.IsPublic if err := s.db.Save(&client).Error; err != nil { return model.OidcClient{}, err @@ -358,19 +379,23 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma return claims, nil } -func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) { +func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) { randomString, err := utils.GenerateRandomAlphanumericString(32) if err != nil { return "", err } + codeChallengeMethodSha256 := strings.ToUpper(codeChallengeMethod) == "S256" + oidcAuthorizationCode := model.OidcAuthorizationCode{ - ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)), - Code: randomString, - ClientID: clientID, - UserID: userID, - Scope: scope, - Nonce: nonce, + ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)), + Code: randomString, + ClientID: clientID, + UserID: userID, + Scope: scope, + Nonce: nonce, + CodeChallenge: &codeChallenge, + CodeChallengeMethodSha256: &codeChallengeMethodSha256, } if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil { @@ -380,7 +405,23 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc return randomString, nil } -func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) { +func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, codeChallengeMethodSha256 bool) bool { + if !codeChallengeMethodSha256 { + return codeVerifier == codeChallenge + } + + // Compute SHA-256 hash of the codeVerifier + h := sha256.New() + h.Write([]byte(codeVerifier)) + codeVerifierHash := h.Sum(nil) + + // Base64 URL encode the verifier hash + encodedVerifierHash := base64.RawURLEncoding.EncodeToString(codeVerifierHash) + + return encodedVerifierHash == codeChallenge +} + +func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) { if inputCallbackURL == "" { return client.CallbackURLs[0], nil } diff --git a/backend/migrations/20241115131129_pkce.down.sql b/backend/migrations/20241115131129_pkce.down.sql new file mode 100644 index 0000000..1ec7b0c --- /dev/null +++ b/backend/migrations/20241115131129_pkce.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge; +ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge_method_sha256; +ALTER TABLE oidc_clients DROP COLUMN is_public; \ No newline at end of file diff --git a/backend/migrations/20241115131129_pkce.up.sql b/backend/migrations/20241115131129_pkce.up.sql new file mode 100644 index 0000000..db4b58a --- /dev/null +++ b/backend/migrations/20241115131129_pkce.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge TEXT; +ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge_method_sha256 NUMERIC; +ALTER TABLE oidc_clients ADD COLUMN is_public BOOLEAN DEFAULT FALSE; \ No newline at end of file diff --git a/frontend/src/lib/services/oidc-service.ts b/frontend/src/lib/services/oidc-service.ts index 28d64f8..6bdce0c 100644 --- a/frontend/src/lib/services/oidc-service.ts +++ b/frontend/src/lib/services/oidc-service.ts @@ -3,23 +3,27 @@ import type { Paginated, PaginationRequest } from '$lib/types/pagination.type'; import APIService from './api-service'; class OidcService extends APIService { - async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string) { + async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string, codeChallenge?: string, codeChallengeMethod?: string) { const res = await this.api.post('/oidc/authorize', { scope, nonce, callbackURL, - clientId + clientId, + codeChallenge, + codeChallengeMethod }); return res.data as AuthorizeResponse; } - async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string) { + async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string, codeChallenge?: string, codeChallengeMethod?: string) { const res = await this.api.post('/oidc/authorize/new-client', { scope, nonce, callbackURL, - clientId + clientId, + codeChallenge, + codeChallengeMethod }); return res.data as AuthorizeResponse; diff --git a/frontend/src/lib/types/oidc.type.ts b/frontend/src/lib/types/oidc.type.ts index 459e973..8278acc 100644 --- a/frontend/src/lib/types/oidc.type.ts +++ b/frontend/src/lib/types/oidc.type.ts @@ -4,6 +4,7 @@ export type OidcClient = { logoURL: string; callbackURLs: [string, ...string[]]; hasLogo: boolean; + isPublic: boolean; }; export type OidcClientCreate = Omit; diff --git a/frontend/src/routes/authorize/+page.server.ts b/frontend/src/routes/authorize/+page.server.ts index ace06d5..ed81215 100644 --- a/frontend/src/routes/authorize/+page.server.ts +++ b/frontend/src/routes/authorize/+page.server.ts @@ -12,6 +12,8 @@ export const load: PageServerLoad = async ({ url, cookies }) => { nonce: url.searchParams.get('nonce') || undefined, state: url.searchParams.get('state')!, callbackURL: url.searchParams.get('redirect_uri')!, - client + client, + codeChallenge: url.searchParams.get('code_challenge')!, + codeChallengeMethod: url.searchParams.get('code_challenge_method')! }; }; diff --git a/frontend/src/routes/authorize/+page.svelte b/frontend/src/routes/authorize/+page.svelte index 4e5667f..06171c9 100644 --- a/frontend/src/routes/authorize/+page.svelte +++ b/frontend/src/routes/authorize/+page.svelte @@ -24,7 +24,7 @@ let authorizationRequired = false; export let data: PageData; - let { scope, nonce, client, state, callbackURL } = data; + let { scope, nonce, client, state, callbackURL, codeChallenge, codeChallengeMethod } = data; async function authorize() { isLoading = true; @@ -37,7 +37,7 @@ } await oidService - .authorize(client!.id, scope, callbackURL, nonce) + .authorize(client!.id, scope, callbackURL, nonce, codeChallenge, codeChallengeMethod) .then(async ({ code, callbackURL }) => { onSuccess(code, callbackURL); }); @@ -55,7 +55,7 @@ isLoading = true; try { await oidService - .authorizeNewClient(client!.id, scope, callbackURL, nonce) + .authorizeNewClient(client!.id, scope, callbackURL, nonce, codeChallenge, codeChallengeMethod) .then(async ({ code, callbackURL }) => { onSuccess(code, callbackURL); }); diff --git a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte index 1085fe0..6be93a0 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/[id]/+page.svelte @@ -26,7 +26,8 @@ 'OIDC Discovery URL': `https://${$page.url.hostname}/.well-known/openid-configuration`, 'Token URL': `https://${$page.url.hostname}/api/oidc/token`, 'Userinfo URL': `https://${$page.url.hostname}/api/oidc/userinfo`, - 'Certificate URL': `https://${$page.url.hostname}/.well-known/jwks.json` + 'Certificate URL': `https://${$page.url.hostname}/.well-known/jwks.json`, + PKCE: client.isPublic ? 'Enabled' : 'Disabled' }; async function updateClient(updatedClient: OidcClientCreateWithLogo) { @@ -34,6 +35,8 @@ const dataPromise = oidcService.updateClient(client.id, updatedClient); const imagePromise = oidcService.updateClientLogo(client, updatedClient.logo); + client.isPublic = updatedClient.isPublic; + await Promise.all([dataPromise, imagePromise]) .then(() => { toast.success('OIDC client updated successfully'); @@ -93,27 +96,29 @@ {client.id} -
- - {#if $clientSecretStore} - - - {$clientSecretStore} - - - {:else} - •••••••••••••••••••••••••••••••• - - {/if} -
+ {#if !client.isPublic} +
+ + {#if $clientSecretStore} + + + {$clientSecretStore} + + + {:else} + •••••••••••••••••••••••••••••••• + + {/if} +
+ {/if} {#if showAllDetails}
{#each Object.entries(setupDetails) as [key, value]} diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte index 286661d..f5dd115 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte @@ -2,6 +2,7 @@ import FileInput from '$lib/components/file-input.svelte'; import FormInput from '$lib/components/form-input.svelte'; import { Button } from '$lib/components/ui/button'; + import { Checkbox } from '$lib/components/ui/checkbox'; import Label from '$lib/components/ui/label/label.svelte'; import type { OidcClient, @@ -28,12 +29,14 @@ const client: OidcClientCreate = { name: existingClient?.name || '', - callbackURLs: existingClient?.callbackURLs || [""] + callbackURLs: existingClient?.callbackURLs || [''], + isPublic: existingClient?.isPublic || false }; const formSchema = z.object({ name: z.string().min(2).max(50), - callbackURLs: z.array(z.string().url()).nonempty() + callbackURLs: z.array(z.string().url()).nonempty(), + isPublic: z.boolean() }); type FormSchema = typeof formSchema; @@ -71,15 +74,27 @@
-
+
+
+ +
+ +

+ Public clients do not have a client secret and use PKCE instead. Enable this if your + client is a SPA or mobile app. +

+
+
-
+
{#if logoDataURL}