diff --git a/cache/credentials_cache.go b/cache/credentials_cache.go new file mode 100644 index 00000000..731d1640 --- /dev/null +++ b/cache/credentials_cache.go @@ -0,0 +1,73 @@ +package cache + +import ( + "sync" + + "github.com/cloudbase/garm/params" +) + +var credentialsCache *GithubCredentials + +func init() { + ghCredentialsCache := &GithubCredentials{ + cache: make(map[uint]params.GithubCredentials), + } + credentialsCache = ghCredentialsCache +} + +type GithubCredentials struct { + mux sync.Mutex + + cache map[uint]params.GithubCredentials +} + +func (g *GithubCredentials) SetCredentials(credentials params.GithubCredentials) { + g.mux.Lock() + defer g.mux.Unlock() + + g.cache[credentials.ID] = credentials +} + +func (g *GithubCredentials) GetCredentials(id uint) (params.GithubCredentials, bool) { + g.mux.Lock() + defer g.mux.Unlock() + + if creds, ok := g.cache[id]; ok { + return creds, true + } + return params.GithubCredentials{}, false +} + +func (g *GithubCredentials) DeleteCredentials(id uint) { + g.mux.Lock() + defer g.mux.Unlock() + + delete(g.cache, id) +} + +func (g *GithubCredentials) GetAllCredentials() []params.GithubCredentials { + g.mux.Lock() + defer g.mux.Unlock() + + creds := make([]params.GithubCredentials, 0, len(g.cache)) + for _, cred := range g.cache { + creds = append(creds, cred) + } + return creds +} + +func SetGithubCredentials(credentials params.GithubCredentials) { + credentialsCache.SetCredentials(credentials) +} + +func GetGithubCredentials(id uint) (params.GithubCredentials, bool) { + return credentialsCache.GetCredentials(id) +} + +func DeleteGithubCredentials(id uint) { + credentialsCache.DeleteCredentials(id) +} + +func GetAllGithubCredentials() []params.GithubCredentials { + return credentialsCache.GetAllCredentials() +} diff --git a/cache/entity_cache.go b/cache/entity_cache.go new file mode 100644 index 00000000..920b9a9b --- /dev/null +++ b/cache/entity_cache.go @@ -0,0 +1,189 @@ +package cache + +import ( + "sync" + + "github.com/cloudbase/garm/params" +) + +var entityCache *EntityCache + +func init() { + ghEntityCache := &EntityCache{ + entities: make(map[string]EntityItem), + } + entityCache = ghEntityCache +} + +type EntityItem struct { + Entity params.GithubEntity + Pools map[string]params.Pool + ScaleSets map[uint]params.ScaleSet +} + +type EntityCache struct { + mux sync.Mutex + // entity IDs are UUID4s. It is highly unlikely they will collide (🤞). + entities map[string]EntityItem +} + +func (e *EntityCache) GetEntity(entity params.GithubEntity) (EntityItem, bool) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entity.ID]; ok { + // Updating specific credential details will not update entity cache which + // uses those credentials. + // Entity credentials in the cache are only updated if you swap the creds + // on the entity. We get the updated credentials from the credentials cache. + creds, ok := GetGithubCredentials(cache.Entity.Credentials.ID) + if ok { + cache.Entity.Credentials = creds + } + return cache, true + } + return EntityItem{}, false +} + +func (e *EntityCache) SetEntity(entity params.GithubEntity) { + e.mux.Lock() + defer e.mux.Unlock() + + e.entities[entity.ID] = EntityItem{ + Entity: entity, + } +} + +func (e *EntityCache) ReplaceEntityPools(entityID string, pools map[string]params.Pool) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + cache.Pools = pools + e.entities[entityID] = cache + } +} + +func (e *EntityCache) ReplaceEntityScaleSets(entityID string, scaleSets map[uint]params.ScaleSet) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + cache.ScaleSets = scaleSets + e.entities[entityID] = cache + } +} + +func (e *EntityCache) DeleteEntity(entityID string) { + e.mux.Lock() + defer e.mux.Unlock() + delete(e.entities, entityID) +} + +func (e *EntityCache) SetEntityPool(entityID string, pool params.Pool) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + cache.Pools[pool.ID] = pool + e.entities[entityID] = cache + } +} + +func (e *EntityCache) SetEntityScaleSet(entityID string, scaleSet params.ScaleSet) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + cache.ScaleSets[scaleSet.ID] = scaleSet + e.entities[entityID] = cache + } +} + +func (e *EntityCache) DeleteEntityPool(entityID string, poolID string) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + delete(cache.Pools, poolID) + e.entities[entityID] = cache + } +} + +func (e *EntityCache) DeleteEntityScaleSet(entityID string, scaleSetID uint) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + delete(cache.ScaleSets, scaleSetID) + e.entities[entityID] = cache + } +} + +func (e *EntityCache) GetEntityPool(entityID string, poolID string) (params.Pool, bool) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + if pool, ok := cache.Pools[poolID]; ok { + return pool, true + } + } + return params.Pool{}, false +} + +func (e *EntityCache) GetEntityScaleSet(entityID string, scaleSetID uint) (params.ScaleSet, bool) { + e.mux.Lock() + defer e.mux.Unlock() + + if cache, ok := e.entities[entityID]; ok { + if scaleSet, ok := cache.ScaleSets[scaleSetID]; ok { + return scaleSet, true + } + } + return params.ScaleSet{}, false +} + +func GetEntity(entity params.GithubEntity) (EntityItem, bool) { + return entityCache.GetEntity(entity) +} + +func SetEntity(entity params.GithubEntity) { + entityCache.SetEntity(entity) +} + +func ReplaceEntityPools(entityID string, pools map[string]params.Pool) { + entityCache.ReplaceEntityPools(entityID, pools) +} + +func ReplaceEntityScaleSets(entityID string, scaleSets map[uint]params.ScaleSet) { + entityCache.ReplaceEntityScaleSets(entityID, scaleSets) +} + +func DeleteEntity(entityID string) { + entityCache.DeleteEntity(entityID) +} + +func SetEntityPool(entityID string, pool params.Pool) { + entityCache.SetEntityPool(entityID, pool) +} + +func SetEntityScaleSet(entityID string, scaleSet params.ScaleSet) { + entityCache.SetEntityScaleSet(entityID, scaleSet) +} + +func DeleteEntityPool(entityID string, poolID string) { + entityCache.DeleteEntityPool(entityID, poolID) +} + +func DeleteEntityScaleSet(entityID string, scaleSetID uint) { + entityCache.DeleteEntityScaleSet(entityID, scaleSetID) +} + +func GetEntityPool(entityID string, poolID string) (params.Pool, bool) { + return entityCache.GetEntityPool(entityID, poolID) +} + +func GetEntityScaleSet(entityID string, scaleSetID uint) (params.ScaleSet, bool) { + return entityCache.GetEntityScaleSet(entityID, scaleSetID) +} diff --git a/cache/cache.go b/cache/tools_cache.go similarity index 100% rename from cache/cache.go rename to cache/tools_cache.go diff --git a/cmd/garm/main.go b/cmd/garm/main.go index c43e3c93..15ba7069 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -51,6 +51,7 @@ import ( garmUtil "github.com/cloudbase/garm/util" "github.com/cloudbase/garm/util/appdefaults" "github.com/cloudbase/garm/websocket" + "github.com/cloudbase/garm/workers/credentials" "github.com/cloudbase/garm/workers/entity" "github.com/cloudbase/garm/workers/provider" ) @@ -180,7 +181,127 @@ func maybeUpdateURLsFromConfig(cfg config.Config, store common.Store) error { return nil } -//gocyclo:ignore +func configureRouter(ctx context.Context, cfg config.Config, db common.Store, hub *websocket.Hub, runner *runner.Runner) (http.Handler, error) { + authenticator := auth.NewAuthenticator(cfg.JWTAuth, db) + controller, err := controllers.NewAPIController(runner, authenticator, hub) + if err != nil { + return nil, fmt.Errorf("creating controller: %w", err) + } + + instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth) + if err != nil { + return nil, fmt.Errorf("creating instance middleware: %w", err) + } + + jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth) + if err != nil { + return nil, fmt.Errorf("creating jwt middleware: %w", err) + } + + initMiddleware, err := auth.NewInitRequiredMiddleware(db) + if err != nil { + return nil, fmt.Errorf("creating init required middleware: %w", err) + } + + urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db) + if err != nil { + return nil, fmt.Errorf("creating urls required middleware: %w", err) + } + + metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth) + if err != nil { + return nil, fmt.Errorf("creating metrics middleware: %w", err) + } + + router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement) + + // start the metrics collector + if cfg.Metrics.Enable { + slog.InfoContext(ctx, "setting up metric routes") + router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware) + + slog.InfoContext(ctx, "register metrics") + if err := metrics.RegisterMetrics(); err != nil { + return nil, fmt.Errorf("registering metrics: %w", err) + } + + slog.InfoContext(ctx, "start metrics collection") + runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration()) + } + + if cfg.Default.DebugServer { + runtime.SetBlockProfileRate(1) + runtime.SetMutexProfileFraction(1) + slog.InfoContext(ctx, "setting up debug routes") + router = routers.WithDebugServer(router) + } + + corsMw := mux.CORSMethodMiddleware(router) + router.Use(corsMw) + + allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins) + methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"}) + headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"}) + + handler := handlers.CORS(methodsOk, headersOk, allowedOrigins)(router) + return handler, nil +} + +func startWorkers(ctx context.Context, cfg config.Config, db common.Store, controllerID string) (func() error, error) { + credsWorker, err := credentials.NewWorker(ctx, db) + if err != nil { + return nil, fmt.Errorf("failed to create credentials worker: %+v", err) + } + + if err := credsWorker.Start(); err != nil { + return nil, fmt.Errorf("failed to start credentials worker: %+v", err) + } + + providers, err := providers.LoadProvidersFromConfig(ctx, cfg, controllerID) + if err != nil { + return nil, fmt.Errorf("loading providers: %+v", err) + } + + entityController, err := entity.NewController(ctx, db, providers) + if err != nil { + return nil, fmt.Errorf("failed to create entity controller: %+v", err) + } + if err := entityController.Start(); err != nil { + return nil, fmt.Errorf("failed to start entity controller: %+v", err) + } + + instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create instance token getter: %+v", err) + } + + providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter) + if err != nil { + return nil, fmt.Errorf("failed to create provider worker: %+v", err) + } + if err := providerWorker.Start(); err != nil { + return nil, fmt.Errorf("failed to start provider worker: %+v", err) + } + + return func() error { + slog.InfoContext(ctx, "shutting down credentials worker") + if err := credsWorker.Stop(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop credentials worker") + } + + slog.InfoContext(ctx, "shutting down entity controller") + if err := entityController.Stop(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller") + } + + slog.InfoContext(ctx, "shutting down provider worker") + if err := providerWorker.Stop(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker") + } + return nil + }, nil +} + func main() { flag.Parse() if *version { @@ -192,7 +313,6 @@ func main() { watcher.InitWatcher(ctx) ctx = auth.GetAdminContext(ctx) - cfg, err := config.NewConfig(*conf) if err != nil { log.Fatalf("Fetching config: %+v", err) //nolint:gocritic @@ -237,30 +357,9 @@ func main() { log.Fatal(err) } - providers, err := providers.LoadProvidersFromConfig(ctx, *cfg, controllerInfo.ControllerID.String()) + stopWorkersFn, err := startWorkers(ctx, *cfg, db, controllerInfo.ControllerID.String()) if err != nil { - log.Fatalf("loading providers: %+v", err) - } - - entityController, err := entity.NewController(ctx, db, providers) - if err != nil { - log.Fatalf("failed to create entity controller: %+v", err) - } - if err := entityController.Start(); err != nil { - log.Fatalf("failed to start entity controller: %+v", err) - } - - instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret) - if err != nil { - log.Fatalf("failed to create instance token getter: %+v", err) - } - - providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter) - if err != nil { - log.Fatalf("failed to create provider worker: %+v", err) - } - if err := providerWorker.Start(); err != nil { - log.Fatalf("failed to start provider worker: %+v", err) + log.Fatalf("failed to start workers: %+v", err) } runner, err := runner.NewRunner(ctx, *cfg, db) @@ -273,73 +372,17 @@ func main() { log.Fatal(err) } - authenticator := auth.NewAuthenticator(cfg.JWTAuth, db) - controller, err := controllers.NewAPIController(runner, authenticator, hub) + handler, err := configureRouter(ctx, *cfg, db, hub, runner) if err != nil { - log.Fatalf("failed to create controller: %+v", err) + log.Fatalf("failed to configure router: %+v", err) } - instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth) - if err != nil { - log.Fatal(err) - } - - jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth) - if err != nil { - log.Fatal(err) - } - - initMiddleware, err := auth.NewInitRequiredMiddleware(db) - if err != nil { - log.Fatal(err) - } - - urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db) - if err != nil { - log.Fatal(err) - } - - metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth) - if err != nil { - log.Fatal(err) - } - - router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement) - - // start the metrics collector - if cfg.Metrics.Enable { - slog.InfoContext(ctx, "setting up metric routes") - router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware) - - slog.InfoContext(ctx, "register metrics") - if err := metrics.RegisterMetrics(); err != nil { - log.Fatal(err) - } - - slog.InfoContext(ctx, "start metrics collection") - runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration()) - } - - if cfg.Default.DebugServer { - runtime.SetBlockProfileRate(1) - runtime.SetMutexProfileFraction(1) - slog.InfoContext(ctx, "setting up debug routes") - router = routers.WithDebugServer(router) - } - - corsMw := mux.CORSMethodMiddleware(router) - router.Use(corsMw) - - allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins) - methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"}) - headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"}) - // nolint:golangci-lint,gosec // G112: Potential Slowloris Attack because ReadHeaderTimeout is not configured in the http.Server srv := &http.Server{ Addr: cfg.APIServer.BindAddress(), // Pass our instance of gorilla/mux in. - Handler: handlers.CORS(methodsOk, headersOk, allowedOrigins)(router), + Handler: handler, } listener, err := net.Listen("tcp", srv.Addr) @@ -361,22 +404,16 @@ func main() { <-ctx.Done() - slog.InfoContext(ctx, "shutting down entity controller") - if err := entityController.Stop(); err != nil { - slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller") - } - - slog.InfoContext(ctx, "shutting down provider worker") - if err := providerWorker.Stop(); err != nil { - slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker") - } - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second) defer shutdownCancel() if err := srv.Shutdown(shutdownCtx); err != nil { slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed") } + if err := stopWorkersFn(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop workers") + } + slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop") if err := runner.Wait(); err != nil { slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers") diff --git a/workers/credentials/credentials.go b/workers/credentials/credentials.go new file mode 100644 index 00000000..7c590401 --- /dev/null +++ b/workers/credentials/credentials.go @@ -0,0 +1,133 @@ +package credentials + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/cloudbase/garm/cache" + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" + "github.com/cloudbase/garm/params" + garmUtil "github.com/cloudbase/garm/util" +) + +func NewWorker(ctx context.Context, store dbCommon.Store) (*Worker, error) { + consumerID := "credentials-worker" + + ctx = garmUtil.WithSlogContext( + ctx, + slog.Any("worker", consumerID)) + + return &Worker{ + ctx: ctx, + consumerID: consumerID, + store: store, + running: false, + quit: make(chan struct{}), + credentials: make(map[uint]params.GithubCredentials), + }, nil +} + +// Worker is responsible for maintaining the credentials cache. +type Worker struct { + consumerID string + ctx context.Context + + consumer dbCommon.Consumer + store dbCommon.Store + + credentials map[uint]params.GithubCredentials + + running bool + quit chan struct{} + + mux sync.Mutex +} + +func (w *Worker) loadAllCredentials() error { + creds, err := w.store.ListGithubCredentials(w.ctx) + if err != nil { + return err + } + + for _, cred := range creds { + w.credentials[cred.ID] = cred + cache.SetGithubCredentials(cred) + } + + return nil +} + +func (w *Worker) Start() error { + w.mux.Lock() + defer w.mux.Unlock() + + if w.running { + return nil + } + slog.DebugContext(w.ctx, "starting credentials worker") + if err := w.loadAllCredentials(); err != nil { + return fmt.Errorf("loading credentials: %w", err) + } + + consumer, err := watcher.RegisterConsumer( + w.ctx, w.consumerID, + watcher.WithEntityTypeFilter(dbCommon.GithubCredentialsEntityType), + ) + if err != nil { + return fmt.Errorf("failed to create consumer for entity controller: %w", err) + } + w.consumer = consumer + + w.running = true + go w.loop() + return nil +} + +func (w *Worker) Stop() error { + w.mux.Lock() + defer w.mux.Unlock() + + if !w.running { + return nil + } + + close(w.quit) + w.running = false + + return nil +} + +func (w *Worker) loop() { + defer w.Stop() + + for { + select { + case <-w.quit: + return + case event, ok := <-w.consumer.Watch(): + if !ok { + slog.ErrorContext(w.ctx, "consumer channel closed") + return + } + creds, ok := event.Payload.(params.GithubCredentials) + if !ok { + slog.ErrorContext(w.ctx, "invalid payload for entity type", "entity_type", event.EntityType, "payload", event.Payload) + continue + } + w.mux.Lock() + switch event.Operation { + case dbCommon.DeleteOperation: + slog.DebugContext(w.ctx, "got delete operation") + delete(w.credentials, creds.ID) + cache.DeleteGithubCredentials(creds.ID) + default: + w.credentials[creds.ID] = creds + cache.SetGithubCredentials(creds) + } + w.mux.Unlock() + } + } +} diff --git a/workers/entity/controller.go b/workers/entity/controller.go index 41708ec2..066bdfe3 100644 --- a/workers/entity/controller.go +++ b/workers/entity/controller.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/cloudbase/garm/auth" + "github.com/cloudbase/garm/cache" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/runner/common" @@ -66,6 +67,9 @@ func (c *Controller) loadAllRepositories() error { return fmt.Errorf("starting worker: %w", err) } c.Entities[entity.ID] = worker + // take advantage of the fact that we're loading all entities + // and set the cache. + cache.SetEntity(entity) } return nil } @@ -90,6 +94,9 @@ func (c *Controller) loadAllOrganizations() error { return fmt.Errorf("starting worker: %w", err) } c.Entities[entity.ID] = worker + // take advantage of the fact that we're loading all entities + // and set the cache. + cache.SetEntity(entity) } return nil } @@ -114,6 +121,9 @@ func (c *Controller) loadAllEnterprises() error { return fmt.Errorf("starting worker: %w", err) } c.Entities[entity.ID] = worker + // take advantage of the fact that we're loading all entities + // and set the cache. + cache.SetEntity(entity) } return nil } @@ -126,14 +136,14 @@ func (c *Controller) Start() error { } c.mux.Unlock() - if err := c.loadAllRepositories(); err != nil { - return fmt.Errorf("loading repositories: %w", err) + if err := c.loadAllEnterprises(); err != nil { + return fmt.Errorf("loading enterprises: %w", err) } if err := c.loadAllOrganizations(); err != nil { return fmt.Errorf("loading organizations: %w", err) } - if err := c.loadAllEnterprises(); err != nil { - return fmt.Errorf("loading enterprises: %w", err) + if err := c.loadAllRepositories(); err != nil { + return fmt.Errorf("loading repositories: %w", err) } consumer, err := watcher.RegisterConsumer( diff --git a/workers/entity/controller_watcher.go b/workers/entity/controller_watcher.go index ace63702..dcd6ee9a 100644 --- a/workers/entity/controller_watcher.go +++ b/workers/entity/controller_watcher.go @@ -3,6 +3,7 @@ package entity import ( "log/slog" + "github.com/cloudbase/garm/cache" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) @@ -95,4 +96,5 @@ func (c *Controller) handleWatcherDeleteOperation(entityGetter params.EntityGett return } delete(c.Entities, entity.ID) + cache.DeleteEntity(entity.ID) } diff --git a/workers/scaleset/controller_watcher.go b/workers/scaleset/controller_watcher.go index 04cfe1cd..131cb56c 100644 --- a/workers/scaleset/controller_watcher.go +++ b/workers/scaleset/controller_watcher.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" + "github.com/cloudbase/garm/cache" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner/common" @@ -63,6 +64,7 @@ func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet, ghCli c if _, ok := c.ScaleSets[sSet.ID]; ok { slog.DebugContext(c.ctx, "scale set already exists in worker list", "scale_set_id", sSet.ID) + cache.SetEntityScaleSet(c.Entity.ID, sSet) return nil } @@ -88,9 +90,9 @@ func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet, ghCli c } c.ScaleSets[sSet.ID] = &scaleSet{ scaleSet: sSet, - // status: scaleSetStatus{}, - worker: worker, + worker: worker, } + cache.SetEntityScaleSet(c.Entity.ID, sSet) return nil } @@ -109,6 +111,7 @@ func (c *Controller) handleScaleSetDeleteOperation(sSet params.ScaleSet) error { return fmt.Errorf("stopping scale set worker: %w", err) } delete(c.ScaleSets, sSet.ID) + cache.DeleteEntityScaleSet(c.Entity.ID, sSet.ID) return nil } @@ -116,12 +119,16 @@ func (c *Controller) handleScaleSetUpdateOperation(sSet params.ScaleSet) error { c.mux.Lock() defer c.mux.Unlock() - if _, ok := c.ScaleSets[sSet.ID]; !ok { + set, ok := c.ScaleSets[sSet.ID] + if !ok { // Some error may have occurred when the scale set was first created, so we // attempt to create it after the user updated the scale set, hopefully // fixing the reason for the failure. return c.handleScaleSetCreateOperation(sSet, c.ghCli) } + set.scaleSet = sSet + c.ScaleSets[sSet.ID] = set + cache.SetEntityScaleSet(c.Entity.ID, sSet) // We let the watcher in the scale set worker handle the update operation. return nil } @@ -139,6 +146,7 @@ func (c *Controller) handleCredentialsEvent(event dbCommon.ChangePayload) { c.mux.Lock() defer c.mux.Unlock() + cache.SetGithubCredentials(credentials) if c.Entity.Credentials.ID != credentials.ID { // stale update event. return @@ -177,6 +185,7 @@ func (c *Controller) handleEntityEvent(event dbCommon.ChangePayload) { } } c.Entity = entity + cache.SetEntity(c.Entity) default: slog.ErrorContext(c.ctx, "invalid operation type", "operation_type", event.Operation) return