diff --git a/database/common/store.go b/database/common/store.go index a2b2cf77..2ac55a4b 100644 --- a/database/common/store.go +++ b/database/common/store.go @@ -142,10 +142,14 @@ type ScaleSetsStore interface { UpdateEntityScaleSet(_ context.Context, entity params.GithubEntity, scaleSetID uint, param params.UpdateScaleSetParams, callback func(old, new params.ScaleSet) error) (updatedScaleSet params.ScaleSet, err error) GetScaleSetByID(ctx context.Context, scaleSet uint) (params.ScaleSet, error) DeleteScaleSetByID(ctx context.Context, scaleSetID uint) (err error) - ListScaleSetInstances(_ context.Context, scalesetID uint) ([]params.Instance, error) SetScaleSetLastMessageID(ctx context.Context, scaleSetID uint, lastMessageID int64) error } +type ScaleSetInstanceStore interface { + ListScaleSetInstances(_ context.Context, scalesetID uint) ([]params.Instance, error) + CreateScaleSetInstance(_ context.Context, scaleSetID uint, param params.CreateInstanceParams) (instance params.Instance, err error) +} + //go:generate mockery --name=Store type Store interface { RepoStore @@ -160,6 +164,7 @@ type Store interface { ControllerStore EntityPoolStore ScaleSetsStore + ScaleSetInstanceStore ControllerInfo() (params.ControllerInfo, error) InitController() (params.ControllerInfo, error) diff --git a/database/sql/instances.go b/database/sql/instances.go index d4bfd019..f88cd33b 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -136,7 +136,7 @@ func (s *sqlDatabase) GetPoolInstanceByName(_ context.Context, poolID string, in } func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error) { - instance, err := s.getInstanceByName(ctx, instanceName, "StatusMessages", "Pool") + instance, err := s.getInstanceByName(ctx, instanceName, "StatusMessages", "Pool", "ScaleSet") if err != nil { return params.Instance{}, errors.Wrap(err, "fetching instance") } @@ -196,7 +196,7 @@ func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, } func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) { - instance, err := s.getInstanceByName(ctx, instanceName, "Pool") + instance, err := s.getInstanceByName(ctx, instanceName, "Pool", "ScaleSet") if err != nil { return params.Instance{}, errors.Wrap(err, "updating instance") } @@ -290,25 +290,6 @@ func (s *sqlDatabase) ListPoolInstances(_ context.Context, poolID string) ([]par return ret, nil } -func (s *sqlDatabase) ListScaleSetInstances(_ context.Context, scalesetID uint) ([]params.Instance, error) { - var instances []Instance - query := s.conn.Model(&Instance{}).Preload("Job", "ScaleSet").Where("scale_set_fk_id = ?", scalesetID) - - if err := query.Find(&instances); err.Error != nil { - return nil, errors.Wrap(err.Error, "fetching instances") - } - - var err error - ret := make([]params.Instance, len(instances)) - for idx, inst := range instances { - ret[idx], err = s.sqlToParamsInstance(inst) - if err != nil { - return nil, errors.Wrap(err, "converting instance") - } - } - return ret, nil -} - func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, error) { var instances []Instance diff --git a/database/sql/scaleset_instances.go b/database/sql/scaleset_instances.go new file mode 100644 index 00000000..3278b934 --- /dev/null +++ b/database/sql/scaleset_instances.go @@ -0,0 +1,69 @@ +package sql + +import ( + "context" + + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" + "github.com/pkg/errors" +) + +func (s *sqlDatabase) CreateScaleSetInstance(_ context.Context, scaleSetID uint, param params.CreateInstanceParams) (instance params.Instance, err error) { + scaleSet, err := s.getScaleSetByID(s.conn, scaleSetID) + if err != nil { + return params.Instance{}, errors.Wrap(err, "fetching scale set") + } + + defer func() { + if err == nil { + s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance) + } + }() + + var secret []byte + if len(param.JitConfiguration) > 0 { + secret, err = s.marshalAndSeal(param.JitConfiguration) + if err != nil { + return params.Instance{}, errors.Wrap(err, "marshalling jit config") + } + } + + newInstance := Instance{ + ScaleSet: scaleSet, + Name: param.Name, + Status: param.Status, + RunnerStatus: param.RunnerStatus, + OSType: param.OSType, + OSArch: param.OSArch, + CallbackURL: param.CallbackURL, + MetadataURL: param.MetadataURL, + GitHubRunnerGroup: param.GitHubRunnerGroup, + JitConfiguration: secret, + AgentID: param.AgentID, + } + q := s.conn.Create(&newInstance) + if q.Error != nil { + return params.Instance{}, errors.Wrap(q.Error, "creating instance") + } + + return s.sqlToParamsInstance(newInstance) +} + +func (s *sqlDatabase) ListScaleSetInstances(_ context.Context, scalesetID uint) ([]params.Instance, error) { + var instances []Instance + query := s.conn.Model(&Instance{}).Preload("Job", "ScaleSet").Where("scale_set_fk_id = ?", scalesetID) + + if err := query.Find(&instances); err.Error != nil { + return nil, errors.Wrap(err.Error, "fetching instances") + } + + var err error + ret := make([]params.Instance, len(instances)) + for idx, inst := range instances { + ret[idx], err = s.sqlToParamsInstance(inst) + if err != nil { + return nil, errors.Wrap(err, "converting instance") + } + } + return ret, nil +} diff --git a/workers/scaleset/scaleset.go b/workers/scaleset/scaleset.go index c392d5cd..f2fc36af 100644 --- a/workers/scaleset/scaleset.go +++ b/workers/scaleset/scaleset.go @@ -30,7 +30,7 @@ func NewWorker(ctx context.Context, store dbCommon.Store, scaleSet params.ScaleS consumerID: consumerID, store: store, provider: provider, - Entity: scaleSet, + scaleSet: scaleSet, ghCli: ghCli, scaleSetCli: scaleSetCli, }, nil @@ -43,7 +43,7 @@ type Worker struct { provider common.Provider store dbCommon.Store - Entity params.ScaleSet + scaleSet params.ScaleSet ghCli common.GithubClient scaleSetCli *scalesets.ScaleSetClient @@ -88,7 +88,7 @@ func (w *Worker) Start() (err error) { consumer, err := watcher.RegisterConsumer( w.ctx, w.consumerID, watcher.WithAll( - watcher.WithScaleSetFilter(w.Entity), + watcher.WithScaleSetFilter(w.scaleSet), watcher.WithOperationTypeFilter(dbCommon.UpdateOperation), ), ) @@ -148,13 +148,13 @@ func (w *Worker) handleEvent(event dbCommon.ChangePayload) { case dbCommon.UpdateOperation: slog.DebugContext(w.ctx, "got update operation") w.mux.Lock() - if scaleSet.MaxRunners < w.Entity.MaxRunners { + if scaleSet.MaxRunners < w.scaleSet.MaxRunners { slog.DebugContext(w.ctx, "max runners changed; stopping listener") if err := w.listener.Stop(); err != nil { slog.ErrorContext(w.ctx, "error stopping listener", "error", err) } } - w.Entity = scaleSet + w.scaleSet = scaleSet w.mux.Unlock() default: slog.DebugContext(w.ctx, "invalid operation type; ignoring", "operation_type", event.Operation) diff --git a/workers/scaleset/scaleset_helper.go b/workers/scaleset/scaleset_helper.go index 4604a919..4d84a76b 100644 --- a/workers/scaleset/scaleset_helper.go +++ b/workers/scaleset/scaleset_helper.go @@ -12,7 +12,7 @@ func (w *Worker) ScaleSetCLI() *scalesets.ScaleSetClient { } func (w *Worker) GetScaleSet() params.ScaleSet { - return w.Entity + return w.scaleSet } func (w *Worker) Owner() string { @@ -20,7 +20,7 @@ func (w *Worker) Owner() string { } func (w *Worker) SetLastMessageID(id int64) error { - if err := w.store.SetScaleSetLastMessageID(w.ctx, w.Entity.ID, id); err != nil { + if err := w.store.SetScaleSetLastMessageID(w.ctx, w.scaleSet.ID, id); err != nil { return fmt.Errorf("setting last message ID: %w", err) } return nil