239 lines
5.5 KiB
Go
239 lines
5.5 KiB
Go
package oauth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// JWTValidator implements TokenValidator using JWT
|
|
type JWTValidator struct {
|
|
expectedAudience string
|
|
expectedIssuer string
|
|
jwksURL string
|
|
publicKeys map[string]*rsa.PublicKey
|
|
lastFetch time.Time
|
|
cacheDuration time.Duration
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewJWTValidator creates a new JWT token validator
|
|
func NewJWTValidator(audience, issuer, jwksURL string) *JWTValidator {
|
|
return &JWTValidator{
|
|
expectedAudience: audience,
|
|
expectedIssuer: issuer,
|
|
jwksURL: jwksURL,
|
|
publicKeys: make(map[string]*rsa.PublicKey),
|
|
cacheDuration: 15 * time.Minute,
|
|
}
|
|
}
|
|
|
|
// ValidateToken validates a JWT token
|
|
func (v *JWTValidator) ValidateToken(ctx context.Context, tokenString string) (*TokenClaims, error) {
|
|
// Parse JWT without verification first to get key ID
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
// Verify signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
|
|
// Get key ID from header
|
|
kid, ok := token.Header["kid"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing or invalid kid in token header")
|
|
}
|
|
|
|
// Fetch public key
|
|
publicKey, err := v.getPublicKey(ctx, kid)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get public key: %w", err)
|
|
}
|
|
|
|
return publicKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, fmt.Errorf("token is invalid")
|
|
}
|
|
|
|
// Extract claims
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid token claims")
|
|
}
|
|
|
|
// Validate audience (RFC 8707 - CRITICAL for security)
|
|
aud, err := claims.GetAudience()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid audience claim: %w", err)
|
|
}
|
|
|
|
// Normalize expected audience (remove trailing slash)
|
|
expectedAudience := strings.TrimSuffix(v.expectedAudience, "/")
|
|
|
|
audienceValid := false
|
|
for _, a := range aud {
|
|
// Normalize token audience (remove trailing slash)
|
|
normalizedAud := strings.TrimSuffix(a, "/")
|
|
if normalizedAud == expectedAudience {
|
|
audienceValid = true
|
|
break
|
|
}
|
|
}
|
|
if !audienceValid {
|
|
fmt.Printf("DEBUG: Audience validation failed. Expected: %s, Got: %v\n", expectedAudience, aud)
|
|
return nil, fmt.Errorf("token audience does not match expected audience")
|
|
}
|
|
|
|
// Validate issuer
|
|
iss, err := claims.GetIssuer()
|
|
if err != nil || iss != v.expectedIssuer {
|
|
return nil, fmt.Errorf("invalid issuer: expected %s, got %s", v.expectedIssuer, iss)
|
|
}
|
|
|
|
// Validate expiration
|
|
exp, err := claims.GetExpirationTime()
|
|
if err != nil || exp.Before(time.Now()) {
|
|
return nil, fmt.Errorf("token is expired")
|
|
}
|
|
|
|
// Validate issued at
|
|
iat, err := claims.GetIssuedAt()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid issued at claim: %w", err)
|
|
}
|
|
|
|
// Extract subject
|
|
sub, err := claims.GetSubject()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid subject claim: %w", err)
|
|
}
|
|
|
|
// Extract scopes
|
|
scopes := []string{}
|
|
if scopeClaim, ok := claims["scope"].(string); ok {
|
|
if scopeClaim != "" {
|
|
scopes = strings.Split(scopeClaim, " ")
|
|
}
|
|
}
|
|
|
|
// Build token claims
|
|
tokenClaims := &TokenClaims{
|
|
Subject: sub,
|
|
Audience: aud,
|
|
Issuer: iss,
|
|
ExpiresAt: exp.Time,
|
|
IssuedAt: iat.Time,
|
|
Scopes: scopes,
|
|
}
|
|
|
|
if clientID, ok := claims["client_id"].(string); ok {
|
|
tokenClaims.ClientID = clientID
|
|
}
|
|
|
|
return tokenClaims, nil
|
|
}
|
|
|
|
// GetJWKS fetches and returns the JWKS
|
|
func (v *JWTValidator) GetJWKS(ctx context.Context) (*JWKS, error) {
|
|
if err := v.refreshJWKS(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
v.mu.RLock()
|
|
defer v.mu.RUnlock()
|
|
|
|
// Build JWKS from cached keys
|
|
jwks := &JWKS{
|
|
Keys: make([]JWK, 0, len(v.publicKeys)),
|
|
}
|
|
|
|
for kid, pubKey := range v.publicKeys {
|
|
jwk := RSAPublicKeyToJWK(pubKey, kid)
|
|
jwks.Keys = append(jwks.Keys, jwk)
|
|
}
|
|
|
|
return jwks, nil
|
|
}
|
|
|
|
// getPublicKey fetches a public key by ID from JWKS endpoint
|
|
func (v *JWTValidator) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
|
// Check cache
|
|
v.mu.RLock()
|
|
key, exists := v.publicKeys[kid]
|
|
lastFetch := v.lastFetch
|
|
v.mu.RUnlock()
|
|
|
|
if exists && time.Since(lastFetch) < v.cacheDuration {
|
|
return key, nil
|
|
}
|
|
|
|
// Fetch JWKS
|
|
if err := v.refreshJWKS(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check cache again after refresh
|
|
v.mu.RLock()
|
|
key, exists = v.publicKeys[kid]
|
|
v.mu.RUnlock()
|
|
|
|
if exists {
|
|
return key, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("public key with kid %s not found", kid)
|
|
}
|
|
|
|
// refreshJWKS fetches the latest JWKS from the authorization server
|
|
func (v *JWTValidator) refreshJWKS(ctx context.Context) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, v.jwksURL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
var jwks JWKS
|
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Parse and cache keys
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
for _, key := range jwks.Keys {
|
|
if key.KeyType == "RSA" && key.KeyID != "" {
|
|
publicKey, err := JWKToRSAPublicKey(key)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
v.publicKeys[key.KeyID] = publicKey
|
|
}
|
|
}
|
|
|
|
v.lastFetch = time.Now()
|
|
return nil
|
|
}
|