diff --git a/cmd/garm/main.go b/cmd/garm/main.go index 3ffcdc1f..5879fd0a 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -46,6 +46,7 @@ import ( "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner" //nolint:typecheck runnerMetrics "github.com/cloudbase/garm/runner/metrics" + "github.com/cloudbase/garm/runner/providers" garmUtil "github.com/cloudbase/garm/util" "github.com/cloudbase/garm/util/appdefaults" "github.com/cloudbase/garm/websocket" @@ -62,16 +63,17 @@ var signals = []os.Signal{ syscall.SIGTERM, } -func maybeInitController(db common.Store) error { - if _, err := db.ControllerInfo(); err == nil { - return nil +func maybeInitController(db common.Store) (params.ControllerInfo, error) { + if info, err := db.ControllerInfo(); err == nil { + return info, nil } - if _, err := db.InitController(); err != nil { - return errors.Wrap(err, "initializing controller") + info, err := db.InitController() + if err != nil { + return params.ControllerInfo{}, errors.Wrap(err, "initializing controller") } - return nil + return info, nil } func setupLogging(ctx context.Context, logCfg config.Logging, hub *websocket.Hub) { @@ -212,7 +214,8 @@ func main() { log.Fatal(err) } - if err := maybeInitController(db); err != nil { + controllerInfo, err := maybeInitController(db) + if err != nil { log.Fatal(err) } @@ -231,7 +234,12 @@ func main() { log.Fatal(err) } - entityController, err := entity.NewController(ctx, db, *cfg) + providers, err := providers.LoadProvidersFromConfig(ctx, *cfg, controllerInfo.ControllerID.String()) + if err != nil { + log.Fatalf("loading providers: %+v", err) + } + + entityController, err := entity.NewController(ctx, db, providers) if err != nil { log.Fatalf("failed to create entity controller: %+v", err) } diff --git a/database/common/store.go b/database/common/store.go index 82c5e4c0..87804281 100644 --- a/database/common/store.go +++ b/database/common/store.go @@ -92,6 +92,7 @@ type UserStore interface { type InstanceStore interface { CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) DeleteInstance(ctx context.Context, poolID string, instanceName string) error + DeleteInstanceByName(ctx context.Context, instanceName string) error UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) // Probably a bad idea without some king of filter or at least pagination diff --git a/database/sql/instances.go b/database/sql/instances.go index f88cd33b..cf0020b5 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -177,6 +177,39 @@ func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceN return nil } +func (s *sqlDatabase) DeleteInstanceByName(ctx context.Context, instanceName string) error { + instance, err := s.getInstanceByName(ctx, instanceName) + if err != nil { + return errors.Wrap(err, "deleting instance") + } + + defer func() { + if err == nil { + var providerID string + if instance.ProviderID != nil { + providerID = *instance.ProviderID + } + if notifyErr := s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{ + ID: instance.ID.String(), + Name: instance.Name, + ProviderID: providerID, + AgentID: instance.AgentID, + PoolID: instance.PoolID.String(), + }); notifyErr != nil { + slog.With(slog.Any("error", notifyErr)).Error("failed to send notify") + } + } + }() + + if q := s.conn.Unscoped().Delete(&instance); q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return nil + } + return errors.Wrap(q.Error, "deleting instance") + } + return nil +} + func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error { instance, err := s.getInstanceByName(ctx, instanceName) if err != nil { @@ -293,7 +326,7 @@ func (s *sqlDatabase) ListPoolInstances(_ context.Context, poolID string) ([]par func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, error) { var instances []Instance - q := s.conn.Model(&Instance{}).Preload("Job", "Pool", "ScaleSet").Find(&instances) + q := s.conn.Model(&Instance{}).Preload("Job").Find(&instances) if q.Error != nil { return nil, errors.Wrap(q.Error, "fetching instances") } diff --git a/database/sql/models.go b/database/sql/models.go index 45e329f6..c1b6462d 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -277,7 +277,7 @@ type Instance struct { GitHubRunnerGroup string AditionalLabels datatypes.JSON - PoolID uuid.UUID + PoolID *uuid.UUID Pool Pool `gorm:"foreignKey:PoolID"` ScaleSetFkID *uint diff --git a/database/sql/scaleset_instances.go b/database/sql/scaleset_instances.go index 3278b934..106df956 100644 --- a/database/sql/scaleset_instances.go +++ b/database/sql/scaleset_instances.go @@ -51,7 +51,7 @@ func (s *sqlDatabase) CreateScaleSetInstance(_ context.Context, scaleSetID uint, func (s *sqlDatabase) ListScaleSetInstances(_ context.Context, scalesetID uint) ([]params.Instance, error) { var instances []Instance - query := s.conn.Model(&Instance{}).Preload("Job", "ScaleSet").Where("scale_set_fk_id = ?", scalesetID) + query := s.conn.Model(&Instance{}).Preload("Job").Where("scale_set_fk_id = ?", scalesetID) if err := query.Find(&instances); err.Error != nil { return nil, errors.Wrap(err.Error, "fetching instances") diff --git a/database/sql/util.go b/database/sql/util.go index dda3e9cf..112d0a76 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -79,7 +79,7 @@ func (s *sqlDatabase) sqlToParamsInstance(instance Instance) (params.Instance, e ret.RunnerBootstrapTimeout = instance.ScaleSet.RunnerBootstrapTimeout } - if instance.PoolID != uuid.Nil { + if instance.PoolID != nil { ret.PoolID = instance.PoolID.String() ret.ProviderName = instance.Pool.ProviderName ret.RunnerBootstrapTimeout = instance.Pool.RunnerBootstrapTimeout diff --git a/database/watcher/filters.go b/database/watcher/filters.go index 0c259bce..6a7e8abf 100644 --- a/database/watcher/filters.go +++ b/database/watcher/filters.go @@ -1,6 +1,8 @@ package watcher import ( + commonParams "github.com/cloudbase/garm-provider-common/params" + dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) @@ -281,3 +283,24 @@ func WithEntityTypeAndCallbackFilter(entityType dbCommon.DatabaseEntityType, cal return ok } } + +func WithInstanceStatusFilter(statuses ...commonParams.InstanceStatus) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + if payload.EntityType != dbCommon.InstanceEntityType { + return false + } + instance, ok := payload.Payload.(params.Instance) + if !ok { + return false + } + if len(statuses) == 0 { + return false + } + for _, status := range statuses { + if instance.Status == status { + return true + } + } + return false + } +} diff --git a/params/github.go b/params/github.go index 9b0a1e43..e0ad0452 100644 --- a/params/github.go +++ b/params/github.go @@ -419,18 +419,18 @@ func (r RunnerScaleSetMessage) GetJobsFromBody() ([]ScaleSetJobMessage, error) { } type RunnerReference struct { - ID int `json:"id"` - Name string `json:"name"` - RunnerScaleSetID int `json:"runnerScaleSetId"` - CreatedOn time.Time `json:"createdOn"` - RunnerGroupID uint64 `json:"runnerGroupId"` - RunnerGroupName string `json:"runnerGroupName"` - Version string `json:"version"` - Enabled bool `json:"enabled"` - Ephemeral bool `json:"ephemeral"` - Status RunnerStatus `json:"status"` - DisableUpdate bool `json:"disableUpdate"` - ProvisioningState string `json:"provisioningState"` + ID int64 `json:"id"` + Name string `json:"name"` + RunnerScaleSetID int `json:"runnerScaleSetId"` + CreatedOn interface{} `json:"createdOn"` + RunnerGroupID uint64 `json:"runnerGroupId"` + RunnerGroupName string `json:"runnerGroupName"` + Version string `json:"version"` + Enabled bool `json:"enabled"` + Ephemeral bool `json:"ephemeral"` + Status interface{} `json:"status"` + DisableUpdate bool `json:"disableUpdate"` + ProvisioningState string `json:"provisioningState"` } type RunnerScaleSetJitRunnerConfig struct { diff --git a/util/github/scalesets/runners.go b/util/github/scalesets/runners.go index d4d2b3f6..4d6434eb 100644 --- a/util/github/scalesets/runners.go +++ b/util/github/scalesets/runners.go @@ -30,7 +30,7 @@ type scaleSetJitRunnerConfig struct { WorkFolder string `json:"workFolder"` } -func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName string, scaleSet params.RunnerScaleSet) (params.RunnerScaleSetJitRunnerConfig, error) { +func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName string, scaleSetID int) (params.RunnerScaleSetJitRunnerConfig, error) { runnerSettings := scaleSetJitRunnerConfig{ Name: runnerName, WorkFolder: "_work", @@ -41,7 +41,14 @@ func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName return params.RunnerScaleSetJitRunnerConfig{}, err } - req, err := s.newActionsRequest(ctx, http.MethodPost, scaleSet.RunnerJitConfigURL, bytes.NewBuffer(body)) + serviceUrl, err := s.actionsServiceInfo.GetURL() + if err != nil { + return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to get pipeline URL: %w", err) + } + jitConfigPath := fmt.Sprintf("/%s/%d/generatejitconfig", scaleSetEndpoint, scaleSetID) + jitConfigURL := serviceUrl.JoinPath(jitConfigPath) + + req, err := s.newActionsRequest(ctx, http.MethodPost, jitConfigURL.String(), bytes.NewBuffer(body)) if err != nil { return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to create request: %w", err) } @@ -81,6 +88,26 @@ func (s *ScaleSetClient) GetRunner(ctx context.Context, runnerID int64) (params. return runnerReference, nil } +func (s *ScaleSetClient) ListAllRunners(ctx context.Context) (params.RunnerReferenceList, error) { + req, err := s.newActionsRequest(ctx, http.MethodGet, runnerEndpoint, nil) + if err != nil { + return params.RunnerReferenceList{}, fmt.Errorf("failed to construct request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerReferenceList{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + var runnerList params.RunnerReferenceList + if err := json.NewDecoder(resp.Body).Decode(&runnerList); err != nil { + return params.RunnerReferenceList{}, fmt.Errorf("failed to decode response: %w", err) + } + + return runnerList, nil +} + func (s *ScaleSetClient) GetRunnerByName(ctx context.Context, runnerName string) (params.RunnerReference, error) { path := fmt.Sprintf("%s?agentName=%s", runnerEndpoint, runnerName) diff --git a/util/github/scalesets/util.go b/util/github/scalesets/util.go index 15c3a5cf..66171dd6 100644 --- a/util/github/scalesets/util.go +++ b/util/github/scalesets/util.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" ) @@ -50,5 +51,7 @@ func (s *ScaleSetClient) newActionsRequest(ctx context.Context, method, path str req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.actionsServiceInfo.Token)) + slog.DebugContext(ctx, "newActionsRequest", "method", method, "url", uri.String(), "body", body, "headers", req.Header) + return req, nil } diff --git a/workers/entity/controller.go b/workers/entity/controller.go index 1e0035c0..bfdcabfe 100644 --- a/workers/entity/controller.go +++ b/workers/entity/controller.go @@ -7,31 +7,19 @@ import ( "sync" "github.com/cloudbase/garm/auth" - "github.com/cloudbase/garm/config" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/runner/common" - "github.com/cloudbase/garm/runner/providers" garmUtil "github.com/cloudbase/garm/util" ) -func NewController(ctx context.Context, store dbCommon.Store, cfg config.Config) (*Controller, error) { +func NewController(ctx context.Context, store dbCommon.Store, providers map[string]common.Provider) (*Controller, error) { consumerID := "entity-controller" - ctrlID, err := store.ControllerInfo() - if err != nil { - return nil, fmt.Errorf("getting controller info: %w", err) - } - ctx = garmUtil.WithSlogContext( ctx, slog.Any("worker", consumerID)) ctx = auth.GetAdminContext(ctx) - providers, err := providers.LoadProvidersFromConfig(ctx, cfg, ctrlID.ControllerID.String()) - if err != nil { - return nil, fmt.Errorf("loading providers: %w", err) - } - return &Controller{ consumerID: consumerID, ctx: ctx, diff --git a/workers/provider/provider.go b/workers/provider/provider.go new file mode 100644 index 00000000..7f0784e9 --- /dev/null +++ b/workers/provider/provider.go @@ -0,0 +1,73 @@ +package provider + +import ( + "context" + "fmt" + "sync" + + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" + "github.com/cloudbase/garm/runner/common" +) + +func NewWorker(ctx context.Context, store dbCommon.Store, providers map[string]common.Provider) (*provider, error) { + consumerID := "provider-worker" + return &provider{ + ctx: context.Background(), + store: store, + consumerID: consumerID, + providers: providers, + }, nil +} + +type provider struct { + ctx context.Context + consumerID string + + consumer dbCommon.Consumer + // TODO: not all workers should have access to the store. + // We need to implement way to RPC from workers to controllers + // and abstract that into something we can use to eventually + // scale out. + store dbCommon.Store + + providers map[string]common.Provider + + mux sync.Mutex + running bool + quit chan struct{} +} + +func (p *provider) Start() error { + p.mux.Lock() + defer p.mux.Unlock() + + if p.running { + return nil + } + + consumer, err := watcher.RegisterConsumer( + p.ctx, p.consumerID, composeProviderWatcher()) + if err != nil { + return fmt.Errorf("registering consumer: %w", err) + } + p.consumer = consumer + + p.quit = make(chan struct{}) + p.running = true + return nil +} + +func (p *provider) Stop() error { + p.mux.Lock() + defer p.mux.Unlock() + + if !p.running { + return nil + } + + p.consumer.Close() + close(p.quit) + p.running = false + return nil +} diff --git a/workers/provider/util.go b/workers/provider/util.go new file mode 100644 index 00000000..2d84e25e --- /dev/null +++ b/workers/provider/util.go @@ -0,0 +1,18 @@ +package provider + +import ( + commonParams "github.com/cloudbase/garm-provider-common/params" + + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" +) + +func composeProviderWatcher() dbCommon.PayloadFilterFunc { + return watcher.WithAny( + watcher.WithInstanceStatusFilter( + commonParams.InstancePendingCreate, + commonParams.InstancePendingDelete, + commonParams.InstancePendingForceDelete, + ), + ) +} diff --git a/workers/scaleset/scaleset.go b/workers/scaleset/scaleset.go index 7e134adb..24df1cbb 100644 --- a/workers/scaleset/scaleset.go +++ b/workers/scaleset/scaleset.go @@ -2,13 +2,19 @@ package scaleset import ( "context" + "errors" "fmt" "log/slog" "sync" "time" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + commonParams "github.com/cloudbase/garm-provider-common/params" + + "github.com/cloudbase/garm-provider-common/util" dbCommon "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" + "github.com/cloudbase/garm/locking" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner/common" "github.com/cloudbase/garm/util/github/scalesets" @@ -188,6 +194,17 @@ func (w *Worker) handleScaleSetEvent(event dbCommon.ChangePayload) { } } +func (w *Worker) handleInstanceCleanup(instance params.Instance) error { + if instance.Status == commonParams.InstanceDeleted { + if err := w.store.DeleteInstanceByName(w.ctx, instance.Name); err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return fmt.Errorf("deleting instance %s: %w", instance.ID, err) + } + } + } + return nil +} + func (w *Worker) handleInstanceEntityEvent(event dbCommon.ChangePayload) { instance, ok := event.Payload.(params.Instance) if !ok { @@ -319,6 +336,138 @@ func (w *Worker) keepListenerAlive() { } } +func (w *Worker) handleScaleUp(target, current uint) { + if !w.scaleSet.Enabled { + slog.DebugContext(w.ctx, "scale set is disabled; not scaling up") + return + } + + if target <= current { + slog.DebugContext(w.ctx, "target is less than or equal to current; not scaling up") + return + } + + controllerConfig, err := w.store.ControllerInfo() + if err != nil { + slog.ErrorContext(w.ctx, "error getting controller config", "error", err) + return + } + + for i := current; i < target; i++ { + newRunnerName := fmt.Sprintf("%s-%s", w.scaleSet.GetRunnerPrefix(), util.NewID()) + jitConfig, err := w.scaleSetCli.GenerateJitRunnerConfig(w.ctx, newRunnerName, w.scaleSet.ScaleSetID) + if err != nil { + slog.ErrorContext(w.ctx, "error generating jit config", "error", err) + continue + } + slog.DebugContext(w.ctx, "creating new runner", "runner_name", newRunnerName) + decodedJit, err := jitConfig.DecodedJITConfig() + if err != nil { + slog.ErrorContext(w.ctx, "error decoding jit config", "error", err) + continue + } + runnerParams := params.CreateInstanceParams{ + Name: newRunnerName, + Status: commonParams.InstancePendingCreate, + RunnerStatus: params.RunnerPending, + OSArch: w.scaleSet.OSArch, + OSType: w.scaleSet.OSType, + CallbackURL: controllerConfig.CallbackURL, + MetadataURL: controllerConfig.MetadataURL, + CreateAttempt: 1, + GitHubRunnerGroup: w.scaleSet.GitHubRunnerGroup, + JitConfiguration: decodedJit, + AgentID: int64(jitConfig.Runner.ID), + } + + if _, err := w.store.CreateScaleSetInstance(w.ctx, w.scaleSet.ID, runnerParams); err != nil { + slog.ErrorContext(w.ctx, "error creating instance", "error", err) + if err := w.scaleSetCli.RemoveRunner(w.ctx, jitConfig.Runner.ID); err != nil { + slog.ErrorContext(w.ctx, "error deleting runner", "error", err) + } + continue + } + + runnerDetails, err := w.scaleSetCli.GetRunner(w.ctx, jitConfig.Runner.ID) + if err != nil { + slog.ErrorContext(w.ctx, "error getting runner details", "error", err) + continue + } + slog.DebugContext(w.ctx, "runner details", "runner_details", runnerDetails) + } +} + +func (w *Worker) handleScaleDown(target, current uint) { + delta := current - target + if delta <= 0 { + return + } + w.mux.Lock() + defer w.mux.Unlock() + removed := 0 + for _, runner := range w.runners { + if removed >= int(delta) { + break + } + + locked, err := locking.TryLock(runner.Name) + if err != nil || !locked { + slog.DebugContext(w.ctx, "runner is locked; skipping", "runner_name", runner.Name) + continue + } + + switch runner.Status { + case commonParams.InstancePendingCreate, commonParams.InstanceRunning: + case commonParams.InstancePendingDelete, commonParams.InstancePendingForceDelete: + removed++ + locking.Unlock(runner.Name, true) + continue + default: + slog.DebugContext(w.ctx, "runner is not in a valid state; skipping", "runner_name", runner.Name, "runner_status", runner.Status) + locking.Unlock(runner.Name, false) + continue + } + + switch runner.RunnerStatus { + case params.RunnerTerminated, params.RunnerActive: + slog.DebugContext(w.ctx, "runner is not in a valid state; skipping", "runner_name", runner.Name, "runner_status", runner.RunnerStatus) + locking.Unlock(runner.Name, false) + continue + } + + slog.DebugContext(w.ctx, "removing runner", "runner_name", runner.Name) + if err := w.scaleSetCli.RemoveRunner(w.ctx, runner.AgentID); err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + slog.ErrorContext(w.ctx, "error removing runner", "runner_name", runner.Name, "error", err) + locking.Unlock(runner.Name, false) + continue + } + } + runnerUpdateParams := params.UpdateInstanceParams{ + Status: commonParams.InstancePendingDelete, + } + if _, err := w.store.UpdateInstance(w.ctx, runner.Name, runnerUpdateParams); err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + // The error seems to be that the instance was removed from the database. We still had it in our + // state, so either the update never came from the watcher or something else happened. + // Remove it from the local cache. + delete(w.runners, runner.ID) + removed++ + locking.Unlock(runner.Name, true) + continue + } + // TODO: This should not happen, unless there is some issue with the database. + // The UpdateInstance() function should add tenacity, but even in that case, if it + // still errors out, we need to handle it somehow. + slog.ErrorContext(w.ctx, "error updating runner", "runner_name", runner.Name, "error", err) + locking.Unlock(runner.Name, false) + continue + } + removed++ + locking.Unlock(runner.Name, false) + } +} + func (w *Worker) handleAutoScale() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -337,6 +486,14 @@ func (w *Worker) handleAutoScale() { case <-w.ctx.Done(): return case <-ticker.C: + w.mux.Lock() + for _, instance := range w.runners { + if err := w.handleInstanceCleanup(instance); err != nil { + slog.ErrorContext(w.ctx, "error cleaning up instance", "instance_id", instance.ID, "error", err) + } + } + w.mux.Unlock() + var desiredRunners uint if w.scaleSet.DesiredRunnerCount > 0 { desiredRunners = uint(w.scaleSet.DesiredRunnerCount) @@ -351,8 +508,10 @@ func (w *Worker) handleAutoScale() { if currentRunners < targetRunners { lastMsgDebugLog("scaling up", targetRunners, currentRunners) + w.handleScaleUp(targetRunners, currentRunners) } else { lastMsgDebugLog("attempting to scale down", targetRunners, currentRunners) + w.handleScaleDown(targetRunners, currentRunners) } } }