mirror of
https://github.com/nikdoof/pocket-id.git
synced 2025-12-14 15:22:18 +00:00
fix: add key id to JWK
This commit is contained in:
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
@@ -51,6 +52,7 @@ type AccessTokenJWTClaims struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type JWK struct {
|
type JWK struct {
|
||||||
|
Kid string `json:"kid"`
|
||||||
Kty string `json:"kty"`
|
Kty string `json:"kty"`
|
||||||
Use string `json:"use"`
|
Use string `json:"use"`
|
||||||
Alg string `json:"alg"`
|
Alg string `json:"alg"`
|
||||||
@@ -98,7 +100,15 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
|||||||
},
|
},
|
||||||
IsAdmin: user.IsAdmin,
|
IsAdmin: user.IsAdmin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kid, err := s.generateKeyID(s.publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||||
|
token.Header["kid"] = kid
|
||||||
|
|
||||||
return token.SignedString(s.privateKey)
|
return token.SignedString(s.privateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,9 +147,17 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
|
|||||||
claims["nonce"] = nonce
|
claims["nonce"] = nonce
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kid, err := s.generateKeyID(s.publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
token.Header["kid"] = kid
|
||||||
|
|
||||||
return token.SignedString(s.privateKey)
|
return token.SignedString(s.privateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||||
claim := jwt.RegisteredClaims{
|
claim := jwt.RegisteredClaims{
|
||||||
Subject: user.ID,
|
Subject: user.ID,
|
||||||
@@ -148,7 +166,15 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
|||||||
Audience: jwt.ClaimStrings{clientID},
|
Audience: jwt.ClaimStrings{clientID},
|
||||||
Issuer: common.EnvConfig.AppURL,
|
Issuer: common.EnvConfig.AppURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kid, err := s.generateKeyID(s.publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||||
|
token.Header["kid"] = kid
|
||||||
|
|
||||||
return token.SignedString(s.privateKey)
|
return token.SignedString(s.privateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,7 +200,13 @@ func (s *JwtService) GetJWK() (JWK, error) {
|
|||||||
return JWK{}, errors.New("public key is not initialized")
|
return JWK{}, errors.New("public key is not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kid, err := s.generateKeyID(s.publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return JWK{}, err
|
||||||
|
}
|
||||||
|
|
||||||
jwk := JWK{
|
jwk := JWK{
|
||||||
|
Kid: kid,
|
||||||
Kty: "RSA",
|
Kty: "RSA",
|
||||||
Use: "sig",
|
Use: "sig",
|
||||||
Alg: "RS256",
|
Alg: "RS256",
|
||||||
@@ -185,6 +217,25 @@ func (s *JwtService) GetJWK() (JWK, error) {
|
|||||||
return jwk, nil
|
return jwk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key.
|
||||||
|
func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) {
|
||||||
|
pubASN1, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.New("failed to marshal public key: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute SHA-256 hash of the public key
|
||||||
|
hash := sha256.New()
|
||||||
|
hash.Write(pubASN1)
|
||||||
|
hashed := hash.Sum(nil)
|
||||||
|
|
||||||
|
// Truncate the hash to the first 8 bytes for a shorter Key ID
|
||||||
|
shortHash := hashed[:8]
|
||||||
|
|
||||||
|
// Return Base64 encoded truncated hash as Key ID
|
||||||
|
return base64.RawURLEncoding.EncodeToString(shortHash), nil
|
||||||
|
}
|
||||||
|
|
||||||
// generateKeys generates a new RSA key pair and saves them to the specified paths.
|
// generateKeys generates a new RSA key pair and saves them to the specified paths.
|
||||||
func (s *JwtService) generateKeys() error {
|
func (s *JwtService) generateKeys() error {
|
||||||
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user