diff --git a/cmd/garm-cli/cmd/github_credentials.go b/cmd/garm-cli/cmd/github_credentials.go index bd3521bf..2b2128d0 100644 --- a/cmd/garm-cli/cmd/github_credentials.go +++ b/cmd/garm-cli/cmd/github_credentials.go @@ -388,6 +388,7 @@ func formatOneGithubCredential(cred params.GithubCredentials) { t.AppendRow(table.Row{"Type", cred.AuthType}) t.AppendRow(table.Row{"Endpoint", cred.Endpoint.Name}) if resetMinutes > 0 { + t.AppendRow(table.Row{"", ""}) t.AppendRow(table.Row{"Remaining API requests", cred.RateLimit.Remaining}) t.AppendRow(table.Row{"Rate limit reset", fmt.Sprintf("%d minutes", int64(resetMinutes))}) } diff --git a/workers/cache/cache.go b/workers/cache/cache.go index 315876d6..13400a3a 100644 --- a/workers/cache/cache.go +++ b/workers/cache/cache.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "golang.org/x/sync/errgroup" + "github.com/cloudbase/garm/cache" "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" @@ -148,6 +150,27 @@ func (w *Worker) loadAllCredentials() error { return nil } +func (w *Worker) waitForErrorGroupOrContextCancelled(g *errgroup.Group) error { + if g == nil { + return nil + } + + done := make(chan error, 1) + go func() { + waitErr := g.Wait() + done <- waitErr + }() + + select { + case err := <-done: + return err + case <-w.ctx.Done(): + return w.ctx.Err() + case <-w.quit: + return nil + } +} + func (w *Worker) Start() error { slog.DebugContext(w.ctx, "starting cache worker") w.mux.Lock() @@ -157,16 +180,31 @@ func (w *Worker) Start() error { return nil } - if err := w.loadAllEntities(); err != nil { - return fmt.Errorf("loading all entities: %w", err) - } + g, _ := errgroup.WithContext(w.ctx) - if err := w.loadAllInstances(); err != nil { - return fmt.Errorf("loading all instances: %w", err) - } + g.Go(func() error { + if err := w.loadAllEntities(); err != nil { + return fmt.Errorf("loading all entities: %w", err) + } + return nil + }) - if err := w.loadAllCredentials(); err != nil { - return fmt.Errorf("loading all credentials: %w", err) + g.Go(func() error { + if err := w.loadAllInstances(); err != nil { + return fmt.Errorf("loading all instances: %w", err) + } + return nil + }) + + g.Go(func() error { + if err := w.loadAllCredentials(); err != nil { + return fmt.Errorf("loading all credentials: %w", err) + } + return nil + }) + + if err := w.waitForErrorGroupOrContextCancelled(g); err != nil { + return fmt.Errorf("waiting for error group: %w", err) } consumer, err := watcher.RegisterConsumer( diff --git a/workers/entity/controller.go b/workers/entity/controller.go index 07fb38ce..db353f0e 100644 --- a/workers/entity/controller.go +++ b/workers/entity/controller.go @@ -6,6 +6,8 @@ import ( "log/slog" "sync" + "golang.org/x/sync/errgroup" + "github.com/cloudbase/garm/auth" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" @@ -53,19 +55,26 @@ func (c *Controller) loadAllRepositories() error { return fmt.Errorf("fetching repositories: %w", err) } + g, _ := errgroup.WithContext(c.ctx) for _, repo := range repos { - entity, err := repo.GetEntity() - if err != nil { - return fmt.Errorf("getting entity: %w", err) - } - worker, err := NewWorker(c.ctx, c.store, entity, c.providers) - if err != nil { - return fmt.Errorf("creating worker: %w", err) - } - if err := worker.Start(); err != nil { - return fmt.Errorf("starting worker: %w", err) - } - c.Entities[entity.ID] = worker + g.Go(func() error { + entity, err := repo.GetEntity() + if err != nil { + return fmt.Errorf("getting entity: %w", err) + } + worker, err := NewWorker(c.ctx, c.store, entity, c.providers) + if err != nil { + return fmt.Errorf("creating worker: %w", err) + } + if err := worker.Start(); err != nil { + return fmt.Errorf("starting worker: %w", err) + } + c.Entities[entity.ID] = worker + return nil + }) + } + if err := c.waitForErrorGroupOrContextCancelled(g); err != nil { + return fmt.Errorf("waiting for error group: %w", err) } return nil } @@ -77,19 +86,27 @@ func (c *Controller) loadAllOrganizations() error { if err != nil { return fmt.Errorf("fetching organizations: %w", err) } + + g, _ := errgroup.WithContext(c.ctx) for _, org := range orgs { - entity, err := org.GetEntity() - if err != nil { - return fmt.Errorf("getting entity: %w", err) - } - worker, err := NewWorker(c.ctx, c.store, entity, c.providers) - if err != nil { - return fmt.Errorf("creating worker: %w", err) - } - if err := worker.Start(); err != nil { - return fmt.Errorf("starting worker: %w", err) - } - c.Entities[entity.ID] = worker + g.Go(func() error { + entity, err := org.GetEntity() + if err != nil { + return fmt.Errorf("getting entity: %w", err) + } + worker, err := NewWorker(c.ctx, c.store, entity, c.providers) + if err != nil { + return fmt.Errorf("creating worker: %w", err) + } + if err := worker.Start(); err != nil { + return fmt.Errorf("starting worker: %w", err) + } + c.Entities[entity.ID] = worker + return nil + }) + } + if err := c.waitForErrorGroupOrContextCancelled(g); err != nil { + return fmt.Errorf("waiting for error group: %w", err) } return nil } @@ -101,19 +118,28 @@ func (c *Controller) loadAllEnterprises() error { if err != nil { return fmt.Errorf("fetching enterprises: %w", err) } + + g, _ := errgroup.WithContext(c.ctx) + for _, enterprise := range enterprises { - entity, err := enterprise.GetEntity() - if err != nil { - return fmt.Errorf("getting entity: %w", err) - } - worker, err := NewWorker(c.ctx, c.store, entity, c.providers) - if err != nil { - return fmt.Errorf("creating worker: %w", err) - } - if err := worker.Start(); err != nil { - return fmt.Errorf("starting worker: %w", err) - } - c.Entities[entity.ID] = worker + g.Go(func() error { + entity, err := enterprise.GetEntity() + if err != nil { + return fmt.Errorf("getting entity: %w", err) + } + worker, err := NewWorker(c.ctx, c.store, entity, c.providers) + if err != nil { + return fmt.Errorf("creating worker: %w", err) + } + if err := worker.Start(); err != nil { + return fmt.Errorf("starting worker: %w", err) + } + c.Entities[entity.ID] = worker + return nil + }) + } + if err := c.waitForErrorGroupOrContextCancelled(g); err != nil { + return fmt.Errorf("waiting for error group: %w", err) } return nil } @@ -126,14 +152,30 @@ func (c *Controller) Start() error { } c.mux.Unlock() - if err := c.loadAllEnterprises(); err != nil { - return fmt.Errorf("loading enterprises: %w", err) - } - if err := c.loadAllOrganizations(); err != nil { - return fmt.Errorf("loading organizations: %w", err) - } - if err := c.loadAllRepositories(); err != nil { - return fmt.Errorf("loading repositories: %w", err) + g, _ := errgroup.WithContext(c.ctx) + g.Go(func() error { + if err := c.loadAllEnterprises(); err != nil { + return fmt.Errorf("loading enterprises: %w", err) + } + return nil + }) + + g.Go(func() error { + if err := c.loadAllOrganizations(); err != nil { + return fmt.Errorf("loading organizations: %w", err) + } + return nil + }) + + g.Go(func() error { + if err := c.loadAllRepositories(); err != nil { + return fmt.Errorf("loading repositories: %w", err) + } + return nil + }) + + if err := c.waitForErrorGroupOrContextCancelled(g); err != nil { + return fmt.Errorf("waiting for error group: %w", err) } consumer, err := watcher.RegisterConsumer( diff --git a/workers/entity/util.go b/workers/entity/util.go index 28b9f955..4912beba 100644 --- a/workers/entity/util.go +++ b/workers/entity/util.go @@ -1,6 +1,8 @@ package entity import ( + "golang.org/x/sync/errgroup" + dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/params" @@ -33,3 +35,24 @@ func composeWorkerWatcherFilters(entity params.GithubEntity) dbCommon.PayloadFil ), ) } + +func (c *Controller) waitForErrorGroupOrContextCancelled(g *errgroup.Group) error { + if g == nil { + return nil + } + + done := make(chan error, 1) + go func() { + waitErr := g.Wait() + done <- waitErr + }() + + select { + case err := <-done: + return err + case <-c.ctx.Done(): + return c.ctx.Err() + case <-c.quit: + return nil + } +} diff --git a/workers/provider/provider.go b/workers/provider/provider.go index b1ab1220..ffc5183d 100644 --- a/workers/provider/provider.go +++ b/workers/provider/provider.go @@ -6,6 +6,8 @@ import ( "log/slog" "sync" + "golang.org/x/sync/errgroup" + commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm/auth" dbCommon "github.com/cloudbase/garm/database/common" @@ -131,12 +133,24 @@ func (p *Provider) Start() error { return nil } - if err := p.loadAllScaleSets(); err != nil { - return fmt.Errorf("loading all scale sets: %w", err) - } + g, _ := errgroup.WithContext(p.ctx) - if err := p.loadAllRunners(); err != nil { - return fmt.Errorf("loading all runners: %w", err) + g.Go(func() error { + if err := p.loadAllScaleSets(); err != nil { + return fmt.Errorf("loading all scale sets: %w", err) + } + return nil + }) + + g.Go(func() error { + if err := p.loadAllRunners(); err != nil { + return fmt.Errorf("loading all runners: %w", err) + } + return nil + }) + + if err := p.waitForErrorGroupOrContextCancelled(g); err != nil { + return fmt.Errorf("waiting for error group: %w", err) } consumer, err := watcher.RegisterConsumer( diff --git a/workers/provider/util.go b/workers/provider/util.go index 8cd33525..ca2626c0 100644 --- a/workers/provider/util.go +++ b/workers/provider/util.go @@ -1,6 +1,8 @@ package provider import ( + "golang.org/x/sync/errgroup" + dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" ) @@ -11,3 +13,24 @@ func composeProviderWatcher() dbCommon.PayloadFilterFunc { watcher.WithEntityTypeFilter(dbCommon.ScaleSetEntityType), ) } + +func (p *Provider) waitForErrorGroupOrContextCancelled(g *errgroup.Group) error { + if g == nil { + return nil + } + + done := make(chan error, 1) + go func() { + waitErr := g.Wait() + done <- waitErr + }() + + select { + case err := <-done: + return err + case <-p.ctx.Done(): + return p.ctx.Err() + case <-p.quit: + return nil + } +}