From 39003f006ab15d3edc9b5aae28ff252a42d06122 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Sat, 23 Aug 2025 00:02:11 +0000 Subject: [PATCH] Ensure scale set exists Github will remove inactive scale sets after 7 days. This change ensures the scale set exists in github before spinning up the listener. Signed-off-by: Gabriel Adrian Samfira --- cache/entity_cache.go | 56 ++++++++++-- database/sql/scalesets.go | 4 + params/requests.go | 1 + runner/common/mocks/GithubClient.go | 57 ++++++++++++ runner/common/mocks/GithubEntityOperations.go | 57 ++++++++++++ runner/common/util.go | 1 + runner/pool/pool.go | 3 +- runner/pool/stub_client.go | 4 + runner/scalesets.go | 11 +-- util/github/client.go | 38 +++++--- util/github/scalesets/client.go | 9 ++ workers/scaleset/controller.go | 5 +- workers/scaleset/controller_watcher.go | 43 ++++++--- workers/scaleset/scaleset.go | 87 ++++++++++++++++++- 14 files changed, 333 insertions(+), 43 deletions(-) diff --git a/cache/entity_cache.go b/cache/entity_cache.go index 4800dd9c..c676332f 100644 --- a/cache/entity_cache.go +++ b/cache/entity_cache.go @@ -15,6 +15,7 @@ package cache import ( "sync" + "time" "github.com/cloudbase/garm/params" ) @@ -28,10 +29,16 @@ func init() { entityCache = ghEntityCache } +type RunnerGroupEntry struct { + RunnerGroupID int64 + time time.Time +} + type EntityItem struct { - Entity params.ForgeEntity - Pools map[string]params.Pool - ScaleSets map[uint]params.ScaleSet + Entity params.ForgeEntity + Pools map[string]params.Pool + ScaleSets map[uint]params.ScaleSet + RunnerGroups map[string]RunnerGroupEntry } type EntityCache struct { @@ -80,9 +87,10 @@ func (e *EntityCache) SetEntity(entity params.ForgeEntity) { cache, ok := e.entities[entity.ID] if !ok { e.entities[entity.ID] = EntityItem{ - Entity: entity, - Pools: make(map[string]params.Pool), - ScaleSets: make(map[uint]params.ScaleSet), + Entity: entity, + Pools: make(map[string]params.Pool), + ScaleSets: make(map[uint]params.ScaleSet), + RunnerGroups: make(map[string]RunnerGroupEntry), } return } @@ -314,6 +322,42 @@ func (e *EntityCache) GetAllScaleSets() []params.ScaleSet { return scaleSets } +func (e *EntityCache) SetEntityRunnerGroup(entityID, runnerGroupName string, runnerGroupID int64) { + e.mux.Lock() + defer e.mux.Unlock() + + if _, ok := e.entities[entityID]; ok { + e.entities[entityID].RunnerGroups[runnerGroupName] = RunnerGroupEntry{ + RunnerGroupID: runnerGroupID, + time: time.Now().UTC(), + } + } +} + +func (e *EntityCache) GetEntityRunnerGroup(entityID, runnerGroupName string) (int64, bool) { + e.mux.Lock() + defer e.mux.Unlock() + + if _, ok := e.entities[entityID]; ok { + if runnerGroup, ok := e.entities[entityID].RunnerGroups[runnerGroupName]; ok { + if time.Now().UTC().After(runnerGroup.time.Add(1 * time.Hour)) { + delete(e.entities[entityID].RunnerGroups, runnerGroupName) + return 0, false + } + return runnerGroup.RunnerGroupID, true + } + } + return 0, false +} + +func SetEntityRunnerGroup(entityID, runnerGroupName string, runnerGroupID int64) { + entityCache.SetEntityRunnerGroup(entityID, runnerGroupName, runnerGroupID) +} + +func GetEntityRunnerGroup(entityID, runnerGroupName string) (int64, bool) { + return entityCache.GetEntityRunnerGroup(entityID, runnerGroupName) +} + func GetEntity(entityID string) (params.ForgeEntity, bool) { return entityCache.GetEntity(entityID) } diff --git a/database/sql/scalesets.go b/database/sql/scalesets.go index b247b7a8..5877ad5c 100644 --- a/database/sql/scalesets.go +++ b/database/sql/scalesets.go @@ -294,6 +294,10 @@ func (s *sqlDatabase) updateScaleSet(tx *gorm.DB, scaleSet ScaleSet, param param scaleSet.ExtendedState = *param.ExtendedState } + if param.ScaleSetID != 0 { + scaleSet.ScaleSetID = param.ScaleSetID + } + if param.Name != "" { scaleSet.Name = param.Name } diff --git a/params/requests.go b/params/requests.go index 7bc17959..c9021434 100644 --- a/params/requests.go +++ b/params/requests.go @@ -636,6 +636,7 @@ type UpdateScaleSetParams struct { GitHubRunnerGroup *string `json:"runner_group,omitempty"` State *ScaleSetState `json:"state"` ExtendedState *string `json:"extended_state"` + ScaleSetID int `json:"-"` } // swagger:model CreateGiteaEndpointParams diff --git a/runner/common/mocks/GithubClient.go b/runner/common/mocks/GithubClient.go index 92d4aa06..c1dbeae9 100644 --- a/runner/common/mocks/GithubClient.go +++ b/runner/common/mocks/GithubClient.go @@ -385,6 +385,63 @@ func (_c *GithubClient_GetEntityJITConfig_Call) RunAndReturn(run func(context.Co return _c } +// GetEntityRunnerGroupIDByName provides a mock function with given fields: ctx, runnerGroupName +func (_m *GithubClient) GetEntityRunnerGroupIDByName(ctx context.Context, runnerGroupName string) (int64, error) { + ret := _m.Called(ctx, runnerGroupName) + + if len(ret) == 0 { + panic("no return value specified for GetEntityRunnerGroupIDByName") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok { + return rf(ctx, runnerGroupName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { + r0 = rf(ctx, runnerGroupName) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, runnerGroupName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GithubClient_GetEntityRunnerGroupIDByName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEntityRunnerGroupIDByName' +type GithubClient_GetEntityRunnerGroupIDByName_Call struct { + *mock.Call +} + +// GetEntityRunnerGroupIDByName is a helper method to define mock.On call +// - ctx context.Context +// - runnerGroupName string +func (_e *GithubClient_Expecter) GetEntityRunnerGroupIDByName(ctx interface{}, runnerGroupName interface{}) *GithubClient_GetEntityRunnerGroupIDByName_Call { + return &GithubClient_GetEntityRunnerGroupIDByName_Call{Call: _e.mock.On("GetEntityRunnerGroupIDByName", ctx, runnerGroupName)} +} + +func (_c *GithubClient_GetEntityRunnerGroupIDByName_Call) Run(run func(ctx context.Context, runnerGroupName string)) *GithubClient_GetEntityRunnerGroupIDByName_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *GithubClient_GetEntityRunnerGroupIDByName_Call) Return(_a0 int64, _a1 error) *GithubClient_GetEntityRunnerGroupIDByName_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *GithubClient_GetEntityRunnerGroupIDByName_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *GithubClient_GetEntityRunnerGroupIDByName_Call { + _c.Call.Return(run) + return _c +} + // GetWorkflowJobByID provides a mock function with given fields: ctx, owner, repo, jobID func (_m *GithubClient) GetWorkflowJobByID(ctx context.Context, owner string, repo string, jobID int64) (*github.WorkflowJob, *github.Response, error) { ret := _m.Called(ctx, owner, repo, jobID) diff --git a/runner/common/mocks/GithubEntityOperations.go b/runner/common/mocks/GithubEntityOperations.go index 2448df4c..0b3c3f83 100644 --- a/runner/common/mocks/GithubEntityOperations.go +++ b/runner/common/mocks/GithubEntityOperations.go @@ -385,6 +385,63 @@ func (_c *GithubEntityOperations_GetEntityJITConfig_Call) RunAndReturn(run func( return _c } +// GetEntityRunnerGroupIDByName provides a mock function with given fields: ctx, runnerGroupName +func (_m *GithubEntityOperations) GetEntityRunnerGroupIDByName(ctx context.Context, runnerGroupName string) (int64, error) { + ret := _m.Called(ctx, runnerGroupName) + + if len(ret) == 0 { + panic("no return value specified for GetEntityRunnerGroupIDByName") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok { + return rf(ctx, runnerGroupName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { + r0 = rf(ctx, runnerGroupName) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, runnerGroupName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GithubEntityOperations_GetEntityRunnerGroupIDByName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEntityRunnerGroupIDByName' +type GithubEntityOperations_GetEntityRunnerGroupIDByName_Call struct { + *mock.Call +} + +// GetEntityRunnerGroupIDByName is a helper method to define mock.On call +// - ctx context.Context +// - runnerGroupName string +func (_e *GithubEntityOperations_Expecter) GetEntityRunnerGroupIDByName(ctx interface{}, runnerGroupName interface{}) *GithubEntityOperations_GetEntityRunnerGroupIDByName_Call { + return &GithubEntityOperations_GetEntityRunnerGroupIDByName_Call{Call: _e.mock.On("GetEntityRunnerGroupIDByName", ctx, runnerGroupName)} +} + +func (_c *GithubEntityOperations_GetEntityRunnerGroupIDByName_Call) Run(run func(ctx context.Context, runnerGroupName string)) *GithubEntityOperations_GetEntityRunnerGroupIDByName_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *GithubEntityOperations_GetEntityRunnerGroupIDByName_Call) Return(_a0 int64, _a1 error) *GithubEntityOperations_GetEntityRunnerGroupIDByName_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *GithubEntityOperations_GetEntityRunnerGroupIDByName_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *GithubEntityOperations_GetEntityRunnerGroupIDByName_Call { + _c.Call.Return(run) + return _c +} + // GithubBaseURL provides a mock function with no fields func (_m *GithubEntityOperations) GithubBaseURL() *url.URL { ret := _m.Called() diff --git a/runner/common/util.go b/runner/common/util.go index d8519438..5130dcfd 100644 --- a/runner/common/util.go +++ b/runner/common/util.go @@ -35,6 +35,7 @@ type GithubEntityOperations interface { RateLimit(ctx context.Context) (*github.RateLimits, error) CreateEntityRegistrationToken(ctx context.Context) (*github.RegistrationToken, *github.Response, error) GetEntityJITConfig(ctx context.Context, instance string, pool params.Pool, labels []string) (jitConfigMap map[string]string, runner *github.Runner, err error) + GetEntityRunnerGroupIDByName(ctx context.Context, runnerGroupName string) (int64, error) // GetEntity returns the GitHub entity for which the github client was instanciated. GetEntity() params.ForgeEntity diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 8610d4c9..97ffa992 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -771,8 +771,7 @@ func (r *basePoolManager) AddRunner(ctx context.Context, poolID string, aditiona // Attempt to create JIT config jitConfig, runner, err = r.ghcli.GetEntityJITConfig(ctx, name, pool, labels) if err != nil { - slog.With(slog.Any("error", err)).ErrorContext( - ctx, "failed to get JIT config, falling back to registration token") + return fmt.Errorf("failed to generate JIT config: %w", err) } } diff --git a/runner/pool/stub_client.go b/runner/pool/stub_client.go index 6493f7a5..0afd6a52 100644 --- a/runner/pool/stub_client.go +++ b/runner/pool/stub_client.go @@ -82,3 +82,7 @@ func (s *stubGithubClient) GithubBaseURL() *url.URL { func (s *stubGithubClient) RateLimit(_ context.Context) (*github.RateLimits, error) { return nil, s.err } + +func (s *stubGithubClient) GetEntityRunnerGroupIDByName(_ context.Context, _ string) (int64, error) { + return 0, s.err +} diff --git a/runner/scalesets.go b/runner/scalesets.go index d9361698..136ddec2 100644 --- a/runner/scalesets.go +++ b/runner/scalesets.go @@ -225,13 +225,10 @@ func (r *Runner) CreateEntityScaleSet(ctx context.Context, entityType params.For if err != nil { return params.ScaleSet{}, fmt.Errorf("error getting scaleset client: %w", err) } - var runnerGroupID int64 = 1 - if param.GitHubRunnerGroup != "Default" { - runnerGroup, err := scalesetCli.GetRunnerGroupByName(ctx, param.GitHubRunnerGroup) - if err != nil { - return params.ScaleSet{}, fmt.Errorf("error getting runner group: %w", err) - } - runnerGroupID = runnerGroup.ID + + runnerGroupID, err := ghCli.GetEntityRunnerGroupIDByName(ctx, param.GitHubRunnerGroup) + if err != nil { + return params.ScaleSet{}, fmt.Errorf("failed to get github runner group for entity %s: %w", entity.ID, err) } createParam := ¶ms.RunnerScaleSet{ diff --git a/util/github/client.go b/util/github/client.go index 19380587..7c1dc11b 100644 --- a/util/github/client.go +++ b/util/github/client.go @@ -28,6 +28,7 @@ import ( "github.com/google/go-github/v72/github" runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/cache" "github.com/cloudbase/garm/metrics" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner/common" @@ -419,22 +420,35 @@ func (g *githubClient) getEnterpriseRunnerGroupIDByName(ctx context.Context, ent return 0, runnerErrors.NewNotFoundError("runner group not found") } -func (g *githubClient) GetEntityJITConfig(ctx context.Context, instance string, pool params.Pool, labels []string) (jitConfigMap map[string]string, runner *github.Runner, err error) { - // If no runner group is set, use the default runner group ID. This is also the default for - // repository level runners. +func (g *githubClient) GetEntityRunnerGroupIDByName(ctx context.Context, runnerGroupName string) (int64, error) { var rgID int64 = 1 + var ok bool + var err error + // attempt to get the runner group ID from cache. Cache will invalidate after 1 hour. + if runnerGroupName != "" && !strings.EqualFold(runnerGroupName, "default") { + rgID, ok = cache.GetEntityRunnerGroup(g.entity.ID, runnerGroupName) + if !ok { + switch g.entity.EntityType { + case params.ForgeEntityTypeOrganization: + rgID, err = g.getOrganizationRunnerGroupIDByName(ctx, g.entity, runnerGroupName) + case params.ForgeEntityTypeEnterprise: + rgID, err = g.getEnterpriseRunnerGroupIDByName(ctx, g.entity, runnerGroupName) + } - if pool.GitHubRunnerGroup != "" { - switch g.entity.EntityType { - case params.ForgeEntityTypeOrganization: - rgID, err = g.getOrganizationRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup) - case params.ForgeEntityTypeEnterprise: - rgID, err = g.getEnterpriseRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup) + if err != nil { + return 0, fmt.Errorf("getting runner group ID: %w", err) + } } + // set cache. Avoid getting the same runner group for more than once an hour. + cache.SetEntityRunnerGroup(g.entity.ID, runnerGroupName, rgID) + } + return rgID, nil +} - if err != nil { - return nil, nil, fmt.Errorf("getting runner group ID: %w", err) - } +func (g *githubClient) GetEntityJITConfig(ctx context.Context, instance string, pool params.Pool, labels []string) (jitConfigMap map[string]string, runner *github.Runner, err error) { + rgID, err := g.GetEntityRunnerGroupIDByName(ctx, pool.GitHubRunnerGroup) + if err != nil { + return nil, nil, fmt.Errorf("failed to get runner group: %w", err) } req := github.GenerateJITConfigRequest{ diff --git a/util/github/scalesets/client.go b/util/github/scalesets/client.go index 5b01a539..6b4b1bab 100644 --- a/util/github/scalesets/client.go +++ b/util/github/scalesets/client.go @@ -57,6 +57,15 @@ func (s *ScaleSetClient) SetGithubClient(cli common.GithubClient) { s.ghCli = cli } +func (s *ScaleSetClient) GetGithubClient() (common.GithubClient, error) { + s.mux.Lock() + defer s.mux.Unlock() + if s.ghCli == nil { + return nil, fmt.Errorf("github client is not set in scaleset client") + } + return s.ghCli, nil +} + func (s *ScaleSetClient) Do(req *http.Request) (*http.Response, error) { if s.httpClient == nil { return nil, fmt.Errorf("http client is not initialized") diff --git a/workers/scaleset/controller.go b/workers/scaleset/controller.go index 63112f43..32d3d713 100644 --- a/workers/scaleset/controller.go +++ b/workers/scaleset/controller.go @@ -33,7 +33,10 @@ func NewController(ctx context.Context, store dbCommon.Store, entity params.Forg ctx = garmUtil.WithSlogContext( ctx, - slog.Any("worker", consumerID)) + slog.Any("worker", consumerID), + slog.Any("entity", entity.String()), + slog.Any("endpoint", entity.Credentials.Endpoint), + ) return &Controller{ ctx: ctx, diff --git a/workers/scaleset/controller_watcher.go b/workers/scaleset/controller_watcher.go index 8344cac5..e3c32ea6 100644 --- a/workers/scaleset/controller_watcher.go +++ b/workers/scaleset/controller_watcher.go @@ -65,6 +65,22 @@ func (c *Controller) handleScaleSet(event dbCommon.ChangePayload) { } } +func (c *Controller) createScaleSetWorker(scaleSet params.ScaleSet) (*Worker, error) { + provider, ok := c.providers[scaleSet.ProviderName] + if !ok { + // Providers are currently static, set in the config and cannot be updated without a restart. + // ScaleSets and pools also do not allow updating the provider. This condition is not recoverable + // without a restart, so we don't need to instantiate a worker for this scale set. + return nil, fmt.Errorf("provider %s not found for scale set %s", scaleSet.ProviderName, scaleSet.Name) + } + + worker, err := NewWorker(c.ctx, c.store, scaleSet, provider) + if err != nil { + return nil, fmt.Errorf("creating scale set worker: %w", err) + } + return worker, nil +} + func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet) error { c.mux.Lock() defer c.mux.Unlock() @@ -74,17 +90,9 @@ func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet) error { return nil } - provider, ok := c.providers[sSet.ProviderName] - if !ok { - // Providers are currently static, set in the config and cannot be updated without a restart. - // ScaleSets and pools also do not allow updating the provider. This condition is not recoverable - // without a restart, so we don't need to instantiate a worker for this scale set. - return fmt.Errorf("provider %s not found for scale set %s", sSet.ProviderName, sSet.Name) - } - - worker, err := NewWorker(c.ctx, c.store, sSet, provider) + worker, err := c.createScaleSetWorker(sSet) if err != nil { - return fmt.Errorf("creating scale set worker: %w", err) + return fmt.Errorf("error creating scale set worker: %w", err) } if err := worker.Start(); err != nil { // The Start() function should only return an error if an unrecoverable error occurs. @@ -92,7 +100,7 @@ func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet) error { // to retry fixing the condition. For example, not being able to retrieve tools due to bad // credentials should not stop the worker. The credentials can be fixed and the worker // can continue to work. - return fmt.Errorf("starting scale set worker: %w", err) + return fmt.Errorf("error starting scale set worker: %w", err) } c.ScaleSets[sSet.ID] = &scaleSet{ scaleSet: sSet, @@ -130,6 +138,19 @@ func (c *Controller) handleScaleSetUpdateOperation(sSet params.ScaleSet) error { // fixing the reason for the failure. return c.handleScaleSetCreateOperation(sSet) } + if set.worker != nil && !set.worker.IsRunning() { + worker, err := c.createScaleSetWorker(sSet) + if err != nil { + return fmt.Errorf("creating scale set worker: %w", err) + } + set.worker = worker + defer func() { + if err := worker.Start(); err != nil { + slog.ErrorContext(c.ctx, "failed to start worker", "error", err, "scaleset", sSet.Name) + } + }() + } + set.scaleSet = sSet c.ScaleSets[sSet.ID] = set // We let the watcher in the scale set worker handle the update operation. diff --git a/workers/scaleset/scaleset.go b/workers/scaleset/scaleset.go index 48aa8508..8c0abefa 100644 --- a/workers/scaleset/scaleset.go +++ b/workers/scaleset/scaleset.go @@ -68,6 +68,74 @@ type Worker struct { quit chan struct{} } +func (w *Worker) ensureScaleSetInGitHub() error { + entity, err := w.scaleSet.GetEntity() + if err != nil { + return fmt.Errorf("failed to get entity: %w", err) + } + cli, err := w.GetScaleSetClient() + if err != nil { + return fmt.Errorf("failed to get scaleset client: %w", err) + } + + ghCli, err := cli.GetGithubClient() + if err != nil { + return fmt.Errorf("failed to get github client: %w", err) + } + + rgID, err := ghCli.GetEntityRunnerGroupIDByName(w.ctx, w.scaleSet.GitHubRunnerGroup) + if err != nil { + return fmt.Errorf("failed to get github runner group for entity %s: %w", entity.ID, err) + } + scaleSet, err := cli.GetRunnerScaleSetByNameAndRunnerGroup(w.ctx, int(rgID), w.scaleSet.Name) + if err == nil { + // The scale set exists + if scaleSet.ID != w.scaleSet.ScaleSetID { + // The scale set exists in github, but the ID differs from what we know to be true. + // It is possible that the scale set is being managed by some other auto scaler. + // We error here, as there is no way to listen on a scale set that already has a listener + // or is being managed by something else. + return fmt.Errorf("scale set already exists in github and it differs from the ID we know (github: %d vs local: %d)", scaleSet.ID, w.scaleSet.ScaleSetID) + } + return nil + } + if !errors.Is(err, runnerErrors.ErrNotFound) { + return fmt.Errorf("failed to get scale set: %w", err) + } + + createScaleSetParams := ¶ms.RunnerScaleSet{ + Name: w.scaleSet.Name, + RunnerGroupID: rgID, + Labels: []params.Label{ + { + Name: w.scaleSet.Name, + Type: "System", + }, + }, + RunnerSetting: params.RunnerSetting{ + Ephemeral: true, + DisableUpdate: w.scaleSet.DisableUpdate, + }, + Enabled: &w.scaleSet.Enabled, + } + runnerScaleSet, err := cli.CreateRunnerScaleSet(w.ctx, createScaleSetParams) + if err != nil { + return fmt.Errorf("error creating runner scale set: %w", err) + } + + // update the DB scale set + updateParams := params.UpdateScaleSetParams{ + ScaleSetID: runnerScaleSet.ID, + } + _, err = w.store.UpdateEntityScaleSet(w.ctx, entity, w.scaleSet.ID, updateParams, nil) + if err != nil { + return fmt.Errorf("failed to update scale set: %w", err) + } + w.scaleSet.ScaleSetID = runnerScaleSet.ID + + return nil +} + func (w *Worker) Stop() error { slog.DebugContext(w.ctx, "stopping scale set worker", "scale_set", w.consumerID) w.mux.Lock() @@ -86,6 +154,13 @@ func (w *Worker) Stop() error { return nil } +func (w *Worker) IsRunning() bool { + w.mux.Lock() + defer w.mux.Unlock() + + return w.running +} + func (w *Worker) Start() (err error) { slog.DebugContext(w.ctx, "starting scale set worker", "scale_set", w.consumerID) w.mux.Lock() @@ -101,8 +176,8 @@ func (w *Worker) Start() (err error) { } for _, instance := range instances { - switch { - case instance.Status == commonParams.InstanceCreating: + switch instance.Status { + case commonParams.InstanceCreating: // We're just starting up. We found an instance stuck in creating. // When a provider creates an instance, it sets the db instance to // creating and then issues an API call to the IaaS to create the @@ -177,7 +252,7 @@ func (w *Worker) Start() (err error) { return fmt.Errorf("updating runner %s: %w", instance.Name, err) } } - case instance.Status == commonParams.InstanceDeleting: + case commonParams.InstanceDeleting: // Set the instance in deleting. It is assumed that the runner was already // removed from github either by github or by garm. Deleting status indicates // that it was already being handled by the provider. There should be no entry on @@ -194,7 +269,7 @@ func (w *Worker) Start() (err error) { return fmt.Errorf("updating runner %s: %w", instance.Name, err) } } - case instance.Status == commonParams.InstanceDeleted: + case commonParams.InstanceDeleted: if err := w.handleInstanceCleanup(instance); err != nil { locking.Unlock(instance.Name, false) return fmt.Errorf("failed to remove database entry for %s: %w", instance.Name, err) @@ -205,6 +280,10 @@ func (w *Worker) Start() (err error) { locking.Unlock(instance.Name, false) } + if err := w.ensureScaleSetInGitHub(); err != nil { + return fmt.Errorf("failed to ensure scale set: %w", err) + } + consumer, err := watcher.RegisterConsumer( w.ctx, w.consumerID, watcher.WithAny(