From 0152b215294320e763eeb368604fc00c727da109 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 28 Mar 2024 10:08:19 +0000 Subject: [PATCH] Implement some common logic for pool creation Signed-off-by: Gabriel Adrian Samfira --- database/common/common.go | 15 +++-- database/common/mocks/Store.go | 84 ----------------------- database/sql/enterprise.go | 12 +--- database/sql/enterprise_test.go | 23 ------- database/sql/instances.go | 12 ++-- database/sql/organizations.go | 36 +--------- database/sql/organizations_test.go | 23 ------- database/sql/pools.go | 103 ++++++++++++++++++++++++++++- database/sql/repositories.go | 37 ++++++++--- database/sql/repositories_test.go | 22 ------ database/sql/util.go | 34 ++++++++-- runner/enterprises.go | 16 ++--- runner/organizations.go | 16 ++--- runner/pool/pool.go | 6 +- runner/repositories.go | 16 ++--- 15 files changed, 201 insertions(+), 254 deletions(-) diff --git a/database/common/common.go b/database/common/common.go index 8ca57ac2..1bd4c437 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -29,11 +29,9 @@ type RepoStore interface { UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) - GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error) ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) @@ -52,11 +50,20 @@ type OrgStore interface { DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) } +type EntityPools interface { + CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) + GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) + DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error + UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) + + ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) + ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) +} + type EnterpriseStore interface { CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) @@ -70,7 +77,6 @@ type EnterpriseStore interface { DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) } @@ -139,6 +145,7 @@ type Store interface { UserStore InstanceStore JobsStore + EntityPools ControllerInfo() (params.ControllerInfo, error) InitController() (params.ControllerInfo, error) diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index 81e47799..e6f6d815 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -510,62 +510,6 @@ func (_m *Store) DeleteRepositoryPool(ctx context.Context, repoID string, poolID return r0 } -// FindEnterprisePoolByTags provides a mock function with given fields: ctx, enterpriseID, tags -func (_m *Store) FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindEnterprisePoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, enterpriseID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, enterpriseID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// FindOrganizationPoolByTags provides a mock function with given fields: ctx, orgID, tags -func (_m *Store) FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, orgID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindOrganizationPoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, orgID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, orgID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, orgID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // FindPoolsMatchingAllTags provides a mock function with given fields: ctx, entityType, entityID, tags func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params.GithubEntityType, entityID string, tags []string) ([]params.Pool, error) { ret := _m.Called(ctx, entityType, entityID, tags) @@ -596,34 +540,6 @@ func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params return r0, r1 } -// FindRepositoryPoolByTags provides a mock function with given fields: ctx, repoID, tags -func (_m *Store) FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, repoID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindRepositoryPoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, repoID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, repoID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, repoID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetEnterprise provides a mock function with given fields: ctx, name func (_m *Store) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { ret := _m.Called(ctx, name) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index f83dab8c..f8665a45 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -175,7 +175,7 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str tags := []Tag{} for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -193,7 +193,7 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str } } - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -230,14 +230,6 @@ func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, po return s.updatePool(pool, param) } -func (s *sqlDatabase) FindEnterprisePoolByTags(_ context.Context, enterpriseID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(enterpriseID, params.GithubEntityTypeEnterprise, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - func (s *sqlDatabase) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) { pools, err := s.listEntityPools(ctx, params.GithubEntityTypeEnterprise, enterpriseID, "Tags", "Instances", "Enterprise") if err != nil { diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index fa709b89..86b68872 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -740,29 +740,6 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) } -func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTags() { - enterprisePool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) - } - - pool, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.Tags) - - s.Require().Nil(err) - s.Require().Equal(enterprisePool.ID, pool.ID) - s.Require().Equal(enterprisePool.Image, pool.Image) - s.Require().Equal(enterprisePool.Flavor, pool.Flavor) -} - -func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) -} - func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) if err != nil { diff --git a/database/sql/instances.go b/database/sql/instances.go index 4c475bf2..552fc39d 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -49,7 +49,7 @@ func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error } func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") } @@ -108,8 +108,8 @@ func (s *sqlDatabase) getInstanceByID(_ context.Context, instanceID string) (Ins return instance, nil } -func (s *sqlDatabase) getPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) +func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) { + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return Instance{}, errors.Wrap(err, "fetching pool") } @@ -153,7 +153,7 @@ func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string, } func (s *sqlDatabase) GetPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error) { - instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName) + instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching instance") } @@ -171,7 +171,7 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string } func (s *sqlDatabase) DeleteInstance(ctx context.Context, poolID string, instanceName string) error { - instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName) + instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return errors.Wrap(err, "deleting instance") } @@ -338,7 +338,7 @@ func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, er } func (s *sqlDatabase) PoolInstanceCount(ctx context.Context, poolID string) (int64, error) { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return 0, errors.Wrap(err, "fetching pool") } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 4d246065..5ee28520 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -192,7 +192,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string, tags := []Tag{} for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -210,7 +210,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string, } } - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -255,14 +255,6 @@ func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID return nil } -func (s *sqlDatabase) FindOrganizationPoolByTags(_ context.Context, orgID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(orgID, params.GithubEntityTypeOrganization, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) { pools, err := s.listEntityPools(ctx, params.GithubEntityTypeOrganization, orgID, "Tags", "Instances", "Instances.Job") if err != nil { @@ -290,30 +282,6 @@ func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID return s.updatePool(pool, param) } -func (s *sqlDatabase) getPoolByID(_ context.Context, poolID string, preload ...string) (Pool, error) { - u, err := uuid.Parse(poolID) - if err != nil { - return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") - } - var pool Pool - q := s.conn.Model(&Pool{}) - if len(preload) > 0 { - for _, item := range preload { - q = q.Preload(item) - } - } - - q = q.Where("id = ?", u).First(&pool) - - if q.Error != nil { - if errors.Is(q.Error, gorm.ErrRecordNotFound) { - return Pool{}, runnerErrors.ErrNotFound - } - return Pool{}, errors.Wrap(q.Error, "fetching org from database") - } - return pool, nil -} - func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) { u, err := uuid.Parse(id) if err != nil { diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index db4f8ccd..126c54ab 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -740,29 +740,6 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) } -func (s *OrgTestSuite) TestFindOrganizationPoolByTags() { - orgPool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) - } - - pool, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.Tags) - - s.Require().Nil(err) - s.Require().Equal(orgPool.ID, pool.ID) - s.Require().Equal(orgPool.Image, pool.Image) - s.Require().Equal(orgPool.Flavor, pool.Flavor) -} - -func (s *OrgTestSuite) TestFindOrganizationPoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) -} - func (s *OrgTestSuite) TestListOrgInstances() { pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) if err != nil { diff --git a/database/sql/pools.go b/database/sql/pools.go index 65aca8ba..5f7d0d7a 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -20,6 +20,7 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" + "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -58,7 +59,7 @@ func (s *sqlDatabase) ListAllPools(_ context.Context) ([]params.Pool, error) { } func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) { - pool, err := s.getPoolByID(ctx, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool by ID") } @@ -66,7 +67,7 @@ func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Po } func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return errors.Wrap(err, "fetching pool by ID") } @@ -231,3 +232,101 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par return pools, nil } + +func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { + if len(param.Tags) == 0 { + return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") + } + + newPool := Pool{ + ProviderName: param.ProviderName, + MaxRunners: param.MaxRunners, + MinIdleRunners: param.MinIdleRunners, + RunnerPrefix: param.GetRunnerPrefix(), + Image: param.Image, + Flavor: param.Flavor, + OSType: param.OSType, + OSArch: param.OSArch, + Enabled: param.Enabled, + RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, + GitHubRunnerGroup: param.GitHubRunnerGroup, + Priority: param.Priority, + } + if len(param.ExtraSpecs) > 0 { + newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) + } + + entityID, err := uuid.Parse(entity.ID) + if err != nil { + return params.Pool{}, fmt.Errorf("parsing entity ID: %w", err) + } + + switch entity.EntityType { + case params.GithubEntityTypeRepository: + newPool.RepoID = &entityID + case params.GithubEntityTypeOrganization: + newPool.OrgID = &entityID + case params.GithubEntityTypeEnterprise: + newPool.EnterpriseID = &entityID + } + err = s.conn.Transaction(func(tx *gorm.DB) error { + if _, err := s.getEntityPoolByUniqueFields(tx, entity, newPool.ProviderName, newPool.Image, newPool.Flavor); err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return fmt.Errorf("checking for existing pool: %w", err) + } + } else { + return runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") + } + + tags := []Tag{} + for _, val := range param.Tags { + t, err := s.getOrCreateTag(tx, val) + if err != nil { + return fmt.Errorf("creating tag: %w", err) + } + tags = append(tags, t) + } + + q := tx.Create(&newPool) + if q.Error != nil { + return fmt.Errorf("creating pool: %w", q.Error) + } + + for i := range tags { + if err := tx.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { + return fmt.Errorf("associating tags: %w", err) + } + } + return nil + }) + if err != nil { + return params.Pool{}, fmt.Errorf("creating pool: %w", err) + } + + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + if err != nil { + return params.Pool{}, errors.Wrap(err, "fetching pool") + } + + return s.sqlToCommonPool(pool) +} + +func (s *sqlDatabase) GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { + return params.Pool{}, nil +} + +func (s *sqlDatabase) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error { + return nil +} + +func (s *sqlDatabase) UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + return params.Pool{}, nil +} + +func (s *sqlDatabase) ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) { + return nil, nil +} + +func (s *sqlDatabase) ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) { + return nil, nil +} diff --git a/database/sql/repositories.go b/database/sql/repositories.go index f7671840..0da0e794 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -192,7 +192,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, p tags := []Tag{} for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -210,7 +210,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, p } } - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -255,14 +255,6 @@ func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID s return nil } -func (s *sqlDatabase) FindRepositoryPoolByTags(_ context.Context, repoID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(repoID, params.GithubEntityTypeRepository, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) { pools, err := s.listEntityPools(ctx, params.GithubEntityTypeRepository, repoID, "Tags", "Instances", "Instances.Job") if err != nil { @@ -308,6 +300,31 @@ func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository return repo, nil } +func (s *sqlDatabase) getEntityPoolByUniqueFields(tx *gorm.DB, entity params.GithubEntity, provider, image, flavor string) (pool Pool, err error) { + var entityField string + switch entity.EntityType { + case params.GithubEntityTypeRepository: + entityField = entityTypeRepoName + case params.GithubEntityTypeOrganization: + entityField = entityTypeOrgName + case params.GithubEntityTypeEnterprise: + entityField = entityTypeEnterpriseName + } + entityID, err := uuid.Parse(entity.ID) + if err != nil { + return pool, fmt.Errorf("parsing entity ID: %w", err) + } + poolQueryString := fmt.Sprintf("provider_name = ? and image = ? and flavor = ? and %s = ?", entityField) + err = tx.Where(poolQueryString, provider, image, flavor, entityID).First(&pool).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return pool, runnerErrors.ErrNotFound + } + return + } + return Pool{}, nil +} + func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID string, provider, image, flavor string) (Pool, error) { repo, err := s.getRepoByID(ctx, repoID) if err != nil { diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 796048ea..5a5396b8 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -777,28 +777,6 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) } -func (s *RepoTestSuite) TestFindRepositoryPoolByTags() { - repoPool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) - } - - pool, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.Tags) - s.Require().Nil(err) - s.Require().Equal(repoPool.ID, pool.ID) - s.Require().Equal(repoPool.Image, pool.Image) - s.Require().Equal(repoPool.Flavor, pool.Flavor) -} - -func (s *RepoTestSuite) TestFindRepositoryPoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) -} - func (s *RepoTestSuite) TestListRepoInstances() { pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) if err != nil { diff --git a/database/sql/util.go b/database/sql/util.go index 2dd810f5..f0bcc867 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -18,10 +18,12 @@ import ( "encoding/json" "fmt" + "github.com/google/uuid" "github.com/pkg/errors" "gorm.io/datatypes" "gorm.io/gorm" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm-provider-common/util" "github.com/cloudbase/garm/params" @@ -275,9 +277,9 @@ func (s *sqlDatabase) sqlToParamsUser(user User) params.User { } } -func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) { +func (s *sqlDatabase) getOrCreateTag(tx *gorm.DB, tagName string) (Tag, error) { var tag Tag - q := s.conn.Where("name = ?", tagName).First(&tag) + q := tx.Where("name = ?", tagName).First(&tag) if q.Error == nil { return tag, nil } @@ -288,7 +290,7 @@ func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) { Name: tagName, } - q = s.conn.Create(&newTag) + q = tx.Create(&newTag) if q.Error != nil { return Tag{}, errors.Wrap(q.Error, "creating tag") } @@ -351,7 +353,7 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para tags := []Tag{} if param.Tags != nil && len(param.Tags) > 0 { for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -365,3 +367,27 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para return s.sqlToCommonPool(pool) } + +func (s *sqlDatabase) getPoolByID(tx *gorm.DB, poolID string, preload ...string) (Pool, error) { + u, err := uuid.Parse(poolID) + if err != nil { + return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + var pool Pool + q := tx.Model(&Pool{}) + if len(preload) > 0 { + for _, item := range preload { + q = q.Preload(item) + } + } + + q = q.Where("id = ?", u).First(&pool) + + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return Pool{}, runnerErrors.ErrNotFound + } + return Pool{}, errors.Wrap(q.Error, "fetching org from database") + } + return pool, nil +} diff --git a/runner/enterprises.go b/runner/enterprises.go index c76d3973..ae1e2fc8 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -193,18 +193,11 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID) + _, err := r.store.GetEnterpriseByID(ctx, enterpriseID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching enterprise") } - if _, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -214,7 +207,12 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateEnterprisePool(ctx, enterpriseID, createPoolParams) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { return params.Pool{}, errors.Wrap(err, "creating pool") } diff --git a/runner/organizations.go b/runner/organizations.go index 3d24dcda..482bd55d 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -222,18 +222,11 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - org, err := r.store.GetOrganizationByID(ctx, orgID) + _, err := r.store.GetOrganizationByID(ctx, orgID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching org") } - if _, err := r.poolManagerCtrl.GetOrgPoolManager(org); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -243,7 +236,12 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateOrganizationPool(ctx, orgID, createPoolParams) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { return params.Pool{}, errors.Wrap(err, "creating pool") } diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 3fe1eb3e..2226fa13 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -1725,10 +1725,6 @@ func (r *basePoolManager) WebhookSecret() string { return r.entity.WebhookSecret } -func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) { - return r.GetGithubRegistrationToken() -} - func (r *basePoolManager) ID() string { return r.entity.ID } @@ -2095,7 +2091,7 @@ func (r *basePoolManager) GetRunnerInfoFromWorkflow(job params.WorkflowJob) (par return params.RunnerInfo{}, fmt.Errorf("failed to find runner name from workflow") } -func (r *basePoolManager) GetGithubRegistrationToken() (string, error) { +func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) { tk, ghResp, err := r.ghcli.CreateEntityRegistrationToken(r.ctx) if err != nil { if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { diff --git a/runner/repositories.go b/runner/repositories.go index c71fab39..b8b25c06 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -221,18 +221,11 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - repo, err := r.store.GetRepositoryByID(ctx, repoID) + _, err := r.store.GetRepositoryByID(ctx, repoID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching repo") } - if _, err := r.poolManagerCtrl.GetRepoPoolManager(repo); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -242,7 +235,12 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params createPoolParams.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateRepositoryPool(ctx, repoID, createPoolParams) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { return params.Pool{}, errors.Wrap(err, "creating pool") }