feat: add user info endpoint to support more oidc clients

This commit is contained in:
Elias Schneider
2024-08-19 18:48:18 +02:00
parent 601f6c488a
commit fdc1921f5d
6 changed files with 155 additions and 66 deletions

View File

@@ -10,14 +10,16 @@ import (
"github.com/stonith404/pocket-id/backend/internal/utils"
"net/http"
"strconv"
"strings"
)
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService) {
oc := &OidcController{OidcService: oidcService}
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) {
oc := &OidcController{oidcService: oidcService, jwtService: jwtService}
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler)
group.POST("/oidc/token", oc.createIDTokenHandler)
group.GET("/oidc/userinfo", oc.userInfoHandler)
group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler)
group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler)
@@ -33,7 +35,8 @@ func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
}
type OidcController struct {
OidcService *service.OidcService
oidcService *service.OidcService
jwtService *service.JwtService
}
func (oc *OidcController) authorizeHandler(c *gin.Context) {
@@ -43,7 +46,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
return
}
code, err := oc.OidcService.Authorize(parsedBody, c.GetString("userID"))
code, err := oc.oidcService.Authorize(parsedBody, c.GetString("userID"))
if err != nil {
if errors.Is(err, common.ErrOidcMissingAuthorization) {
utils.HandlerError(c, http.StatusForbidden, err.Error())
@@ -63,7 +66,7 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
return
}
code, err := oc.OidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
code, err := oc.oidcService.AuthorizeNewClient(parsedBody, c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -80,7 +83,20 @@ func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
return
}
idToken, err := oc.OidcService.CreateIDToken(body)
clientID := body.ClientID
clientSecret := body.ClientSecret
// Client id and secret can also be passed over the Authorization header
if clientID == "" || clientSecret == "" {
var ok bool
clientID, clientSecret, ok = c.Request.BasicAuth()
if !ok {
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided")
return
}
}
idToken, accessToken, err := oc.oidcService.CreateTokens(body.Code, body.GrantType, clientID, clientSecret)
if err != nil {
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
@@ -93,12 +109,30 @@ func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
return
}
c.JSON(http.StatusOK, gin.H{"id_token": idToken})
c.JSON(http.StatusOK, gin.H{"id_token": idToken, "access_token": accessToken, "token_type": "Bearer"})
}
func (oc *OidcController) userInfoHandler(c *gin.Context) {
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
if err != nil {
utils.HandlerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
return
}
userID := jwtClaims.Subject
clientId := jwtClaims.Audience[0]
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
if err != nil {
utils.UnknownHandlerError(c, err)
return
}
c.JSON(http.StatusOK, claims)
}
func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id")
client, err := oc.OidcService.GetClient(clientId)
client, err := oc.oidcService.GetClient(clientId)
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -112,7 +146,7 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
searchTerm := c.Query("search")
clients, pagination, err := oc.OidcService.ListClients(searchTerm, page, pageSize)
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -131,7 +165,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
return
}
client, err := oc.OidcService.CreateClient(input, c.GetString("userID"))
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -141,7 +175,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
}
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.OidcService.DeleteClient(c.Param("id"))
err := oc.oidcService.DeleteClient(c.Param("id"))
if err != nil {
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
return
@@ -157,7 +191,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
return
}
client, err := oc.OidcService.UpdateClient(c.Param("id"), input)
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -167,7 +201,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
}
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.OidcService.CreateClientSecret(c.Param("id"))
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -177,7 +211,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
}
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.OidcService.GetClientLogo(c.Param("id"))
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
return
@@ -194,7 +228,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
return
}
err = oc.OidcService.UpdateClientLogo(c.Param("id"), file)
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
if err != nil {
if errors.Is(err, common.ErrFileTypeNotSupported) {
utils.HandlerError(c, http.StatusBadRequest, err.Error())
@@ -208,7 +242,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
}
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.OidcService.DeleteClientLogo(c.Param("id"))
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
if err != nil {
utils.UnknownHandlerError(c, err)
return