diff --git a/cache/cache_test.go b/cache/cache_test.go index 43b15953..7a977394 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -55,7 +55,7 @@ func (c *CacheTestSuite) TestSetCacheWorks() { c.Require().Len(githubToolsCache.entities, 0) SetGithubToolsCache(c.entity, tools) c.Require().Len(githubToolsCache.entities, 1) - cachedTools, ok := GetGithubToolsCache(c.entity) + cachedTools, ok := GetGithubToolsCache(c.entity.ID) c.Require().True(ok) c.Require().Len(cachedTools, 1) c.Require().Equal(tools[0].GetDownloadURL(), cachedTools[0].GetDownloadURL()) @@ -72,11 +72,11 @@ func (c *CacheTestSuite) TestTimedOutToolsCache() { c.Require().Len(githubToolsCache.entities, 0) SetGithubToolsCache(c.entity, tools) c.Require().Len(githubToolsCache.entities, 1) - entity := githubToolsCache.entities[c.entity.String()] + entity := githubToolsCache.entities[c.entity.ID] entity.updatedAt = entity.updatedAt.Add(-2 * time.Hour) - githubToolsCache.entities[c.entity.String()] = entity + githubToolsCache.entities[c.entity.ID] = entity - cachedTools, ok := GetGithubToolsCache(c.entity) + cachedTools, ok := GetGithubToolsCache(c.entity.ID) c.Require().False(ok) c.Require().Nil(cachedTools) } @@ -84,7 +84,7 @@ func (c *CacheTestSuite) TestTimedOutToolsCache() { func (c *CacheTestSuite) TestGetInexistentCache() { c.Require().NotNil(githubToolsCache) c.Require().Len(githubToolsCache.entities, 0) - cachedTools, ok := GetGithubToolsCache(c.entity) + cachedTools, ok := GetGithubToolsCache(c.entity.ID) c.Require().False(ok) c.Require().Nil(cachedTools) } diff --git a/cache/credentials_cache.go b/cache/credentials_cache.go index 731d1640..092d2e90 100644 --- a/cache/credentials_cache.go +++ b/cache/credentials_cache.go @@ -26,6 +26,7 @@ func (g *GithubCredentials) SetCredentials(credentials params.GithubCredentials) defer g.mux.Unlock() g.cache[credentials.ID] = credentials + UpdateCredentialsInAffectedEntities(credentials) } func (g *GithubCredentials) GetCredentials(id uint) (params.GithubCredentials, bool) { diff --git a/cache/entity_cache.go b/cache/entity_cache.go index 3e3a1337..0c549498 100644 --- a/cache/entity_cache.go +++ b/cache/entity_cache.go @@ -1,7 +1,6 @@ package cache import ( - "log/slog" "sync" "github.com/cloudbase/garm/params" @@ -28,15 +27,24 @@ type EntityCache struct { entities map[string]EntityItem } +func (e *EntityCache) UpdateCredentialsInAffectedEntities(creds params.GithubCredentials) { + e.mux.Lock() + defer e.mux.Unlock() + + for entityID, cache := range e.entities { + if cache.Entity.Credentials.ID == creds.ID { + cache.Entity.Credentials = creds + e.entities[entityID] = cache + } + } +} + func (e *EntityCache) GetEntity(entityID string) (params.GithubEntity, bool) { e.mux.Lock() defer e.mux.Unlock() if cache, ok := e.entities[entityID]; ok { - // Updating specific credential details will not update entity cache which - // uses those credentials. - // Entity credentials in the cache are only updated if you swap the creds - // on the entity. We get the updated credentials from the credentials cache. + // Get the credentials from the credentials cache. creds, ok := GetGithubCredentials(cache.Entity.Credentials.ID) if ok { cache.Entity.Credentials = creds @@ -173,7 +181,6 @@ func (e *EntityCache) FindPoolsMatchingAllTags(entityID string, tags []string) [ if cache, ok := e.entities[entityID]; ok { var pools []params.Pool - slog.Debug("Finding pools matching all tags", "entityID", entityID, "tags", tags, "pools", cache.Pools) for _, pool := range cache.Pools { if pool.HasRequiredLabels(tags) { pools = append(pools, pool) @@ -212,6 +219,35 @@ func (e *EntityCache) GetEntityScaleSets(entityID string) []params.ScaleSet { return nil } +func (e *EntityCache) GetEntitiesUsingGredentials(credsID uint) []params.GithubEntity { + e.mux.Lock() + defer e.mux.Unlock() + + var entities []params.GithubEntity + for _, cache := range e.entities { + if cache.Entity.Credentials.ID == credsID { + entities = append(entities, cache.Entity) + } + } + return entities +} + +func (e *EntityCache) GetAllEntities() []params.GithubEntity { + e.mux.Lock() + defer e.mux.Unlock() + + var entities []params.GithubEntity + for _, cache := range e.entities { + // Get the credentials from the credentials cache. + creds, ok := GetGithubCredentials(cache.Entity.Credentials.ID) + if ok { + cache.Entity.Credentials = creds + } + entities = append(entities, cache.Entity) + } + return entities +} + func GetEntity(entityID string) (params.GithubEntity, bool) { return entityCache.GetEntity(entityID) } @@ -267,3 +303,15 @@ func GetEntityPools(entityID string) []params.Pool { func GetEntityScaleSets(entityID string) []params.ScaleSet { return entityCache.GetEntityScaleSets(entityID) } + +func UpdateCredentialsInAffectedEntities(creds params.GithubCredentials) { + entityCache.UpdateCredentialsInAffectedEntities(creds) +} + +func GetEntitiesUsingGredentials(credsID uint) []params.GithubEntity { + return entityCache.GetEntitiesUsingGredentials(credsID) +} + +func GetAllEntities() []params.GithubEntity { + return entityCache.GetAllEntities() +} diff --git a/cache/tools_cache.go b/cache/tools_cache.go index 1960de38..233de2c1 100644 --- a/cache/tools_cache.go +++ b/cache/tools_cache.go @@ -29,14 +29,14 @@ type GithubToolsCache struct { entities map[string]GithubEntityTools } -func (g *GithubToolsCache) Get(entity params.GithubEntity) ([]commonParams.RunnerApplicationDownload, bool) { +func (g *GithubToolsCache) Get(entityID string) ([]commonParams.RunnerApplicationDownload, bool) { g.mux.Lock() defer g.mux.Unlock() - if cache, ok := g.entities[entity.String()]; ok { + if cache, ok := g.entities[entityID]; ok { if time.Since(cache.updatedAt) > 1*time.Hour { // Stale cache, remove it. - delete(g.entities, entity.String()) + delete(g.entities, entityID) return nil, false } return cache.tools, true @@ -48,7 +48,7 @@ func (g *GithubToolsCache) Set(entity params.GithubEntity, tools []commonParams. g.mux.Lock() defer g.mux.Unlock() - g.entities[entity.String()] = GithubEntityTools{ + g.entities[entity.ID] = GithubEntityTools{ updatedAt: time.Now(), entity: entity, tools: tools, @@ -59,6 +59,6 @@ func SetGithubToolsCache(entity params.GithubEntity, tools []commonParams.Runner githubToolsCache.Set(entity, tools) } -func GetGithubToolsCache(entity params.GithubEntity) ([]commonParams.RunnerApplicationDownload, bool) { - return githubToolsCache.Get(entity) +func GetGithubToolsCache(entityID string) ([]commonParams.RunnerApplicationDownload, bool) { + return githubToolsCache.Get(entityID) } diff --git a/cmd/garm-cli/cmd/scalesets.go b/cmd/garm-cli/cmd/scalesets.go index 79486a0e..920b60cf 100644 --- a/cmd/garm-cli/cmd/scalesets.go +++ b/cmd/garm-cli/cmd/scalesets.go @@ -436,7 +436,7 @@ func formatScaleSets(scaleSets []params.ScaleSet) { return } t := table.NewWriter() - header := table.Row{"ID", "Scale Set Name", "Image", "Flavor", "Belongs to", "Level", "Enabled", "Runner Prefix", "Provider"} + header := table.Row{"ID", "Scale Set Name", "Image", "Flavor", "Belongs to", "Level", "Runner Group", "Enabled", "Runner Prefix", "Provider"} t.AppendHeader(header) for _, scaleSet := range scaleSets { @@ -454,7 +454,7 @@ func formatScaleSets(scaleSets []params.ScaleSet) { belongsTo = scaleSet.EnterpriseName level = entityTypeEnterprise } - t.AppendRow(table.Row{scaleSet.ID, scaleSet.Name, scaleSet.Image, scaleSet.Flavor, belongsTo, level, scaleSet.Enabled, scaleSet.GetRunnerPrefix(), scaleSet.ProviderName}) + t.AppendRow(table.Row{scaleSet.ID, scaleSet.Name, scaleSet.Image, scaleSet.Flavor, belongsTo, level, scaleSet.GitHubRunnerGroup, scaleSet.Enabled, scaleSet.GetRunnerPrefix(), scaleSet.ProviderName}) t.AppendSeparator() } fmt.Println(t.Render()) diff --git a/go.mod b/go.mod index db57a68b..a0b3901f 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( gorm.io/datatypes v1.2.5 gorm.io/driver/mysql v1.5.7 gorm.io/driver/sqlite v1.5.7 - gorm.io/gorm v1.26.0 + gorm.io/gorm v1.26.1 ) require ( diff --git a/go.sum b/go.sum index 5ca7575d..3c9af9bb 100644 --- a/go.sum +++ b/go.sum @@ -229,5 +229,5 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa gorm.io/driver/sqlserver v1.5.4 h1:xA+Y1KDNspv79q43bPyjDMUgHoYHLhXYmdFcYPobg8g= gorm.io/driver/sqlserver v1.5.4/go.mod h1:+frZ/qYmuna11zHPlh5oc2O6ZA/lS88Keb0XSH1Zh/g= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.26.0 h1:9lqQVPG5aNNS6AyHdRiwScAVnXHg/L/Srzx55G5fOgs= -gorm.io/gorm v1.26.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.26.1 h1:ghB2gUI9FkS46luZtn6DLZ0f6ooBJ5IbVej2ENFDjRw= +gorm.io/gorm v1.26.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/vendor/gorm.io/gorm/gorm.go b/vendor/gorm.io/gorm/gorm.go index d253736d..63a28b37 100644 --- a/vendor/gorm.io/gorm/gorm.go +++ b/vendor/gorm.io/gorm/gorm.go @@ -110,8 +110,6 @@ type DB struct { type Session struct { DryRun bool PrepareStmt bool - PrepareStmtMaxSize int - PrepareStmtTTL time.Duration NewDB bool Initialized bool SkipHooks bool @@ -275,7 +273,7 @@ func (db *DB) Session(config *Session) *DB { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt = v.(*PreparedStmtDB) } else { - preparedStmt = NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL) + preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL) db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } diff --git a/vendor/gorm.io/gorm/internal/stmt_store/stmt_store.go b/vendor/gorm.io/gorm/internal/stmt_store/stmt_store.go index 7068419d..a82b2cf5 100644 --- a/vendor/gorm.io/gorm/internal/stmt_store/stmt_store.go +++ b/vendor/gorm.io/gorm/internal/stmt_store/stmt_store.go @@ -3,6 +3,7 @@ package stmt_store import ( "context" "database/sql" + "math" "sync" "time" @@ -73,7 +74,7 @@ type Store interface { // the cache can theoretically store as many elements as possible. // (1 << 63) - 1 is the maximum value that an int64 type can represent. const ( - defaultMaxSize = (1 << 63) - 1 + defaultMaxSize = math.MaxInt // defaultTTL defines the default time-to-live (TTL) for each cache entry. // When the TTL for cache entries is not specified, each cache entry will expire after 24 hours. defaultTTL = time.Hour * 24 diff --git a/vendor/modules.txt b/vendor/modules.txt index 9ca8e528..5cb70bb1 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -362,7 +362,7 @@ gorm.io/driver/mysql # gorm.io/driver/sqlite v1.5.7 ## explicit; go 1.20 gorm.io/driver/sqlite -# gorm.io/gorm v1.26.0 +# gorm.io/gorm v1.26.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/workers/cache/cache.go b/workers/cache/cache.go index 3973e7c7..d19bbbaf 100644 --- a/workers/cache/cache.go +++ b/workers/cache/cache.go @@ -20,10 +20,11 @@ func NewWorker(ctx context.Context, store common.Store) *Worker { slog.Any("worker", consumerID)) return &Worker{ - ctx: ctx, - store: store, - consumerID: consumerID, - quit: make(chan struct{}), + ctx: ctx, + store: store, + consumerID: consumerID, + toolsWorkes: make(map[string]*toolsUpdater), + quit: make(chan struct{}), } } @@ -31,8 +32,9 @@ type Worker struct { ctx context.Context consumerID string - consumer common.Consumer - store common.Store + consumer common.Consumer + store common.Store + toolsWorkes map[string]*toolsUpdater mux sync.Mutex running bool @@ -110,6 +112,13 @@ func (w *Worker) loadAllEntities() error { } } + for _, entity := range cache.GetAllEntities() { + worker := newToolsUpdater(w.ctx, entity) + if err := worker.Start(); err != nil { + return fmt.Errorf("starting tools updater: %w", err) + } + w.toolsWorkes[entity.ID] = worker + } return nil } @@ -181,6 +190,11 @@ func (w *Worker) Stop() error { return nil } + for _, worker := range w.toolsWorkes { + if err := worker.Stop(); err != nil { + slog.ErrorContext(w.ctx, "stopping tools updater", "error", err) + } + } w.consumer.Close() w.running = false close(w.quit) @@ -195,9 +209,31 @@ func (w *Worker) handleEntityEvent(entityGetter params.EntityGetter, op common.O } switch op { case common.CreateOperation, common.UpdateOperation: + old, hasOld := cache.GetEntity(entity.ID) cache.SetEntity(entity) + worker, ok := w.toolsWorkes[entity.ID] + if !ok { + worker = newToolsUpdater(w.ctx, entity) + if err := worker.Start(); err != nil { + slog.ErrorContext(w.ctx, "starting tools updater", "error", err) + return + } + w.toolsWorkes[entity.ID] = worker + } else if hasOld { + // probably an update operation + if old.Credentials.ID != entity.Credentials.ID { + worker.Reset() + } + } case common.DeleteOperation: cache.DeleteEntity(entity.ID) + worker, ok := w.toolsWorkes[entity.ID] + if ok { + if err := worker.Stop(); err != nil { + slog.ErrorContext(w.ctx, "stopping tools updater", "error", err) + } + delete(w.toolsWorkes, entity.ID) + } } } @@ -291,13 +327,20 @@ func (w *Worker) handleCredentialsEvent(event common.ChangePayload) { switch event.Operation { case common.CreateOperation, common.UpdateOperation: cache.SetGithubCredentials(credentials) + entities := cache.GetEntitiesUsingGredentials(credentials.ID) + for _, entity := range entities { + worker, ok := w.toolsWorkes[entity.ID] + if ok { + worker.Reset() + } + } case common.DeleteOperation: cache.DeleteGithubCredentials(credentials.ID) } } func (w *Worker) handleEvent(event common.ChangePayload) { - slog.DebugContext(w.ctx, "handling event", "event", event) + slog.DebugContext(w.ctx, "handling event", "event_entity_type", event.EntityType, "event_operation", event.Operation) switch event.EntityType { case common.PoolEntityType: w.handlePoolEvent(event) diff --git a/workers/cache/tool_cache.go b/workers/cache/tool_cache.go new file mode 100644 index 00000000..6133580d --- /dev/null +++ b/workers/cache/tool_cache.go @@ -0,0 +1,170 @@ +package cache + +import ( + "context" + "crypto/rand" + "fmt" + "log/slog" + "math/big" + "sync" + "time" + + commonParams "github.com/cloudbase/garm-provider-common/params" + "github.com/cloudbase/garm/cache" + "github.com/cloudbase/garm/params" + garmUtil "github.com/cloudbase/garm/util" + "github.com/cloudbase/garm/util/github" +) + +func newToolsUpdater(ctx context.Context, entity params.GithubEntity) *toolsUpdater { + return &toolsUpdater{ + ctx: ctx, + entity: entity, + quit: make(chan struct{}), + } +} + +type toolsUpdater struct { + ctx context.Context + + entity params.GithubEntity + tools []commonParams.RunnerApplicationDownload + lastUpdate time.Time + + mux sync.Mutex + running bool + quit chan struct{} + + reset chan struct{} +} + +func (t *toolsUpdater) Start() error { + t.mux.Lock() + defer t.mux.Unlock() + + if t.running { + return nil + } + + t.running = true + t.quit = make(chan struct{}) + + go t.loop() + return nil +} + +func (t *toolsUpdater) Stop() error { + t.mux.Lock() + defer t.mux.Unlock() + + if !t.running { + return nil + } + + t.running = false + close(t.quit) + + return nil +} + +func (t *toolsUpdater) updateTools() error { + slog.DebugContext(t.ctx, "updating tools", "entity", t.entity.String()) + entity, ok := cache.GetEntity(t.entity.ID) + if !ok { + return fmt.Errorf("getting entity from cache: %s", t.entity.ID) + } + ghCli, err := github.Client(t.ctx, entity) + if err != nil { + return fmt.Errorf("getting github client: %w", err) + } + + tools, err := garmUtil.FetchTools(t.ctx, ghCli) + if err != nil { + return fmt.Errorf("fetching tools: %w", err) + } + t.lastUpdate = time.Now().UTC() + t.tools = tools + + slog.DebugContext(t.ctx, "updating tools cache", "entity", t.entity.String()) + cache.SetGithubToolsCache(entity, tools) + return nil +} + +func (t *toolsUpdater) Reset() { + t.mux.Lock() + defer t.mux.Unlock() + + if !t.running { + return + } + + if t.reset != nil { + close(t.reset) + t.reset = nil + } +} + +func (t *toolsUpdater) loop() { + defer t.Stop() + + // add some jitter. When spinning up multiple entities, we add + // jitter to prevent stampeeding herd. + randInt, err := rand.Int(rand.Reader, big.NewInt(3000)) + if err != nil { + randInt = big.NewInt(0) + } + time.Sleep(time.Duration(randInt.Int64()) * time.Millisecond) + + var resetTime time.Time + now := time.Now().UTC() + if now.After(t.lastUpdate.Add(40 * time.Minute)) { + if err := t.updateTools(); err != nil { + slog.ErrorContext(t.ctx, "initial tools update error", "error", err) + resetTime = now.Add(5 * time.Minute) + slog.ErrorContext(t.ctx, "initial tools update error", "error", err) + } else { + // Tools are usually valid for 1 hour. + resetTime = t.lastUpdate.Add(40 * time.Minute) + } + } + + for { + if t.reset == nil { + t.reset = make(chan struct{}) + } + // add some jitter + randInt, err := rand.Int(rand.Reader, big.NewInt(300)) + if err != nil { + randInt = big.NewInt(0) + } + timer := time.NewTimer(resetTime.Sub(now) + time.Duration(randInt.Int64())*time.Second) + select { + case <-t.quit: + slog.DebugContext(t.ctx, "stopping tools updater") + timer.Stop() + return + case <-timer.C: + slog.DebugContext(t.ctx, "updating tools") + now = time.Now().UTC() + if err := t.updateTools(); err == nil { + slog.ErrorContext(t.ctx, "updating tools", "error", err) + resetTime = now.Add(5 * time.Minute) + } else { + // Tools are usually valid for 1 hour. + resetTime = t.lastUpdate.Add(40 * time.Minute) + } + case <-t.reset: + slog.DebugContext(t.ctx, "resetting tools updater") + timer.Stop() + now = time.Now().UTC() + if err := t.updateTools(); err != nil { + slog.ErrorContext(t.ctx, "updating tools", "error", err) + resetTime = now.Add(5 * time.Minute) + } else { + // Tools are usually valid for 1 hour. + resetTime = t.lastUpdate.Add(40 * time.Minute) + } + } + timer.Stop() + } +} diff --git a/workers/provider/instance_manager.go b/workers/provider/instance_manager.go index dcb10257..37680cd0 100644 --- a/workers/provider/instance_manager.go +++ b/workers/provider/instance_manager.go @@ -148,7 +148,7 @@ func (i *instanceManager) handleCreateInstanceInProvider(instance params.Instanc if err != nil { return fmt.Errorf("creating instance token: %w", err) } - tools, ok := cache.GetGithubToolsCache(entity) + tools, ok := cache.GetGithubToolsCache(entity.ID) if !ok { return fmt.Errorf("tools not found in cache for entity %s", entity.String()) } diff --git a/workers/scaleset/controller.go b/workers/scaleset/controller.go index b6d61f54..5d00471f 100644 --- a/workers/scaleset/controller.go +++ b/workers/scaleset/controller.go @@ -2,7 +2,6 @@ package scaleset import ( "context" - "errors" "fmt" "log/slog" "sync" @@ -10,8 +9,6 @@ import ( "golang.org/x/sync/errgroup" - runnerErrors "github.com/cloudbase/garm-provider-common/errors" - "github.com/cloudbase/garm/cache" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/params" @@ -29,7 +26,7 @@ const ( ) func NewController(ctx context.Context, store dbCommon.Store, entity params.GithubEntity, providers map[string]common.Provider) (*Controller, error) { - consumerID := fmt.Sprintf("scaleset-worker-%s", entity.String()) + consumerID := fmt.Sprintf("scaleset-controller-%s", entity.String()) ctx = garmUtil.WithSlogContext( ctx, @@ -76,8 +73,7 @@ type Controller struct { store dbCommon.Store providers map[string]common.Provider - ghCli common.GithubClient - forgeCredsAreValid bool + ghCli common.GithubClient mux sync.Mutex running bool @@ -163,29 +159,6 @@ func (c *Controller) Stop() error { return nil } -func (c *Controller) updateTools() error { - c.mux.Lock() - defer c.mux.Unlock() - - slog.DebugContext(c.ctx, "updating tools for entity", "entity", c.Entity.String()) - - tools, err := garmUtil.FetchTools(c.ctx, c.ghCli) - if err != nil { - slog.With(slog.Any("error", err)).ErrorContext( - c.ctx, "failed to update tools for entity", "entity", c.Entity.String()) - if errors.Is(err, runnerErrors.ErrUnauthorized) { - // nolint:golangci-lint,godox - // TODO: block all scale sets - c.forgeCredsAreValid = false - } - return fmt.Errorf("failed to update tools for entity %s: %w", c.Entity.String(), err) - } - slog.DebugContext(c.ctx, "tools successfully updated for entity", "entity", c.Entity.String()) - c.forgeCredsAreValid = true - cache.SetGithubToolsCache(c.Entity, tools) - return nil -} - // consolidateRunnerState will list all runners on GitHub for this entity, sort by // pool or scale set and pass those runners to the appropriate worker. The worker will // then have the responsibility to cross check the runners from github with what it @@ -259,23 +232,10 @@ func (c *Controller) waitForErrorGroupOrContextCancelled(g *errgroup.Group) erro func (c *Controller) loop() { defer c.Stop() - updateToolsTicker := time.NewTicker(common.PoolToolUpdateInterval) - defer updateToolsTicker.Stop() consilidateTicker := time.NewTicker(common.PoolReapTimeoutInterval) defer consilidateTicker.Stop() - initialToolUpdate := make(chan struct{}, 1) - defer close(initialToolUpdate) - - go func() { - slog.InfoContext(c.ctx, "running initial tool update") - if err := c.updateTools(); err != nil { - slog.With(slog.Any("error", err)).Error("failed to update tools") - } - initialToolUpdate <- struct{}{} - }() - for { select { case payload, ok := <-c.consumer.Watch(): @@ -287,25 +247,6 @@ func (c *Controller) loop() { go c.handleWatcherEvent(payload) case <-c.ctx.Done(): return - case <-initialToolUpdate: - case _, ok := <-updateToolsTicker.C: - if !ok { - slog.InfoContext(c.ctx, "update tools ticker closed") - return - } - validCreds := c.forgeCredsAreValid - if err := c.updateTools(); err != nil { - if err := c.store.AddEntityEvent(c.ctx, c.Entity, params.StatusEvent, params.EventError, fmt.Sprintf("failed to update tools: %q", err.Error()), 30); err != nil { - slog.With(slog.Any("error", err)).Error("failed to add entity event") - } - slog.With(slog.Any("error", err)).Error("failed to update tools") - continue - } - if validCreds != c.forgeCredsAreValid && c.forgeCredsAreValid { - if err := c.store.AddEntityEvent(c.ctx, c.Entity, params.StatusEvent, params.EventInfo, "tools updated successfully", 30); err != nil { - slog.With(slog.Any("error", err)).Error("failed to add entity event") - } - } case _, ok := <-consilidateTicker.C: if !ok { slog.InfoContext(c.ctx, "consolidate ticker closed") diff --git a/workers/scaleset/controller_watcher.go b/workers/scaleset/controller_watcher.go index 9d94c794..6702e0f0 100644 --- a/workers/scaleset/controller_watcher.go +++ b/workers/scaleset/controller_watcher.go @@ -31,7 +31,7 @@ func (c *Controller) handleWatcherEvent(event dbCommon.ChangePayload) { func (c *Controller) handleScaleSet(event dbCommon.ChangePayload) { scaleSet, ok := event.Payload.(params.ScaleSet) if !ok { - slog.ErrorContext(c.ctx, "invalid payload for entity type", "entity_type", event.EntityType, "payload", event.Payload) + slog.ErrorContext(c.ctx, "invalid scale set payload for entity type", "entity_type", event.EntityType, "payload", event) return } @@ -131,7 +131,7 @@ func (c *Controller) handleScaleSetUpdateOperation(sSet params.ScaleSet) error { func (c *Controller) handleCredentialsEvent(event dbCommon.ChangePayload) { credentials, ok := event.Payload.(params.GithubCredentials) if !ok { - slog.ErrorContext(c.ctx, "invalid payload for entity type", "entity_type", event.EntityType, "payload", event.Payload) + slog.ErrorContext(c.ctx, "invalid credentials payload for entity type", "entity_type", event.EntityType, "payload", event) return } @@ -158,9 +158,24 @@ func (c *Controller) handleCredentialsEvent(event dbCommon.ChangePayload) { } func (c *Controller) handleEntityEvent(event dbCommon.ChangePayload) { - entity, ok := event.Payload.(params.GithubEntity) + var entityGetter params.EntityGetter + var ok bool + switch c.Entity.EntityType { + case params.GithubEntityTypeRepository: + entityGetter, ok = event.Payload.(params.Repository) + case params.GithubEntityTypeOrganization: + entityGetter, ok = event.Payload.(params.Organization) + case params.GithubEntityTypeEnterprise: + entityGetter, ok = event.Payload.(params.Enterprise) + } if !ok { - slog.ErrorContext(c.ctx, "invalid payload for entity type", "entity_type", event.EntityType, "payload", event.Payload) + slog.ErrorContext(c.ctx, "invalid entity payload for entity type", "entity_type", event.EntityType, "payload", event) + return + } + + entity, err := entityGetter.GetEntity() + if err != nil { + slog.ErrorContext(c.ctx, "invalid GitHub entity payload for entity type", "entity_type", event.EntityType, "payload", event) return } diff --git a/workers/scaleset/scaleset.go b/workers/scaleset/scaleset.go index 660bbe97..73d08c98 100644 --- a/workers/scaleset/scaleset.go +++ b/workers/scaleset/scaleset.go @@ -11,6 +11,7 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/cache" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/locking" @@ -769,6 +770,24 @@ func (w *Worker) handleScaleUp(target, current uint) { } } +func (w *Worker) waitForToolsOrCancel() (hasTools, stopped bool) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + select { + case <-ticker.C: + entity, err := w.scaleSet.GetEntity() + if err != nil { + slog.ErrorContext(w.ctx, "error getting entity", "error", err) + } + _, ok := cache.GetGithubToolsCache(entity.ID) + return ok, false + case <-w.quit: + return false, true + case <-w.ctx.Done(): + return false, true + } +} + func (w *Worker) handleScaleDown(target, current uint) { delta := current - target if delta <= 0 { @@ -880,7 +899,19 @@ func (w *Worker) handleAutoScale() { lastMsg = msg } } + for { + hasTools, stopped := w.waitForToolsOrCancel() + if stopped { + slog.DebugContext(w.ctx, "worker is stopped; exiting handleAutoScale") + return + } + + if !hasTools { + time.Sleep(1 * time.Second) + continue + } + select { case <-w.quit: return diff --git a/workers/scaleset/scaleset_listener.go b/workers/scaleset/scaleset_listener.go index 9f2087d7..7a521e46 100644 --- a/workers/scaleset/scaleset_listener.go +++ b/workers/scaleset/scaleset_listener.go @@ -109,7 +109,7 @@ func (l *scaleSetListener) handleSessionMessage(msg params.RunnerScaleSetMessage if err != nil { slog.ErrorContext(l.ctx, "getting jobs from body", "error", err) } - slog.InfoContext(l.ctx, "handling message", "message", msg, "body", body) + if msg.MessageID < l.lastMessageID { slog.DebugContext(l.ctx, "message is older than last message, ignoring") return