diff --git a/database/common/store.go b/database/common/store.go index c732400a..a2b2cf77 100644 --- a/database/common/store.go +++ b/database/common/store.go @@ -143,6 +143,7 @@ type ScaleSetsStore interface { GetScaleSetByID(ctx context.Context, scaleSet uint) (params.ScaleSet, error) DeleteScaleSetByID(ctx context.Context, scaleSetID uint) (err error) ListScaleSetInstances(_ context.Context, scalesetID uint) ([]params.Instance, error) + SetScaleSetLastMessageID(ctx context.Context, scaleSetID uint, lastMessageID int64) error } //go:generate mockery --name=Store diff --git a/database/sql/models.go b/database/sql/models.go index e443e75a..5b4d86f9 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -118,6 +118,7 @@ type ScaleSet struct { OSType commonParams.OSType OSArch commonParams.OSArch Enabled bool + LastMessageID int64 // ExtraSpecs is an opaque json that gets sent to the provider // as part of the bootstrap params for instances. It can contain // any kind of data needed by providers. diff --git a/database/sql/scalesets.go b/database/sql/scalesets.go index 3a5d8431..7a67f2d6 100644 --- a/database/sql/scalesets.go +++ b/database/sql/scalesets.go @@ -379,3 +379,15 @@ func (s *sqlDatabase) DeleteScaleSetByID(ctx context.Context, scaleSetID uint) ( } return nil } + +func (s *sqlDatabase) SetScaleSetLastMessageID(ctx context.Context, scaleSetID uint, lastMessageID int64) error { + if err := s.conn.Transaction(func(tx *gorm.DB) error { + if q := tx.Model(&ScaleSet{}).Where("id = ?", scaleSetID).Update("last_message_id", lastMessageID); q.Error != nil { + return errors.Wrap(q.Error, "saving database entry") + } + return nil + }); err != nil { + return errors.Wrap(err, "setting last message ID") + } + return nil +} diff --git a/database/sql/util.go b/database/sql/util.go index c1a44cb8..c7b64961 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -309,6 +309,7 @@ func (s *sqlDatabase) sqlToCommonScaleSet(scaleSet ScaleSet) (params.ScaleSet, e GitHubRunnerGroup: scaleSet.GitHubRunnerGroup, State: scaleSet.State, ExtendedState: scaleSet.ExtendedState, + LastMessageID: scaleSet.LastMessageID, } if scaleSet.RepoID != nil { diff --git a/params/github.go b/params/github.go index 9eec6e8c..b609e682 100644 --- a/params/github.go +++ b/params/github.go @@ -402,6 +402,10 @@ type RunnerScaleSetMessage struct { Statistics *RunnerScaleSetStatistic `json:"statistics"` } +func (r RunnerScaleSetMessage) IsNil() bool { + return r.MessageID == 0 && r.MessageType == "" && r.Body == "" && r.Statistics == nil +} + func (r RunnerScaleSetMessage) GetJobsFromBody() ([]ScaleSetJobMessage, error) { var body []ScaleSetJobMessage if r.Body == "" { diff --git a/params/params.go b/params/params.go index b0a6492e..3ac0c0c5 100644 --- a/params/params.go +++ b/params/params.go @@ -472,6 +472,8 @@ type ScaleSet struct { EnterpriseID string `json:"enterprise_id,omitempty"` EnterpriseName string `json:"enterprise_name,omitempty"` + + LastMessageID int64 `json:"-"` } func (p ScaleSet) GithubEntity() (GithubEntity, error) { diff --git a/util/github/scalesets/message_sessions.go b/util/github/scalesets/message_sessions.go index e4152e08..5e260b96 100644 --- a/util/github/scalesets/message_sessions.go +++ b/util/github/scalesets/message_sessions.go @@ -31,6 +31,7 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm/params" + garmUtil "github.com/cloudbase/garm/util" ) const maxCapacityHeader = "X-ScaleSetMaxCapacity" @@ -63,16 +64,22 @@ func (m *MessageSession) LastError() error { } func (m *MessageSession) loop() { - timer := time.NewTimer(1 * time.Minute) + slog.DebugContext(m.ctx, "starting message session refresh loop", "session_id", m.session.SessionID.String()) + timer := time.NewTicker(1 * time.Minute) defer timer.Stop() + defer m.Close() + if m.closed { + slog.DebugContext(m.ctx, "message session refresh loop closed") return } for { select { case <-m.ctx.Done(): + slog.DebugContext(m.ctx, "message session refresh loop context done") return case <-m.done: + slog.DebugContext(m.ctx, "message session refresh loop done") return case <-timer.C: if err := m.maybeRefreshToken(m.ctx); err != nil { @@ -99,6 +106,7 @@ func (m *MessageSession) SessionsRelativeURL() (string, error) { } func (m *MessageSession) Refresh(ctx context.Context) error { + slog.DebugContext(ctx, "refreshing message session token", "session_id", m.session.SessionID.String()) m.mux.Lock() defer m.mux.Unlock() @@ -114,13 +122,15 @@ func (m *MessageSession) Refresh(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to delete message session: %w", err) } + defer resp.Body.Close() var refreshedSession params.RunnerScaleSetSession if err := json.NewDecoder(resp.Body).Decode(&refreshedSession); err != nil { return fmt.Errorf("failed to decode response: %w", err) } - - m.session = &refreshedSession + slog.DebugContext(ctx, "refreshed message session token", "session_id", refreshedSession.SessionID.String()) + m.session.MessageQueueAccessToken = refreshedSession.MessageQueueAccessToken + m.session.Statistics = refreshedSession.Statistics return nil } @@ -129,16 +139,23 @@ func (m *MessageSession) maybeRefreshToken(ctx context.Context) error { return fmt.Errorf("session is nil") } // add some jitter - randInt, err := rand.Int(rand.Reader, big.NewInt(1000)) + randInt, err := rand.Int(rand.Reader, big.NewInt(5000)) if err != nil { return fmt.Errorf("failed to get a random number") } - jitter := time.Duration(randInt.Int64()) * time.Millisecond - if m.session.ExpiresIn(2*time.Minute + jitter) { + expiresAt, err := m.session.ExiresAt() + if err != nil { + return fmt.Errorf("failed to get expires at: %w", err) + } + expiresIn := time.Duration(randInt.Int64())*time.Millisecond + 10*time.Minute + slog.DebugContext(ctx, "checking if message session token needs refresh", "expires_at", expiresAt) + if m.session.ExpiresIn(expiresIn) { + slog.DebugContext(ctx, "refreshing message session token") if err := m.Refresh(ctx); err != nil { return fmt.Errorf("failed to refresh message queue token: %w", err) } } + return nil } @@ -170,6 +187,7 @@ func (m *MessageSession) GetMessage(ctx context.Context, lastMessageID int64, ma defer resp.Body.Close() if resp.StatusCode == http.StatusAccepted { + slog.DebugContext(ctx, "no messages available in queue") return params.RunnerScaleSetMessage{}, nil } @@ -200,8 +218,8 @@ func (m *MessageSession) DeleteMessage(ctx context.Context, messageID int64) err if err != nil { return err } - resp.Body.Close() + return nil } @@ -233,10 +251,13 @@ func (s *ScaleSetClient) CreateMessageSession(ctx context.Context, runnerScaleSe return nil, fmt.Errorf("failed to decode response: %w", err) } + msgSessionCtx := garmUtil.WithSlogContext( + ctx, + slog.Any("session_id", createdSession.SessionID.String())) sess := &MessageSession{ ssCli: s, session: &createdSession, - ctx: ctx, + ctx: msgSessionCtx, done: make(chan struct{}), closed: false, } @@ -256,11 +277,12 @@ func (s *ScaleSetClient) DeleteMessageSession(ctx context.Context, session *Mess return fmt.Errorf("failed to create message delete request: %w", err) } - _, err = s.Do(req) + resp, err := s.Do(req) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return fmt.Errorf("failed to delete message session: %w", err) } } + defer resp.Body.Close() return nil } diff --git a/util/github/scalesets/scalesets.go b/util/github/scalesets/scalesets.go index f7ef2763..2aae493a 100644 --- a/util/github/scalesets/scalesets.go +++ b/util/github/scalesets/scalesets.go @@ -47,6 +47,7 @@ func (s *ScaleSetClient) GetRunnerScaleSetByNameAndRunnerGroup(ctx context.Conte if err != nil { return params.RunnerScaleSet{}, err } + defer resp.Body.Close() var runnerScaleSetList *params.RunnerScaleSetsResponse if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSetList); err != nil { @@ -72,6 +73,7 @@ func (s *ScaleSetClient) GetRunnerScaleSetByID(ctx context.Context, runnerScaleS if err != nil { return params.RunnerScaleSet{}, fmt.Errorf("failed to get runner scaleset with ID %d: %w", runnerScaleSetID, err) } + defer resp.Body.Close() var runnerScaleSet params.RunnerScaleSet if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSet); err != nil { @@ -94,6 +96,7 @@ func (s *ScaleSetClient) ListRunnerScaleSets(ctx context.Context) (*params.Runne if err != nil { return nil, fmt.Errorf("failed to list runner scale sets: %w", err) } + defer resp.Body.Close() var runnerScaleSetList params.RunnerScaleSetsResponse if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSetList); err != nil { @@ -119,6 +122,7 @@ func (s *ScaleSetClient) CreateRunnerScaleSet(ctx context.Context, runnerScaleSe if err != nil { return params.RunnerScaleSet{}, fmt.Errorf("failed to create runner scale set: %w", err) } + defer resp.Body.Close() var createdRunnerScaleSet params.RunnerScaleSet if err := json.NewDecoder(resp.Body).Decode(&createdRunnerScaleSet); err != nil { @@ -144,6 +148,7 @@ func (s *ScaleSetClient) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSe if err != nil { return params.RunnerScaleSet{}, fmt.Errorf("failed to make request: %w", err) } + defer resp.Body.Close() var ret params.RunnerScaleSet if err := json.NewDecoder(resp.Body).Decode(&ret); err != nil { @@ -164,12 +169,12 @@ func (s *ScaleSetClient) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSe if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode != http.StatusNoContent { return fmt.Errorf("failed to delete scale set with code %d", resp.StatusCode) } - resp.Body.Close() return nil } diff --git a/workers/scaleset/interfaces.go b/workers/scaleset/interfaces.go index 365ac0be..7b96168d 100644 --- a/workers/scaleset/interfaces.go +++ b/workers/scaleset/interfaces.go @@ -8,5 +8,6 @@ import ( type scaleSetHelper interface { ScaleSetCLI() *scalesets.ScaleSetClient GetScaleSet() params.ScaleSet + SetLastMessageID(id int64) error Owner() string } diff --git a/workers/scaleset/scaleset.go b/workers/scaleset/scaleset.go index c5e31b5d..c392d5cd 100644 --- a/workers/scaleset/scaleset.go +++ b/workers/scaleset/scaleset.go @@ -148,6 +148,12 @@ func (w *Worker) handleEvent(event dbCommon.ChangePayload) { case dbCommon.UpdateOperation: slog.DebugContext(w.ctx, "got update operation") w.mux.Lock() + if scaleSet.MaxRunners < w.Entity.MaxRunners { + slog.DebugContext(w.ctx, "max runners changed; stopping listener") + if err := w.listener.Stop(); err != nil { + slog.ErrorContext(w.ctx, "error stopping listener", "error", err) + } + } w.Entity = scaleSet w.mux.Unlock() default: diff --git a/workers/scaleset/scaleset_helper.go b/workers/scaleset/scaleset_helper.go index 8cfa9264..abfd37c4 100644 --- a/workers/scaleset/scaleset_helper.go +++ b/workers/scaleset/scaleset_helper.go @@ -18,3 +18,10 @@ func (w *Worker) GetScaleSet() params.ScaleSet { func (w *Worker) Owner() string { return fmt.Sprintf("garm-%s", w.controllerInfo.ControllerID) } + +func (w *Worker) SetLastMessageID(id int64) error { + if err := w.store.SetScaleSetLastMessageID(w.ctx, w.Entity.ID, id); err != nil { + return fmt.Errorf("setting last message ID: %w", err) + } + return nil +} diff --git a/workers/scaleset/scaleset_listener.go b/workers/scaleset/scaleset_listener.go index f92eaff1..bf22b61b 100644 --- a/workers/scaleset/scaleset_listener.go +++ b/workers/scaleset/scaleset_listener.go @@ -7,6 +7,7 @@ import ( "log/slog" "sync" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/util/github/scalesets" ) @@ -15,6 +16,7 @@ func newListener(ctx context.Context, scaleSetHelper scaleSetHelper) *scaleSetLi return &scaleSetListener{ ctx: ctx, scaleSetHelper: scaleSetHelper, + lastMessageID: scaleSetHelper.GetScaleSet().LastMessageID, } } @@ -33,9 +35,10 @@ type scaleSetListener struct { scaleSetHelper scaleSetHelper messageSession *scalesets.MessageSession - mux sync.Mutex - running bool - quit chan struct{} + mux sync.Mutex + running bool + quit chan struct{} + loopExited chan struct{} } func (l *scaleSetListener) Start() error { @@ -56,6 +59,7 @@ func (l *scaleSetListener) Start() error { l.messageSession = session l.quit = make(chan struct{}) l.running = true + l.loopExited = make(chan struct{}) go l.loop() return nil @@ -78,10 +82,12 @@ func (l *scaleSetListener) Stop() error { slog.ErrorContext(l.ctx, "error deleting message session", "error", err) } } - l.cancelFunc() + l.messageSession.Close() l.running = false + l.listenerCtx = nil close(l.quit) + l.cancelFunc() return nil } @@ -91,14 +97,22 @@ func (l *scaleSetListener) handleSessionMessage(msg params.RunnerScaleSetMessage body, err := msg.GetJobsFromBody() if err != nil { slog.ErrorContext(l.ctx, "getting jobs from body", "error", err) - return } slog.InfoContext(l.ctx, "handling message", "message", msg, "body", body) - l.lastMessageID = msg.MessageID + if msg.MessageID < l.lastMessageID { + slog.DebugContext(l.ctx, "message is older than last message, ignoring") + } else { + l.lastMessageID = msg.MessageID + if err := l.scaleSetHelper.SetLastMessageID(msg.MessageID); err != nil { + slog.ErrorContext(l.ctx, "setting last message ID", "error", err) + } + } } func (l *scaleSetListener) loop() { + defer close(l.loopExited) defer l.Stop() + retryAfterUnauthorized := false slog.DebugContext(l.ctx, "starting scale set listener loop", "scale_set", l.scaleSetHelper.GetScaleSet().ScaleSetID) for { @@ -112,23 +126,46 @@ func (l *scaleSetListener) loop() { slog.DebugContext(l.ctx, "scaleset worker has stopped") return default: - slog.DebugContext(l.ctx, "getting message") + slog.DebugContext(l.ctx, "getting message", "last_message_id", l.lastMessageID, "max_runners", l.scaleSetHelper.GetScaleSet().MaxRunners) msg, err := l.messageSession.GetMessage( l.listenerCtx, l.lastMessageID, l.scaleSetHelper.GetScaleSet().MaxRunners) if err != nil { + if errors.Is(err, runnerErrors.ErrUnauthorized) { + if retryAfterUnauthorized { + slog.DebugContext(l.ctx, "unauthorized, stopping listener") + return + } + // The session manager refreshes the token automatically, but once we call + // GetMessage(), it blocks until a new message is sent on the longpoll. + // If there are no messages for a while, the token used to longpoll expires + // and we get an unauthorized error. We simply need to retry the request + // and it should use the refreshed token. If we fail a second time, we can + // return and the scaleset worker will attempt to restart the listener. + retryAfterUnauthorized = true + slog.DebugContext(l.ctx, "got unauthorized error, retrying") + continue + } if !errors.Is(err, context.Canceled) { slog.ErrorContext(l.ctx, "getting message", "error", err) } + slog.DebugContext(l.ctx, "stopping scale set listener") return } - l.handleSessionMessage(msg) + retryAfterUnauthorized = false + if !msg.IsNil() { + l.handleSessionMessage(msg) + } } } } func (l *scaleSetListener) Wait() <-chan struct{} { + l.mux.Lock() if !l.running { + slog.DebugContext(l.ctx, "scale set listener is not running") + l.mux.Unlock() return nil } - return l.listenerCtx.Done() + l.mux.Unlock() + return l.loopExited }