edge-connect-mcp/oauth/token_validator.go

239 lines
5.5 KiB
Go

package oauth
import (
"context"
"crypto/rsa"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
// JWTValidator implements TokenValidator using JWT
type JWTValidator struct {
expectedAudience string
expectedIssuer string
jwksURL string
publicKeys map[string]*rsa.PublicKey
lastFetch time.Time
cacheDuration time.Duration
mu sync.RWMutex
}
// NewJWTValidator creates a new JWT token validator
func NewJWTValidator(audience, issuer, jwksURL string) *JWTValidator {
return &JWTValidator{
expectedAudience: audience,
expectedIssuer: issuer,
jwksURL: jwksURL,
publicKeys: make(map[string]*rsa.PublicKey),
cacheDuration: 15 * time.Minute,
}
}
// ValidateToken validates a JWT token
func (v *JWTValidator) ValidateToken(ctx context.Context, tokenString string) (*TokenClaims, error) {
// Parse JWT without verification first to get key ID
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Get key ID from header
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("missing or invalid kid in token header")
}
// Fetch public key
publicKey, err := v.getPublicKey(ctx, kid)
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
return publicKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
if !token.Valid {
return nil, fmt.Errorf("token is invalid")
}
// Extract claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
// Validate audience (RFC 8707 - CRITICAL for security)
aud, err := claims.GetAudience()
if err != nil {
return nil, fmt.Errorf("invalid audience claim: %w", err)
}
// Normalize expected audience (remove trailing slash)
expectedAudience := strings.TrimSuffix(v.expectedAudience, "/")
audienceValid := false
for _, a := range aud {
// Normalize token audience (remove trailing slash)
normalizedAud := strings.TrimSuffix(a, "/")
if normalizedAud == expectedAudience {
audienceValid = true
break
}
}
if !audienceValid {
fmt.Printf("DEBUG: Audience validation failed. Expected: %s, Got: %v\n", expectedAudience, aud)
return nil, fmt.Errorf("token audience does not match expected audience")
}
// Validate issuer
iss, err := claims.GetIssuer()
if err != nil || iss != v.expectedIssuer {
return nil, fmt.Errorf("invalid issuer: expected %s, got %s", v.expectedIssuer, iss)
}
// Validate expiration
exp, err := claims.GetExpirationTime()
if err != nil || exp.Before(time.Now()) {
return nil, fmt.Errorf("token is expired")
}
// Validate issued at
iat, err := claims.GetIssuedAt()
if err != nil {
return nil, fmt.Errorf("invalid issued at claim: %w", err)
}
// Extract subject
sub, err := claims.GetSubject()
if err != nil {
return nil, fmt.Errorf("invalid subject claim: %w", err)
}
// Extract scopes
scopes := []string{}
if scopeClaim, ok := claims["scope"].(string); ok {
if scopeClaim != "" {
scopes = strings.Split(scopeClaim, " ")
}
}
// Build token claims
tokenClaims := &TokenClaims{
Subject: sub,
Audience: aud,
Issuer: iss,
ExpiresAt: exp.Time,
IssuedAt: iat.Time,
Scopes: scopes,
}
if clientID, ok := claims["client_id"].(string); ok {
tokenClaims.ClientID = clientID
}
return tokenClaims, nil
}
// GetJWKS fetches and returns the JWKS
func (v *JWTValidator) GetJWKS(ctx context.Context) (*JWKS, error) {
if err := v.refreshJWKS(ctx); err != nil {
return nil, err
}
v.mu.RLock()
defer v.mu.RUnlock()
// Build JWKS from cached keys
jwks := &JWKS{
Keys: make([]JWK, 0, len(v.publicKeys)),
}
for kid, pubKey := range v.publicKeys {
jwk := RSAPublicKeyToJWK(pubKey, kid)
jwks.Keys = append(jwks.Keys, jwk)
}
return jwks, nil
}
// getPublicKey fetches a public key by ID from JWKS endpoint
func (v *JWTValidator) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
// Check cache
v.mu.RLock()
key, exists := v.publicKeys[kid]
lastFetch := v.lastFetch
v.mu.RUnlock()
if exists && time.Since(lastFetch) < v.cacheDuration {
return key, nil
}
// Fetch JWKS
if err := v.refreshJWKS(ctx); err != nil {
return nil, err
}
// Check cache again after refresh
v.mu.RLock()
key, exists = v.publicKeys[kid]
v.mu.RUnlock()
if exists {
return key, nil
}
return nil, fmt.Errorf("public key with kid %s not found", kid)
}
// refreshJWKS fetches the latest JWKS from the authorization server
func (v *JWTValidator) refreshJWKS(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, v.jwksURL, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode)
}
var jwks JWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return err
}
// Parse and cache keys
v.mu.Lock()
defer v.mu.Unlock()
for _, key := range jwks.Keys {
if key.KeyType == "RSA" && key.KeyID != "" {
publicKey, err := JWKToRSAPublicKey(key)
if err != nil {
continue
}
v.publicKeys[key.KeyID] = publicKey
}
}
v.lastFetch = time.Now()
return nil
}