diff --git a/backend/internal/common/errors.go b/backend/internal/common/errors.go index 966a21d..3dc657a 100644 --- a/backend/internal/common/errors.go +++ b/backend/internal/common/errors.go @@ -6,7 +6,6 @@ var ( ErrUsernameTaken = errors.New("username is already taken") ErrEmailTaken = errors.New("email is already taken") ErrSetupAlreadyCompleted = errors.New("setup already completed") - ErrInvalidBody = errors.New("invalid request body") ErrTokenInvalidOrExpired = errors.New("token is invalid or expired") ErrOidcMissingAuthorization = errors.New("missing authorization") ErrOidcGrantTypeNotSupported = errors.New("grant type not supported") diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index f2eaf3b..89bdaf2 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -40,39 +40,55 @@ type OidcController struct { } func (oc *OidcController) authorizeHandler(c *gin.Context) { - var input dto.AuthorizeOidcClientDto + var input dto.AuthorizeOidcClientRequestDto if err := c.ShouldBindJSON(&input); err != nil { utils.ControllerError(c, err) return } - code, err := oc.oidcService.Authorize(input, c.GetString("userID")) + code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID")) if err != nil { if errors.Is(err, common.ErrOidcMissingAuthorization) { utils.CustomControllerError(c, http.StatusForbidden, err.Error()) + } else if errors.Is(err, common.ErrOidcInvalidCallbackURL) { + utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) } else { utils.ControllerError(c, err) } return } - c.JSON(http.StatusOK, gin.H{"code": code}) + response := dto.AuthorizeOidcClientResponseDto{ + Code: code, + CallbackURL: callbackURL, + } + + c.JSON(http.StatusOK, response) } func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) { - var input dto.AuthorizeOidcClientDto + var input dto.AuthorizeOidcClientRequestDto if err := c.ShouldBindJSON(&input); err != nil { utils.ControllerError(c, err) return } - code, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID")) + code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID")) if err != nil { - utils.ControllerError(c, err) + if errors.Is(err, common.ErrOidcInvalidCallbackURL) { + utils.CustomControllerError(c, http.StatusBadRequest, err.Error()) + } else { + utils.ControllerError(c, err) + } return } - c.JSON(http.StatusOK, gin.H{"code": code}) + response := dto.AuthorizeOidcClientResponseDto{ + Code: code, + CallbackURL: callbackURL, + } + + c.JSON(http.StatusOK, response) } func (oc *OidcController) createIDTokenHandler(c *gin.Context) { diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index 3f01f62..4729c77 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -17,10 +17,16 @@ type OidcClientCreateDto struct { CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"` } -type AuthorizeOidcClientDto struct { - ClientID string `json:"clientID" binding:"required"` - Scope string `json:"scope" binding:"required"` - Nonce string `json:"nonce"` +type AuthorizeOidcClientRequestDto struct { + ClientID string `json:"clientID" binding:"required"` + Scope string `json:"scope" binding:"required"` + CallbackURL string `json:"callbackURL"` + Nonce string `json:"nonce"` +} + +type AuthorizeOidcClientResponseDto struct { + Code string `json:"code"` + CallbackURL string `json:"callbackURL"` } type OidcIdTokenDto struct { diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 8af9756..4d914a8 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -52,17 +52,14 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { type CallbackURLs []string -func (s *CallbackURLs) Scan(value interface{}) error { - switch v := value.(type) { - case []byte: - return json.Unmarshal(v, s) - case string: - return json.Unmarshal([]byte(v), s) - default: - return errors.New("type assertion to []byte or string failed") +func (cu *CallbackURLs) Scan(value interface{}) error { + if v, ok := value.([]byte); ok { + return json.Unmarshal(v, cu) + } else { + return errors.New("type assertion to []byte failed") } } -func (atl CallbackURLs) Value() (driver.Value, error) { - return json.Marshal(atl) +func (cu CallbackURLs) Value() (driver.Value, error) { + return json.Marshal(cu) } diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index f34677d..76605b1 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -11,6 +11,7 @@ import ( "gorm.io/gorm" "mime/multipart" "os" + "slices" "strings" "time" ) @@ -27,33 +28,50 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService { } } -func (s *OidcService) Authorize(req dto.AuthorizeOidcClientDto, userID string) (string, error) { +func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID string) (string, string, error) { var userAuthorizedOIDCClient model.UserAuthorizedOidcClient - s.db.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", req.ClientID, userID) + s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID) - if userAuthorizedOIDCClient.Scope != req.Scope { - return "", common.ErrOidcMissingAuthorization + if userAuthorizedOIDCClient.Scope != input.Scope { + return "", "", common.ErrOidcMissingAuthorization } - return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce) + callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL) + if err != nil { + return "", "", err + } + + code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce) + return code, callbackURL, err } -func (s *OidcService) AuthorizeNewClient(req dto.AuthorizeOidcClientDto, userID string) (string, error) { +func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto, userID string) (string, string, error) { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", input.ClientID).Error; err != nil { + return "", "", err + } + + callbackURL, err := getCallbackURL(client, input.CallbackURL) + if err != nil { + return "", "", err + } + userAuthorizedClient := model.UserAuthorizedOidcClient{ UserID: userID, - ClientID: req.ClientID, - Scope: req.Scope, + ClientID: input.ClientID, + Scope: input.Scope, } if err := s.db.Create(&userAuthorizedClient).Error; err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - err = s.db.Model(&userAuthorizedClient).Update("scope", req.Scope).Error + err = s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error } else { - return "", err + return "", "", err } } - return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce) + code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce) + return code, callbackURL, err } func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) { @@ -321,3 +339,14 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc return randomString, nil } + +func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) { + if inputCallbackURL == "" { + return client.CallbackURLs[0], nil + } + if slices.Contains(client.CallbackURLs, inputCallbackURL) { + return inputCallbackURL, nil + } + + return "", common.ErrOidcInvalidCallbackURL +} diff --git a/backend/internal/utils/controller_error_util.go b/backend/internal/utils/controller_error_util.go index db9eb3d..dea05d9 100644 --- a/backend/internal/utils/controller_error_util.go +++ b/backend/internal/utils/controller_error_util.go @@ -58,7 +58,7 @@ func handleValidationError(validationErrors validator.ValidationErrors) string { default: errorMessage = fmt.Sprintf("%s is invalid", fieldName) } - + errorMessages = append(errorMessages, errorMessage) } diff --git a/backend/migrations/20240731203656_init.up.sql b/backend/migrations/20240731203656_init.up.sql index 8d4baaa..54a1501 100644 --- a/backend/migrations/20240731203656_init.up.sql +++ b/backend/migrations/20240731203656_init.up.sql @@ -57,7 +57,7 @@ CREATE TABLE webauthn_credentials credential_id TEXT NOT NULL UNIQUE, public_key BLOB NOT NULL, attestation_type TEXT NOT NULL, - transport TEXT NOT NULL, + transport BLOB NOT NULL, user_id TEXT REFERENCES users ); diff --git a/backend/migrations/20240820205521_multiple_callback_urls.down.sql b/backend/migrations/20240820205521_multiple_callback_urls.down.sql new file mode 100644 index 0000000..48417bf --- /dev/null +++ b/backend/migrations/20240820205521_multiple_callback_urls.down.sql @@ -0,0 +1,25 @@ +create table oidc_clients +( + id TEXT not null + primary key, + created_at DATETIME, + name TEXT, + secret TEXT, + callback_url TEXT, + image_type TEXT, + created_by_id TEXT + references users +); + +insert into oidc_clients(id, created_at, name, secret, callback_url, image_type, created_by_id) +select + id, + created_at, + name, + secret, + json_extract(callback_urls, '$[0]'), + image_type, + created_by_id +from oidc_clients_dg_tmp; + +drop table oidc_clients_dg_tmp; \ No newline at end of file diff --git a/backend/migrations/20240820205521_multiple_callback_urls.up.sql b/backend/migrations/20240820205521_multiple_callback_urls.up.sql new file mode 100644 index 0000000..b50689e --- /dev/null +++ b/backend/migrations/20240820205521_multiple_callback_urls.up.sql @@ -0,0 +1,27 @@ +create table oidc_clients_dg_tmp +( + id TEXT not null + primary key, + created_at DATETIME, + name TEXT, + secret TEXT, + callback_urls BLOB, + image_type TEXT, + created_by_id TEXT + references users +); + +insert into oidc_clients_dg_tmp(id, created_at, name, secret, callback_urls, image_type, created_by_id) +select id, + created_at, + name, + secret, + CAST(json_group_array(json_quote(callback_url)) AS BLOB), + image_type, + created_by_id +from oidc_clients; + +drop table oidc_clients; + +alter table oidc_clients_dg_tmp + rename to oidc_clients; \ No newline at end of file diff --git a/frontend/src/lib/components/form-input.svelte b/frontend/src/lib/components/form-input.svelte index 5424e01..fc2f636 100644 --- a/frontend/src/lib/components/form-input.svelte +++ b/frontend/src/lib/components/form-input.svelte @@ -2,15 +2,17 @@ import { Label } from '$lib/components/ui/label'; import type { FormInput } from '$lib/utils/form-util'; import type { Snippet } from 'svelte'; + import type { HTMLAttributes } from 'svelte/elements'; import { Input } from './ui/input'; let { input = $bindable(), label, description, - children - }: { - input: FormInput; + children, + ...restProps + }: HTMLAttributes & { + input?: FormInput; label: string; description?: string; children?: Snippet; @@ -19,19 +21,19 @@ const id = label.toLowerCase().replace(/ /g, '-'); -
+
{#if description} -

{description}

+

{description}

{/if}
{#if children} {@render children()} - {:else} + {:else if input} {/if} - {#if input.error} -

{input.error}

+ {#if input?.error} +

{input.error}

{/if}
diff --git a/frontend/src/lib/services/oidc-service.ts b/frontend/src/lib/services/oidc-service.ts index 4d4ff26..cfe03f2 100644 --- a/frontend/src/lib/services/oidc-service.ts +++ b/frontend/src/lib/services/oidc-service.ts @@ -1,26 +1,28 @@ -import type { OidcClient, OidcClientCreate } from '$lib/types/oidc.type'; +import type { AuthorizeResponse, OidcClient, OidcClientCreate } from '$lib/types/oidc.type'; import type { Paginated, PaginationRequest } from '$lib/types/pagination.type'; import APIService from './api-service'; class OidcService extends APIService { - async authorize(clientId: string, scope: string, nonce?: string) { + async authorize(clientId: string, scope: string, callbackURL : string, nonce?: string) { const res = await this.api.post('/oidc/authorize', { scope, nonce, + callbackURL, clientId }); - return res.data.code as string; + return res.data as AuthorizeResponse; } - async authorizeNewClient(clientId: string, scope: string, nonce?: string) { + async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string) { const res = await this.api.post('/oidc/authorize/new-client', { scope, nonce, + callbackURL, clientId }); - return res.data.code as string; + return res.data as AuthorizeResponse; } async listClients(search?: string, pagination?: PaginationRequest) { diff --git a/frontend/src/lib/types/oidc.type.ts b/frontend/src/lib/types/oidc.type.ts index 30b6cca..459e973 100644 --- a/frontend/src/lib/types/oidc.type.ts +++ b/frontend/src/lib/types/oidc.type.ts @@ -2,7 +2,7 @@ export type OidcClient = { id: string; name: string; logoURL: string; - callbackURL: string; + callbackURLs: [string, ...string[]]; hasLogo: boolean; }; @@ -11,3 +11,8 @@ export type OidcClientCreate = Omit; export type OidcClientCreateWithLogo = OidcClientCreate & { logo: File | null; }; + +export type AuthorizeResponse = { + code: string; + callbackURL: string; +}; diff --git a/frontend/src/routes/authorize/+page.server.ts b/frontend/src/routes/authorize/+page.server.ts index 5783528..ace06d5 100644 --- a/frontend/src/routes/authorize/+page.server.ts +++ b/frontend/src/routes/authorize/+page.server.ts @@ -11,6 +11,7 @@ export const load: PageServerLoad = async ({ url, cookies }) => { scope: url.searchParams.get('scope')!, nonce: url.searchParams.get('nonce') || undefined, state: url.searchParams.get('state')!, + callbackURL: url.searchParams.get('redirect_uri')!, client }; }; diff --git a/frontend/src/routes/authorize/+page.svelte b/frontend/src/routes/authorize/+page.svelte index 76f7d50..57213fd 100644 --- a/frontend/src/routes/authorize/+page.svelte +++ b/frontend/src/routes/authorize/+page.svelte @@ -24,7 +24,7 @@ let authorizationRequired = false; export let data: PageData; - let { scope, nonce, client, state } = data; + let { scope, nonce, client, state, callbackURL } = data; async function authorize() { isLoading = true; @@ -36,9 +36,11 @@ await webauthnService.finishLogin(authResponse); } - await oidService.authorize(client!.id, scope, nonce).then(async (code) => { - onSuccess(code); - }); + await oidService + .authorize(client!.id, scope, callbackURL, nonce) + .then(async ({ code, callbackURL }) => { + onSuccess(code, callbackURL); + }); } catch (e) { if (e instanceof AxiosError && e.response?.status === 403) { authorizationRequired = true; @@ -52,19 +54,21 @@ async function authorizeNewClient() { isLoading = true; try { - await oidService.authorizeNewClient(client!.id, scope, nonce).then(async (code) => { - onSuccess(code); - }); + await oidService + .authorizeNewClient(client!.id, scope, callbackURL, nonce) + .then(async ({ code, callbackURL }) => { + onSuccess(code, callbackURL); + }); } catch (e) { errorMessage = getWebauthnErrorMessage(e); isLoading = false; } } - function onSuccess(code: string) { + function onSuccess(code: string, callbackURL: string) { success = true; setTimeout(() => { - window.location.href = `${client!.callbackURL}?code=${code}&state=${state}`; + window.location.href = `${callbackURL}?code=${code}&state=${state}`; }, 1000); } diff --git a/frontend/src/routes/settings/admin/application-configuration/application-configuration-form.svelte b/frontend/src/routes/settings/admin/application-configuration/application-configuration-form.svelte index 2791b07..81ab900 100644 --- a/frontend/src/routes/settings/admin/application-configuration/application-configuration-form.svelte +++ b/frontend/src/routes/settings/admin/application-configuration/application-configuration-form.svelte @@ -47,7 +47,6 @@
- + import FormInput from '$lib/components/form-input.svelte'; + import { Button } from '$lib/components/ui/button'; + import { Input } from '$lib/components/ui/input'; + import { LucideMinus, LucidePlus } from 'lucide-svelte'; + import type { Snippet } from 'svelte'; + import type { HTMLAttributes } from 'svelte/elements'; + + let { + callbackURLs = $bindable(), + error = $bindable(null), + ...restProps + }: HTMLAttributes & { + callbackURLs: string[]; + error?: string | null; + children?: Snippet; + } = $props(); + + const limit = 5; + + +
+ +
+ {#each callbackURLs as _, i} +
+ + {#if callbackURLs.length > 1} + + {/if} +
+ {/each} +
+
+ {#if error} +

{error}

+ {/if} + {#if callbackURLs.length < limit} + + {/if} +
diff --git a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte index 65e929e..286661d 100644 --- a/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte +++ b/frontend/src/routes/settings/admin/oidc-clients/oidc-client-form.svelte @@ -10,6 +10,7 @@ } from '$lib/types/oidc.type'; import { createForm } from '$lib/utils/form-util'; import { z } from 'zod'; + import OidcCallbackUrlInput from './oidc-callback-url-input.svelte'; let { callback, @@ -27,12 +28,12 @@ const client: OidcClientCreate = { name: existingClient?.name || '', - callbackURL: existingClient?.callbackURL || '' + callbackURLs: existingClient?.callbackURLs || [""] }; const formSchema = z.object({ name: z.string().min(2).max(50), - callbackURL: z.string().url() + callbackURLs: z.array(z.string().url()).nonempty() }); type FormSchema = typeof formSchema; @@ -70,32 +71,40 @@ -
- - -
- -
- {#if logoDataURL} -
- {`${$inputs.name.value} -
- {/if} -
- - {#if logoDataURL} - - {/if} +
+ + +
+
+ +
+ {#if logoDataURL} +
+ {`${$inputs.name.value}
+ {/if} +
+ + {#if logoDataURL} + + {/if}
diff --git a/frontend/src/routes/settings/admin/users/user-form.svelte b/frontend/src/routes/settings/admin/users/user-form.svelte index 3dce07c..be3fa19 100644 --- a/frontend/src/routes/settings/admin/users/user-form.svelte +++ b/frontend/src/routes/settings/admin/users/user-form.svelte @@ -26,9 +26,13 @@ }; const formSchema = z.object({ - firstName: z.string().min(2).max(50), - lastName: z.string().min(2).max(50), - username: z.string().min(2).max(50), + firstName: z.string().min(2).max(30), + lastName: z.string().min(2).max(30), + username: z + .string() + .min(2) + .max(30) + .regex(/^[a-z0-9_]+$/, 'Only lowercase letters, numbers, and underscores are allowed'), email: z.string().email(), isAdmin: z.boolean() }); @@ -66,10 +70,10 @@
-
diff --git a/frontend/tests/data.ts b/frontend/tests/data.ts index 2437d48..ba78262 100644 --- a/frontend/tests/data.ts +++ b/frontend/tests/data.ts @@ -23,17 +23,18 @@ export const users = { export const oidcClients = { nextcloud: { - id: "3654a746-35d4-4321-ac61-0bdcff2b4055", + id: '3654a746-35d4-4321-ac61-0bdcff2b4055', name: 'Nextcloud', callbackUrl: 'http://nextcloud/auth/callback' }, - immich: { - id: "606c7782-f2b1-49e5-8ea9-26eb1b06d018", - name: 'Immich', - callbackUrl: 'http://immich/auth/callback' - }, + immich: { + id: '606c7782-f2b1-49e5-8ea9-26eb1b06d018', + name: 'Immich', + callbackUrl: 'http://immich/auth/callback' + }, pingvinShare: { name: 'Pingvin Share', - callbackUrl: 'http://pingvin.share/auth/callback' + callbackUrl: 'http://pingvin.share/auth/callback', + secondCallbackUrl: 'http://pingvin.share/auth/callback2' } }; diff --git a/frontend/tests/oidc-client-settings.spec.ts b/frontend/tests/oidc-client-settings.spec.ts index a08053e..2e35c21 100644 --- a/frontend/tests/oidc-client-settings.spec.ts +++ b/frontend/tests/oidc-client-settings.spec.ts @@ -10,7 +10,11 @@ test('Create OIDC client', async ({ page }) => { await page.getByRole('button', { name: 'Add OIDC Client' }).click(); await page.getByLabel('Name').fill(oidcClient.name); - await page.getByLabel('Callback URL').fill(oidcClient.callbackUrl); + + await page.getByTestId('callback-url-1').fill(oidcClient.callbackUrl); + await page.getByRole('button', { name: 'Add another' }).click(); + await page.getByTestId('callback-url-2').fill(oidcClient.secondCallbackUrl!); + await page.getByLabel('logo').setInputFiles('tests/assets/pingvin-share-logo.png'); await page.getByRole('button', { name: 'Save' }).click(); @@ -20,7 +24,8 @@ test('Create OIDC client', async ({ page }) => { expect(clientId?.length).toBe(36); expect((await page.getByTestId('client-secret').textContent())?.length).toBe(32); await expect(page.getByLabel('Name')).toHaveValue(oidcClient.name); - await expect(page.getByLabel('Callback URL')).toHaveValue(oidcClient.callbackUrl); + await expect(page.getByTestId('callback-url-1')).toHaveValue(oidcClient.callbackUrl); + await expect(page.getByTestId('callback-url-2')).toHaveValue(oidcClient.secondCallbackUrl!); await expect(page.getByRole('img', { name: `${oidcClient.name} logo` })).toBeVisible(); await page.request .get(`/api/oidc/clients/${clientId}/logo`) @@ -32,7 +37,7 @@ test('Edit OIDC client', async ({ page }) => { await page.goto(`/settings/admin/oidc-clients/${oidcClient.id}`); await page.getByLabel('Name').fill('Nextcloud updated'); - await page.getByLabel('Callback URL').fill('http://nextcloud-updated/auth/callback'); + await page.getByTestId('callback-url-1').fill('http://nextcloud-updated/auth/callback'); await page.getByLabel('logo').setInputFiles('tests/assets/nextcloud-logo.png'); await page.getByRole('button', { name: 'Save' }).click();