diff --git a/auth/instance_middleware.go b/auth/instance_middleware.go index dc31327e..6d1d66e4 100644 --- a/auth/instance_middleware.go +++ b/auth/instance_middleware.go @@ -120,7 +120,7 @@ func (amw *instanceMiddleware) claimsToContext(ctx context.Context, claims *Inst return nil, runnerErrors.ErrUnauthorized } - instanceInfo, err := amw.store.GetInstanceByName(ctx, claims.Name) + instanceInfo, err := amw.store.GetInstance(ctx, claims.Name) if err != nil { return ctx, runnerErrors.ErrUnauthorized } diff --git a/database/common/store.go b/database/common/store.go index d768f159..0cf5d929 100644 --- a/database/common/store.go +++ b/database/common/store.go @@ -75,7 +75,6 @@ type PoolStore interface { ListPoolInstances(ctx context.Context, poolID string) ([]params.Instance, error) PoolInstanceCount(ctx context.Context, poolID string) (int64, error) - GetPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error) FindPoolsMatchingAllTags(ctx context.Context, entityType params.ForgeEntityType, entityID string, tags []string) ([]params.Pool, error) } @@ -91,9 +90,9 @@ type UserStore interface { type InstanceStore interface { CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) - DeleteInstance(ctx context.Context, poolID string, instanceName string) error + DeleteInstance(ctx context.Context, poolID string, instanceNameOrID string) error DeleteInstanceByName(ctx context.Context, instanceName string) error - UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) + UpdateInstance(ctx context.Context, instanceNameOrID string, param params.UpdateInstanceParams) (params.Instance, error) // Probably a bad idea without some king of filter or at least pagination // @@ -101,8 +100,8 @@ type InstanceStore interface { // TODO: add filter/pagination ListAllInstances(ctx context.Context) ([]params.Instance, error) - GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error) - AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, eventMessage string) error + GetInstance(ctx context.Context, instanceNameOrID string) (params.Instance, error) + AddInstanceEvent(ctx context.Context, instanceNameOrID string, event params.EventType, eventLevel params.EventLevel, eventMessage string) error } type JobsStore interface { diff --git a/database/sql/instances.go b/database/sql/instances.go index 92194c5e..5f9d018e 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -103,9 +103,16 @@ func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) return instance, nil } -func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string, preload ...string) (Instance, error) { +func (s *sqlDatabase) getInstance(_ context.Context, instanceNameOrID string, preload ...string) (Instance, error) { var instance Instance + var whereArg any = instanceNameOrID + whereClause := "name = ?" + id, err := uuid.Parse(instanceNameOrID) + if err == nil { + whereArg = id + whereClause = "id = ?" + } q := s.conn if len(preload) > 0 { @@ -116,7 +123,7 @@ func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string, q = q.Model(&Instance{}). Preload(clause.Associations). - Where("name = ?", instanceName). + Where(whereClause, whereArg). First(&instance) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { @@ -127,17 +134,8 @@ func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string, return instance, nil } -func (s *sqlDatabase) GetPoolInstanceByName(_ context.Context, poolID string, instanceName string) (params.Instance, error) { - instance, err := s.getPoolInstanceByName(poolID, instanceName) - if err != nil { - return params.Instance{}, fmt.Errorf("error fetching instance: %w", err) - } - - return s.sqlToParamsInstance(instance) -} - -func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error) { - instance, err := s.getInstanceByName(ctx, instanceName, "StatusMessages", "Pool", "ScaleSet") +func (s *sqlDatabase) GetInstance(ctx context.Context, instanceName string) (params.Instance, error) { + instance, err := s.getInstance(ctx, instanceName, "StatusMessages", "Pool", "ScaleSet") if err != nil { return params.Instance{}, fmt.Errorf("error fetching instance: %w", err) } @@ -189,7 +187,7 @@ func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceN } func (s *sqlDatabase) DeleteInstanceByName(ctx context.Context, instanceName string) error { - instance, err := s.getInstanceByName(ctx, instanceName, "Pool", "ScaleSet") + instance, err := s.getInstance(ctx, instanceName, "Pool", "ScaleSet") if err != nil { if errors.Is(err, runnerErrors.ErrNotFound) { return nil @@ -231,7 +229,7 @@ func (s *sqlDatabase) DeleteInstanceByName(ctx context.Context, instanceName str } func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error { - instance, err := s.getInstanceByName(ctx, instanceName) + instance, err := s.getInstance(ctx, instanceName) if err != nil { return fmt.Errorf("error updating instance: %w", err) } @@ -249,7 +247,7 @@ func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, } func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) { - instance, err := s.getInstanceByName(ctx, instanceName, "Pool", "ScaleSet") + instance, err := s.getInstance(ctx, instanceName, "Pool", "ScaleSet") if err != nil { return params.Instance{}, fmt.Errorf("error updating instance: %w", err) } diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index c6093327..5ec55107 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -196,7 +196,7 @@ func (s *InstancesTestSuite) TestCreateInstance() { // assertions s.Require().Nil(err) - storeInstance, err := s.Store.GetInstanceByName(s.adminCtx, s.Fixtures.CreateInstanceParams.Name) + storeInstance, err := s.Store.GetInstance(s.adminCtx, s.Fixtures.CreateInstanceParams.Name) if err != nil { s.FailNow(fmt.Sprintf("failed to get instance: %v", err)) } @@ -236,29 +236,10 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() { s.Require().Equal("error creating instance: mocked insert instance error", err.Error()) } -func (s *InstancesTestSuite) TestGetPoolInstanceByName() { - storeInstance := s.Fixtures.Instances[0] // this is already created in `SetupTest()` - - instance, err := s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name) - - s.Require().Nil(err) - s.Require().Equal(storeInstance.Name, instance.Name) - s.Require().Equal(storeInstance.PoolID, instance.PoolID) - s.Require().Equal(storeInstance.OSArch, instance.OSArch) - s.Require().Equal(storeInstance.OSType, instance.OSType) - s.Require().Equal(storeInstance.CallbackURL, instance.CallbackURL) -} - -func (s *InstancesTestSuite) TestGetPoolInstanceByNameNotFound() { - _, err := s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, "not-existent-instance-name") - - s.Require().Equal("error fetching instance: error fetching pool instance by name: not found", err.Error()) -} - func (s *InstancesTestSuite) TestGetInstanceByName() { storeInstance := s.Fixtures.Instances[1] - instance, err := s.Store.GetInstanceByName(s.adminCtx, storeInstance.Name) + instance, err := s.Store.GetInstance(s.adminCtx, storeInstance.Name) s.Require().Nil(err) s.Require().Equal(storeInstance.Name, instance.Name) @@ -269,7 +250,7 @@ func (s *InstancesTestSuite) TestGetInstanceByName() { } func (s *InstancesTestSuite) TestGetInstanceByNameFetchInstanceFailed() { - _, err := s.Store.GetInstanceByName(s.adminCtx, "not-existent-instance-name") + _, err := s.Store.GetInstance(s.adminCtx, "not-existent-instance-name") s.Require().Equal("error fetching instance: error fetching instance by name: not found", err.Error()) } @@ -281,8 +262,8 @@ func (s *InstancesTestSuite) TestDeleteInstance() { s.Require().Nil(err) - _, err = s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name) - s.Require().Equal("error fetching instance: error fetching pool instance by name: not found", err.Error()) + _, err = s.Store.GetInstance(s.adminCtx, storeInstance.Name) + s.Require().Equal("error fetching instance: error fetching instance by name: not found", err.Error()) err = s.Store.DeleteInstance(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name) s.Require().Nil(err) @@ -295,8 +276,8 @@ func (s *InstancesTestSuite) TestDeleteInstanceByName() { s.Require().Nil(err) - _, err = s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name) - s.Require().Equal("error fetching instance: error fetching pool instance by name: not found", err.Error()) + _, err = s.Store.GetInstance(s.adminCtx, storeInstance.Name) + s.Require().Equal("error fetching instance: error fetching instance by name: not found", err.Error()) err = s.Store.DeleteInstanceByName(s.adminCtx, storeInstance.Name) s.Require().Nil(err) @@ -390,7 +371,7 @@ func (s *InstancesTestSuite) TestAddInstanceEvent() { err := s.Store.AddInstanceEvent(s.adminCtx, storeInstance.Name, params.StatusEvent, params.EventInfo, statusMsg) s.Require().Nil(err) - instance, err := s.Store.GetInstanceByName(s.adminCtx, storeInstance.Name) + instance, err := s.Store.GetInstance(s.adminCtx, storeInstance.Name) if err != nil { s.FailNow(fmt.Sprintf("failed to get db instance: %s", err)) } diff --git a/database/sql/jobs.go b/database/sql/jobs.go index f4d24e42..5740052a 100644 --- a/database/sql/jobs.go +++ b/database/sql/jobs.go @@ -100,7 +100,7 @@ func (s *sqlDatabase) paramsJobToWorkflowJob(ctx context.Context, job params.Job } if job.RunnerName != "" { - instance, err := s.getInstanceByName(s.ctx, job.RunnerName) + instance, err := s.getInstance(s.ctx, job.RunnerName) if err != nil { // This usually is very normal as not all jobs run on our runners. slog.DebugContext(ctx, "failed to get instance by name", "instance_name", job.RunnerName) @@ -282,7 +282,7 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa } if job.RunnerName != "" { - instance, err := s.getInstanceByName(ctx, job.RunnerName) + instance, err := s.getInstance(ctx, job.RunnerName) if err == nil { workflowJob.InstanceID = &instance.ID } else { diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 690fed93..8610d4c9 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -557,7 +557,7 @@ func (r *basePoolManager) cleanupOrphanedGithubRunners(runners []*github.Runner) continue } - dbInstance, err := r.store.GetInstanceByName(r.ctx, *runner.Name) + dbInstance, err := r.store.GetInstance(r.ctx, *runner.Name) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return fmt.Errorf("error fetching instance from DB: %w", err) diff --git a/runner/runner.go b/runner/runner.go index 2c12071d..bf081522 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -731,7 +731,7 @@ func (r *Runner) GetInstance(ctx context.Context, instanceName string) (params.I return params.Instance{}, runnerErrors.ErrUnauthorized } - instance, err := r.store.GetInstanceByName(ctx, instanceName) + instance, err := r.store.GetInstance(ctx, instanceName) if err != nil { return params.Instance{}, fmt.Errorf("error fetching instance: %w", err) } @@ -852,7 +852,7 @@ func (r *Runner) DeleteRunner(ctx context.Context, instanceName string, forceDel return runnerErrors.ErrUnauthorized } - instance, err := r.store.GetInstanceByName(ctx, instanceName) + instance, err := r.store.GetInstance(ctx, instanceName) if err != nil { return fmt.Errorf("error fetching instance: %w", err) }