feat: add user info endpoint to support more oidc clients

This commit is contained in:
Elias Schneider
2024-08-19 18:48:18 +02:00
parent 601f6c488a
commit fdc1921f5d
6 changed files with 155 additions and 66 deletions

View File

@@ -18,7 +18,6 @@ import (
"path/filepath"
"slices"
"strconv"
"strings"
"time"
)
@@ -54,7 +53,6 @@ type AccessTokenJWTClaims struct {
type JWK struct {
Kty string `json:"kty"`
Use string `json:"use"`
Kid string `json:"kid"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
@@ -89,37 +87,6 @@ func (s *JwtService) loadOrGenerateKeys() error {
return nil
}
func (s *JwtService) GenerateIDToken(user model.User, clientID string, scope string, nonce string) (string, error) {
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"email": user.Email,
"preferred_username": user.Username,
}
claims := jwt.MapClaims{
"sub": user.ID,
"aud": clientID,
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
"iat": jwt.NewNumericDate(time.Now()),
}
if nonce != "" {
claims["nonce"] = nonce
}
if strings.Contains(scope, "profile") {
for k, v := range profileClaims {
claims[k] = v
}
}
if strings.Contains(scope, "email") {
claims["email"] = user.Email
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
return token.SignedString(s.privateKey)
}
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
claim := AccessTokenJWTClaims{
@@ -154,6 +121,53 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaim
return claims, nil
}
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
claims := jwt.MapClaims{
"aud": clientID,
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
"iat": jwt.NewNumericDate(time.Now()),
"iss": common.EnvConfig.AppURL,
}
for k, v := range userClaims {
claims[k] = v
}
if nonce != "" {
claims["nonce"] = nonce
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
return token.SignedString(s.privateKey)
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
claim := jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{clientID},
Issuer: common.EnvConfig.AppURL,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
return token.SignedString(s.privateKey)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.publicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
}
// GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) {
if s.publicKey == nil {
@@ -163,7 +177,6 @@ func (s *JwtService) GetJWK() (JWK, error) {
jwk := JWK{
Kty: "RSA",
Use: "sig",
Kid: "1",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()),

View File

@@ -10,6 +10,7 @@ import (
"gorm.io/gorm"
"mime/multipart"
"os"
"strings"
"time"
)
@@ -54,46 +55,50 @@ func (s *OidcService) AuthorizeNewClient(req model.AuthorizeNewClientDto, userID
return s.createAuthorizationCode(req.ClientID, userID, req.Scope, req.Nonce)
}
func (s *OidcService) CreateIDToken(req model.OidcIdTokenDto) (string, error) {
if req.GrantType != "authorization_code" {
return "", common.ErrOidcGrantTypeNotSupported
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", common.ErrOidcGrantTypeNotSupported
}
clientID := req.ClientID
clientSecret := req.ClientSecret
if clientID == "" || clientSecret == "" {
return "", common.ErrOidcMissingClientCredentials
return "", "", common.ErrOidcMissingClientCredentials
}
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", err
return "", "", err
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", common.ErrOidcClientSecretInvalid
return "", "", common.ErrOidcClientSecretInvalid
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", req.Code).Error
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", common.ErrOidcInvalidAuthorizationCode
return "", "", common.ErrOidcInvalidAuthorizationCode
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
return "", common.ErrOidcInvalidAuthorizationCode
return "", "", common.ErrOidcInvalidAuthorizationCode
}
idToken, err := s.jwtService.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce)
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
if err != nil {
return "", err
return "", "", err
}
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
if err != nil {
return "", "", err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
s.db.Delete(&authorizationCodeMetaData)
return idToken, nil
return idToken, accessToken, nil
}
func (s *OidcService) GetClient(clientID string) (*model.OidcClient, error) {
@@ -259,6 +264,41 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {
return nil
}
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.Preload("User").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
return nil, err
}
user := authorizedOidcClient.User
scope := authorizedOidcClient.Scope
claims := map[string]interface{}{
"sub": user.ID,
}
if strings.Contains(scope, "email") {
claims["email"] = user.Email
}
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"preferred_username": user.Username,
}
if strings.Contains(scope, "profile") {
for k, v := range profileClaims {
claims[k] = v
}
}
if strings.Contains(scope, "email") {
claims["email"] = user.Email
}
return claims, nil
}
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {