feat: add PKCE support

This commit is contained in:
Elias Schneider
2024-11-15 15:00:25 +01:00
parent 760c8e83bb
commit 3613ac261c
15 changed files with 188 additions and 86 deletions

View File

@@ -11,7 +11,7 @@ Additionally, what makes Pocket ID special is that it only supports [passkey](ht
## Setup
> [!WARNING]
> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue. For example PKCE is not yet implemented.
> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue.
### Before you start

View File

@@ -102,7 +102,7 @@ func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyR
type ClientIdOrSecretNotProvidedError struct{}
func (e *ClientIdOrSecretNotProvidedError) Error() string {
return "Client id and secret not provided"
return "Client id or secret not provided"
}
func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest }
@@ -146,3 +146,17 @@ func (e *AccountEditNotAllowedError) Error() string {
return "You are not allowed to edit your account"
}
func (e *AccountEditNotAllowedError) HttpStatusCode() int { return http.StatusForbidden }
type OidcInvalidCodeVerifierError struct{}
func (e *OidcInvalidCodeVerifierError) Error() string {
return "Invalid code verifier"
}
func (e *OidcInvalidCodeVerifierError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcMissingCodeChallengeError struct{}
func (e *OidcMissingCodeChallengeError) Error() string {
return "Missing code challenge"
}
func (e *OidcMissingCodeChallengeError) HttpStatusCode() int { return http.StatusBadRequest }

View File

@@ -2,7 +2,6 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/service"
@@ -80,7 +79,10 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
}
func (oc *OidcController) createTokensHandler(c *gin.Context) {
var input dto.OidcIdTokenDto
// Disable cors for this endpoint
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
var input dto.OidcCreateTokensDto
if err := c.ShouldBind(&input); err != nil {
c.Error(err)
@@ -91,16 +93,11 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
clientSecret := input.ClientSecret
// Client id and secret can also be passed over the Authorization header
if clientID == "" || clientSecret == "" {
var ok bool
clientID, clientSecret, ok = c.Request.BasicAuth()
if !ok {
c.Error(&common.ClientIdOrSecretNotProvidedError{})
return
}
if clientID == "" && clientSecret == "" {
clientID, clientSecret, _ = c.Request.BasicAuth()
}
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier)
if err != nil {
c.Error(err)
return

View File

@@ -9,12 +9,14 @@ type PublicOidcClientDto struct {
type OidcClientDto struct {
PublicOidcClientDto
CallbackURLs []string `json:"callbackURLs"`
IsPublic bool `json:"isPublic"`
CreatedBy UserDto `json:"createdBy"`
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
IsPublic bool `json:"isPublic"`
}
type AuthorizeOidcClientRequestDto struct {
@@ -22,6 +24,8 @@ type AuthorizeOidcClientRequestDto struct {
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
type AuthorizeOidcClientResponseDto struct {
@@ -29,9 +33,10 @@ type AuthorizeOidcClientResponseDto struct {
CallbackURL string `json:"callbackURL"`
}
type OidcIdTokenDto struct {
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
}

View File

@@ -1,11 +1,8 @@
package middleware
import (
"github.com/stonith404/pocket-id/backend/internal/common"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
)
type CorsMiddleware struct{}
@@ -15,10 +12,22 @@ func NewCorsMiddleware() *CorsMiddleware {
}
func (m *CorsMiddleware) Add() gin.HandlerFunc {
return cors.New(cors.Config{
AllowOrigins: []string{common.EnvConfig.AppURL},
AllowMethods: []string{"*"},
AllowHeaders: []string{"*"},
MaxAge: 12 * time.Hour,
})
return func(c *gin.Context) {
// Allow all origins for the token endpoint
if c.FullPath() == "/api/oidc/token" {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else {
c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL)
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@@ -23,6 +23,8 @@ type OidcAuthorizationCode struct {
Code string
Scope string
Nonce string
CodeChallenge *string
CodeChallengeMethodSha256 *bool
ExpiresAt datatype.DateTime
UserID string
@@ -39,6 +41,7 @@ type OidcClient struct {
CallbackURLs CallbackURLs
ImageType *string
HasLogo bool `gorm:"-"`
IsPublic bool
CreatedByID string
CreatedBy User

View File

@@ -1,6 +1,8 @@
package service
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
@@ -39,16 +41,20 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
if userAuthorizedOIDCClient.Client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
if userAuthorizedOIDCClient.Scope != input.Scope {
return "", "", &common.OidcMissingAuthorizationError{}
}
callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
if err != nil {
return "", "", err
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
@@ -64,7 +70,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return "", "", err
}
callbackURL, err := getCallbackURL(client, input.CallbackURL)
if client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
if err != nil {
return "", "", err
}
@@ -83,7 +93,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
}
}
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
@@ -93,31 +103,41 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return code, callbackURL, nil
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
}
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
// If the client is public, the code verifier must match the code challenge
if client.IsPublic {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
}
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
@@ -186,6 +206,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.IsPublic = input.IsPublic
if err := s.db.Save(&client).Error; err != nil {
return model.OidcClient{}, err
@@ -358,12 +379,14 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
return claims, nil
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}
codeChallengeMethodSha256 := strings.ToUpper(codeChallengeMethod) == "S256"
oidcAuthorizationCode := model.OidcAuthorizationCode{
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
@@ -371,6 +394,8 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
UserID: userID,
Scope: scope,
Nonce: nonce,
CodeChallenge: &codeChallenge,
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
}
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
@@ -380,7 +405,23 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
return randomString, nil
}
func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, codeChallengeMethodSha256 bool) bool {
if !codeChallengeMethodSha256 {
return codeVerifier == codeChallenge
}
// Compute SHA-256 hash of the codeVerifier
h := sha256.New()
h.Write([]byte(codeVerifier))
codeVerifierHash := h.Sum(nil)
// Base64 URL encode the verifier hash
encodedVerifierHash := base64.RawURLEncoding.EncodeToString(codeVerifierHash)
return encodedVerifierHash == codeChallenge
}
func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
if inputCallbackURL == "" {
return client.CallbackURLs[0], nil
}

View File

@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge;
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge_method_sha256;
ALTER TABLE oidc_clients DROP COLUMN is_public;

View File

@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge TEXT;
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge_method_sha256 NUMERIC;
ALTER TABLE oidc_clients ADD COLUMN is_public BOOLEAN DEFAULT FALSE;

View File

@@ -3,23 +3,27 @@ import type { Paginated, PaginationRequest } from '$lib/types/pagination.type';
import APIService from './api-service';
class OidcService extends APIService {
async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string) {
async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string, codeChallenge?: string, codeChallengeMethod?: string) {
const res = await this.api.post('/oidc/authorize', {
scope,
nonce,
callbackURL,
clientId
clientId,
codeChallenge,
codeChallengeMethod
});
return res.data as AuthorizeResponse;
}
async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string) {
async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string, codeChallenge?: string, codeChallengeMethod?: string) {
const res = await this.api.post('/oidc/authorize/new-client', {
scope,
nonce,
callbackURL,
clientId
clientId,
codeChallenge,
codeChallengeMethod
});
return res.data as AuthorizeResponse;

View File

@@ -4,6 +4,7 @@ export type OidcClient = {
logoURL: string;
callbackURLs: [string, ...string[]];
hasLogo: boolean;
isPublic: boolean;
};
export type OidcClientCreate = Omit<OidcClient, 'id' | 'logoURL' | 'hasLogo'>;

View File

@@ -12,6 +12,8 @@ export const load: PageServerLoad = async ({ url, cookies }) => {
nonce: url.searchParams.get('nonce') || undefined,
state: url.searchParams.get('state')!,
callbackURL: url.searchParams.get('redirect_uri')!,
client
client,
codeChallenge: url.searchParams.get('code_challenge')!,
codeChallengeMethod: url.searchParams.get('code_challenge_method')!
};
};

View File

@@ -24,7 +24,7 @@
let authorizationRequired = false;
export let data: PageData;
let { scope, nonce, client, state, callbackURL } = data;
let { scope, nonce, client, state, callbackURL, codeChallenge, codeChallengeMethod } = data;
async function authorize() {
isLoading = true;
@@ -37,7 +37,7 @@
}
await oidService
.authorize(client!.id, scope, callbackURL, nonce)
.authorize(client!.id, scope, callbackURL, nonce, codeChallenge, codeChallengeMethod)
.then(async ({ code, callbackURL }) => {
onSuccess(code, callbackURL);
});
@@ -55,7 +55,7 @@
isLoading = true;
try {
await oidService
.authorizeNewClient(client!.id, scope, callbackURL, nonce)
.authorizeNewClient(client!.id, scope, callbackURL, nonce, codeChallenge, codeChallengeMethod)
.then(async ({ code, callbackURL }) => {
onSuccess(code, callbackURL);
});

View File

@@ -26,7 +26,8 @@
'OIDC Discovery URL': `https://${$page.url.hostname}/.well-known/openid-configuration`,
'Token URL': `https://${$page.url.hostname}/api/oidc/token`,
'Userinfo URL': `https://${$page.url.hostname}/api/oidc/userinfo`,
'Certificate URL': `https://${$page.url.hostname}/.well-known/jwks.json`
'Certificate URL': `https://${$page.url.hostname}/.well-known/jwks.json`,
PKCE: client.isPublic ? 'Enabled' : 'Disabled'
};
async function updateClient(updatedClient: OidcClientCreateWithLogo) {
@@ -34,6 +35,8 @@
const dataPromise = oidcService.updateClient(client.id, updatedClient);
const imagePromise = oidcService.updateClientLogo(client, updatedClient.logo);
client.isPublic = updatedClient.isPublic;
await Promise.all([dataPromise, imagePromise])
.then(() => {
toast.success('OIDC client updated successfully');
@@ -93,6 +96,7 @@
<span class="text-muted-foreground text-sm" data-testid="client-id"> {client.id}</span>
</CopyToClipboard>
</div>
{#if !client.isPublic}
<div class="mb-2 mt-1 flex items-center">
<Label class="w-44">Client secret</Label>
{#if $clientSecretStore}
@@ -114,6 +118,7 @@
>
{/if}
</div>
{/if}
{#if showAllDetails}
<div transition:slide>
{#each Object.entries(setupDetails) as [key, value]}

View File

@@ -2,6 +2,7 @@
import FileInput from '$lib/components/file-input.svelte';
import FormInput from '$lib/components/form-input.svelte';
import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox';
import Label from '$lib/components/ui/label/label.svelte';
import type {
OidcClient,
@@ -28,12 +29,14 @@
const client: OidcClientCreate = {
name: existingClient?.name || '',
callbackURLs: existingClient?.callbackURLs || [""]
callbackURLs: existingClient?.callbackURLs || [''],
isPublic: existingClient?.isPublic || false
};
const formSchema = z.object({
name: z.string().min(2).max(50),
callbackURLs: z.array(z.string().url()).nonempty()
callbackURLs: z.array(z.string().url()).nonempty(),
isPublic: z.boolean()
});
type FormSchema = typeof formSchema;
@@ -71,15 +74,27 @@
</script>
<form onsubmit={onSubmit}>
<div class="flex flex-col gap-3 sm:flex-row">
<div class="grid grid-cols-2 gap-3 sm:flex-row">
<FormInput label="Name" class="w-full" bind:input={$inputs.name} />
<OidcCallbackUrlInput
class="w-full"
bind:callbackURLs={$inputs.callbackURLs.value}
bind:error={$inputs.callbackURLs.error}
/>
<div class="items-top flex space-x-2">
<Checkbox id="admin-privileges" bind:checked={$inputs.isPublic.value} />
<div class="grid gap-1.5 leading-none">
<Label for="admin-privileges" class="mb-0 text-sm font-medium leading-none">
Public Client
</Label>
<p class="text-muted-foreground text-[0.8rem]">
Public clients do not have a client secret and use PKCE instead. Enable this if your
client is a SPA or mobile app.
</p>
</div>
<div class="mt-3">
</div>
</div>
<div class="mt-8">
<Label for="logo">Logo</Label>
<div class="mt-2 flex items-end gap-3">
{#if logoDataURL}