From 282ff82b0c7e2414b3528c8ca325758245b8ae61 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Fri, 11 Oct 2024 20:42:31 +0200 Subject: [PATCH] fix: add key id to JWK --- backend/internal/service/jwt_service.go | 51 +++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index 8abff14..e450d14 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -3,6 +3,7 @@ package service import ( "crypto/rand" "crypto/rsa" + "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/pem" @@ -51,6 +52,7 @@ type AccessTokenJWTClaims struct { } type JWK struct { + Kid string `json:"kid"` Kty string `json:"kty"` Use string `json:"use"` Alg string `json:"alg"` @@ -98,7 +100,15 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { }, IsAdmin: user.IsAdmin, } + + kid, err := s.generateKeyID(s.publicKey) + if err != nil { + return "", errors.New("failed to generate key ID: " + err.Error()) + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) + token.Header["kid"] = kid + return token.SignedString(s.privateKey) } @@ -137,9 +147,17 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID claims["nonce"] = nonce } + kid, err := s.generateKeyID(s.publicKey) + if err != nil { + return "", errors.New("failed to generate key ID: " + err.Error()) + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = kid + return token.SignedString(s.privateKey) } + func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { claim := jwt.RegisteredClaims{ Subject: user.ID, @@ -148,7 +166,15 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) Audience: jwt.ClaimStrings{clientID}, Issuer: common.EnvConfig.AppURL, } + + kid, err := s.generateKeyID(s.publicKey) + if err != nil { + return "", errors.New("failed to generate key ID: " + err.Error()) + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) + token.Header["kid"] = kid + return token.SignedString(s.privateKey) } @@ -174,7 +200,13 @@ func (s *JwtService) GetJWK() (JWK, error) { return JWK{}, errors.New("public key is not initialized") } + kid, err := s.generateKeyID(s.publicKey) + if err != nil { + return JWK{}, err + } + jwk := JWK{ + Kid: kid, Kty: "RSA", Use: "sig", Alg: "RS256", @@ -185,6 +217,25 @@ func (s *JwtService) GetJWK() (JWK, error) { return jwk, nil } +// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key. +func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) { + pubASN1, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return "", errors.New("failed to marshal public key: " + err.Error()) + } + + // Compute SHA-256 hash of the public key + hash := sha256.New() + hash.Write(pubASN1) + hashed := hash.Sum(nil) + + // Truncate the hash to the first 8 bytes for a shorter Key ID + shortHash := hashed[:8] + + // Return Base64 encoded truncated hash as Key ID + return base64.RawURLEncoding.EncodeToString(shortHash), nil +} + // generateKeys generates a new RSA key pair and saves them to the specified paths. func (s *JwtService) generateKeys() error { if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {