garm/util/github/scalesets/message_sessions.go
Gabriel Adrian Samfira c601f88cf7 Add more tests
Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
2025-05-03 22:29:41 +00:00

291 lines
8.1 KiB
Go

// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math/big"
"net/http"
"net/url"
"strconv"
"sync"
"time"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/params"
garmUtil "github.com/cloudbase/garm/util"
)
const maxCapacityHeader = "X-ScaleSetMaxCapacity"
type MessageSession struct {
ssCli *ScaleSetClient
session *params.RunnerScaleSetSession
ctx context.Context
done chan struct{}
closed bool
lastErr error
mux sync.Mutex
}
func (m *MessageSession) Close() error {
m.mux.Lock()
defer m.mux.Unlock()
if m.closed {
return nil
}
close(m.done)
m.closed = true
return nil
}
func (m *MessageSession) MessageQueueAccessToken() string {
return m.session.MessageQueueAccessToken
}
func (m *MessageSession) LastError() error {
return m.lastErr
}
func (m *MessageSession) loop() {
slog.DebugContext(m.ctx, "starting message session refresh loop", "session_id", m.session.SessionID.String())
timer := time.NewTicker(1 * time.Minute)
defer timer.Stop()
defer m.Close()
if m.closed {
slog.DebugContext(m.ctx, "message session refresh loop closed")
return
}
for {
select {
case <-m.ctx.Done():
slog.DebugContext(m.ctx, "message session refresh loop context done")
return
case <-m.done:
slog.DebugContext(m.ctx, "message session refresh loop done")
return
case <-timer.C:
if err := m.maybeRefreshToken(m.ctx); err != nil {
// We endlessly retry. If it's a transient error, it should eventually
// work, if it's credentials issues, users can update them.
slog.With(slog.Any("error", err)).ErrorContext(m.ctx, "failed to refresh message queue token")
m.lastErr = err
continue
}
m.lastErr = nil
}
}
}
func (m *MessageSession) SessionsRelativeURL() (string, error) {
if m.session == nil {
return "", fmt.Errorf("session is nil")
}
if m.session.RunnerScaleSet == nil {
return "", fmt.Errorf("runner scale set is nil")
}
relativePath := fmt.Sprintf("%s/%d/sessions/%s", scaleSetEndpoint, m.session.RunnerScaleSet.ID, m.session.SessionID.String())
return relativePath, nil
}
func (m *MessageSession) Refresh(ctx context.Context) error {
slog.DebugContext(ctx, "refreshing message session token", "session_id", m.session.SessionID.String())
m.mux.Lock()
defer m.mux.Unlock()
relPath, err := m.SessionsRelativeURL()
if err != nil {
return fmt.Errorf("failed to get session URL: %w", err)
}
req, err := m.ssCli.newActionsRequest(ctx, http.MethodPatch, relPath, nil)
if err != nil {
return fmt.Errorf("failed to create message delete request: %w", err)
}
resp, err := m.ssCli.Do(req)
if err != nil {
return fmt.Errorf("failed to delete message session: %w", err)
}
defer resp.Body.Close()
var refreshedSession params.RunnerScaleSetSession
if err := json.NewDecoder(resp.Body).Decode(&refreshedSession); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
slog.DebugContext(ctx, "refreshed message session token")
m.session = &refreshedSession
return nil
}
func (m *MessageSession) maybeRefreshToken(ctx context.Context) error {
if m.session == nil {
return fmt.Errorf("session is nil")
}
expiresAt, err := m.session.ExiresAt()
if err != nil {
return fmt.Errorf("failed to get expires at: %w", err)
}
// add some jitter (30 second interval)
randInt, err := rand.Int(rand.Reader, big.NewInt(30))
if err != nil {
return fmt.Errorf("failed to get a random number")
}
expiresIn := time.Duration(randInt.Int64())*time.Second + 10*time.Minute
slog.DebugContext(ctx, "checking if message session token needs refresh", "expires_at", expiresAt)
if m.session.ExpiresIn(expiresIn) {
if err := m.Refresh(ctx); err != nil {
return fmt.Errorf("failed to refresh message queue token: %w", err)
}
}
return nil
}
func (m *MessageSession) GetMessage(ctx context.Context, lastMessageID int64, maxCapacity uint) (params.RunnerScaleSetMessage, error) {
u, err := url.Parse(m.session.MessageQueueURL)
if err != nil {
return params.RunnerScaleSetMessage{}, err
}
if lastMessageID > 0 {
q := u.Query()
q.Set("lastMessageId", strconv.FormatInt(lastMessageID, 10))
u.RawQuery = q.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json; api-version=6.0-preview")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken))
req.Header.Set(maxCapacityHeader, fmt.Sprintf("%d", maxCapacity))
resp, err := m.ssCli.Do(req)
if err != nil {
return params.RunnerScaleSetMessage{}, fmt.Errorf("request to %s failed: %w", req.URL.String(), err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusAccepted {
slog.DebugContext(ctx, "no messages available in queue")
return params.RunnerScaleSetMessage{}, nil
}
var message params.RunnerScaleSetMessage
if err := json.NewDecoder(resp.Body).Decode(&message); err != nil {
return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to decode response: %w", err)
}
return message, nil
}
func (m *MessageSession) DeleteMessage(ctx context.Context, messageID int64) error {
u, err := url.Parse(m.session.MessageQueueURL)
if err != nil {
return err
}
u.Path = fmt.Sprintf("%s/%d", u.Path, messageID)
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.String(), nil)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken))
resp, err := m.ssCli.Do(req)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
func (s *ScaleSetClient) CreateMessageSession(ctx context.Context, runnerScaleSetID int, owner string) (*MessageSession, error) {
path := fmt.Sprintf("%s/%d/sessions", scaleSetEndpoint, runnerScaleSetID)
newSession := params.RunnerScaleSetSession{
OwnerName: owner,
}
requestData, err := json.Marshal(newSession)
if err != nil {
return nil, fmt.Errorf("failed to marshal session data: %w", err)
}
req, err := s.newActionsRequest(ctx, http.MethodPost, path, bytes.NewBuffer(requestData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request to %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
var createdSession params.RunnerScaleSetSession
if err := json.NewDecoder(resp.Body).Decode(&createdSession); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
msgSessionCtx := garmUtil.WithSlogContext(
ctx,
slog.Any("session_id", createdSession.SessionID.String()))
sess := &MessageSession{
ssCli: s,
session: &createdSession,
ctx: msgSessionCtx,
done: make(chan struct{}),
closed: false,
}
go sess.loop()
return sess, nil
}
func (s *ScaleSetClient) DeleteMessageSession(ctx context.Context, session *MessageSession) error {
path, err := session.SessionsRelativeURL()
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil)
if err != nil {
return fmt.Errorf("failed to create message delete request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
return fmt.Errorf("failed to delete message session: %w", err)
}
}
defer resp.Body.Close()
return nil
}