291 lines
8.1 KiB
Go
291 lines
8.1 KiB
Go
// Copyright 2024 Cloudbase Solutions SRL
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
// not use this file except in compliance with the License. You may obtain
|
|
// a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
// License for the specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
package scalesets
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"math/big"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
|
"github.com/cloudbase/garm/params"
|
|
garmUtil "github.com/cloudbase/garm/util"
|
|
)
|
|
|
|
const maxCapacityHeader = "X-ScaleSetMaxCapacity"
|
|
|
|
type MessageSession struct {
|
|
ssCli *ScaleSetClient
|
|
session *params.RunnerScaleSetSession
|
|
ctx context.Context
|
|
|
|
done chan struct{}
|
|
closed bool
|
|
lastErr error
|
|
|
|
mux sync.Mutex
|
|
}
|
|
|
|
func (m *MessageSession) Close() error {
|
|
m.mux.Lock()
|
|
defer m.mux.Unlock()
|
|
if m.closed {
|
|
return nil
|
|
}
|
|
close(m.done)
|
|
m.closed = true
|
|
return nil
|
|
}
|
|
|
|
func (m *MessageSession) MessageQueueAccessToken() string {
|
|
return m.session.MessageQueueAccessToken
|
|
}
|
|
|
|
func (m *MessageSession) LastError() error {
|
|
return m.lastErr
|
|
}
|
|
|
|
func (m *MessageSession) loop() {
|
|
slog.DebugContext(m.ctx, "starting message session refresh loop", "session_id", m.session.SessionID.String())
|
|
timer := time.NewTicker(1 * time.Minute)
|
|
defer timer.Stop()
|
|
defer m.Close()
|
|
|
|
if m.closed {
|
|
slog.DebugContext(m.ctx, "message session refresh loop closed")
|
|
return
|
|
}
|
|
for {
|
|
select {
|
|
case <-m.ctx.Done():
|
|
slog.DebugContext(m.ctx, "message session refresh loop context done")
|
|
return
|
|
case <-m.done:
|
|
slog.DebugContext(m.ctx, "message session refresh loop done")
|
|
return
|
|
case <-timer.C:
|
|
if err := m.maybeRefreshToken(m.ctx); err != nil {
|
|
// We endlessly retry. If it's a transient error, it should eventually
|
|
// work, if it's credentials issues, users can update them.
|
|
slog.With(slog.Any("error", err)).ErrorContext(m.ctx, "failed to refresh message queue token")
|
|
m.lastErr = err
|
|
continue
|
|
}
|
|
m.lastErr = nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *MessageSession) SessionsRelativeURL() (string, error) {
|
|
if m.session == nil {
|
|
return "", fmt.Errorf("session is nil")
|
|
}
|
|
if m.session.RunnerScaleSet == nil {
|
|
return "", fmt.Errorf("runner scale set is nil")
|
|
}
|
|
relativePath := fmt.Sprintf("%s/%d/sessions/%s", scaleSetEndpoint, m.session.RunnerScaleSet.ID, m.session.SessionID.String())
|
|
return relativePath, nil
|
|
}
|
|
|
|
func (m *MessageSession) Refresh(ctx context.Context) error {
|
|
slog.DebugContext(ctx, "refreshing message session token", "session_id", m.session.SessionID.String())
|
|
m.mux.Lock()
|
|
defer m.mux.Unlock()
|
|
|
|
relPath, err := m.SessionsRelativeURL()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session URL: %w", err)
|
|
}
|
|
req, err := m.ssCli.newActionsRequest(ctx, http.MethodPatch, relPath, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create message delete request: %w", err)
|
|
}
|
|
resp, err := m.ssCli.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete message session: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var refreshedSession params.RunnerScaleSetSession
|
|
if err := json.NewDecoder(resp.Body).Decode(&refreshedSession); err != nil {
|
|
return fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
slog.DebugContext(ctx, "refreshed message session token")
|
|
m.session = &refreshedSession
|
|
return nil
|
|
}
|
|
|
|
func (m *MessageSession) maybeRefreshToken(ctx context.Context) error {
|
|
if m.session == nil {
|
|
return fmt.Errorf("session is nil")
|
|
}
|
|
|
|
expiresAt, err := m.session.ExiresAt()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get expires at: %w", err)
|
|
}
|
|
// add some jitter (30 second interval)
|
|
randInt, err := rand.Int(rand.Reader, big.NewInt(30))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get a random number")
|
|
}
|
|
expiresIn := time.Duration(randInt.Int64())*time.Second + 10*time.Minute
|
|
slog.DebugContext(ctx, "checking if message session token needs refresh", "expires_at", expiresAt)
|
|
if m.session.ExpiresIn(expiresIn) {
|
|
if err := m.Refresh(ctx); err != nil {
|
|
return fmt.Errorf("failed to refresh message queue token: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MessageSession) GetMessage(ctx context.Context, lastMessageID int64, maxCapacity uint) (params.RunnerScaleSetMessage, error) {
|
|
u, err := url.Parse(m.session.MessageQueueURL)
|
|
if err != nil {
|
|
return params.RunnerScaleSetMessage{}, err
|
|
}
|
|
|
|
if lastMessageID > 0 {
|
|
q := u.Query()
|
|
q.Set("lastMessageId", strconv.FormatInt(lastMessageID, 10))
|
|
u.RawQuery = q.Encode()
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
|
if err != nil {
|
|
return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Accept", "application/json; api-version=6.0-preview")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken))
|
|
req.Header.Set(maxCapacityHeader, fmt.Sprintf("%d", maxCapacity))
|
|
|
|
resp, err := m.ssCli.Do(req)
|
|
if err != nil {
|
|
return params.RunnerScaleSetMessage{}, fmt.Errorf("request to %s failed: %w", req.URL.String(), err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusAccepted {
|
|
slog.DebugContext(ctx, "no messages available in queue")
|
|
return params.RunnerScaleSetMessage{}, nil
|
|
}
|
|
|
|
var message params.RunnerScaleSetMessage
|
|
if err := json.NewDecoder(resp.Body).Decode(&message); err != nil {
|
|
return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
return message, nil
|
|
}
|
|
|
|
func (m *MessageSession) DeleteMessage(ctx context.Context, messageID int64) error {
|
|
u, err := url.Parse(m.session.MessageQueueURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
u.Path = fmt.Sprintf("%s/%d", u.Path, messageID)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.String(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken))
|
|
|
|
resp, err := m.ssCli.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp.Body.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ScaleSetClient) CreateMessageSession(ctx context.Context, runnerScaleSetID int, owner string) (*MessageSession, error) {
|
|
path := fmt.Sprintf("%s/%d/sessions", scaleSetEndpoint, runnerScaleSetID)
|
|
|
|
newSession := params.RunnerScaleSetSession{
|
|
OwnerName: owner,
|
|
}
|
|
|
|
requestData, err := json.Marshal(newSession)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal session data: %w", err)
|
|
}
|
|
|
|
req, err := s.newActionsRequest(ctx, http.MethodPost, path, bytes.NewBuffer(requestData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
resp, err := s.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute request to %s: %w", req.URL.String(), err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var createdSession params.RunnerScaleSetSession
|
|
if err := json.NewDecoder(resp.Body).Decode(&createdSession); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
msgSessionCtx := garmUtil.WithSlogContext(
|
|
ctx,
|
|
slog.Any("session_id", createdSession.SessionID.String()))
|
|
sess := &MessageSession{
|
|
ssCli: s,
|
|
session: &createdSession,
|
|
ctx: msgSessionCtx,
|
|
done: make(chan struct{}),
|
|
closed: false,
|
|
}
|
|
go sess.loop()
|
|
|
|
return sess, nil
|
|
}
|
|
|
|
func (s *ScaleSetClient) DeleteMessageSession(ctx context.Context, session *MessageSession) error {
|
|
path, err := session.SessionsRelativeURL()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete session: %w", err)
|
|
}
|
|
|
|
req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create message delete request: %w", err)
|
|
}
|
|
|
|
resp, err := s.Do(req)
|
|
if err != nil {
|
|
if !errors.Is(err, runnerErrors.ErrNotFound) {
|
|
return fmt.Errorf("failed to delete message session: %w", err)
|
|
}
|
|
}
|
|
defer resp.Body.Close()
|
|
return nil
|
|
}
|