diff --git a/auth/jwt.go b/auth/jwt.go index e9b5745f..52fce0c9 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -97,26 +97,37 @@ func invalidAuthResponse(ctx context.Context, w http.ResponseWriter) { } } +func (amw *jwtMiddleware) getTokenFromRequest(r *http.Request) (string, error) { + authorizationHeader := r.Header.Get("authorization") + if authorizationHeader == "" { + cookie, err := r.Cookie("garm_token") + if err != nil { + return "", fmt.Errorf("failed to get cookie: %w", err) + } + return cookie.Value, nil + } + + bearerToken := strings.Split(authorizationHeader, " ") + if len(bearerToken) != 2 { + return "", fmt.Errorf("invalid auth header") + } + return bearerToken[1], nil +} + // Middleware implements the middleware interface func (amw *jwtMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // nolint:golangci-lint,godox // TODO: Log error details when authentication fails ctx := r.Context() - authorizationHeader := r.Header.Get("authorization") - if authorizationHeader == "" { + authToken, err := amw.getTokenFromRequest(r) + if err != nil { + slog.ErrorContext(ctx, "failed to get auth token", "error", err) invalidAuthResponse(ctx, w) return } - - bearerToken := strings.Split(authorizationHeader, " ") - if len(bearerToken) != 2 { - invalidAuthResponse(ctx, w) - return - } - claims := &JWTClaims{} - token, err := jwt.ParseWithClaims(bearerToken[1], claims, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(authToken, claims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("invalid signing method") }