From 004ad1f12446252a4ed5bf47e997dd0bddd0e272 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 24 Apr 2025 23:29:40 +0000 Subject: [PATCH] Add provider worker code Runners now get created and cleaned up in scale sets. Signed-off-by: Gabriel Adrian Samfira --- auth/instance_middleware.go | 4 +- cmd/garm/main.go | 22 ++ database/sql/instances.go | 12 +- database/sql/models.go | 17 +- database/sql/scalesets.go | 42 ++- database/sql/sql.go | 1 - database/sql/util.go | 34 --- locking/interface.go | 4 +- locking/local_locker.go | 40 ++- locking/locking.go | 20 +- runner/metadata.go | 48 ++- runner/pool/pool.go | 16 +- util/github/scalesets/util.go | 3 - workers/entity/controller.go | 2 +- workers/entity/worker.go | 2 +- workers/provider/errors.go | 7 + workers/provider/instance_manager.go | 422 ++++++++++++++++++++++++++ workers/provider/provider.go | 117 +++++-- workers/provider/provider_helper.go | 81 +++++ workers/scaleset/controller.go | 15 +- workers/scaleset/scaleset.go | 61 +++- workers/scaleset/scaleset_helper.go | 13 +- workers/scaleset/scaleset_listener.go | 1 + 23 files changed, 837 insertions(+), 147 deletions(-) create mode 100644 workers/provider/errors.go create mode 100644 workers/provider/instance_manager.go create mode 100644 workers/provider/provider_helper.go diff --git a/auth/instance_middleware.go b/auth/instance_middleware.go index b7194d5c..dbd3cfb7 100644 --- a/auth/instance_middleware.go +++ b/auth/instance_middleware.go @@ -60,7 +60,7 @@ type instanceToken struct { jwtSecret string } -func (i *instanceToken) NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) { +func (i *instanceToken) NewInstanceJWTToken(instance params.Instance, entity string, entityType params.GithubEntityType, ttlMinutes uint) (string, error) { // Token expiration is equal to the bootstrap timeout set on the pool plus the polling // interval garm uses to check for timed out runners. Runners that have not sent their info // by the end of this interval are most likely failed and will be reaped by garm anyway. @@ -82,7 +82,7 @@ func (i *instanceToken) NewInstanceJWTToken(instance params.Instance, entity str ID: instance.ID, Name: instance.Name, PoolID: instance.PoolID, - Scope: poolType, + Scope: entityType, Entity: entity, CreateAttempt: instance.CreateAttempt, } diff --git a/cmd/garm/main.go b/cmd/garm/main.go index 5879fd0a..d117dc6a 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "os/signal" + "runtime" "syscall" "time" @@ -51,6 +52,7 @@ import ( "github.com/cloudbase/garm/util/appdefaults" "github.com/cloudbase/garm/websocket" "github.com/cloudbase/garm/workers/entity" + "github.com/cloudbase/garm/workers/provider" ) var ( @@ -247,6 +249,19 @@ func main() { log.Fatalf("failed to start entity controller: %+v", err) } + instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret) + if err != nil { + log.Fatalf("failed to create instance token getter: %+v", err) + } + + providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter) + if err != nil { + log.Fatalf("failed to create provider worker: %+v", err) + } + if err := providerWorker.Start(); err != nil { + log.Fatalf("failed to start provider worker: %+v", err) + } + runner, err := runner.NewRunner(ctx, *cfg, db) if err != nil { log.Fatalf("failed to create controller: %+v", err) @@ -305,6 +320,8 @@ func main() { } if cfg.Default.DebugServer { + runtime.SetBlockProfileRate(1) + runtime.SetMutexProfileFraction(1) slog.InfoContext(ctx, "setting up debug routes") router = routers.WithDebugServer(router) } @@ -348,6 +365,11 @@ func main() { slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller") } + slog.InfoContext(ctx, "shutting down provider worker") + if err := providerWorker.Stop(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker") + } + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second) defer shutdownCancel() if err := srv.Shutdown(shutdownCtx); err != nil { diff --git a/database/sql/instances.go b/database/sql/instances.go index cf0020b5..604682e9 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -189,13 +189,19 @@ func (s *sqlDatabase) DeleteInstanceByName(ctx context.Context, instanceName str if instance.ProviderID != nil { providerID = *instance.ProviderID } - if notifyErr := s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{ + payload := params.Instance{ ID: instance.ID.String(), Name: instance.Name, ProviderID: providerID, AgentID: instance.AgentID, - PoolID: instance.PoolID.String(), - }); notifyErr != nil { + } + if instance.PoolID != nil { + payload.PoolID = instance.PoolID.String() + } + if instance.ScaleSetFkID != nil { + payload.ScaleSetID = *instance.ScaleSetFkID + } + if notifyErr := s.sendNotify(common.InstanceEntityType, common.DeleteOperation, payload); notifyErr != nil { slog.With(slog.Any("error", notifyErr)).Error("failed to send notify") } } diff --git a/database/sql/models.go b/database/sql/models.go index 3b1dcc9b..c1b6462d 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -86,17 +86,6 @@ type Pool struct { Priority uint `gorm:"index:idx_pool_priority"` } -type ScaleSetEvent struct { - gorm.Model - - EventType params.EventType - EventLevel params.EventLevel - Message string `gorm:"type:text"` - - ScaleSetID uint `gorm:"index:idx_scale_set_event"` - ScaleSet ScaleSet `gorm:"foreignKey:ScaleSetID"` -} - // ScaleSet represents a github scale set. Scale sets are almost identical to pools with a few // notable exceptions: // - Labels are no longer relevant @@ -146,11 +135,7 @@ type ScaleSet struct { EnterpriseID *uuid.UUID `gorm:"index"` Enterprise Enterprise `gorm:"foreignKey:EnterpriseID"` - Status string - StatusReason string `gorm:"type:text"` - - Instances []Instance `gorm:"foreignKey:ScaleSetFkID"` - Events []ScaleSetEvent `gorm:"foreignKey:ScaleSetID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE;"` + Instances []Instance `gorm:"foreignKey:ScaleSetFkID"` } type RepositoryEvent struct { diff --git a/database/sql/scalesets.go b/database/sql/scalesets.go index 3adc423c..f168813b 100644 --- a/database/sql/scalesets.go +++ b/database/sql/scalesets.go @@ -380,10 +380,25 @@ func (s *sqlDatabase) DeleteScaleSetByID(ctx context.Context, scaleSetID uint) ( return nil } -func (s *sqlDatabase) SetScaleSetLastMessageID(ctx context.Context, scaleSetID uint, lastMessageID int64) error { +func (s *sqlDatabase) SetScaleSetLastMessageID(ctx context.Context, scaleSetID uint, lastMessageID int64) (err error) { + var scaleSet params.ScaleSet + defer func() { + if err == nil && scaleSet.ID != 0 { + s.sendNotify(common.ScaleSetEntityType, common.UpdateOperation, scaleSet) + } + }() if err := s.conn.Transaction(func(tx *gorm.DB) error { - if q := tx.Model(&ScaleSet{}).Where("id = ?", scaleSetID).Update("last_message_id", lastMessageID); q.Error != nil { - return errors.Wrap(q.Error, "saving database entry") + dbSet, err := s.getScaleSetByID(tx, scaleSetID) + if err != nil { + return errors.Wrap(err, "fetching scale set") + } + dbSet.LastMessageID = lastMessageID + if err := tx.Save(&dbSet).Error; err != nil { + return errors.Wrap(err, "saving database entry") + } + scaleSet, err = s.sqlToCommonScaleSet(dbSet) + if err != nil { + return errors.Wrap(err, "converting scale set") } return nil }); err != nil { @@ -392,10 +407,25 @@ func (s *sqlDatabase) SetScaleSetLastMessageID(ctx context.Context, scaleSetID u return nil } -func (s *sqlDatabase) SetScaleSetDesiredRunnerCount(ctx context.Context, scaleSetID uint, desiredRunnerCount int) error { +func (s *sqlDatabase) SetScaleSetDesiredRunnerCount(ctx context.Context, scaleSetID uint, desiredRunnerCount int) (err error) { + var scaleSet params.ScaleSet + defer func() { + if err == nil && scaleSet.ID != 0 { + s.sendNotify(common.ScaleSetEntityType, common.UpdateOperation, scaleSet) + } + }() if err := s.conn.Transaction(func(tx *gorm.DB) error { - if q := tx.Model(&ScaleSet{}).Where("id = ?", scaleSetID).Update("desired_runner_count", desiredRunnerCount); q.Error != nil { - return errors.Wrap(q.Error, "saving database entry") + dbSet, err := s.getScaleSetByID(tx, scaleSetID) + if err != nil { + return errors.Wrap(err, "fetching scale set") + } + dbSet.DesiredRunnerCount = desiredRunnerCount + if err := tx.Save(&dbSet).Error; err != nil { + return errors.Wrap(err, "saving database entry") + } + scaleSet, err = s.sqlToCommonScaleSet(dbSet) + if err != nil { + return errors.Wrap(err, "converting scale set") } return nil }); err != nil { diff --git a/database/sql/sql.go b/database/sql/sql.go index 878224c6..a704d9c3 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -432,7 +432,6 @@ func (s *sqlDatabase) migrateDB() error { &ControllerInfo{}, &WorkflowJob{}, &ScaleSet{}, - &ScaleSetEvent{}, ); err != nil { return errors.Wrap(err, "running auto migrate") } diff --git a/database/sql/util.go b/database/sql/util.go index 5bd8de01..112d0a76 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -634,40 +634,6 @@ func (s *sqlDatabase) GetGithubEntity(_ context.Context, entityType params.Githu return entity, nil } -func (s *sqlDatabase) AddScaleSetEvent(ctx context.Context, scaleSetID uint, event params.EventType, eventLevel params.EventLevel, statusMessage string, maxEvents int) error { - scaleSet, err := s.GetScaleSetByID(ctx, scaleSetID) - if err != nil { - return errors.Wrap(err, "updating instance") - } - - msg := InstanceStatusUpdate{ - Message: statusMessage, - EventType: event, - EventLevel: eventLevel, - } - - if err := s.conn.Model(&scaleSet).Association("Events").Append(&msg); err != nil { - return errors.Wrap(err, "adding status message") - } - - if maxEvents > 0 { - var latestEvents []ScaleSetEvent - q := s.conn.Model(&ScaleSetEvent{}). - Limit(maxEvents).Order("id desc"). - Where("scale_set_id = ?", scaleSetID).Find(&latestEvents) - if q.Error != nil { - return errors.Wrap(q.Error, "fetching latest events") - } - if len(latestEvents) == maxEvents { - lastInList := latestEvents[len(latestEvents)-1] - if err := s.conn.Where("scale_set_id = ? and id < ?", scaleSetID, lastInList.ID).Unscoped().Delete(&ScaleSetEvent{}).Error; err != nil { - return errors.Wrap(err, "deleting old events") - } - } - } - return nil -} - func (s *sqlDatabase) addRepositoryEvent(ctx context.Context, repoID string, event params.EventType, eventLevel params.EventLevel, statusMessage string, maxEvents int) error { repo, err := s.GetRepositoryByID(ctx, repoID) if err != nil { diff --git a/locking/interface.go b/locking/interface.go index 07380a7b..d6a0b62d 100644 --- a/locking/interface.go +++ b/locking/interface.go @@ -4,8 +4,8 @@ import "time" // TODO(gabriel-samfira): needs owner attribute. type Locker interface { - TryLock(key string) bool - Lock(key string) + TryLock(key, identifier string) bool + Lock(key, identifier string) Unlock(key string, remove bool) Delete(key string) } diff --git a/locking/local_locker.go b/locking/local_locker.go index ad41345c..270138ef 100644 --- a/locking/local_locker.go +++ b/locking/local_locker.go @@ -2,6 +2,9 @@ package locking import ( "context" + "fmt" + "log/slog" + "runtime" "sync" "time" @@ -21,18 +24,29 @@ type keyMutex struct { muxes sync.Map } -var _ Locker = &keyMutex{} - -func (k *keyMutex) TryLock(key string) bool { - mux, _ := k.muxes.LoadOrStore(key, &sync.Mutex{}) - keyMux := mux.(*sync.Mutex) - return keyMux.TryLock() +type lockWithIdent struct { + mux sync.Mutex + ident string } -func (k *keyMutex) Lock(key string) { - mux, _ := k.muxes.LoadOrStore(key, &sync.Mutex{}) - keyMux := mux.(*sync.Mutex) - keyMux.Lock() +var _ Locker = &keyMutex{} + +func (k *keyMutex) TryLock(key, identifier string) bool { + mux, _ := k.muxes.LoadOrStore(key, &lockWithIdent{ + mux: sync.Mutex{}, + ident: identifier, + }) + keyMux := mux.(*lockWithIdent) + return keyMux.mux.TryLock() +} + +func (k *keyMutex) Lock(key, identifier string) { + mux, _ := k.muxes.LoadOrStore(key, &lockWithIdent{ + mux: sync.Mutex{}, + ident: identifier, + }) + keyMux := mux.(*lockWithIdent) + keyMux.mux.Lock() } func (k *keyMutex) Unlock(key string, remove bool) { @@ -40,11 +54,13 @@ func (k *keyMutex) Unlock(key string, remove bool) { if !ok { return } - keyMux := mux.(*sync.Mutex) + keyMux := mux.(*lockWithIdent) if remove { k.Delete(key) } - keyMux.Unlock() + _, filename, line, _ := runtime.Caller(1) + slog.Debug("unlocking", "key", key, "identifier", keyMux.ident, "caller", fmt.Sprintf("%s:%d", filename, line)) + keyMux.mux.Unlock() } func (k *keyMutex) Delete(key string) { diff --git a/locking/locking.go b/locking/locking.go index 6628d8b1..c7d99b1d 100644 --- a/locking/locking.go +++ b/locking/locking.go @@ -2,29 +2,41 @@ package locking import ( "fmt" + "log/slog" + "runtime" "sync" ) var locker Locker var lockerMux = sync.Mutex{} -func TryLock(key string) (bool, error) { +func TryLock(key, identifier string) (ok bool, err error) { + _, filename, line, _ := runtime.Caller(1) + slog.Debug("attempting to try lock", "key", key, "identifier", identifier, "caller", fmt.Sprintf("%s:%d", filename, line)) + defer slog.Debug("try lock returned", "key", key, "identifier", identifier, "locked", ok, "caller", fmt.Sprintf("%s:%d", filename, line)) if locker == nil { return false, fmt.Errorf("no locker is registered") } - return locker.TryLock(key), nil + ok = locker.TryLock(key, identifier) + return ok, nil } -func Lock(key string) { +func Lock(key, identifier string) { + _, filename, line, _ := runtime.Caller(1) + slog.Debug("attempting to lock", "key", key, "identifier", identifier, "caller", fmt.Sprintf("%s:%d", filename, line)) + defer slog.Debug("lock acquired", "key", key, "identifier", identifier, "caller", fmt.Sprintf("%s:%d", filename, line)) + if locker == nil { panic("no locker is registered") } - locker.Lock(key) + locker.Lock(key, identifier) } func Unlock(key string, remove bool) error { + _, filename, line, _ := runtime.Caller(1) + slog.Debug("attempting to unlock", "key", key, "remove", remove, "caller", fmt.Sprintf("%s:%d", filename, line)) if locker == nil { return fmt.Errorf("no locker is registered") } diff --git a/runner/metadata.go b/runner/metadata.go index 6b19c0d5..0be41fc7 100644 --- a/runner/metadata.go +++ b/runner/metadata.go @@ -7,7 +7,6 @@ import ( "fmt" "html/template" "log/slog" - "strings" "github.com/pkg/errors" @@ -57,24 +56,51 @@ func (r *Runner) GetRunnerServiceName(ctx context.Context) (string, error) { ctx, "failed to get instance params") return "", runnerErrors.ErrUnauthorized } + var entity params.GithubEntity - pool, err := r.store.GetPoolByID(r.ctx, instance.PoolID) - if err != nil { - slog.With(slog.Any("error", err)).ErrorContext( - ctx, "failed to get pool", - "pool_id", instance.PoolID) - return "", errors.Wrap(err, "fetching pool") + if instance.PoolID != "" { + pool, err := r.store.GetPoolByID(r.ctx, instance.PoolID) + if err != nil { + slog.With(slog.Any("error", err)).ErrorContext( + ctx, "failed to get pool", + "pool_id", instance.PoolID) + return "", errors.Wrap(err, "fetching pool") + } + entity, err = pool.GithubEntity() + if err != nil { + slog.With(slog.Any("error", err)).ErrorContext( + ctx, "failed to get pool entity", + "pool_id", instance.PoolID) + return "", errors.Wrap(err, "fetching pool entity") + } + } else if instance.ScaleSetID != 0 { + scaleSet, err := r.store.GetScaleSetByID(r.ctx, instance.ScaleSetID) + if err != nil { + slog.With(slog.Any("error", err)).ErrorContext( + ctx, "failed to get scale set", + "scale_set_id", instance.ScaleSetID) + return "", errors.Wrap(err, "fetching scale set") + } + entity, err = scaleSet.GithubEntity() + if err != nil { + slog.With(slog.Any("error", err)).ErrorContext( + ctx, "failed to get scale set entity", + "scale_set_id", instance.ScaleSetID) + return "", errors.Wrap(err, "fetching scale set entity") + } + } else { + return "", errors.New("instance not associated with a pool or scale set") } tpl := "actions.runner.%s.%s" var serviceName string - switch pool.PoolType() { + switch entity.EntityType { case params.GithubEntityTypeEnterprise: - serviceName = fmt.Sprintf(tpl, pool.EnterpriseName, instance.Name) + serviceName = fmt.Sprintf(tpl, entity.Owner, instance.Name) case params.GithubEntityTypeOrganization: - serviceName = fmt.Sprintf(tpl, pool.OrgName, instance.Name) + serviceName = fmt.Sprintf(tpl, entity.Owner, instance.Name) case params.GithubEntityTypeRepository: - serviceName = fmt.Sprintf(tpl, strings.ReplaceAll(pool.RepoName, "/", "-"), instance.Name) + serviceName = fmt.Sprintf(tpl, fmt.Sprintf("%s-%s", entity.Owner, entity.Name), instance.Name) } return serviceName, nil } diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 3ec72dad..88be9e97 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -100,6 +100,7 @@ func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, insta repo := &basePoolManager{ ctx: ctx, + consumerID: consumerID, entity: entity, ghcli: ghc, controllerInfo: controllerInfo, @@ -117,6 +118,7 @@ func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, insta type basePoolManager struct { ctx context.Context + consumerID string entity params.GithubEntity ghcli common.GithubClient controllerInfo params.ControllerInfo @@ -420,7 +422,7 @@ func (r *basePoolManager) cleanupOrphanedProviderRunners(runners []*github.Runne continue } - lockAcquired, err := locking.TryLock(instance.Name) + lockAcquired, err := locking.TryLock(instance.Name, r.consumerID) if !lockAcquired || err != nil { slog.DebugContext( r.ctx, "failed to acquire lock for instance", @@ -499,7 +501,7 @@ func (r *basePoolManager) reapTimedOutRunners(runners []*github.Runner) error { slog.DebugContext( r.ctx, "attempting to lock instance", "runner_name", instance.Name) - lockAcquired, err := locking.TryLock(instance.Name) + lockAcquired, err := locking.TryLock(instance.Name, r.consumerID) if !lockAcquired || err != nil { slog.DebugContext( r.ctx, "failed to acquire lock for instance", @@ -626,7 +628,7 @@ func (r *basePoolManager) cleanupOrphanedGithubRunners(runners []*github.Runner) poolInstanceCache[dbInstance.PoolID] = poolInstances } - lockAcquired, err := locking.TryLock(dbInstance.Name) + lockAcquired, err := locking.TryLock(dbInstance.Name, r.consumerID) if !lockAcquired || err != nil { slog.DebugContext( r.ctx, "failed to acquire lock for instance", @@ -1064,7 +1066,7 @@ func (r *basePoolManager) scaleDownOnePool(ctx context.Context, pool params.Pool for _, instanceToDelete := range idleWorkers[:numScaleDown] { instanceToDelete := instanceToDelete - lockAcquired, err := locking.TryLock(instanceToDelete.Name) + lockAcquired, err := locking.TryLock(instanceToDelete.Name, r.consumerID) if !lockAcquired || err != nil { slog.With(slog.Any("error", err)).ErrorContext( ctx, "failed to acquire lock for instance", @@ -1217,7 +1219,7 @@ func (r *basePoolManager) retryFailedInstancesForOnePool(ctx context.Context, po slog.DebugContext( ctx, "attempting to retry failed instance", "runner_name", instance.Name) - lockAcquired, err := locking.TryLock(instance.Name) + lockAcquired, err := locking.TryLock(instance.Name, r.consumerID) if !lockAcquired || err != nil { slog.DebugContext( ctx, "failed to acquire lock for instance", @@ -1401,7 +1403,7 @@ func (r *basePoolManager) deletePendingInstances() error { r.ctx, "removing instance from pool", "runner_name", instance.Name, "pool_id", instance.PoolID) - lockAcquired, err := locking.TryLock(instance.Name) + lockAcquired, err := locking.TryLock(instance.Name, r.consumerID) if !lockAcquired || err != nil { slog.InfoContext( r.ctx, "failed to acquire lock for instance", @@ -1513,7 +1515,7 @@ func (r *basePoolManager) addPendingInstances() error { r.ctx, "attempting to acquire lock for instance", "runner_name", instance.Name, "action", "create_pending") - lockAcquired, err := locking.TryLock(instance.Name) + lockAcquired, err := locking.TryLock(instance.Name, r.consumerID) if !lockAcquired || err != nil { slog.DebugContext( r.ctx, "failed to acquire lock for instance", diff --git a/util/github/scalesets/util.go b/util/github/scalesets/util.go index 66171dd6..15c3a5cf 100644 --- a/util/github/scalesets/util.go +++ b/util/github/scalesets/util.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "io" - "log/slog" "net/http" ) @@ -51,7 +50,5 @@ func (s *ScaleSetClient) newActionsRequest(ctx context.Context, method, path str req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.actionsServiceInfo.Token)) - slog.DebugContext(ctx, "newActionsRequest", "method", method, "url", uri.String(), "body", body, "headers", req.Header) - return req, nil } diff --git a/workers/entity/controller.go b/workers/entity/controller.go index bfdcabfe..424f9099 100644 --- a/workers/entity/controller.go +++ b/workers/entity/controller.go @@ -183,7 +183,7 @@ func (c *Controller) loop() { for { select { case payload := <-c.consumer.Watch(): - slog.InfoContext(c.ctx, "received payload", slog.Any("payload", payload)) + slog.InfoContext(c.ctx, "received payload") go c.handleWatcherEvent(payload) case <-c.ctx.Done(): return diff --git a/workers/entity/worker.go b/workers/entity/worker.go index 49fb75cb..070a9711 100644 --- a/workers/entity/worker.go +++ b/workers/entity/worker.go @@ -113,7 +113,7 @@ func (w *Worker) loop() { for { select { case payload := <-w.consumer.Watch(): - slog.InfoContext(w.ctx, "received payload", slog.Any("payload", payload)) + slog.InfoContext(w.ctx, "received payload") go w.handleWorkerWatcherEvent(payload) case <-w.ctx.Done(): return diff --git a/workers/provider/errors.go b/workers/provider/errors.go new file mode 100644 index 00000000..d46a721b --- /dev/null +++ b/workers/provider/errors.go @@ -0,0 +1,7 @@ +package provider + +import "fmt" + +var ( + ErrInstanceDeleted = fmt.Errorf("instance deleted") +) diff --git a/workers/provider/instance_manager.go b/workers/provider/instance_manager.go new file mode 100644 index 00000000..c20c75ae --- /dev/null +++ b/workers/provider/instance_manager.go @@ -0,0 +1,422 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + commonParams "github.com/cloudbase/garm-provider-common/params" + + "github.com/cloudbase/garm/cache" + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" + "github.com/cloudbase/garm/runner/common" + garmUtil "github.com/cloudbase/garm/util" +) + +func NewInstanceManager(ctx context.Context, instance params.Instance, scaleSet params.ScaleSet, provider common.Provider, helper providerHelper) (*instanceManager, error) { + ctx = garmUtil.WithSlogContext(ctx, slog.Any("instance", instance.Name)) + + githubEntity, err := scaleSet.GithubEntity() + if err != nil { + return nil, fmt.Errorf("getting github entity: %w", err) + } + return &instanceManager{ + ctx: ctx, + instance: instance, + provider: provider, + deleteBackoff: time.Second * 0, + scaleSet: scaleSet, + helper: helper, + scaleSetEntity: githubEntity, + }, nil +} + +// instanceManager handles the lifecycle of a single instance. +// When an instance is created, a new instance manager is created +// for it. When the instance is placed in pending_create, the manager +// will attempt to create a new compute resource in the designated +// provider. Finally, when an instance is marked as pending_delete, it is removed +// from the provider and on success the instance is marked as deleted. Failure to +// delete, will place the instance back in pending delete. The removal process is +// retried after a backoff period. Instances placed in force_pending_delete will +// ignore provider errors and exit. +type instanceManager struct { + ctx context.Context + + instance params.Instance + provider common.Provider + helper providerHelper + + scaleSet params.ScaleSet + scaleSetEntity params.GithubEntity + + deleteBackoff time.Duration + + updates chan dbCommon.ChangePayload + mux sync.Mutex + running bool + quit chan struct{} +} + +func (i *instanceManager) Start() error { + i.mux.Lock() + defer i.mux.Unlock() + + if i.running { + return nil + } + + // switch i.instance.Status { + // case commonParams.InstancePendingCreate, + // commonParams.InstancePendingDelete, + // commonParams.InstancePendingForceDelete: + // if err := i.consolidateState(); err != nil { + // return fmt.Errorf("consolidating state: %w", err) + // } + // case commonParams.InstanceDeleted: + // return ErrInstanceDeleted + // } + i.running = true + i.quit = make(chan struct{}) + i.updates = make(chan dbCommon.ChangePayload) + + go i.loop() + return nil +} + +func (i *instanceManager) Stop() error { + i.mux.Lock() + defer i.mux.Unlock() + + if !i.running { + return nil + } + + i.running = false + close(i.quit) + close(i.updates) + return nil +} + +func (i *instanceManager) sleepForBackOffOrCanceled() bool { + timer := time.NewTimer(i.deleteBackoff) + defer timer.Stop() + + select { + case <-timer.C: + return false + case <-i.quit: + return true + case <-i.ctx.Done(): + return true + } +} + +func (i *instanceManager) incrementBackOff() { + if i.deleteBackoff == 0 { + i.deleteBackoff = time.Second * 1 + } else { + i.deleteBackoff *= 2 + } + if i.deleteBackoff > time.Minute*5 { + i.deleteBackoff = time.Minute * 5 + } +} + +func (i *instanceManager) getEntity() (params.GithubEntity, error) { + entity, err := i.scaleSet.GithubEntity() + if err != nil { + return params.GithubEntity{}, fmt.Errorf("getting entity: %w", err) + } + ghEntity, err := i.helper.GetGithubEntity(entity) + if err != nil { + return params.GithubEntity{}, fmt.Errorf("getting entity: %w", err) + } + return ghEntity, nil +} + +func (i *instanceManager) pseudoPoolID() string { + // This is temporary. We need to extend providers to know about scale sets. + return fmt.Sprintf("%s-%s", i.scaleSet.Name, i.scaleSetEntity.ID) +} + +func (i *instanceManager) handleCreateInstanceInProvider(instance params.Instance) error { + // TODO(gabriel-samfira): implement the creation of the instance in the provider. + entity, err := i.getEntity() + if err != nil { + return fmt.Errorf("getting entity: %w", err) + } + jwtValidity := instance.RunnerTimeout() + token, err := i.helper.InstanceTokenGetter().NewInstanceJWTToken( + instance, entity.String(), entity.EntityType, jwtValidity) + if err != nil { + return fmt.Errorf("creating instance token: %w", err) + } + tools, ok := cache.GetGithubToolsCache(entity) + if !ok { + return fmt.Errorf("tools not found in cache for entity %s", entity.String()) + } + + bootstrapArgs := commonParams.BootstrapInstance{ + Name: instance.Name, + Tools: tools, + RepoURL: entity.GithubURL(), + MetadataURL: instance.MetadataURL, + CallbackURL: instance.CallbackURL, + InstanceToken: token, + OSArch: i.scaleSet.OSArch, + OSType: i.scaleSet.OSType, + Flavor: i.scaleSet.Flavor, + Image: i.scaleSet.Image, + ExtraSpecs: i.scaleSet.ExtraSpecs, + // This is temporary. We need to extend providers to know about scale sets. + PoolID: i.pseudoPoolID(), + CACertBundle: entity.Credentials.CABundle, + GitHubRunnerGroup: i.scaleSet.GitHubRunnerGroup, + JitConfigEnabled: true, + } + + var instanceIDToDelete string + baseParams, err := i.getProviderBaseParams() + if err != nil { + return fmt.Errorf("getting provider base params: %w", err) + } + + defer func() { + if instanceIDToDelete != "" { + deleteInstanceParams := common.DeleteInstanceParams{ + DeleteInstanceV011: common.DeleteInstanceV011Params{ + ProviderBaseParams: baseParams, + }, + } + if err := i.provider.DeleteInstance(i.ctx, instanceIDToDelete, deleteInstanceParams); err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + slog.With(slog.Any("error", err)).ErrorContext( + i.ctx, "failed to cleanup instance", + "provider_id", instanceIDToDelete) + } + } + } + }() + + createInstanceParams := common.CreateInstanceParams{ + CreateInstanceV011: common.CreateInstanceV011Params{ + ProviderBaseParams: baseParams, + }, + } + + providerInstance, err := i.provider.CreateInstance(i.ctx, bootstrapArgs, createInstanceParams) + if err != nil { + instanceIDToDelete = instance.Name + return fmt.Errorf("creating instance in provider: %w", err) + } + + if providerInstance.Status == commonParams.InstanceError { + instanceIDToDelete = instance.ProviderID + if instanceIDToDelete == "" { + instanceIDToDelete = instance.Name + } + } + + updated, err := i.helper.updateArgsFromProviderInstance(instance.Name, providerInstance) + if err != nil { + return fmt.Errorf("updating instance args: %w", err) + } + i.instance = updated + + return nil +} + +func (i *instanceManager) getProviderBaseParams() (common.ProviderBaseParams, error) { + info, err := i.helper.GetControllerInfo() + if err != nil { + return common.ProviderBaseParams{}, fmt.Errorf("getting controller info: %w", err) + } + + return common.ProviderBaseParams{ + ControllerInfo: info, + }, nil +} + +func (i *instanceManager) handleDeleteInstanceInProvider(instance params.Instance) error { + slog.InfoContext(i.ctx, "deleting instance in provider", "runner_name", instance.Name) + identifier := instance.ProviderID + if identifier == "" { + // provider did not return a provider ID? + // try with name + identifier = instance.Name + } + + baseParams, err := i.getProviderBaseParams() + if err != nil { + return fmt.Errorf("getting provider base params: %w", err) + } + + slog.DebugContext( + i.ctx, "calling delete instance on provider", + "runner_name", instance.Name, + "provider_id", identifier) + + deleteInstanceParams := common.DeleteInstanceParams{ + DeleteInstanceV011: common.DeleteInstanceV011Params{ + ProviderBaseParams: baseParams, + }, + } + if err := i.provider.DeleteInstance(i.ctx, identifier, deleteInstanceParams); err != nil { + return fmt.Errorf("deleting instance in provider: %w", err) + } + return nil +} + +func (i *instanceManager) consolidateState() error { + i.mux.Lock() + defer i.mux.Unlock() + + switch i.instance.Status { + case commonParams.InstancePendingCreate: + // kick off the creation process + if err := i.helper.SetInstanceStatus(i.instance.Name, commonParams.InstanceCreating, nil); err != nil { + return fmt.Errorf("setting instance status to creating: %w", err) + } + if err := i.handleCreateInstanceInProvider(i.instance); err != nil { + slog.ErrorContext(i.ctx, "creating instance in provider", "error", err) + if err := i.helper.SetInstanceStatus(i.instance.Name, commonParams.InstanceError, []byte(err.Error())); err != nil { + return fmt.Errorf("setting instance status to error: %w", err) + } + } + case commonParams.InstanceRunning: + // Nothing to do. The provider finished creating the instance. + case commonParams.InstancePendingDelete, commonParams.InstancePendingForceDelete: + // Remove or force remove the runner. When force remove is specified, we ignore + // IaaS errors. + if i.instance.Status == commonParams.InstancePendingDelete { + // invoke backoff sleep. We only do this for non forced removals, + // as force delete will always return, regardless of whether or not + // the remove operation succeeded in the provider. A user may decide + // to force delete a runner if GARM fails to remove it normally. + if canceled := i.sleepForBackOffOrCanceled(); canceled { + // the worker is shutting down. Return here. + return nil + } + } + + if err := i.helper.SetInstanceStatus(i.instance.Name, commonParams.InstanceDeleting, nil); err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + return nil + } + return fmt.Errorf("setting instance status to deleting: %w", err) + } + + if err := i.handleDeleteInstanceInProvider(i.instance); err != nil { + slog.ErrorContext(i.ctx, "deleting instance in provider", "error", err, "forced", i.instance.Status == commonParams.InstancePendingForceDelete) + if i.instance.Status == commonParams.InstancePendingDelete { + i.incrementBackOff() + if err := i.helper.SetInstanceStatus(i.instance.Name, commonParams.InstancePendingDelete, []byte(err.Error())); err != nil { + return fmt.Errorf("setting instance status to error: %w", err) + } + + return fmt.Errorf("error removing instance. Will retry: %w", err) + } + } + if err := i.helper.SetInstanceStatus(i.instance.Name, commonParams.InstanceDeleted, nil); err != nil { + return fmt.Errorf("setting instance status to deleted: %w", err) + } + case commonParams.InstanceError: + // Instance is in error state. We wait for next status or potentially retry + // spawning the instance with a backoff timer. + if err := i.helper.SetInstanceStatus(i.instance.Name, commonParams.InstancePendingDelete, nil); err != nil { + return fmt.Errorf("setting instance status to error: %w", err) + } + case commonParams.InstanceDeleted: + return ErrInstanceDeleted + } + return nil +} + +func (i *instanceManager) handleUpdate(update dbCommon.ChangePayload) error { + // We need a better way to handle instance state. Database updates may fail, and we + // end up with an inconsistent state between what we know about the instance and what + // is reflected in the database. + i.mux.Lock() + + if !i.running { + i.mux.Unlock() + return nil + } + + instance, ok := update.Payload.(params.Instance) + if !ok { + i.mux.Unlock() + return runnerErrors.NewBadRequestError("invalid payload type") + } + + i.instance = instance + if i.instance.Status == instance.Status { + // Nothing of interest happened. + i.mux.Unlock() + return nil + } + i.mux.Unlock() + return i.consolidateState() +} + +func (i *instanceManager) Update(instance dbCommon.ChangePayload) error { + i.mux.Lock() + defer i.mux.Unlock() + + if !i.running { + return runnerErrors.NewBadRequestError("instance manager is not running") + } + + timer := time.NewTimer(60 * time.Second) + defer timer.Stop() + + select { + case i.updates <- instance: + case <-i.quit: + return nil + case <-i.ctx.Done(): + return nil + case <-timer.C: + return fmt.Errorf("timeout while sending update to instance manager") + } + return nil +} + +func (i *instanceManager) loop() { + defer i.Stop() + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-i.quit: + return + case <-i.ctx.Done(): + return + case <-ticker.C: + if err := i.consolidateState(); err != nil { + if errors.Is(err, ErrInstanceDeleted) { + // instance had been deleted, we can exit the loop. + return + } + slog.ErrorContext(i.ctx, "consolidating state", "error", err) + } + case update, ok := <-i.updates: + if !ok { + return + } + if err := i.handleUpdate(update); err != nil { + if errors.Is(err, ErrInstanceDeleted) { + // instance had been deleted, we can exit the loop. + return + } + slog.ErrorContext(i.ctx, "handling update", "error", err) + } + } + } +} diff --git a/workers/provider/provider.go b/workers/provider/provider.go index 969a373d..07f65b26 100644 --- a/workers/provider/provider.go +++ b/workers/provider/provider.go @@ -6,19 +6,25 @@ import ( "log/slog" "sync" + commonParams "github.com/cloudbase/garm-provider-common/params" + + "github.com/cloudbase/garm/auth" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner/common" ) -func NewWorker(ctx context.Context, store dbCommon.Store, providers map[string]common.Provider) (*provider, error) { +func NewWorker(ctx context.Context, store dbCommon.Store, providers map[string]common.Provider, tokenGetter auth.InstanceTokenGetter) (*provider, error) { consumerID := "provider-worker" return &provider{ - ctx: context.Background(), - store: store, - consumerID: consumerID, - providers: providers, + ctx: context.Background(), + store: store, + consumerID: consumerID, + providers: providers, + tokenGetter: tokenGetter, + scaleSets: make(map[uint]params.ScaleSet), + runners: make(map[string]*instanceManager), }, nil } @@ -31,13 +37,14 @@ type provider struct { // We need to implement way to RPC from workers to controllers // and abstract that into something we can use to eventually // scale out. - store dbCommon.Store + store dbCommon.Store + tokenGetter auth.InstanceTokenGetter providers map[string]common.Provider // A cache of all scale sets kept updated by the watcher. // This helps us avoid a bunch of queries to the database. scaleSets map[uint]params.ScaleSet - runners map[string]params.Instance + runners map[string]*instanceManager mux sync.Mutex running bool @@ -45,9 +52,6 @@ type provider struct { } func (p *provider) loadAllScaleSets() error { - p.mux.Lock() - defer p.mux.Unlock() - scaleSets, err := p.store.ListAllScaleSets(p.ctx) if err != nil { return fmt.Errorf("fetching scale sets: %w", err) @@ -64,16 +68,46 @@ func (p *provider) loadAllScaleSets() error { // care about runners created by scale sets, but in the future, we will migrate // the pool manager to the same model. func (p *provider) loadAllRunners() error { - p.mux.Lock() - defer p.mux.Unlock() - runners, err := p.store.ListAllInstances(p.ctx) if err != nil { return fmt.Errorf("fetching runners: %w", err) } for _, runner := range runners { - p.runners[runner.Name] = runner + // Skip non scale set instances for now. This condition needs to be + // removed once we replace the current pool manager. + if runner.ScaleSetID == 0 { + continue + } + // Ignore runners in "creating" state. If we're just starting up and + // we find a runner in "creating" it was most likely interrupted while + // creating. It is unlikely that it is still usable. We allow the scale set + // worker to clean it up. It will eventually be marked as pending delete and + // this worker will get an update to clean up any resources left behing by + // an incomplete creation event. + if runner.Status == commonParams.InstanceCreating { + continue + } + scaleSet, ok := p.scaleSets[runner.ScaleSetID] + if !ok { + slog.ErrorContext(p.ctx, "scale set not found", "scale_set_id", runner.ScaleSetID) + continue + } + provider, ok := p.providers[scaleSet.ProviderName] + if !ok { + slog.ErrorContext(p.ctx, "provider not found", "provider_name", runner.ProviderName) + continue + } + instanceManager, err := NewInstanceManager( + p.ctx, runner, scaleSet, provider, p) + if err != nil { + return fmt.Errorf("creating instance manager: %w", err) + } + if err := instanceManager.Start(); err != nil { + return fmt.Errorf("starting instance manager: %w", err) + } + + p.runners[runner.Name] = instanceManager } return nil @@ -127,8 +161,12 @@ func (p *provider) loop() { defer p.Stop() for { select { - case payload := <-p.consumer.Watch(): - slog.InfoContext(p.ctx, "received payload", slog.Any("payload", payload)) + case payload, ok := <-p.consumer.Watch(): + if !ok { + slog.ErrorContext(p.ctx, "watcher channel closed") + return + } + slog.InfoContext(p.ctx, "received payload") go p.handleWatcherEvent(payload) case <-p.ctx.Done(): return @@ -172,6 +210,23 @@ func (p *provider) handleScaleSetEvent(event dbCommon.ChangePayload) { } } +func (p *provider) handleInstanceAdded(instance params.Instance) error { + scaleSet, ok := p.scaleSets[instance.ScaleSetID] + if !ok { + return fmt.Errorf("scale set not found for instance %s", instance.Name) + } + instanceManager, err := NewInstanceManager( + p.ctx, instance, scaleSet, p.providers[instance.ProviderName], p) + if err != nil { + return fmt.Errorf("creating instance manager: %w", err) + } + if err := instanceManager.Start(); err != nil { + return fmt.Errorf("starting instance manager: %w", err) + } + p.runners[instance.Name] = instanceManager + return nil +} + func (p *provider) handleInstanceEvent(event dbCommon.ChangePayload) { p.mux.Lock() defer p.mux.Unlock() @@ -183,11 +238,35 @@ func (p *provider) handleInstanceEvent(event dbCommon.ChangePayload) { } switch event.Operation { - case dbCommon.CreateOperation, dbCommon.UpdateOperation: - slog.DebugContext(p.ctx, "got create/update operation") - p.runners[instance.Name] = instance + case dbCommon.CreateOperation: + slog.DebugContext(p.ctx, "got create operation") + if err := p.handleInstanceAdded(instance); err != nil { + slog.ErrorContext(p.ctx, "failed to handle instance added", "error", err) + return + } + case dbCommon.UpdateOperation: + slog.DebugContext(p.ctx, "got update operation") + existingInstance, ok := p.runners[instance.Name] + if !ok { + if err := p.handleInstanceAdded(instance); err != nil { + slog.ErrorContext(p.ctx, "failed to handle instance added", "error", err) + return + } + } else { + if err := existingInstance.Update(event); err != nil { + slog.ErrorContext(p.ctx, "failed to update instance", "error", err) + return + } + } case dbCommon.DeleteOperation: slog.DebugContext(p.ctx, "got delete operation") + existingInstance, ok := p.runners[instance.Name] + if ok { + if err := existingInstance.Stop(); err != nil { + slog.ErrorContext(p.ctx, "failed to stop instance", "error", err) + return + } + } delete(p.runners, instance.Name) default: slog.ErrorContext(p.ctx, "invalid operation type", "operation_type", event.Operation) diff --git a/workers/provider/provider_helper.go b/workers/provider/provider_helper.go new file mode 100644 index 00000000..d420cdad --- /dev/null +++ b/workers/provider/provider_helper.go @@ -0,0 +1,81 @@ +package provider + +import ( + "fmt" + + "github.com/cloudbase/garm-provider-common/errors" + commonParams "github.com/cloudbase/garm-provider-common/params" + "github.com/cloudbase/garm/auth" + "github.com/cloudbase/garm/params" +) + +type providerHelper interface { + SetInstanceStatus(instanceName string, status commonParams.InstanceStatus, providerFault []byte) error + InstanceTokenGetter() auth.InstanceTokenGetter + updateArgsFromProviderInstance(instanceName string, providerInstance commonParams.ProviderInstance) (params.Instance, error) + GetControllerInfo() (params.ControllerInfo, error) + GetGithubEntity(entity params.GithubEntity) (params.GithubEntity, error) +} + +func (p *provider) updateArgsFromProviderInstance(instanceName string, providerInstance commonParams.ProviderInstance) (params.Instance, error) { + updateParams := params.UpdateInstanceParams{ + ProviderID: providerInstance.ProviderID, + OSName: providerInstance.OSName, + OSVersion: providerInstance.OSVersion, + Addresses: providerInstance.Addresses, + Status: providerInstance.Status, + ProviderFault: providerInstance.ProviderFault, + } + + updated, err := p.store.UpdateInstance(p.ctx, instanceName, updateParams) + if err != nil { + return params.Instance{}, fmt.Errorf("updating instance %s: %w", instanceName, err) + } + return updated, nil +} + +func (p *provider) GetControllerInfo() (params.ControllerInfo, error) { + p.mux.Lock() + defer p.mux.Unlock() + + info, err := p.store.ControllerInfo() + if err != nil { + return params.ControllerInfo{}, fmt.Errorf("getting controller info: %w", err) + } + + return info, nil +} + +func (p *provider) SetInstanceStatus(instanceName string, status commonParams.InstanceStatus, providerFault []byte) error { + p.mux.Lock() + defer p.mux.Unlock() + + if _, ok := p.runners[instanceName]; !ok { + return errors.ErrNotFound + } + + updateParams := params.UpdateInstanceParams{ + Status: status, + ProviderFault: providerFault, + } + + _, err := p.store.UpdateInstance(p.ctx, instanceName, updateParams) + if err != nil { + return fmt.Errorf("updating instance %s: %w", instanceName, err) + } + + return nil +} + +func (p *provider) InstanceTokenGetter() auth.InstanceTokenGetter { + return p.tokenGetter +} + +func (p *provider) GetGithubEntity(entity params.GithubEntity) (params.GithubEntity, error) { + ghEntity, err := p.store.GetGithubEntity(p.ctx, entity.EntityType, entity.ID) + if err != nil { + return params.GithubEntity{}, fmt.Errorf("getting github entity: %w", err) + } + + return ghEntity, nil +} diff --git a/workers/scaleset/controller.go b/workers/scaleset/controller.go index 809a2cba..24d1aad3 100644 --- a/workers/scaleset/controller.go +++ b/workers/scaleset/controller.go @@ -210,6 +210,7 @@ func (c *Controller) loop() { defer c.Stop() updateToolsTicker := time.NewTicker(common.PoolToolUpdateInterval) initialToolUpdate := make(chan struct{}, 1) + defer close(initialToolUpdate) go func() { slog.InfoContext(c.ctx, "running initial tool update") if err := c.updateTools(); err != nil { @@ -225,21 +226,21 @@ func (c *Controller) loop() { slog.InfoContext(c.ctx, "consumer channel closed") return } - slog.InfoContext(c.ctx, "received payload", slog.Any("payload", payload)) + slog.InfoContext(c.ctx, "received payload") go c.handleWatcherEvent(payload) case <-c.ctx.Done(): return - case _, ok := <-initialToolUpdate: - if ok { - // channel received the initial update slug. We can close it now. - close(initialToolUpdate) - } + case <-initialToolUpdate: case update, ok := <-c.statusUpdates: if !ok { return } go c.handleScaleSetStatusUpdates(update) - case <-updateToolsTicker.C: + case _, ok := <-updateToolsTicker.C: + if !ok { + slog.InfoContext(c.ctx, "update tools ticker closed") + return + } if err := c.updateTools(); err != nil { slog.With(slog.Any("error", err)).Error("failed to update tools") } diff --git a/workers/scaleset/scaleset.go b/workers/scaleset/scaleset.go index 012a41d1..ba7701d7 100644 --- a/workers/scaleset/scaleset.go +++ b/workers/scaleset/scaleset.go @@ -110,7 +110,7 @@ func (w *Worker) Start() (err error) { // mid boot before it reached the phase where it runs the metadtata, or // if it already failed). instanceState := commonParams.InstancePendingDelete - locking.Lock(instance.Name) + locking.Lock(instance.Name, w.consumerID) if instance.AgentID != 0 { if err := w.scaleSetCli.RemoveRunner(w.ctx, instance.AgentID); err != nil { // scale sets use JIT runners. This means that we create the runner in github @@ -119,9 +119,9 @@ func (w *Worker) Start() (err error) { if !errors.Is(err, runnerErrors.ErrNotFound) { if errors.Is(err, runnerErrors.ErrUnauthorized) { // we don't have access to remove the runner. This implies that our - // credentials may have expired. + // credentials may have expired or ar incorect. // - // TODO: we need to set the scale set as inactive and stop the listener (if any). + // TODO(gabriel-samfira): we need to set the scale set as inactive and stop the listener (if any). slog.ErrorContext(w.ctx, "error removing runner", "runner_name", instance.Name, "error", err) w.runners[instance.ID] = instance locking.Unlock(instance.Name, false) @@ -168,7 +168,6 @@ func (w *Worker) Start() (err error) { return fmt.Errorf("updating runner %s: %w", instance.Name, err) } } - locking.Unlock(instance.Name, false) } else if instance.Status == commonParams.InstanceDeleting { // Set the instance in deleting. It is assumed that the runner was already // removed from github either by github or by garm. Deleting status indicates @@ -309,6 +308,13 @@ func (w *Worker) handleInstanceEntityEvent(event dbCommon.ChangePayload) { case dbCommon.UpdateOperation: slog.DebugContext(w.ctx, "got update operation") w.mux.Lock() + if instance.Status == commonParams.InstanceDeleted { + if err := w.handleInstanceCleanup(instance); err != nil { + slog.ErrorContext(w.ctx, "error cleaning up instance", "instance_id", instance.ID, "error", err) + } + w.mux.Unlock() + return + } oldInstance, ok := w.runners[instance.ID] w.runners[instance.ID] = instance @@ -351,10 +357,10 @@ func (w *Worker) handleInstanceEntityEvent(event dbCommon.ChangePayload) { func (w *Worker) handleEvent(event dbCommon.ChangePayload) { switch event.EntityType { case dbCommon.ScaleSetEntityType: - slog.DebugContext(w.ctx, "got scaleset event", "event", event) + slog.DebugContext(w.ctx, "got scaleset event") w.handleScaleSetEvent(event) case dbCommon.InstanceEntityType: - slog.DebugContext(w.ctx, "got instance event", "event", event) + slog.DebugContext(w.ctx, "got instance event") w.handleInstanceEntityEvent(event) default: slog.DebugContext(w.ctx, "invalid entity type; ignoring", "entity_type", event.EntityType) @@ -509,12 +515,11 @@ func (w *Worker) handleScaleUp(target, current uint) { continue } - runnerDetails, err := w.scaleSetCli.GetRunner(w.ctx, jitConfig.Runner.ID) + _, err = w.scaleSetCli.GetRunner(w.ctx, jitConfig.Runner.ID) if err != nil { slog.ErrorContext(w.ctx, "error getting runner details", "error", err) continue } - slog.DebugContext(w.ctx, "runner details", "runner_details", runnerDetails) } } @@ -523,15 +528,42 @@ func (w *Worker) handleScaleDown(target, current uint) { if delta <= 0 { return } - w.mux.Lock() - defer w.mux.Unlock() removed := 0 + candidates := []params.Instance{} for _, runner := range w.runners { + locked, err := locking.TryLock(runner.Name, w.consumerID) + if err != nil || !locked { + slog.DebugContext(w.ctx, "runner is locked; skipping", "runner_name", runner.Name) + continue + } + switch runner.Status { + case commonParams.InstanceRunning: + if runner.RunnerStatus != params.RunnerActive { + candidates = append(candidates, runner) + } + case commonParams.InstancePendingDelete, commonParams.InstancePendingForceDelete, + commonParams.InstanceDeleting, commonParams.InstanceDeleted: + removed++ + locking.Unlock(runner.Name, true) + continue + default: + slog.DebugContext(w.ctx, "runner is not in a valid state; skipping", "runner_name", runner.Name, "runner_status", runner.Status) + locking.Unlock(runner.Name, false) + continue + } + locking.Unlock(runner.Name, false) + } + + if removed >= int(delta) { + return + } + + for _, runner := range candidates { if removed >= int(delta) { break } - locked, err := locking.TryLock(runner.Name) + locked, err := locking.TryLock(runner.Name, w.consumerID) if err != nil || !locked { slog.DebugContext(w.ctx, "runner is locked; skipping", "runner_name", runner.Name) continue @@ -539,7 +571,8 @@ func (w *Worker) handleScaleDown(target, current uint) { switch runner.Status { case commonParams.InstancePendingCreate, commonParams.InstanceRunning: - case commonParams.InstancePendingDelete, commonParams.InstancePendingForceDelete: + case commonParams.InstancePendingDelete, commonParams.InstancePendingForceDelete, + commonParams.InstanceDeleting, commonParams.InstanceDeleted: removed++ locking.Unlock(runner.Name, true) continue @@ -613,8 +646,6 @@ func (w *Worker) handleAutoScale() { slog.ErrorContext(w.ctx, "error cleaning up instance", "instance_id", instance.ID, "error", err) } } - w.mux.Unlock() - var desiredRunners uint if w.scaleSet.DesiredRunnerCount > 0 { desiredRunners = uint(w.scaleSet.DesiredRunnerCount) @@ -624,6 +655,7 @@ func (w *Worker) handleAutoScale() { currentRunners := uint(len(w.runners)) if currentRunners == targetRunners { lastMsgDebugLog("desired runner count reached", targetRunners, currentRunners) + w.mux.Unlock() continue } @@ -634,6 +666,7 @@ func (w *Worker) handleAutoScale() { lastMsgDebugLog("attempting to scale down", targetRunners, currentRunners) w.handleScaleDown(targetRunners, currentRunners) } + w.mux.Unlock() } } } diff --git a/workers/scaleset/scaleset_helper.go b/workers/scaleset/scaleset_helper.go index e6ae9197..0cf01025 100644 --- a/workers/scaleset/scaleset_helper.go +++ b/workers/scaleset/scaleset_helper.go @@ -35,7 +35,10 @@ func (w *Worker) SetLastMessageID(id int64) error { // HandleJobCompleted handles a job completed message. If a job had a runner // assigned and was not canceled before it had a chance to run, then we mark // that runner as pending_delete. -func (w *Worker) HandleJobsCompleted(jobs []params.ScaleSetJobMessage) error { +func (w *Worker) HandleJobsCompleted(jobs []params.ScaleSetJobMessage) (err error) { + slog.DebugContext(w.ctx, "handling job completed", "jobs", jobs) + defer slog.DebugContext(w.ctx, "finished handling job completed", "jobs", jobs, "error", err) + for _, job := range jobs { if job.RunnerName == "" { // This job was not assigned to a runner, so we can skip it. @@ -47,7 +50,7 @@ func (w *Worker) HandleJobsCompleted(jobs []params.ScaleSetJobMessage) error { RunnerStatus: params.RunnerTerminated, } - locking.Lock(job.RunnerName) + locking.Lock(job.RunnerName, w.consumerID) _, err := w.store.UpdateInstance(w.ctx, job.RunnerName, runnerUpdateParams) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { @@ -62,7 +65,9 @@ func (w *Worker) HandleJobsCompleted(jobs []params.ScaleSetJobMessage) error { // HandleJobStarted updates the runners from idle to active in the DB and // assigns the job to them. -func (w *Worker) HandleJobsStarted(jobs []params.ScaleSetJobMessage) error { +func (w *Worker) HandleJobsStarted(jobs []params.ScaleSetJobMessage) (err error) { + slog.DebugContext(w.ctx, "handling job started", "jobs", jobs) + defer slog.DebugContext(w.ctx, "finished handling job started", "jobs", jobs, "error", err) for _, job := range jobs { if job.RunnerName == "" { // This should not happen, but just in case. @@ -73,7 +78,7 @@ func (w *Worker) HandleJobsStarted(jobs []params.ScaleSetJobMessage) error { RunnerStatus: params.RunnerActive, } - locking.Lock(job.RunnerName) + locking.Lock(job.RunnerName, w.consumerID) _, err := w.store.UpdateInstance(w.ctx, job.RunnerName, updateParams) if err != nil { if errors.Is(err, runnerErrors.ErrNotFound) { diff --git a/workers/scaleset/scaleset_listener.go b/workers/scaleset/scaleset_listener.go index 43a2e5c1..9fbf9a7e 100644 --- a/workers/scaleset/scaleset_listener.go +++ b/workers/scaleset/scaleset_listener.go @@ -232,6 +232,7 @@ func (l *scaleSetListener) loop() { if !msg.IsNil() { // Longpoll returns after 50 seconds. If no message arrives during that interval // we get a nil message. We can simply ignore it and continue. + slog.DebugContext(l.ctx, "handling message", "message_id", msg.MessageID) l.handleSessionMessage(msg) } }