From 0152b215294320e763eeb368604fc00c727da109 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 28 Mar 2024 10:08:19 +0000 Subject: [PATCH 1/3] Implement some common logic for pool creation Signed-off-by: Gabriel Adrian Samfira --- database/common/common.go | 15 +++-- database/common/mocks/Store.go | 84 ----------------------- database/sql/enterprise.go | 12 +--- database/sql/enterprise_test.go | 23 ------- database/sql/instances.go | 12 ++-- database/sql/organizations.go | 36 +--------- database/sql/organizations_test.go | 23 ------- database/sql/pools.go | 103 ++++++++++++++++++++++++++++- database/sql/repositories.go | 37 ++++++++--- database/sql/repositories_test.go | 22 ------ database/sql/util.go | 34 ++++++++-- runner/enterprises.go | 16 ++--- runner/organizations.go | 16 ++--- runner/pool/pool.go | 6 +- runner/repositories.go | 16 ++--- 15 files changed, 201 insertions(+), 254 deletions(-) diff --git a/database/common/common.go b/database/common/common.go index 8ca57ac2..1bd4c437 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -29,11 +29,9 @@ type RepoStore interface { UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) - GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error) ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) @@ -52,11 +50,20 @@ type OrgStore interface { DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) } +type EntityPools interface { + CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) + GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) + DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error + UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) + + ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) + ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) +} + type EnterpriseStore interface { CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) @@ -70,7 +77,6 @@ type EnterpriseStore interface { DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) } @@ -139,6 +145,7 @@ type Store interface { UserStore InstanceStore JobsStore + EntityPools ControllerInfo() (params.ControllerInfo, error) InitController() (params.ControllerInfo, error) diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index 81e47799..e6f6d815 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -510,62 +510,6 @@ func (_m *Store) DeleteRepositoryPool(ctx context.Context, repoID string, poolID return r0 } -// FindEnterprisePoolByTags provides a mock function with given fields: ctx, enterpriseID, tags -func (_m *Store) FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindEnterprisePoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, enterpriseID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, enterpriseID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// FindOrganizationPoolByTags provides a mock function with given fields: ctx, orgID, tags -func (_m *Store) FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, orgID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindOrganizationPoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, orgID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, orgID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, orgID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // FindPoolsMatchingAllTags provides a mock function with given fields: ctx, entityType, entityID, tags func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params.GithubEntityType, entityID string, tags []string) ([]params.Pool, error) { ret := _m.Called(ctx, entityType, entityID, tags) @@ -596,34 +540,6 @@ func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params return r0, r1 } -// FindRepositoryPoolByTags provides a mock function with given fields: ctx, repoID, tags -func (_m *Store) FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, repoID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindRepositoryPoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, repoID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, repoID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, repoID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetEnterprise provides a mock function with given fields: ctx, name func (_m *Store) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { ret := _m.Called(ctx, name) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index f83dab8c..f8665a45 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -175,7 +175,7 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str tags := []Tag{} for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -193,7 +193,7 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str } } - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -230,14 +230,6 @@ func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, po return s.updatePool(pool, param) } -func (s *sqlDatabase) FindEnterprisePoolByTags(_ context.Context, enterpriseID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(enterpriseID, params.GithubEntityTypeEnterprise, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - func (s *sqlDatabase) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) { pools, err := s.listEntityPools(ctx, params.GithubEntityTypeEnterprise, enterpriseID, "Tags", "Instances", "Enterprise") if err != nil { diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index fa709b89..86b68872 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -740,29 +740,6 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) } -func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTags() { - enterprisePool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) - } - - pool, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.Tags) - - s.Require().Nil(err) - s.Require().Equal(enterprisePool.ID, pool.ID) - s.Require().Equal(enterprisePool.Image, pool.Image) - s.Require().Equal(enterprisePool.Flavor, pool.Flavor) -} - -func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) -} - func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) if err != nil { diff --git a/database/sql/instances.go b/database/sql/instances.go index 4c475bf2..552fc39d 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -49,7 +49,7 @@ func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error } func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") } @@ -108,8 +108,8 @@ func (s *sqlDatabase) getInstanceByID(_ context.Context, instanceID string) (Ins return instance, nil } -func (s *sqlDatabase) getPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) +func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) { + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return Instance{}, errors.Wrap(err, "fetching pool") } @@ -153,7 +153,7 @@ func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string, } func (s *sqlDatabase) GetPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error) { - instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName) + instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching instance") } @@ -171,7 +171,7 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string } func (s *sqlDatabase) DeleteInstance(ctx context.Context, poolID string, instanceName string) error { - instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName) + instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return errors.Wrap(err, "deleting instance") } @@ -338,7 +338,7 @@ func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, er } func (s *sqlDatabase) PoolInstanceCount(ctx context.Context, poolID string) (int64, error) { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return 0, errors.Wrap(err, "fetching pool") } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 4d246065..5ee28520 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -192,7 +192,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string, tags := []Tag{} for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -210,7 +210,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string, } } - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -255,14 +255,6 @@ func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID return nil } -func (s *sqlDatabase) FindOrganizationPoolByTags(_ context.Context, orgID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(orgID, params.GithubEntityTypeOrganization, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) { pools, err := s.listEntityPools(ctx, params.GithubEntityTypeOrganization, orgID, "Tags", "Instances", "Instances.Job") if err != nil { @@ -290,30 +282,6 @@ func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID return s.updatePool(pool, param) } -func (s *sqlDatabase) getPoolByID(_ context.Context, poolID string, preload ...string) (Pool, error) { - u, err := uuid.Parse(poolID) - if err != nil { - return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") - } - var pool Pool - q := s.conn.Model(&Pool{}) - if len(preload) > 0 { - for _, item := range preload { - q = q.Preload(item) - } - } - - q = q.Where("id = ?", u).First(&pool) - - if q.Error != nil { - if errors.Is(q.Error, gorm.ErrRecordNotFound) { - return Pool{}, runnerErrors.ErrNotFound - } - return Pool{}, errors.Wrap(q.Error, "fetching org from database") - } - return pool, nil -} - func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) { u, err := uuid.Parse(id) if err != nil { diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index db4f8ccd..126c54ab 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -740,29 +740,6 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) } -func (s *OrgTestSuite) TestFindOrganizationPoolByTags() { - orgPool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) - } - - pool, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.Tags) - - s.Require().Nil(err) - s.Require().Equal(orgPool.ID, pool.ID) - s.Require().Equal(orgPool.Image, pool.Image) - s.Require().Equal(orgPool.Flavor, pool.Flavor) -} - -func (s *OrgTestSuite) TestFindOrganizationPoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) -} - func (s *OrgTestSuite) TestListOrgInstances() { pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) if err != nil { diff --git a/database/sql/pools.go b/database/sql/pools.go index 65aca8ba..5f7d0d7a 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -20,6 +20,7 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" + "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -58,7 +59,7 @@ func (s *sqlDatabase) ListAllPools(_ context.Context) ([]params.Pool, error) { } func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) { - pool, err := s.getPoolByID(ctx, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool by ID") } @@ -66,7 +67,7 @@ func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Po } func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return errors.Wrap(err, "fetching pool by ID") } @@ -231,3 +232,101 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par return pools, nil } + +func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { + if len(param.Tags) == 0 { + return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") + } + + newPool := Pool{ + ProviderName: param.ProviderName, + MaxRunners: param.MaxRunners, + MinIdleRunners: param.MinIdleRunners, + RunnerPrefix: param.GetRunnerPrefix(), + Image: param.Image, + Flavor: param.Flavor, + OSType: param.OSType, + OSArch: param.OSArch, + Enabled: param.Enabled, + RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, + GitHubRunnerGroup: param.GitHubRunnerGroup, + Priority: param.Priority, + } + if len(param.ExtraSpecs) > 0 { + newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) + } + + entityID, err := uuid.Parse(entity.ID) + if err != nil { + return params.Pool{}, fmt.Errorf("parsing entity ID: %w", err) + } + + switch entity.EntityType { + case params.GithubEntityTypeRepository: + newPool.RepoID = &entityID + case params.GithubEntityTypeOrganization: + newPool.OrgID = &entityID + case params.GithubEntityTypeEnterprise: + newPool.EnterpriseID = &entityID + } + err = s.conn.Transaction(func(tx *gorm.DB) error { + if _, err := s.getEntityPoolByUniqueFields(tx, entity, newPool.ProviderName, newPool.Image, newPool.Flavor); err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return fmt.Errorf("checking for existing pool: %w", err) + } + } else { + return runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") + } + + tags := []Tag{} + for _, val := range param.Tags { + t, err := s.getOrCreateTag(tx, val) + if err != nil { + return fmt.Errorf("creating tag: %w", err) + } + tags = append(tags, t) + } + + q := tx.Create(&newPool) + if q.Error != nil { + return fmt.Errorf("creating pool: %w", q.Error) + } + + for i := range tags { + if err := tx.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { + return fmt.Errorf("associating tags: %w", err) + } + } + return nil + }) + if err != nil { + return params.Pool{}, fmt.Errorf("creating pool: %w", err) + } + + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + if err != nil { + return params.Pool{}, errors.Wrap(err, "fetching pool") + } + + return s.sqlToCommonPool(pool) +} + +func (s *sqlDatabase) GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { + return params.Pool{}, nil +} + +func (s *sqlDatabase) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error { + return nil +} + +func (s *sqlDatabase) UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + return params.Pool{}, nil +} + +func (s *sqlDatabase) ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) { + return nil, nil +} + +func (s *sqlDatabase) ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) { + return nil, nil +} diff --git a/database/sql/repositories.go b/database/sql/repositories.go index f7671840..0da0e794 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -192,7 +192,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, p tags := []Tag{} for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -210,7 +210,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, p } } - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -255,14 +255,6 @@ func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID s return nil } -func (s *sqlDatabase) FindRepositoryPoolByTags(_ context.Context, repoID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(repoID, params.GithubEntityTypeRepository, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) { pools, err := s.listEntityPools(ctx, params.GithubEntityTypeRepository, repoID, "Tags", "Instances", "Instances.Job") if err != nil { @@ -308,6 +300,31 @@ func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository return repo, nil } +func (s *sqlDatabase) getEntityPoolByUniqueFields(tx *gorm.DB, entity params.GithubEntity, provider, image, flavor string) (pool Pool, err error) { + var entityField string + switch entity.EntityType { + case params.GithubEntityTypeRepository: + entityField = entityTypeRepoName + case params.GithubEntityTypeOrganization: + entityField = entityTypeOrgName + case params.GithubEntityTypeEnterprise: + entityField = entityTypeEnterpriseName + } + entityID, err := uuid.Parse(entity.ID) + if err != nil { + return pool, fmt.Errorf("parsing entity ID: %w", err) + } + poolQueryString := fmt.Sprintf("provider_name = ? and image = ? and flavor = ? and %s = ?", entityField) + err = tx.Where(poolQueryString, provider, image, flavor, entityID).First(&pool).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return pool, runnerErrors.ErrNotFound + } + return + } + return Pool{}, nil +} + func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID string, provider, image, flavor string) (Pool, error) { repo, err := s.getRepoByID(ctx, repoID) if err != nil { diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 796048ea..5a5396b8 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -777,28 +777,6 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) } -func (s *RepoTestSuite) TestFindRepositoryPoolByTags() { - repoPool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) - } - - pool, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.Tags) - s.Require().Nil(err) - s.Require().Equal(repoPool.ID, pool.ID) - s.Require().Equal(repoPool.Image, pool.Image) - s.Require().Equal(repoPool.Flavor, pool.Flavor) -} - -func (s *RepoTestSuite) TestFindRepositoryPoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) -} - func (s *RepoTestSuite) TestListRepoInstances() { pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) if err != nil { diff --git a/database/sql/util.go b/database/sql/util.go index 2dd810f5..f0bcc867 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -18,10 +18,12 @@ import ( "encoding/json" "fmt" + "github.com/google/uuid" "github.com/pkg/errors" "gorm.io/datatypes" "gorm.io/gorm" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm-provider-common/util" "github.com/cloudbase/garm/params" @@ -275,9 +277,9 @@ func (s *sqlDatabase) sqlToParamsUser(user User) params.User { } } -func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) { +func (s *sqlDatabase) getOrCreateTag(tx *gorm.DB, tagName string) (Tag, error) { var tag Tag - q := s.conn.Where("name = ?", tagName).First(&tag) + q := tx.Where("name = ?", tagName).First(&tag) if q.Error == nil { return tag, nil } @@ -288,7 +290,7 @@ func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) { Name: tagName, } - q = s.conn.Create(&newTag) + q = tx.Create(&newTag) if q.Error != nil { return Tag{}, errors.Wrap(q.Error, "creating tag") } @@ -351,7 +353,7 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para tags := []Tag{} if param.Tags != nil && len(param.Tags) > 0 { for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(s.conn, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } @@ -365,3 +367,27 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para return s.sqlToCommonPool(pool) } + +func (s *sqlDatabase) getPoolByID(tx *gorm.DB, poolID string, preload ...string) (Pool, error) { + u, err := uuid.Parse(poolID) + if err != nil { + return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + var pool Pool + q := tx.Model(&Pool{}) + if len(preload) > 0 { + for _, item := range preload { + q = q.Preload(item) + } + } + + q = q.Where("id = ?", u).First(&pool) + + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return Pool{}, runnerErrors.ErrNotFound + } + return Pool{}, errors.Wrap(q.Error, "fetching org from database") + } + return pool, nil +} diff --git a/runner/enterprises.go b/runner/enterprises.go index c76d3973..ae1e2fc8 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -193,18 +193,11 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID) + _, err := r.store.GetEnterpriseByID(ctx, enterpriseID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching enterprise") } - if _, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -214,7 +207,12 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateEnterprisePool(ctx, enterpriseID, createPoolParams) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { return params.Pool{}, errors.Wrap(err, "creating pool") } diff --git a/runner/organizations.go b/runner/organizations.go index 3d24dcda..482bd55d 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -222,18 +222,11 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - org, err := r.store.GetOrganizationByID(ctx, orgID) + _, err := r.store.GetOrganizationByID(ctx, orgID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching org") } - if _, err := r.poolManagerCtrl.GetOrgPoolManager(org); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -243,7 +236,12 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateOrganizationPool(ctx, orgID, createPoolParams) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { return params.Pool{}, errors.Wrap(err, "creating pool") } diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 3fe1eb3e..2226fa13 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -1725,10 +1725,6 @@ func (r *basePoolManager) WebhookSecret() string { return r.entity.WebhookSecret } -func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) { - return r.GetGithubRegistrationToken() -} - func (r *basePoolManager) ID() string { return r.entity.ID } @@ -2095,7 +2091,7 @@ func (r *basePoolManager) GetRunnerInfoFromWorkflow(job params.WorkflowJob) (par return params.RunnerInfo{}, fmt.Errorf("failed to find runner name from workflow") } -func (r *basePoolManager) GetGithubRegistrationToken() (string, error) { +func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) { tk, ghResp, err := r.ghcli.CreateEntityRegistrationToken(r.ctx) if err != nil { if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { diff --git a/runner/repositories.go b/runner/repositories.go index c71fab39..b8b25c06 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -221,18 +221,11 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - repo, err := r.store.GetRepositoryByID(ctx, repoID) + _, err := r.store.GetRepositoryByID(ctx, repoID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching repo") } - if _, err := r.poolManagerCtrl.GetRepoPoolManager(repo); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -242,7 +235,12 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params createPoolParams.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateRepositoryPool(ctx, repoID, createPoolParams) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { return params.Pool{}, errors.Wrap(err, "creating pool") } From 9384e37bb1eb1fcc2adf94c82cd26724e060308c Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 28 Mar 2024 18:23:49 +0000 Subject: [PATCH 2/3] Fix tests Signed-off-by: Gabriel Adrian Samfira --- database/common/common.go | 38 ++--- database/common/mocks/Store.go | 258 ++++++++++++++--------------- database/sql/enterprise.go | 14 +- database/sql/enterprise_test.go | 77 +++++++-- database/sql/instances.go | 8 +- database/sql/instances_test.go | 6 +- database/sql/organizations.go | 14 +- database/sql/organizations_test.go | 77 +++++++-- database/sql/pools.go | 70 ++++++-- database/sql/pools_test.go | 8 +- database/sql/repositories.go | 14 +- database/sql/repositories_test.go | 77 +++++++-- database/sql/util.go | 35 +++- runner/enterprises.go | 42 ++--- runner/enterprises_test.go | 67 +++++--- runner/organizations.go | 41 +++-- runner/organizations_test.go | 65 +++++--- runner/pool/pool.go | 11 +- runner/pools_test.go | 8 +- runner/repositories.go | 40 +++-- runner/repositories_test.go | 73 +++++--- runner/runner.go | 5 +- 22 files changed, 652 insertions(+), 396 deletions(-) diff --git a/database/common/common.go b/database/common/common.go index 1bd4c437..ab546844 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -28,9 +28,9 @@ type RepoStore interface { DeleteRepository(ctx context.Context, repoID string) error UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) - CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) - GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) - DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error + // CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) + // GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) + // DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) @@ -45,25 +45,15 @@ type OrgStore interface { DeleteOrganization(ctx context.Context, orgID string) error UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) - CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) - GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) - DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error + // CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) + // GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) + // DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) } -type EntityPools interface { - CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) - GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) - DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error - UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) - - ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) - ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) -} - type EnterpriseStore interface { CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) @@ -72,9 +62,9 @@ type EnterpriseStore interface { DeleteEnterprise(ctx context.Context, enterpriseID string) error UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) - CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) - GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) - DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error + // CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) + // GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) + // DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) @@ -136,6 +126,16 @@ type JobsStore interface { DeleteCompletedJobs(ctx context.Context) error } +type EntityPools interface { + CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) + GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) + DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error + UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) + + ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) + ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) +} + //go:generate mockery --name=Store type Store interface { RepoStore diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index e6f6d815..73eef2c3 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -106,27 +106,27 @@ func (_m *Store) CreateEnterprise(ctx context.Context, name string, credentialsN return r0, r1 } -// CreateEnterprisePool provides a mock function with given fields: ctx, enterpriseID, param -func (_m *Store) CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, param) +// CreateEntityPool provides a mock function with given fields: ctx, entity, param +func (_m *Store) CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { + ret := _m.Called(ctx, entity, param) if len(ret) == 0 { - panic("no return value specified for CreateEnterprisePool") + panic("no return value specified for CreateEntityPool") } var r0 params.Pool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, param) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, params.CreatePoolParams) (params.Pool, error)); ok { + return rf(ctx, entity, param) } - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) params.Pool); ok { - r0 = rf(ctx, enterpriseID, param) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, params.CreatePoolParams) params.Pool); ok { + r0 = rf(ctx, entity, param) } else { r0 = ret.Get(0).(params.Pool) } - if rf, ok := ret.Get(1).(func(context.Context, string, params.CreatePoolParams) error); ok { - r1 = rf(ctx, enterpriseID, param) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity, params.CreatePoolParams) error); ok { + r1 = rf(ctx, entity, param) } else { r1 = ret.Error(1) } @@ -218,34 +218,6 @@ func (_m *Store) CreateOrganization(ctx context.Context, name string, credential return r0, r1 } -// CreateOrganizationPool provides a mock function with given fields: ctx, orgID, param -func (_m *Store) CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, orgID, param) - - if len(ret) == 0 { - panic("no return value specified for CreateOrganizationPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) (params.Pool, error)); ok { - return rf(ctx, orgID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) params.Pool); ok { - r0 = rf(ctx, orgID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, params.CreatePoolParams) error); ok { - r1 = rf(ctx, orgID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // CreateRepository provides a mock function with given fields: ctx, owner, name, credentialsName, webhookSecret, poolBalancerType func (_m *Store) CreateRepository(ctx context.Context, owner string, name string, credentialsName string, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) { ret := _m.Called(ctx, owner, name, credentialsName, webhookSecret, poolBalancerType) @@ -274,34 +246,6 @@ func (_m *Store) CreateRepository(ctx context.Context, owner string, name string return r0, r1 } -// CreateRepositoryPool provides a mock function with given fields: ctx, repoID, param -func (_m *Store) CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, repoID, param) - - if len(ret) == 0 { - panic("no return value specified for CreateRepositoryPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) (params.Pool, error)); ok { - return rf(ctx, repoID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) params.Pool); ok { - r0 = rf(ctx, repoID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, params.CreatePoolParams) error); ok { - r1 = rf(ctx, repoID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // CreateUser provides a mock function with given fields: ctx, user func (_m *Store) CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error) { ret := _m.Called(ctx, user) @@ -384,6 +328,24 @@ func (_m *Store) DeleteEnterprisePool(ctx context.Context, enterpriseID string, return r0 } +// DeleteEntityPool provides a mock function with given fields: ctx, entity, poolID +func (_m *Store) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error { + ret := _m.Called(ctx, entity, poolID) + + if len(ret) == 0 { + panic("no return value specified for DeleteEntityPool") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string) error); ok { + r0 = rf(ctx, entity, poolID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeleteInstance provides a mock function with given fields: ctx, poolID, instanceName func (_m *Store) DeleteInstance(ctx context.Context, poolID string, instanceName string) error { ret := _m.Called(ctx, poolID, instanceName) @@ -596,27 +558,27 @@ func (_m *Store) GetEnterpriseByID(ctx context.Context, enterpriseID string) (pa return r0, r1 } -// GetEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID -func (_m *Store) GetEnterprisePool(ctx context.Context, enterpriseID string, poolID string) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, poolID) +// GetEntityPool provides a mock function with given fields: ctx, entity, poolID +func (_m *Store) GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { + ret := _m.Called(ctx, entity, poolID) if len(ret) == 0 { - panic("no return value specified for GetEnterprisePool") + panic("no return value specified for GetEntityPool") } var r0 params.Pool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, poolID) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string) (params.Pool, error)); ok { + return rf(ctx, entity, poolID) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) params.Pool); ok { - r0 = rf(ctx, enterpriseID, poolID) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string) params.Pool); ok { + r0 = rf(ctx, entity, poolID) } else { r0 = ret.Get(0).(params.Pool) } - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, enterpriseID, poolID) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity, string) error); ok { + r1 = rf(ctx, entity, poolID) } else { r1 = ret.Error(1) } @@ -736,34 +698,6 @@ func (_m *Store) GetOrganizationByID(ctx context.Context, orgID string) (params. return r0, r1 } -// GetOrganizationPool provides a mock function with given fields: ctx, orgID, poolID -func (_m *Store) GetOrganizationPool(ctx context.Context, orgID string, poolID string) (params.Pool, error) { - ret := _m.Called(ctx, orgID, poolID) - - if len(ret) == 0 { - panic("no return value specified for GetOrganizationPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (params.Pool, error)); ok { - return rf(ctx, orgID, poolID) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string) params.Pool); ok { - r0 = rf(ctx, orgID, poolID) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, orgID, poolID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetPoolByID provides a mock function with given fields: ctx, poolID func (_m *Store) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) { ret := _m.Called(ctx, poolID) @@ -876,34 +810,6 @@ func (_m *Store) GetRepositoryByID(ctx context.Context, repoID string) (params.R return r0, r1 } -// GetRepositoryPool provides a mock function with given fields: ctx, repoID, poolID -func (_m *Store) GetRepositoryPool(ctx context.Context, repoID string, poolID string) (params.Pool, error) { - ret := _m.Called(ctx, repoID, poolID) - - if len(ret) == 0 { - panic("no return value specified for GetRepositoryPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (params.Pool, error)); ok { - return rf(ctx, repoID, poolID) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string) params.Pool); ok { - r0 = rf(ctx, repoID, poolID) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, repoID, poolID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetUser provides a mock function with given fields: ctx, user func (_m *Store) GetUser(ctx context.Context, user string) (params.User, error) { ret := _m.Called(ctx, user) @@ -1186,6 +1092,36 @@ func (_m *Store) ListEnterprises(ctx context.Context) ([]params.Enterprise, erro return r0, r1 } +// ListEntityInstances provides a mock function with given fields: ctx, entity +func (_m *Store) ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) { + ret := _m.Called(ctx, entity) + + if len(ret) == 0 { + panic("no return value specified for ListEntityInstances") + } + + var r0 []params.Instance + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) ([]params.Instance, error)); ok { + return rf(ctx, entity) + } + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) []params.Instance); ok { + r0 = rf(ctx, entity) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]params.Instance) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity) error); ok { + r1 = rf(ctx, entity) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListEntityJobsByStatus provides a mock function with given fields: ctx, entityType, entityID, status func (_m *Store) ListEntityJobsByStatus(ctx context.Context, entityType params.GithubEntityType, entityID string, status params.JobStatus) ([]params.Job, error) { ret := _m.Called(ctx, entityType, entityID, status) @@ -1216,6 +1152,36 @@ func (_m *Store) ListEntityJobsByStatus(ctx context.Context, entityType params.G return r0, r1 } +// ListEntityPools provides a mock function with given fields: ctx, entity +func (_m *Store) ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) { + ret := _m.Called(ctx, entity) + + if len(ret) == 0 { + panic("no return value specified for ListEntityPools") + } + + var r0 []params.Pool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) ([]params.Pool, error)); ok { + return rf(ctx, entity) + } + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) []params.Pool); ok { + r0 = rf(ctx, entity) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]params.Pool) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity) error); ok { + r1 = rf(ctx, entity) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListInstanceEvents provides a mock function with given fields: ctx, instanceID, eventType, eventLevel func (_m *Store) ListInstanceEvents(ctx context.Context, instanceID string, eventType params.EventType, eventLevel params.EventLevel) ([]params.StatusMessage, error) { ret := _m.Called(ctx, instanceID, eventType, eventLevel) @@ -1606,6 +1572,34 @@ func (_m *Store) UpdateEnterprisePool(ctx context.Context, enterpriseID string, return r0, r1 } +// UpdateEntityPool provides a mock function with given fields: ctx, entity, poolID, param +func (_m *Store) UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + ret := _m.Called(ctx, entity, poolID, param) + + if len(ret) == 0 { + panic("no return value specified for UpdateEntityPool") + } + + var r0 params.Pool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string, params.UpdatePoolParams) (params.Pool, error)); ok { + return rf(ctx, entity, poolID, param) + } + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string, params.UpdatePoolParams) params.Pool); ok { + r0 = rf(ctx, entity, poolID, param) + } else { + r0 = ret.Get(0).(params.Pool) + } + + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity, string, params.UpdatePoolParams) error); ok { + r1 = rf(ctx, entity, poolID, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // UpdateInstance provides a mock function with given fields: ctx, instanceID, param func (_m *Store) UpdateInstance(ctx context.Context, instanceID string, param params.UpdateInstanceParams) (params.Instance, error) { ret := _m.Called(ctx, instanceID, param) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index f8665a45..c3bf5d69 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -201,16 +201,16 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances") +func (s *sqlDatabase) GetEnterprisePool(_ context.Context, enterpriseID, poolID string) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeEnterprise, enterpriseID, poolID) +func (s *sqlDatabase) DeleteEnterprisePool(_ context.Context, enterpriseID, poolID string) error { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeEnterprise, enterpriseID, poolID) if err != nil { return errors.Wrap(err, "looking up enterprise pool") } @@ -221,13 +221,13 @@ func (s *sqlDatabase) DeleteEnterprisePool(ctx context.Context, enterpriseID, po return nil } -func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") +func (s *sqlDatabase) UpdateEnterprisePool(_ context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } - return s.updatePool(pool, param) + return s.updatePool(s.conn, pool, param) } func (s *sqlDatabase) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) { diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index 86b68872..4f4a7da2 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -405,7 +405,11 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseByIDDBDecryptingErr() { } func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) @@ -422,18 +426,25 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { func (s *EnterpriseTestSuite) TestCreateEnterprisePoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - - _, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolInvalidEnterpriseID() { - _, err := s.Store.CreateEnterprisePool(context.Background(), "dummy-enterprise-id", s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBCreateErr() { @@ -655,9 +666,13 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchPoolErr() { func (s *EnterpriseTestSuite) TestListEnterprisePools() { enterprisePools := []params.Pool{} + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } @@ -678,46 +693,66 @@ func (s *EnterpriseTestSuite) TestListEnterprisePoolsInvalidEnterpriseID() { } func (s *EnterpriseTestSuite) TestGetEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - enterprisePool, err := s.Store.GetEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) + enterprisePool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) s.Require().Equal(enterprisePool.ID, pool.ID) } func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() { - _, err := s.Store.GetEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - err = s.Store.DeleteEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) + err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) + _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() { - err := s.Store.DeleteEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) - s.Require().Equal("looking up enterprise pool: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } @@ -741,7 +776,11 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { } func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } @@ -769,7 +808,11 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstancesInvalidEnterpriseID() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Enterprises[0].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } diff --git a/database/sql/instances.go b/database/sql/instances.go index 552fc39d..d24961da 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -48,7 +48,7 @@ func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error return nil } -func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { +func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") @@ -152,7 +152,7 @@ func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string, return instance, nil } -func (s *sqlDatabase) GetPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error) { +func (s *sqlDatabase) GetPoolInstanceByName(_ context.Context, poolID string, instanceName string) (params.Instance, error) { instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching instance") @@ -170,7 +170,7 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string return s.sqlToParamsInstance(instance) } -func (s *sqlDatabase) DeleteInstance(ctx context.Context, poolID string, instanceName string) error { +func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) error { instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return errors.Wrap(err, "deleting instance") @@ -337,7 +337,7 @@ func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, er return ret, nil } -func (s *sqlDatabase) PoolInstanceCount(ctx context.Context, poolID string) (int64, error) { +func (s *sqlDatabase) PoolInstanceCount(_ context.Context, poolID string) (int64, error) { pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return 0, errors.Wrap(err, "fetching pool") diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 0c0eadcf..200c683b 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -92,7 +92,11 @@ func (s *InstancesTestSuite) SetupTest() { OSType: "linux", Tags: []string{"self-hosted", "amd64", "linux"}, } - pool, err := s.Store.CreateOrganizationPool(context.Background(), org.ID, createPoolParams) + entity := params.GithubEntity{ + ID: org.ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, createPoolParams) if err != nil { s.FailNow(fmt.Sprintf("failed to create org pool: %s", err)) } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 5ee28520..a2b14ae5 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -235,16 +235,16 @@ func (s *sqlDatabase) ListOrgPools(ctx context.Context, orgID string) ([]params. return ret, nil } -func (s *sqlDatabase) GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances") +func (s *sqlDatabase) GetOrganizationPool(_ context.Context, orgID, poolID string) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeOrganization, orgID, poolID) +func (s *sqlDatabase) DeleteOrganizationPool(_ context.Context, orgID, poolID string) error { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeOrganization, orgID, poolID) if err != nil { return errors.Wrap(err, "looking up org pool") } @@ -273,13 +273,13 @@ func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]par return ret, nil } -func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") +func (s *sqlDatabase) UpdateOrganizationPool(_ context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } - return s.updatePool(pool, param) + return s.updatePool(s.conn, pool, param) } func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) { diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 126c54ab..63b4cebd 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -405,7 +405,11 @@ func (s *OrgTestSuite) TestGetOrganizationByIDDBDecryptingErr() { } func (s *OrgTestSuite) TestCreateOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) @@ -422,18 +426,25 @@ func (s *OrgTestSuite) TestCreateOrganizationPool() { func (s *OrgTestSuite) TestCreateOrganizationPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - - _, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) } func (s *OrgTestSuite) TestCreateOrganizationPoolInvalidOrgID() { - _, err := s.Store.CreateOrganizationPool(context.Background(), "dummy-org-id", s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestCreateOrganizationPoolDBCreateErr() { @@ -655,9 +666,13 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchPoolErr() { func (s *OrgTestSuite) TestListOrgPools() { orgPools := []params.Pool{} + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -678,46 +693,66 @@ func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() { } func (s *OrgTestSuite) TestGetOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - orgPool, err := s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + orgPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) s.Require().Equal(orgPool.ID, pool.ID) } func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() { - _, err := s.Store.GetOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestDeleteOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - err = s.Store.DeleteOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } func (s *OrgTestSuite) TestDeleteOrganizationPoolInvalidOrgID() { - err := s.Store.DeleteOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) - s.Require().Equal("looking up org pool: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -741,7 +776,11 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { } func (s *OrgTestSuite) TestListOrgInstances() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -769,7 +808,11 @@ func (s *OrgTestSuite) TestListOrgInstancesInvalidOrgID() { } func (s *OrgTestSuite) TestUpdateOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Orgs[0].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } diff --git a/database/sql/pools.go b/database/sql/pools.go index 5f7d0d7a..ab892eb2 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -58,7 +58,7 @@ func (s *sqlDatabase) ListAllPools(_ context.Context) ([]params.Pool, error) { return ret, nil } -func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) { +func (s *sqlDatabase) GetPoolByID(_ context.Context, poolID string) (params.Pool, error) { pool, err := s.getPoolByID(s.conn, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool by ID") @@ -66,7 +66,7 @@ func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Po return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error { +func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) error { pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return errors.Wrap(err, "fetching pool by ID") @@ -79,7 +79,7 @@ func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error { return nil } -func (s *sqlDatabase) getEntityPool(_ context.Context, entityType params.GithubEntityType, entityID, poolID string, preload ...string) (Pool, error) { +func (s *sqlDatabase) getEntityPool(tx *gorm.DB, entityType params.GithubEntityType, entityID, poolID string, preload ...string) (Pool, error) { if entityID == "" { return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "missing entity id") } @@ -89,7 +89,7 @@ func (s *sqlDatabase) getEntityPool(_ context.Context, entityType params.GithubE return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } - q := s.conn + q := tx if len(preload) > 0 { for _, item := range preload { q = q.Preload(item) @@ -233,7 +233,7 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par return pools, nil } -func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { +func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { if len(param.Tags) == 0 { return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") } @@ -258,7 +258,7 @@ func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.Github entityID, err := uuid.Parse(entity.ID) if err != nil { - return params.Pool{}, fmt.Errorf("parsing entity ID: %w", err) + return params.Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } switch entity.EntityType { @@ -270,9 +270,14 @@ func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.Github newPool.EnterpriseID = &entityID } err = s.conn.Transaction(func(tx *gorm.DB) error { + ok, err := s.hasGithubEntity(tx, entity.EntityType, entity.ID) + if err != nil || !ok { + return errors.Wrap(err, "checking entity existence") + } + if _, err := s.getEntityPoolByUniqueFields(tx, entity, newPool.ProviderName, newPool.Image, newPool.Flavor); err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { - return fmt.Errorf("checking for existing pool: %w", err) + return errors.Wrap(err, "checking pool existence") } } else { return runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") @@ -282,25 +287,25 @@ func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.Github for _, val := range param.Tags { t, err := s.getOrCreateTag(tx, val) if err != nil { - return fmt.Errorf("creating tag: %w", err) + return errors.Wrap(err, "creating tag") } tags = append(tags, t) } q := tx.Create(&newPool) if q.Error != nil { - return fmt.Errorf("creating pool: %w", q.Error) + return errors.Wrap(q.Error, "creating pool") } for i := range tags { if err := tx.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { - return fmt.Errorf("associating tags: %w", err) + return errors.Wrap(err, "associating tags") } } return nil }) if err != nil { - return params.Pool{}, fmt.Errorf("creating pool: %w", err) + return params.Pool{}, errors.Wrap(err, "creating pool") } pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") @@ -311,22 +316,53 @@ func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.Github return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { - return params.Pool{}, nil +func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, entity.EntityType, entity.ID, poolID, "Tags", "Instances") + if err != nil { + return params.Pool{}, fmt.Errorf("fetching pool: %w", err) + } + return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error { +func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) error { + entityID, err := uuid.Parse(entity.ID) + if err != nil { + return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + + poolUUID, err := uuid.Parse(poolID) + if err != nil { + return errors.Wrap(runnerErrors.ErrBadRequest, "parsing pool id") + } + var fieldName string + switch entity.EntityType { + case params.GithubEntityTypeRepository: + fieldName = entityTypeRepoName + case params.GithubEntityTypeOrganization: + fieldName = entityTypeOrgName + case params.GithubEntityTypeEnterprise: + fieldName = entityTypeEnterpriseName + default: + return fmt.Errorf("invalid entityType: %v", entity.EntityType) + } + condition := fmt.Sprintf("id = ? and %s = ?", fieldName) + if err := s.conn.Unscoped().Where(condition, poolUUID, entityID).Delete(&Pool{}).Error; err != nil { + return errors.Wrap(err, "removing pool") + } return nil } -func (s *sqlDatabase) UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { +func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + fmt.Printf("UpdateEntityPool: %v %v %v\n", entity, poolID, param) return params.Pool{}, nil } -func (s *sqlDatabase) ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) { +func (s *sqlDatabase) ListEntityPools(_ context.Context, entity params.GithubEntity) ([]params.Pool, error) { + fmt.Println(entity) return nil, nil } -func (s *sqlDatabase) ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) { +func (s *sqlDatabase) ListEntityInstances(_ context.Context, entity params.GithubEntity) ([]params.Instance, error) { + fmt.Println(entity) return nil, nil } diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index 33fe8725..aac01f99 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -66,12 +66,16 @@ func (s *PoolsTestSuite) SetupTest() { s.FailNow(fmt.Sprintf("failed to create org: %s", err)) } + entity := params.GithubEntity{ + ID: org.ID, + EntityType: params.GithubEntityTypeOrganization, + } // create some pool objects in the database, for testing purposes orgPools := []params.Pool{} for i := 1; i <= 3; i++ { - pool, err := db.CreateOrganizationPool( + pool, err := db.CreateEntityPool( context.Background(), - org.ID, + entity, params.CreatePoolParams{ ProviderName: "test-provider", MaxRunners: 4, diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 0da0e794..936b796c 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -235,16 +235,16 @@ func (s *sqlDatabase) ListRepoPools(ctx context.Context, repoID string) ([]param return ret, nil } -func (s *sqlDatabase) GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances") +func (s *sqlDatabase) GetRepositoryPool(_ context.Context, repoID, poolID string) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeRepository, repoID, poolID) +func (s *sqlDatabase) DeleteRepositoryPool(_ context.Context, repoID, poolID string) error { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeRepository, repoID, poolID) if err != nil { return errors.Wrap(err, "looking up repo pool") } @@ -274,13 +274,13 @@ func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]p return ret, nil } -func (s *sqlDatabase) UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") +func (s *sqlDatabase) UpdateRepositoryPool(_ context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } - return s.updatePool(pool, param) + return s.updatePool(s.conn, pool, param) } func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository, error) { diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 5a5396b8..ab1f9da5 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -443,7 +443,11 @@ func (s *RepoTestSuite) TestGetRepositoryByIDDBDecryptingErr() { } func (s *RepoTestSuite) TestCreateRepositoryPool() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) repo, err := s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) @@ -459,18 +463,25 @@ func (s *RepoTestSuite) TestCreateRepositoryPool() { func (s *RepoTestSuite) TestCreateRepositoryPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - - _, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) } func (s *RepoTestSuite) TestCreateRepositoryPoolInvalidRepoID() { - _, err := s.Store.CreateRepositoryPool(context.Background(), "dummy-repo-id", s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestCreateRepositoryPoolDBCreateErr() { @@ -691,10 +702,14 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchPoolErr() { } func (s *RepoTestSuite) TestListRepoPools() { + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } repoPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%d", i) - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -715,46 +730,66 @@ func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() { } func (s *RepoTestSuite) TestGetRepositoryPool() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - repoPool, err := s.Store.GetRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) + repoPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) s.Require().Equal(repoPool.ID, pool.ID) } func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() { - _, err := s.Store.GetRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestDeleteRepositoryPool() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - err = s.Store.DeleteRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) + err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) + _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } func (s *RepoTestSuite) TestDeleteRepositoryPoolInvalidRepoID() { - err := s.Store.DeleteRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) - s.Require().Equal("looking up repo pool: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -778,7 +813,11 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { } func (s *RepoTestSuite) TestListRepoInstances() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -806,7 +845,11 @@ func (s *RepoTestSuite) TestListRepoInstancesInvalidRepoID() { } func (s *RepoTestSuite) TestUpdateRepositoryPool() { - repoPool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.Repos[0].ID, + EntityType: params.GithubEntityTypeRepository, + } + repoPool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } diff --git a/database/sql/util.go b/database/sql/util.go index f0bcc867..2a41050c 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -297,7 +297,7 @@ func (s *sqlDatabase) getOrCreateTag(tx *gorm.DB, tagName string) (Tag, error) { return newTag, nil } -func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (params.Pool, error) { +func (s *sqlDatabase) updatePool(tx *gorm.DB, pool Pool, param params.UpdatePoolParams) (params.Pool, error) { if param.Enabled != nil && pool.Enabled != *param.Enabled { pool.Enabled = *param.Enabled } @@ -346,21 +346,21 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para pool.Priority = *param.Priority } - if q := s.conn.Save(&pool); q.Error != nil { + if q := tx.Save(&pool); q.Error != nil { return params.Pool{}, errors.Wrap(q.Error, "saving database entry") } tags := []Tag{} if param.Tags != nil && len(param.Tags) > 0 { for _, val := range param.Tags { - t, err := s.getOrCreateTag(s.conn, val) + t, err := s.getOrCreateTag(tx, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } tags = append(tags, t) } - if err := s.conn.Model(&pool).Association("Tags").Replace(&tags); err != nil { + if err := tx.Model(&pool).Association("Tags").Replace(&tags); err != nil { return params.Pool{}, errors.Wrap(err, "replacing tags") } } @@ -391,3 +391,30 @@ func (s *sqlDatabase) getPoolByID(tx *gorm.DB, poolID string, preload ...string) } return pool, nil } + +func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntityType, entityID string) (bool, error) { + u, err := uuid.Parse(entityID) + if err != nil { + return false, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + var q *gorm.DB + switch entityType { + case params.GithubEntityTypeRepository: + q = tx.Model(&Repository{}).Where("id = ?", u) + case params.GithubEntityTypeOrganization: + q = tx.Model(&Organization{}).Where("id = ?", u) + case params.GithubEntityTypeEnterprise: + q = tx.Model(&Enterprise{}).Where("id = ?", u) + default: + return false, errors.Wrap(runnerErrors.ErrBadRequest, "invalid entity type") + } + + var entity interface{} + if err := q.First(entity).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, errors.Wrap(runnerErrors.ErrNotFound, "entity not found") + } + return false, errors.Wrap(err, "fetching entity from database") + } + return true, nil +} diff --git a/runner/enterprises.go b/runner/enterprises.go index ae1e2fc8..7d6e5b8e 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -193,14 +193,9 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, return params.Pool{}, runnerErrors.ErrUnauthorized } - _, err := r.store.GetEnterpriseByID(ctx, enterpriseID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching enterprise") - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool params") + return params.Pool{}, fmt.Errorf("failed to append tags to create pool params: %w", err) } if param.RunnerBootstrapTimeout == 0 { @@ -214,7 +209,7 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { - return params.Pool{}, errors.Wrap(err, "creating pool") + return params.Pool{}, fmt.Errorf("failed to create enterprise pool: %w", err) } return pool, nil @@ -224,8 +219,11 @@ func (r *Runner) GetEnterprisePoolByID(ctx context.Context, enterpriseID, poolID if !auth.IsAdmin(ctx) { return params.Pool{}, runnerErrors.ErrUnauthorized } - - pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -237,29 +235,27 @@ func (r *Runner) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID return runnerErrors.ErrUnauthorized } - // nolint:golangci-lint,godox - // TODO: dedup instance count verification - pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return errors.Wrap(err, "fetching pool") } - instances, err := r.store.ListPoolInstances(ctx, pool.ID) - if err != nil { - return errors.Wrap(err, "fetching instances") - } - // nolint:golangci-lint,godox // TODO: implement a count function - if len(instances) > 0 { + if len(pool.Instances) > 0 { runnerIDs := []string{} - for _, run := range instances { + for _, run := range pool.Instances { runnerIDs = append(runnerIDs, run.ID) } return runnerErrors.NewBadRequestError("pool has runners: %s", strings.Join(runnerIDs, ", ")) } - if err := r.store.DeleteEnterprisePool(ctx, enterpriseID, poolID); err != nil { + if err := r.store.DeleteEntityPool(ctx, entity, poolID); err != nil { return errors.Wrap(err, "deleting pool") } return nil @@ -282,7 +278,11 @@ func (r *Runner) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index 311e743a..dc81da5e 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -272,7 +272,11 @@ func (s *EnterpriseTestSuite) TestDeleteEnterpriseErrUnauthorized() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDefinedFailed() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store enterprises pool: %v", err)) } @@ -340,8 +344,6 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseCreateEnterprisePoolMgrFailed( } func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { - s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) - pool, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -365,30 +367,21 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolErrUnauthorized() { s.Require().Equal(runnerErrors.ErrUnauthorized, err) } -func (s *EnterpriseTestSuite) TestCreateEnterprisePoolErrNotFound() { - s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, runnerErrors.ErrNotFound) - - _, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) - - s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) - s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(runnerErrors.ErrNotFound, err) -} - func (s *EnterpriseTestSuite) TestCreateEnterprisePoolFetchPoolParamsFailed() { s.Fixtures.CreatePoolParams.ProviderName = notExistingProviderName - - s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) - _, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Regexp("fetching pool params: no such provider", err.Error()) + s.Require().Regexp("failed to append tags to create pool params: no such provider not-existent-provider-name", err.Error()) } func (s *EnterpriseTestSuite) TestGetEnterprisePoolByID() { - enterprisePool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + enterprisePool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -406,7 +399,11 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolByIDErrUnauthorized() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -415,7 +412,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { s.Require().Nil(err) - _, err = s.Fixtures.Store.GetEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID) + _, err = s.Fixtures.Store.GetEntityPool(s.Fixtures.AdminContext, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -426,7 +423,11 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolErrUnauthorized() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolRunnersFailed() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } @@ -441,10 +442,14 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolRunnersFailed() { } func (s *EnterpriseTestSuite) TestListEnterprisePools() { + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } enterprisePools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Image = fmt.Sprintf("test-enterprise-%v", i) - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -464,7 +469,11 @@ func (s *EnterpriseTestSuite) TestListOrgPoolsErrUnauthorized() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { - enterprisePool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + enterprisePool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -483,7 +492,11 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolErrUnauthorized() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMinIdleGreaterThanMax() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -498,7 +511,11 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMinIdleGreaterThanMax() { } func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } diff --git a/runner/organizations.go b/runner/organizations.go index 482bd55d..258753f0 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -222,11 +222,6 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C return params.Pool{}, runnerErrors.ErrUnauthorized } - _, err := r.store.GetOrganizationByID(ctx, orgID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching org") - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -254,10 +249,16 @@ func (r *Runner) GetOrgPoolByID(ctx context.Context, orgID, poolID string) (para return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetOrganizationPool(ctx, orgID, poolID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } + return pool, nil } @@ -266,29 +267,30 @@ func (r *Runner) DeleteOrgPool(ctx context.Context, orgID, poolID string) error return runnerErrors.ErrUnauthorized } - // nolint:golangci-lint,godox - // TODO: dedup instance count verification - pool, err := r.store.GetOrganizationPool(ctx, orgID, poolID) - if err != nil { - return errors.Wrap(err, "fetching pool") + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, } - instances, err := r.store.ListPoolInstances(ctx, pool.ID) + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { - return errors.Wrap(err, "fetching instances") + if !errors.Is(err, runnerErrors.ErrNotFound) { + return errors.Wrap(err, "fetching pool") + } + return nil } // nolint:golangci-lint,godox // TODO: implement a count function - if len(instances) > 0 { + if len(pool.Instances) > 0 { runnerIDs := []string{} - for _, run := range instances { + for _, run := range pool.Instances { runnerIDs = append(runnerIDs, run.ID) } return runnerErrors.NewBadRequestError("pool has runners: %s", strings.Join(runnerIDs, ", ")) } - if err := r.store.DeleteOrganizationPool(ctx, orgID, poolID); err != nil { + if err := r.store.DeleteEntityPool(ctx, entity, poolID); err != nil { return errors.Wrap(err, "deleting pool") } return nil @@ -311,7 +313,12 @@ func (r *Runner) UpdateOrgPool(ctx context.Context, orgID, poolID string, param return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetOrganizationPool(ctx, orgID, poolID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } diff --git a/runner/organizations_test.go b/runner/organizations_test.go index 7ebfcff8..d0113756 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -285,7 +285,11 @@ func (s *OrgTestSuite) TestDeleteOrganizationErrUnauthorized() { } func (s *OrgTestSuite) TestDeleteOrganizationPoolDefinedFailed() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store organizations pool: %v", err)) } @@ -365,8 +369,6 @@ func (s *OrgTestSuite) TestUpdateOrganizationCreateOrgPoolMgrFailed() { } func (s *OrgTestSuite) TestCreateOrgPool() { - s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) - pool, err := s.Runner.CreateOrgPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -390,21 +392,8 @@ func (s *OrgTestSuite) TestCreateOrgPoolErrUnauthorized() { s.Require().Equal(runnerErrors.ErrUnauthorized, err) } -func (s *OrgTestSuite) TestCreateOrgPoolErrNotFound() { - s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, runnerErrors.ErrNotFound) - - _, err := s.Runner.CreateOrgPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) - - s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) - s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(runnerErrors.ErrNotFound, err) -} - func (s *OrgTestSuite) TestCreateOrgPoolFetchPoolParamsFailed() { s.Fixtures.CreatePoolParams.ProviderName = notExistingProviderName - - s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) - _, err := s.Runner.CreateOrgPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -413,7 +402,11 @@ func (s *OrgTestSuite) TestCreateOrgPoolFetchPoolParamsFailed() { } func (s *OrgTestSuite) TestGetOrgPoolByID() { - orgPool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + orgPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -431,7 +424,11 @@ func (s *OrgTestSuite) TestGetOrgPoolByIDErrUnauthorized() { } func (s *OrgTestSuite) TestDeleteOrgPool() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -440,7 +437,7 @@ func (s *OrgTestSuite) TestDeleteOrgPool() { s.Require().Nil(err) - _, err = s.Fixtures.Store.GetOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, pool.ID) + _, err = s.Fixtures.Store.GetEntityPool(s.Fixtures.AdminContext, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -451,7 +448,11 @@ func (s *OrgTestSuite) TestDeleteOrgPoolErrUnauthorized() { } func (s *OrgTestSuite) TestDeleteOrgPoolRunnersFailed() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -466,10 +467,14 @@ func (s *OrgTestSuite) TestDeleteOrgPoolRunnersFailed() { } func (s *OrgTestSuite) TestListOrgPools() { + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } orgPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Image = fmt.Sprintf("test-org-%v", i) - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -489,7 +494,11 @@ func (s *OrgTestSuite) TestListOrgPoolsErrUnauthorized() { } func (s *OrgTestSuite) TestUpdateOrgPool() { - orgPool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + orgPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -508,7 +517,11 @@ func (s *OrgTestSuite) TestUpdateOrgPoolErrUnauthorized() { } func (s *OrgTestSuite) TestUpdateOrgPoolMinIdleGreaterThanMax() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -523,7 +536,11 @@ func (s *OrgTestSuite) TestUpdateOrgPoolMinIdleGreaterThanMax() { } func (s *OrgTestSuite) TestListOrgInstances() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 2226fa13..5291be6b 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -2198,16 +2198,7 @@ func (r *basePoolManager) ListPools() ([]params.Pool, error) { } func (r *basePoolManager) GetPoolByID(poolID string) (params.Pool, error) { - switch r.entity.EntityType { - case params.GithubEntityTypeRepository: - return r.store.GetRepositoryPool(r.ctx, r.entity.ID, poolID) - case params.GithubEntityTypeOrganization: - return r.store.GetOrganizationPool(r.ctx, r.entity.ID, poolID) - case params.GithubEntityTypeEnterprise: - return r.store.GetEnterprisePool(r.ctx, r.entity.ID, poolID) - default: - return params.Pool{}, fmt.Errorf("unknown entity type: %s", r.entity.EntityType) - } + return r.store.GetEntityPool(r.ctx, r.entity, poolID) } func (r *basePoolManager) GetWebhookInfo(ctx context.Context) (params.HookInfo, error) { diff --git a/runner/pools_test.go b/runner/pools_test.go index 59d6ff27..e2b269a0 100644 --- a/runner/pools_test.go +++ b/runner/pools_test.go @@ -64,11 +64,15 @@ func (s *PoolTestSuite) SetupTest() { } // create some pool objects in the database, for testing purposes + entity := params.GithubEntity{ + ID: org.ID, + EntityType: params.GithubEntityTypeOrganization, + } orgPools := []params.Pool{} for i := 1; i <= 3; i++ { - pool, err := db.CreateOrganizationPool( + pool, err := db.CreateEntityPool( context.Background(), - org.ID, + entity, params.CreatePoolParams{ ProviderName: "test-provider", MaxRunners: 4, diff --git a/runner/repositories.go b/runner/repositories.go index b8b25c06..68ef38f5 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -221,14 +221,9 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params return params.Pool{}, runnerErrors.ErrUnauthorized } - _, err := r.store.GetRepositoryByID(ctx, repoID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching repo") - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool params") + return params.Pool{}, fmt.Errorf("failed to append tags to create pool params: %w", err) } if createPoolParams.RunnerBootstrapTimeout == 0 { @@ -242,7 +237,7 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { - return params.Pool{}, errors.Wrap(err, "creating pool") + return params.Pool{}, fmt.Errorf("failed to create pool: %w", err) } return pool, nil @@ -253,10 +248,16 @@ func (r *Runner) GetRepoPoolByID(ctx context.Context, repoID, poolID string) (pa return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetRepositoryPool(ctx, repoID, poolID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } + return pool, nil } @@ -265,27 +266,26 @@ func (r *Runner) DeleteRepoPool(ctx context.Context, repoID, poolID string) erro return runnerErrors.ErrUnauthorized } - pool, err := r.store.GetRepositoryPool(ctx, repoID, poolID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return errors.Wrap(err, "fetching pool") } - instances, err := r.store.ListPoolInstances(ctx, pool.ID) - if err != nil { - return errors.Wrap(err, "fetching instances") - } - // nolint:golangci-lint,godox // TODO: implement a count function - if len(instances) > 0 { + if len(pool.Instances) > 0 { runnerIDs := []string{} - for _, run := range instances { + for _, run := range pool.Instances { runnerIDs = append(runnerIDs, run.ID) } return runnerErrors.NewBadRequestError("pool has runners: %s", strings.Join(runnerIDs, ", ")) } - if err := r.store.DeleteRepositoryPool(ctx, repoID, poolID); err != nil { + if err := r.store.DeleteEntityPool(ctx, entity, poolID); err != nil { return errors.Wrap(err, "deleting pool") } return nil @@ -320,7 +320,11 @@ func (r *Runner) UpdateRepoPool(ctx context.Context, repoID, poolID string, para return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetRepositoryPool(ctx, repoID, poolID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } diff --git a/runner/repositories_test.go b/runner/repositories_test.go index 8a1e8d9c..20814a86 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -295,7 +295,11 @@ func (s *RepoTestSuite) TestDeleteRepositoryErrUnauthorized() { } func (s *RepoTestSuite) TestDeleteRepositoryPoolDefinedFailed() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store repositories pool: %v", err)) } @@ -376,8 +380,6 @@ func (s *RepoTestSuite) TestUpdateRepositoryCreateRepoPoolMgrFailed() { } func (s *RepoTestSuite) TestCreateRepoPool() { - s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) - pool, err := s.Runner.CreateRepoPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -401,30 +403,21 @@ func (s *RepoTestSuite) TestCreateRepoPoolErrUnauthorized() { s.Require().Equal(runnerErrors.ErrUnauthorized, err) } -func (s *RepoTestSuite) TestCreateRepoPoolErrNotFound() { - s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - - _, err := s.Runner.CreateRepoPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) - - s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) - s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(runnerErrors.ErrNotFound, err) -} - func (s *RepoTestSuite) TestCreateRepoPoolFetchPoolParamsFailed() { s.Fixtures.CreatePoolParams.ProviderName = notExistingProviderName - - s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) - _, err := s.Runner.CreateRepoPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Regexp("fetching pool params: no such provider", err.Error()) + s.Require().Regexp("failed to append tags to create pool params: no such provider not-existent-provider-name", err.Error()) } func (s *RepoTestSuite) TestGetRepoPoolByID() { - repoPool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + repoPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -442,7 +435,11 @@ func (s *RepoTestSuite) TestGetRepoPoolByIDErrUnauthorized() { } func (s *RepoTestSuite) TestDeleteRepoPool() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -451,7 +448,7 @@ func (s *RepoTestSuite) TestDeleteRepoPool() { s.Require().Nil(err) - _, err = s.Fixtures.Store.GetRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, pool.ID) + _, err = s.Fixtures.Store.GetEntityPool(s.Fixtures.AdminContext, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -462,7 +459,11 @@ func (s *RepoTestSuite) TestDeleteRepoPoolErrUnauthorized() { } func (s *RepoTestSuite) TestDeleteRepoPoolRunnersFailed() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -477,10 +478,14 @@ func (s *RepoTestSuite) TestDeleteRepoPoolRunnersFailed() { } func (s *RepoTestSuite) TestListRepoPools() { + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } repoPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Image = fmt.Sprintf("test-repo-%v", i) - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -500,7 +505,11 @@ func (s *RepoTestSuite) TestListRepoPoolsErrUnauthorized() { } func (s *RepoTestSuite) TestListPoolInstances() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -527,7 +536,11 @@ func (s *RepoTestSuite) TestListPoolInstancesErrUnauthorized() { } func (s *RepoTestSuite) TestUpdateRepoPool() { - repoPool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + repoPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store repositories pool: %v", err)) } @@ -546,7 +559,11 @@ func (s *RepoTestSuite) TestUpdateRepoPoolErrUnauthorized() { } func (s *RepoTestSuite) TestUpdateRepoPoolMinIdleGreaterThanMax() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -561,7 +578,11 @@ func (s *RepoTestSuite) TestUpdateRepoPoolMinIdleGreaterThanMax() { } func (s *RepoTestSuite) TestListRepoInstances() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } diff --git a/runner/runner.go b/runner/runner.go index 7eab27f9..a29fda0c 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -785,7 +785,8 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ func (r *Runner) appendTagsToCreatePoolParams(param params.CreatePoolParams) (params.CreatePoolParams, error) { if err := param.Validate(); err != nil { - return params.CreatePoolParams{}, errors.Wrapf(runnerErrors.ErrBadRequest, "validating params: %s", err) + return params.CreatePoolParams{}, fmt.Errorf("failed to validate params (%q): %w", err, runnerErrors.ErrBadRequest) + // errors.Wrapf(runnerErrors.ErrBadRequest, "validating params: %s", err) } if !IsSupportedOSType(param.OSType) { @@ -803,7 +804,7 @@ func (r *Runner) appendTagsToCreatePoolParams(param params.CreatePoolParams) (pa newTags, err := r.processTags(string(param.OSArch), param.OSType, param.Tags) if err != nil { - return params.CreatePoolParams{}, errors.Wrap(err, "processing tags") + return params.CreatePoolParams{}, fmt.Errorf("failed to process tags: %w", err) } param.Tags = newTags From f9f545f060f781492be7a682d613592b6d2d7315 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Fri, 29 Mar 2024 18:18:29 +0000 Subject: [PATCH 3/3] Remove duplicate code Signed-off-by: Gabriel Adrian Samfira --- database/common/common.go | 24 --- database/common/mocks/Store.go | 318 ----------------------------- database/sql/enterprise.go | 151 -------------- database/sql/enterprise_test.go | 211 +++++++++---------- database/sql/instances_test.go | 6 +- database/sql/organizations.go | 151 -------------- database/sql/organizations_test.go | 212 +++++++++---------- database/sql/pools.go | 107 +++++++--- database/sql/pools_test.go | 6 +- database/sql/repositories.go | 152 -------------- database/sql/repositories_test.go | 215 +++++++++---------- database/sql/util.go | 17 +- params/params.go | 57 ++++++ runner/enterprises.go | 22 +- runner/organizations.go | 23 ++- runner/pool/pool.go | 21 +- runner/pools.go | 16 +- runner/repositories.go | 23 ++- 18 files changed, 487 insertions(+), 1245 deletions(-) diff --git a/database/common/common.go b/database/common/common.go index ab546844..023a2057 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -27,14 +27,6 @@ type RepoStore interface { ListRepositories(ctx context.Context) ([]params.Repository, error) DeleteRepository(ctx context.Context, repoID string) error UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) - - // CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) - // GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) - // DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error - UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - - ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) - ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) } type OrgStore interface { @@ -44,14 +36,6 @@ type OrgStore interface { ListOrganizations(ctx context.Context) ([]params.Organization, error) DeleteOrganization(ctx context.Context, orgID string) error UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) - - // CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) - // GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) - // DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error - UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - - ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) - ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) } type EnterpriseStore interface { @@ -61,14 +45,6 @@ type EnterpriseStore interface { ListEnterprises(ctx context.Context) ([]params.Enterprise, error) DeleteEnterprise(ctx context.Context, enterpriseID string) error UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) - - // CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) - // GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) - // DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error - UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - - ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) - ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) } type PoolStore interface { diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index 73eef2c3..219057e4 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -310,24 +310,6 @@ func (_m *Store) DeleteEnterprise(ctx context.Context, enterpriseID string) erro return r0 } -// DeleteEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID -func (_m *Store) DeleteEnterprisePool(ctx context.Context, enterpriseID string, poolID string) error { - ret := _m.Called(ctx, enterpriseID, poolID) - - if len(ret) == 0 { - panic("no return value specified for DeleteEnterprisePool") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, enterpriseID, poolID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // DeleteEntityPool provides a mock function with given fields: ctx, entity, poolID func (_m *Store) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error { ret := _m.Called(ctx, entity, poolID) @@ -400,24 +382,6 @@ func (_m *Store) DeleteOrganization(ctx context.Context, orgID string) error { return r0 } -// DeleteOrganizationPool provides a mock function with given fields: ctx, orgID, poolID -func (_m *Store) DeleteOrganizationPool(ctx context.Context, orgID string, poolID string) error { - ret := _m.Called(ctx, orgID, poolID) - - if len(ret) == 0 { - panic("no return value specified for DeleteOrganizationPool") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, orgID, poolID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // DeletePoolByID provides a mock function with given fields: ctx, poolID func (_m *Store) DeletePoolByID(ctx context.Context, poolID string) error { ret := _m.Called(ctx, poolID) @@ -454,24 +418,6 @@ func (_m *Store) DeleteRepository(ctx context.Context, repoID string) error { return r0 } -// DeleteRepositoryPool provides a mock function with given fields: ctx, repoID, poolID -func (_m *Store) DeleteRepositoryPool(ctx context.Context, repoID string, poolID string) error { - ret := _m.Called(ctx, repoID, poolID) - - if len(ret) == 0 { - panic("no return value specified for DeleteRepositoryPool") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, repoID, poolID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // FindPoolsMatchingAllTags provides a mock function with given fields: ctx, entityType, entityID, tags func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params.GithubEntityType, entityID string, tags []string) ([]params.Pool, error) { ret := _m.Called(ctx, entityType, entityID, tags) @@ -1002,66 +948,6 @@ func (_m *Store) ListAllPools(ctx context.Context) ([]params.Pool, error) { return r0, r1 } -// ListEnterpriseInstances provides a mock function with given fields: ctx, enterpriseID -func (_m *Store) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) { - ret := _m.Called(ctx, enterpriseID) - - if len(ret) == 0 { - panic("no return value specified for ListEnterpriseInstances") - } - - var r0 []params.Instance - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Instance, error)); ok { - return rf(ctx, enterpriseID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Instance); ok { - r0 = rf(ctx, enterpriseID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Instance) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, enterpriseID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ListEnterprisePools provides a mock function with given fields: ctx, enterpriseID -func (_m *Store) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) { - ret := _m.Called(ctx, enterpriseID) - - if len(ret) == 0 { - panic("no return value specified for ListEnterprisePools") - } - - var r0 []params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Pool, error)); ok { - return rf(ctx, enterpriseID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Pool); ok { - r0 = rf(ctx, enterpriseID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Pool) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, enterpriseID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // ListEnterprises provides a mock function with given fields: ctx func (_m *Store) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) { ret := _m.Called(ctx) @@ -1242,66 +1128,6 @@ func (_m *Store) ListJobsByStatus(ctx context.Context, status params.JobStatus) return r0, r1 } -// ListOrgInstances provides a mock function with given fields: ctx, orgID -func (_m *Store) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) { - ret := _m.Called(ctx, orgID) - - if len(ret) == 0 { - panic("no return value specified for ListOrgInstances") - } - - var r0 []params.Instance - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Instance, error)); ok { - return rf(ctx, orgID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Instance); ok { - r0 = rf(ctx, orgID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Instance) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, orgID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ListOrgPools provides a mock function with given fields: ctx, orgID -func (_m *Store) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) { - ret := _m.Called(ctx, orgID) - - if len(ret) == 0 { - panic("no return value specified for ListOrgPools") - } - - var r0 []params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Pool, error)); ok { - return rf(ctx, orgID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Pool); ok { - r0 = rf(ctx, orgID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Pool) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, orgID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // ListOrganizations provides a mock function with given fields: ctx func (_m *Store) ListOrganizations(ctx context.Context) ([]params.Organization, error) { ret := _m.Called(ctx) @@ -1362,66 +1188,6 @@ func (_m *Store) ListPoolInstances(ctx context.Context, poolID string) ([]params return r0, r1 } -// ListRepoInstances provides a mock function with given fields: ctx, repoID -func (_m *Store) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) { - ret := _m.Called(ctx, repoID) - - if len(ret) == 0 { - panic("no return value specified for ListRepoInstances") - } - - var r0 []params.Instance - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Instance, error)); ok { - return rf(ctx, repoID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Instance); ok { - r0 = rf(ctx, repoID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Instance) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, repoID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ListRepoPools provides a mock function with given fields: ctx, repoID -func (_m *Store) ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) { - ret := _m.Called(ctx, repoID) - - if len(ret) == 0 { - panic("no return value specified for ListRepoPools") - } - - var r0 []params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Pool, error)); ok { - return rf(ctx, repoID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Pool); ok { - r0 = rf(ctx, repoID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Pool) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, repoID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // ListRepositories provides a mock function with given fields: ctx func (_m *Store) ListRepositories(ctx context.Context) ([]params.Repository, error) { ret := _m.Called(ctx) @@ -1544,34 +1310,6 @@ func (_m *Store) UpdateEnterprise(ctx context.Context, enterpriseID string, para return r0, r1 } -// UpdateEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID, param -func (_m *Store) UpdateEnterprisePool(ctx context.Context, enterpriseID string, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, poolID, param) - - if len(ret) == 0 { - panic("no return value specified for UpdateEnterprisePool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, poolID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) params.Pool); ok { - r0 = rf(ctx, enterpriseID, poolID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, params.UpdatePoolParams) error); ok { - r1 = rf(ctx, enterpriseID, poolID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // UpdateEntityPool provides a mock function with given fields: ctx, entity, poolID, param func (_m *Store) UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { ret := _m.Called(ctx, entity, poolID, param) @@ -1656,34 +1394,6 @@ func (_m *Store) UpdateOrganization(ctx context.Context, orgID string, param par return r0, r1 } -// UpdateOrganizationPool provides a mock function with given fields: ctx, orgID, poolID, param -func (_m *Store) UpdateOrganizationPool(ctx context.Context, orgID string, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, orgID, poolID, param) - - if len(ret) == 0 { - panic("no return value specified for UpdateOrganizationPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) (params.Pool, error)); ok { - return rf(ctx, orgID, poolID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) params.Pool); ok { - r0 = rf(ctx, orgID, poolID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, params.UpdatePoolParams) error); ok { - r1 = rf(ctx, orgID, poolID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // UpdateRepository provides a mock function with given fields: ctx, repoID, param func (_m *Store) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) { ret := _m.Called(ctx, repoID, param) @@ -1712,34 +1422,6 @@ func (_m *Store) UpdateRepository(ctx context.Context, repoID string, param para return r0, r1 } -// UpdateRepositoryPool provides a mock function with given fields: ctx, repoID, poolID, param -func (_m *Store) UpdateRepositoryPool(ctx context.Context, repoID string, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, repoID, poolID, param) - - if len(ret) == 0 { - panic("no return value specified for UpdateRepositoryPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) (params.Pool, error)); ok { - return rf(ctx, repoID, poolID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) params.Pool); ok { - r0 = rf(ctx, repoID, poolID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, params.UpdatePoolParams) error); ok { - r1 = rf(ctx, repoID, poolID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // UpdateUser provides a mock function with given fields: ctx, user, param func (_m *Store) UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error) { ret := _m.Called(ctx, user, param) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index c3bf5d69..3eb53b9e 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -5,7 +5,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -134,137 +133,6 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, return newParams, nil } -func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) { - if len(param.Tags) == 0 { - return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") - } - - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching enterprise") - } - - newPool := Pool{ - ProviderName: param.ProviderName, - MaxRunners: param.MaxRunners, - MinIdleRunners: param.MinIdleRunners, - RunnerPrefix: param.GetRunnerPrefix(), - Image: param.Image, - Flavor: param.Flavor, - OSType: param.OSType, - OSArch: param.OSArch, - EnterpriseID: &enterprise.ID, - Enabled: param.Enabled, - RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, - GitHubRunnerGroup: param.GitHubRunnerGroup, - Priority: param.Priority, - } - - if len(param.ExtraSpecs) > 0 { - newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) - } - - _, err = s.getEnterprisePoolByUniqueFields(ctx, enterpriseID, newPool.ProviderName, newPool.Image, newPool.Flavor) - if err != nil { - if !errors.Is(err, runnerErrors.ErrNotFound) { - return params.Pool{}, errors.Wrap(err, "creating pool") - } - } else { - return params.Pool{}, runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") - } - - tags := []Tag{} - for _, val := range param.Tags { - t, err := s.getOrCreateTag(s.conn, val) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching tag") - } - tags = append(tags, t) - } - - q := s.conn.Create(&newPool) - if q.Error != nil { - return params.Pool{}, errors.Wrap(q.Error, "adding pool") - } - - for i := range tags { - if err := s.conn.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { - return params.Pool{}, errors.Wrap(err, "saving tag") - } - } - - pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) GetEnterprisePool(_ context.Context, enterpriseID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) DeleteEnterprisePool(_ context.Context, enterpriseID, poolID string) error { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeEnterprise, enterpriseID, poolID) - if err != nil { - return errors.Wrap(err, "looking up enterprise pool") - } - q := s.conn.Unscoped().Delete(&pool) - if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { - return errors.Wrap(q.Error, "deleting pool") - } - return nil -} - -func (s *sqlDatabase) UpdateEnterprisePool(_ context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.updatePool(s.conn, pool, param) -} - -func (s *sqlDatabase) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeEnterprise, enterpriseID, "Tags", "Instances", "Enterprise") - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - - ret := make([]params.Pool, len(pools)) - for idx, pool := range pools { - ret[idx], err = s.sqlToCommonPool(pool) - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - } - - return ret, nil -} - -func (s *sqlDatabase) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeEnterprise, enterpriseID, "Instances", "Tags", "Instances.Job") - if err != nil { - return nil, errors.Wrap(err, "fetching enterprise") - } - ret := []params.Instance{} - for _, pool := range pools { - for _, instance := range pool.Instances { - paramsInstance, err := s.sqlToParamsInstance(instance) - if err != nil { - return nil, errors.Wrap(err, "fetching instance") - } - ret = append(ret, paramsInstance) - } - } - return ret, nil -} - func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, error) { var enterprise Enterprise @@ -302,22 +170,3 @@ func (s *sqlDatabase) getEnterpriseByID(_ context.Context, id string, preload .. } return enterprise, nil } - -func (s *sqlDatabase) getEnterprisePoolByUniqueFields(ctx context.Context, enterpriseID string, provider, image, flavor string) (Pool, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching enterprise") - } - - q := s.conn - var pool []Pool - err = q.Model(&enterprise).Association("Pools").Find(&pool, "provider_name = ? and image = ? and flavor = ?", provider, image, flavor) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching pool") - } - if len(pool) == 0 { - return Pool{}, runnerErrors.ErrNotFound - } - - return pool[0], nil -} diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index 4f4a7da2..f77ae3d5 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -405,10 +405,8 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseByIDDBDecryptingErr() { } func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) @@ -426,11 +424,9 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { func (s *EnterpriseTestSuite) TestCreateEnterprisePoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } - _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) @@ -448,41 +444,37 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolInvalidEnterpriseID() { } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBCreateErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked creating pool error")) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating pool: fetching pool: mocked creating pool error", err.Error()) + s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterpriseDBPoolAlreadyExistErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id", "provider_name", "image", "flavor"}). AddRow( s.Fixtures.Enterprises[0].ID, @@ -490,159 +482,141 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBPoolAlreadyExistErr() { s.Fixtures.CreatePoolParams.Image, s.Fixtures.CreatePoolParams.Flavor)) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchTagErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked fetching tag error")) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("fetching tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBAddingPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} - + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnError(fmt.Errorf("mocked adding pool error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("adding pool: mocked adding pool error", err.Error()) + s.Require().Equal("creating pool: mocked adding pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBSaveTagErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnError(fmt.Errorf("mocked saving tag error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving tag: mocked saving tag error", err.Error()) + s.Require().Equal("associating tags: mocked saving tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -657,19 +631,19 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchPoolErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestListEnterprisePools() { enterprisePools := []params.Pool{} - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) @@ -679,24 +653,26 @@ func (s *EnterpriseTestSuite) TestListEnterprisePools() { enterprisePools = append(enterprisePools, pool) } - pools, err := s.Store.ListEnterprisePools(context.Background(), s.Fixtures.Enterprises[0].ID) + pools, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), enterprisePools, pools) } func (s *EnterpriseTestSuite) TestListEnterprisePoolsInvalidEnterpriseID() { - _, err := s.Store.ListEnterprisePools(context.Background(), "dummy-enterprise-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestGetEnterprisePool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) @@ -720,10 +696,8 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) @@ -748,38 +722,29 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). - WithArgs(pool.ID, s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"enterprise_id", "id"}).AddRow(s.Fixtures.Enterprises[0].ID, pool.ID)) s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. - ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE `pools`.`id` = ?")). - WithArgs(pool.ID). + ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE id = ? and enterprise_id = ?")). + WithArgs(pool.ID, s.Fixtures.Enterprises[0].ID). WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) - - s.assertSQLMockExpectations() + err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().NotNil(err) - s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) + s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) @@ -794,30 +759,32 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListEnterpriseInstances(context.Background(), s.Fixtures.Enterprises[0].ID) + instances, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().Nil(err) s.equalInstancesByName(poolInstances, instances) } func (s *EnterpriseTestSuite) TestListEnterpriseInstancesInvalidEnterpriseID() { - _, err := s.Store.ListEnterpriseInstances(context.Background(), "dummy-enterprise-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().NotNil(err) - s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) + s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Enterprises[0].ID, - EntityType: params.GithubEntityTypeEnterprise, - } + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - pool, err = s.Store.UpdateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID, s.Fixtures.UpdatePoolParams) + pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -827,7 +794,11 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolInvalidEnterpriseID() { - _, err := s.Store.UpdateEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 200c683b..b136c8ae 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -92,10 +92,8 @@ func (s *InstancesTestSuite) SetupTest() { OSType: "linux", Tags: []string{"self-hosted", "amd64", "linux"}, } - entity := params.GithubEntity{ - ID: org.ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := org.GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, createPoolParams) if err != nil { s.FailNow(fmt.Sprintf("failed to create org pool: %s", err)) diff --git a/database/sql/organizations.go b/database/sql/organizations.go index a2b14ae5..24704fd9 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -20,7 +20,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -151,137 +150,6 @@ func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (pa return param, nil } -func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) { - if len(param.Tags) == 0 { - return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") - } - - org, err := s.getOrgByID(ctx, orgID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching org") - } - - newPool := Pool{ - ProviderName: param.ProviderName, - MaxRunners: param.MaxRunners, - MinIdleRunners: param.MinIdleRunners, - RunnerPrefix: param.GetRunnerPrefix(), - Image: param.Image, - Flavor: param.Flavor, - OSType: param.OSType, - OSArch: param.OSArch, - OrgID: &org.ID, - Enabled: param.Enabled, - RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, - GitHubRunnerGroup: param.GitHubRunnerGroup, - Priority: param.Priority, - } - - if len(param.ExtraSpecs) > 0 { - newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) - } - - _, err = s.getOrgPoolByUniqueFields(ctx, orgID, newPool.ProviderName, newPool.Image, newPool.Flavor) - if err != nil { - if !errors.Is(err, runnerErrors.ErrNotFound) { - return params.Pool{}, errors.Wrap(err, "creating pool") - } - } else { - return params.Pool{}, runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") - } - - tags := []Tag{} - for _, val := range param.Tags { - t, err := s.getOrCreateTag(s.conn, val) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching tag") - } - tags = append(tags, t) - } - - q := s.conn.Create(&newPool) - if q.Error != nil { - return params.Pool{}, errors.Wrap(q.Error, "adding pool") - } - - for i := range tags { - if err := s.conn.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { - return params.Pool{}, errors.Wrap(err, "saving tag") - } - } - - pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeOrganization, orgID, "Tags", "Instances", "Organization") - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - - ret := make([]params.Pool, len(pools)) - for idx, pool := range pools { - ret[idx], err = s.sqlToCommonPool(pool) - if err != nil { - return nil, errors.Wrap(err, "fetching pool") - } - } - - return ret, nil -} - -func (s *sqlDatabase) GetOrganizationPool(_ context.Context, orgID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) DeleteOrganizationPool(_ context.Context, orgID, poolID string) error { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeOrganization, orgID, poolID) - if err != nil { - return errors.Wrap(err, "looking up org pool") - } - q := s.conn.Unscoped().Delete(&pool) - if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { - return errors.Wrap(q.Error, "deleting pool") - } - return nil -} - -func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeOrganization, orgID, "Tags", "Instances", "Instances.Job") - if err != nil { - return nil, errors.Wrap(err, "fetching org") - } - ret := []params.Instance{} - for _, pool := range pools { - for _, instance := range pool.Instances { - paramsInstance, err := s.sqlToParamsInstance(instance) - if err != nil { - return nil, errors.Wrap(err, "fetching instance") - } - ret = append(ret, paramsInstance) - } - } - return ret, nil -} - -func (s *sqlDatabase) UpdateOrganizationPool(_ context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.updatePool(s.conn, pool, param) -} - func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) { u, err := uuid.Parse(id) if err != nil { @@ -319,22 +187,3 @@ func (s *sqlDatabase) getOrg(_ context.Context, name string) (Organization, erro } return org, nil } - -func (s *sqlDatabase) getOrgPoolByUniqueFields(ctx context.Context, orgID string, provider, image, flavor string) (Pool, error) { - org, err := s.getOrgByID(ctx, orgID) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching org") - } - - q := s.conn - var pool []Pool - err = q.Model(&org).Association("Pools").Find(&pool, "provider_name = ? and image = ? and flavor = ?", provider, image, flavor) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching pool") - } - if len(pool) == 0 { - return Pool{}, runnerErrors.ErrNotFound - } - - return pool[0], nil -} diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 63b4cebd..86d13d72 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -405,10 +405,8 @@ func (s *OrgTestSuite) TestGetOrganizationByIDDBDecryptingErr() { } func (s *OrgTestSuite) TestCreateOrganizationPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) @@ -426,11 +424,9 @@ func (s *OrgTestSuite) TestCreateOrganizationPool() { func (s *OrgTestSuite) TestCreateOrganizationPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } - _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) @@ -448,41 +444,37 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolInvalidOrgID() { } func (s *OrgTestSuite) TestCreateOrganizationPoolDBCreateErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked creating pool error")) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating pool: fetching pool: mocked creating pool error", err.Error()) + s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationDBPoolAlreadyExistErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id", "provider_name", "image", "flavor"}). AddRow( s.Fixtures.Orgs[0].ID, @@ -490,159 +482,142 @@ func (s *OrgTestSuite) TestCreateOrganizationDBPoolAlreadyExistErr() { s.Fixtures.CreatePoolParams.Image, s.Fixtures.CreatePoolParams.Flavor)) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchTagErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked fetching tag error")) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("fetching tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBAddingPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnError(fmt.Errorf("mocked adding pool error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("adding pool: mocked adding pool error", err.Error()) + s.Require().Equal("creating pool: mocked adding pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBSaveTagErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnError(fmt.Errorf("mocked saving tag error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving tag: mocked saving tag error", err.Error()) + s.Require().Equal("associating tags: mocked saving tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -657,19 +632,20 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchPoolErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestListOrgPools() { orgPools := []params.Pool{} - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) @@ -678,25 +654,26 @@ func (s *OrgTestSuite) TestListOrgPools() { } orgPools = append(orgPools, pool) } - - pools, err := s.Store.ListOrgPools(context.Background(), s.Fixtures.Orgs[0].ID) + pools, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), orgPools, pools) } func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() { - _, err := s.Store.ListOrgPools(context.Background(), "dummy-org-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestGetOrganizationPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) @@ -720,10 +697,8 @@ func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() { } func (s *OrgTestSuite) TestDeleteOrganizationPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) @@ -748,38 +723,31 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolInvalidOrgID() { } func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). - WithArgs(pool.ID, s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"org_id", "id"}).AddRow(s.Fixtures.Orgs[0].ID, pool.ID)) s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. - ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE `pools`.`id` = ?")). - WithArgs(pool.ID). + ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE id = ? and org_id = ?")). + WithArgs(pool.ID, s.Fixtures.Orgs[0].ID). WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) + s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestListOrgInstances() { - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) @@ -794,30 +762,32 @@ func (s *OrgTestSuite) TestListOrgInstances() { poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListOrgInstances(context.Background(), s.Fixtures.Orgs[0].ID) + instances, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().Nil(err) s.equalInstancesByName(poolInstances, instances) } func (s *OrgTestSuite) TestListOrgInstancesInvalidOrgID() { - _, err := s.Store.ListOrgInstances(context.Background(), "dummy-org-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().NotNil(err) - s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) + s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestUpdateOrganizationPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Orgs[0].ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - pool, err = s.Store.UpdateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID, s.Fixtures.UpdatePoolParams) + pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -827,7 +797,11 @@ func (s *OrgTestSuite) TestUpdateOrganizationPool() { } func (s *OrgTestSuite) TestUpdateOrganizationPoolInvalidOrgID() { - _, err := s.Store.UpdateOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/pools.go b/database/sql/pools.go index ab892eb2..1dd7e68d 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -89,25 +89,30 @@ func (s *sqlDatabase) getEntityPool(tx *gorm.DB, entityType params.GithubEntityT return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } + var fieldName string + var entityField string + switch entityType { + case params.GithubEntityTypeRepository: + fieldName = entityTypeRepoName + entityField = "Repository" + case params.GithubEntityTypeOrganization: + fieldName = entityTypeOrgName + entityField = "Organization" + case params.GithubEntityTypeEnterprise: + fieldName = entityTypeEnterpriseName + entityField = "Enterprise" + default: + return Pool{}, fmt.Errorf("invalid entityType: %v", entityType) + } + q := tx + q = q.Preload(entityField) if len(preload) > 0 { for _, item := range preload { q = q.Preload(item) } } - var fieldName string - switch entityType { - case params.GithubEntityTypeRepository: - fieldName = entityTypeRepoName - case params.GithubEntityTypeOrganization: - fieldName = entityTypeOrgName - case params.GithubEntityTypeEnterprise: - fieldName = entityTypeEnterpriseName - default: - return Pool{}, fmt.Errorf("invalid entityType: %v", entityType) - } - var pool Pool condition := fmt.Sprintf("id = ? and %s = ?", fieldName) err = q.Model(&Pool{}). @@ -123,30 +128,39 @@ func (s *sqlDatabase) getEntityPool(tx *gorm.DB, entityType params.GithubEntityT return pool, nil } -func (s *sqlDatabase) listEntityPools(_ context.Context, entityType params.GithubEntityType, entityID string, preload ...string) ([]Pool, error) { +func (s *sqlDatabase) listEntityPools(tx *gorm.DB, entityType params.GithubEntityType, entityID string, preload ...string) ([]Pool, error) { if _, err := uuid.Parse(entityID); err != nil { return nil, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } - q := s.conn - if len(preload) > 0 { - for _, item := range preload { - q = q.Preload(item) - } + if err := s.hasGithubEntity(tx, entityType, entityID); err != nil { + return nil, errors.Wrap(err, "checking entity existence") } + var preloadEntity string var fieldName string switch entityType { case params.GithubEntityTypeRepository: fieldName = entityTypeRepoName + preloadEntity = "Repository" case params.GithubEntityTypeOrganization: fieldName = entityTypeOrgName + preloadEntity = "Organization" case params.GithubEntityTypeEnterprise: fieldName = entityTypeEnterpriseName + preloadEntity = "Enterprise" default: return nil, fmt.Errorf("invalid entityType: %v", entityType) } + q := tx + q = q.Preload(preloadEntity) + if len(preload) > 0 { + for _, item := range preload { + q = q.Preload(item) + } + } + var pools []Pool condition := fmt.Sprintf("%s = ?", fieldName) err := q.Model(&Pool{}). @@ -270,8 +284,7 @@ func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEn newPool.EnterpriseID = &entityID } err = s.conn.Transaction(func(tx *gorm.DB) error { - ok, err := s.hasGithubEntity(tx, entity.EntityType, entity.ID) - if err != nil || !ok { + if err := s.hasGithubEntity(tx, entity.EntityType, entity.ID); err != nil { return errors.Wrap(err, "checking entity existence") } @@ -305,7 +318,7 @@ func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEn return nil }) if err != nil { - return params.Pool{}, errors.Wrap(err, "creating pool") + return params.Pool{}, err } pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") @@ -353,16 +366,56 @@ func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEn } func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - fmt.Printf("UpdateEntityPool: %v %v %v\n", entity, poolID, param) - return params.Pool{}, nil + var updatedPool params.Pool + err := s.conn.Transaction(func(tx *gorm.DB) error { + pool, err := s.getEntityPool(tx, entity.EntityType, entity.ID, poolID, "Tags", "Instances") + if err != nil { + return errors.Wrap(err, "fetching pool") + } + + updatedPool, err = s.updatePool(tx, pool, param) + if err != nil { + return errors.Wrap(err, "updating pool") + } + return nil + }) + if err != nil { + return params.Pool{}, err + } + return updatedPool, nil } func (s *sqlDatabase) ListEntityPools(_ context.Context, entity params.GithubEntity) ([]params.Pool, error) { - fmt.Println(entity) - return nil, nil + pools, err := s.listEntityPools(s.conn, entity.EntityType, entity.ID, "Tags") + if err != nil { + return nil, errors.Wrap(err, "fetching pools") + } + + ret := make([]params.Pool, len(pools)) + for idx, pool := range pools { + ret[idx], err = s.sqlToCommonPool(pool) + if err != nil { + return nil, errors.Wrap(err, "fetching pool") + } + } + + return ret, nil } func (s *sqlDatabase) ListEntityInstances(_ context.Context, entity params.GithubEntity) ([]params.Instance, error) { - fmt.Println(entity) - return nil, nil + pools, err := s.listEntityPools(s.conn, entity.EntityType, entity.ID, "Instances", "Instances.Job") + if err != nil { + return nil, errors.Wrap(err, "fetching entity") + } + ret := []params.Instance{} + for _, pool := range pools { + for _, instance := range pool.Instances { + paramsInstance, err := s.sqlToParamsInstance(instance) + if err != nil { + return nil, errors.Wrap(err, "fetching instance") + } + ret = append(ret, paramsInstance) + } + } + return ret, nil } diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index aac01f99..c05711cb 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -66,10 +66,8 @@ func (s *PoolsTestSuite) SetupTest() { s.FailNow(fmt.Sprintf("failed to create org: %s", err)) } - entity := params.GithubEntity{ - ID: org.ID, - EntityType: params.GithubEntityTypeOrganization, - } + entity, err := org.GetEntity() + s.Require().Nil(err) // create some pool objects in the database, for testing purposes orgPools := []params.Pool{} for i := 1; i <= 3; i++ { diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 936b796c..164c0197 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -20,7 +20,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -151,138 +150,6 @@ func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (par return param, nil } -func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) { - if len(param.Tags) == 0 { - return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") - } - - repo, err := s.getRepoByID(ctx, repoID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching repo") - } - - newPool := Pool{ - ProviderName: param.ProviderName, - MaxRunners: param.MaxRunners, - MinIdleRunners: param.MinIdleRunners, - RunnerPrefix: param.GetRunnerPrefix(), - Image: param.Image, - Flavor: param.Flavor, - OSType: param.OSType, - OSArch: param.OSArch, - RepoID: &repo.ID, - Enabled: param.Enabled, - RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, - GitHubRunnerGroup: param.GitHubRunnerGroup, - Priority: param.Priority, - } - - if len(param.ExtraSpecs) > 0 { - newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) - } - - _, err = s.getRepoPoolByUniqueFields(ctx, repoID, newPool.ProviderName, newPool.Image, newPool.Flavor) - if err != nil { - if !errors.Is(err, runnerErrors.ErrNotFound) { - return params.Pool{}, errors.Wrap(err, "creating pool") - } - } else { - return params.Pool{}, runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") - } - - tags := []Tag{} - for _, val := range param.Tags { - t, err := s.getOrCreateTag(s.conn, val) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching tag") - } - tags = append(tags, t) - } - - q := s.conn.Create(&newPool) - if q.Error != nil { - return params.Pool{}, errors.Wrap(q.Error, "adding pool") - } - - for i := range tags { - if err := s.conn.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { - return params.Pool{}, errors.Wrap(err, "saving tag") - } - } - - pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeRepository, repoID, "Tags", "Instances", "Repository") - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - - ret := make([]params.Pool, len(pools)) - for idx, pool := range pools { - ret[idx], err = s.sqlToCommonPool(pool) - if err != nil { - return nil, errors.Wrap(err, "fetching pool") - } - } - - return ret, nil -} - -func (s *sqlDatabase) GetRepositoryPool(_ context.Context, repoID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) DeleteRepositoryPool(_ context.Context, repoID, poolID string) error { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeRepository, repoID, poolID) - if err != nil { - return errors.Wrap(err, "looking up repo pool") - } - q := s.conn.Unscoped().Delete(&pool) - if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { - return errors.Wrap(q.Error, "deleting pool") - } - return nil -} - -func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeRepository, repoID, "Tags", "Instances", "Instances.Job") - if err != nil { - return nil, errors.Wrap(err, "fetching repo") - } - - ret := []params.Instance{} - for _, pool := range pools { - for _, instance := range pool.Instances { - paramsInstance, err := s.sqlToParamsInstance(instance) - if err != nil { - return nil, errors.Wrap(err, "fetching instance") - } - ret = append(ret, paramsInstance) - } - } - return ret, nil -} - -func (s *sqlDatabase) UpdateRepositoryPool(_ context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(s.conn, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.updatePool(s.conn, pool, param) -} - func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository, error) { var repo Repository @@ -325,25 +192,6 @@ func (s *sqlDatabase) getEntityPoolByUniqueFields(tx *gorm.DB, entity params.Git return Pool{}, nil } -func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID string, provider, image, flavor string) (Pool, error) { - repo, err := s.getRepoByID(ctx, repoID) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching repo") - } - - q := s.conn - var pool []Pool - err = q.Model(&repo).Association("Pools").Find(&pool, "provider_name = ? and image = ? and flavor = ?", provider, image, flavor) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching pool") - } - if len(pool) == 0 { - return Pool{}, runnerErrors.ErrNotFound - } - - return pool[0], nil -} - func (s *sqlDatabase) getRepoByID(_ context.Context, id string, preload ...string) (Repository, error) { u, err := uuid.Parse(id) if err != nil { diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index ab1f9da5..18126197 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -443,10 +443,8 @@ func (s *RepoTestSuite) TestGetRepositoryByIDDBDecryptingErr() { } func (s *RepoTestSuite) TestCreateRepositoryPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) @@ -463,11 +461,9 @@ func (s *RepoTestSuite) TestCreateRepositoryPool() { func (s *RepoTestSuite) TestCreateRepositoryPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } - _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) @@ -485,41 +481,37 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolInvalidRepoID() { } func (s *RepoTestSuite) TestCreateRepositoryPoolDBCreateErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked creating pool error")) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating pool: fetching pool: mocked creating pool error", err.Error()) + s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBPoolAlreadyExistErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id", "provider_name", "image", "flavor"}). AddRow( s.Fixtures.Repos[0].ID, @@ -527,159 +519,145 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBPoolAlreadyExistErr() { s.Fixtures.CreatePoolParams.Image, s.Fixtures.CreatePoolParams.Flavor)) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("pool with the same image and flavor already exists on this provider", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchTagErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked fetching tag error")) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("fetching tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBAddingPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnError(fmt.Errorf("mocked adding pool error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("adding pool: mocked adding pool error", err.Error()) + s.Require().Equal("creating pool: mocked adding pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBSaveTagErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnError(fmt.Errorf("mocked saving tag error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) - s.assertSQLMockExpectations() + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("saving tag: mocked saving tag error", err.Error()) + s.Require().Equal("associating tags: mocked saving tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -694,18 +672,19 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchPoolErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestListRepoPools() { - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) repoPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%d", i) @@ -716,24 +695,26 @@ func (s *RepoTestSuite) TestListRepoPools() { repoPools = append(repoPools, pool) } - pools, err := s.Store.ListRepoPools(context.Background(), s.Fixtures.Repos[0].ID) + pools, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), repoPools, pools) } func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() { - _, err := s.Store.ListRepoPools(context.Background(), "dummy-repo-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestGetRepositoryPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) @@ -757,10 +738,8 @@ func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() { } func (s *RepoTestSuite) TestDeleteRepositoryPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) @@ -785,38 +764,30 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolInvalidRepoID() { } func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). - WithArgs(pool.ID, s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"repo_id", "id"}).AddRow(s.Fixtures.Repos[0].ID, pool.ID)) s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. - ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE `pools`.`id` = ?")). - WithArgs(pool.ID). + ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE id = ? and repo_id = ?")). + WithArgs(pool.ID, s.Fixtures.Repos[0].ID). WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) - - s.assertSQLMockExpectations() + err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().NotNil(err) - s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) + s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestListRepoInstances() { - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) @@ -831,30 +802,32 @@ func (s *RepoTestSuite) TestListRepoInstances() { poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListRepoInstances(context.Background(), s.Fixtures.Repos[0].ID) + instances, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().Nil(err) s.equalInstancesByID(poolInstances, instances) } func (s *RepoTestSuite) TestListRepoInstancesInvalidRepoID() { - _, err := s.Store.ListRepoInstances(context.Background(), "dummy-repo-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().NotNil(err) - s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) + s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestUpdateRepositoryPool() { - entity := params.GithubEntity{ - ID: s.Fixtures.Repos[0].ID, - EntityType: params.GithubEntityTypeRepository, - } + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) repoPool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - pool, err := s.Store.UpdateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, repoPool.ID, s.Fixtures.UpdatePoolParams) + pool, err := s.Store.UpdateEntityPool(context.Background(), entity, repoPool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -864,7 +837,11 @@ func (s *RepoTestSuite) TestUpdateRepositoryPool() { } func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() { - _, err := s.Store.UpdateRepositoryPool(context.Background(), "dummy-org-id", "dummy-repo-id", s.Fixtures.UpdatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-repo-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/util.go b/database/sql/util.go index 2a41050c..aaea31fe 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -290,9 +290,8 @@ func (s *sqlDatabase) getOrCreateTag(tx *gorm.DB, tagName string) (Tag, error) { Name: tagName, } - q = tx.Create(&newTag) - if q.Error != nil { - return Tag{}, errors.Wrap(q.Error, "creating tag") + if err := tx.Create(&newTag).Error; err != nil { + return Tag{}, errors.Wrap(err, "creating tag") } return newTag, nil } @@ -392,10 +391,10 @@ func (s *sqlDatabase) getPoolByID(tx *gorm.DB, poolID string, preload ...string) return pool, nil } -func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntityType, entityID string) (bool, error) { +func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntityType, entityID string) error { u, err := uuid.Parse(entityID) if err != nil { - return false, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var q *gorm.DB switch entityType { @@ -406,15 +405,15 @@ func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntit case params.GithubEntityTypeEnterprise: q = tx.Model(&Enterprise{}).Where("id = ?", u) default: - return false, errors.Wrap(runnerErrors.ErrBadRequest, "invalid entity type") + return errors.Wrap(runnerErrors.ErrBadRequest, "invalid entity type") } var entity interface{} if err := q.First(entity).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return false, errors.Wrap(runnerErrors.ErrNotFound, "entity not found") + return errors.Wrap(runnerErrors.ErrNotFound, "entity not found") } - return false, errors.Wrap(err, "fetching entity from database") + return errors.Wrap(err, "fetching entity from database") } - return true, nil + return nil } diff --git a/params/params.go b/params/params.go index 2b87a4c5..a2a44222 100644 --- a/params/params.go +++ b/params/params.go @@ -317,6 +317,27 @@ type Pool struct { Priority uint `json:"priority"` } +func (p Pool) GithubEntity() (GithubEntity, error) { + switch p.PoolType() { + case GithubEntityTypeRepository: + return GithubEntity{ + ID: p.RepoID, + EntityType: GithubEntityTypeRepository, + }, nil + case GithubEntityTypeOrganization: + return GithubEntity{ + ID: p.OrgID, + EntityType: GithubEntityTypeOrganization, + }, nil + case GithubEntityTypeEnterprise: + return GithubEntity{ + ID: p.EnterpriseID, + EntityType: GithubEntityTypeEnterprise, + }, nil + } + return GithubEntity{}, fmt.Errorf("pool has no associated entity") +} + func (p Pool) GetID() string { return p.ID } @@ -383,6 +404,18 @@ type Repository struct { WebhookSecret string `json:"-"` } +func (r Repository) GetEntity() (GithubEntity, error) { + if r.ID == "" { + return GithubEntity{}, fmt.Errorf("repository has no ID") + } + return GithubEntity{ + ID: r.ID, + EntityType: GithubEntityTypeRepository, + Owner: r.Owner, + Name: r.Name, + }, nil +} + func (r Repository) GetName() string { return r.Name } @@ -412,6 +445,18 @@ type Organization struct { WebhookSecret string `json:"-"` } +func (o Organization) GetEntity() (GithubEntity, error) { + if o.ID == "" { + return GithubEntity{}, fmt.Errorf("organization has no ID") + } + return GithubEntity{ + ID: o.ID, + EntityType: GithubEntityTypeOrganization, + Owner: o.Name, + WebhookSecret: o.WebhookSecret, + }, nil +} + func (o Organization) GetName() string { return o.Name } @@ -441,6 +486,18 @@ type Enterprise struct { WebhookSecret string `json:"-"` } +func (e Enterprise) GetEntity() (GithubEntity, error) { + if e.ID == "" { + return GithubEntity{}, fmt.Errorf("enterprise has no ID") + } + return GithubEntity{ + ID: e.ID, + EntityType: GithubEntityTypeEnterprise, + Owner: e.Name, + WebhookSecret: e.WebhookSecret, + }, nil +} + func (e Enterprise) GetName() string { return e.Name } diff --git a/runner/enterprises.go b/runner/enterprises.go index 7d6e5b8e..c5274e09 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -124,7 +124,12 @@ func (r *Runner) DeleteEnterprise(ctx context.Context, enterpriseID string) erro return errors.Wrap(err, "fetching enterprise") } - pools, err := r.store.ListEnterprisePools(ctx, enterpriseID) + entity, err := enterprise.GetEntity() + if err != nil { + return errors.Wrap(err, "getting entity") + } + + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return errors.Wrap(err, "fetching enterprise pools") } @@ -266,7 +271,11 @@ func (r *Runner) ListEnterprisePools(ctx context.Context, enterpriseID string) ( return []params.Pool{}, runnerErrors.ErrUnauthorized } - pools, err := r.store.ListEnterprisePools(ctx, enterpriseID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return nil, errors.Wrap(err, "fetching pools") } @@ -301,7 +310,7 @@ func (r *Runner) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID return params.Pool{}, runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners") } - newPool, err := r.store.UpdateEnterprisePool(ctx, enterpriseID, poolID, param) + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } @@ -312,8 +321,11 @@ func (r *Runner) ListEnterpriseInstances(ctx context.Context, enterpriseID strin if !auth.IsAdmin(ctx) { return nil, runnerErrors.ErrUnauthorized } - - instances, err := r.store.ListEnterpriseInstances(ctx, enterpriseID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + instances, err := r.store.ListEntityInstances(ctx, entity) if err != nil { return []params.Instance{}, errors.Wrap(err, "fetching instances") } diff --git a/runner/organizations.go b/runner/organizations.go index 258753f0..40847ccf 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -138,7 +138,12 @@ func (r *Runner) DeleteOrganization(ctx context.Context, orgID string, keepWebho return errors.Wrap(err, "fetching org") } - pools, err := r.store.ListOrgPools(ctx, orgID) + entity, err := org.GetEntity() + if err != nil { + return errors.Wrap(err, "getting entity") + } + + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return errors.Wrap(err, "fetching org pools") } @@ -300,8 +305,11 @@ func (r *Runner) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, if !auth.IsAdmin(ctx) { return []params.Pool{}, runnerErrors.ErrUnauthorized } - - pools, err := r.store.ListOrgPools(ctx, orgID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return nil, errors.Wrap(err, "fetching pools") } @@ -337,7 +345,7 @@ func (r *Runner) UpdateOrgPool(ctx context.Context, orgID, poolID string, param return params.Pool{}, runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners") } - newPool, err := r.store.UpdateOrganizationPool(ctx, orgID, poolID, param) + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } @@ -349,7 +357,12 @@ func (r *Runner) ListOrgInstances(ctx context.Context, orgID string) ([]params.I return nil, runnerErrors.ErrUnauthorized } - instances, err := r.store.ListOrgInstances(ctx, orgID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + instances, err := r.store.ListEntityInstances(ctx, entity) if err != nil { return []params.Instance{}, errors.Wrap(err, "fetching instances") } diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 5291be6b..bdfc0d3b 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -2173,28 +2173,11 @@ func (r *basePoolManager) GithubURL() string { } func (r *basePoolManager) FetchDbInstances() ([]params.Instance, error) { - switch r.entity.EntityType { - case params.GithubEntityTypeRepository: - return r.store.ListRepoInstances(r.ctx, r.entity.ID) - case params.GithubEntityTypeOrganization: - return r.store.ListOrgInstances(r.ctx, r.entity.ID) - case params.GithubEntityTypeEnterprise: - return r.store.ListEnterpriseInstances(r.ctx, r.entity.ID) - } - return nil, fmt.Errorf("unknown entity type: %s", r.entity.EntityType) + return r.store.ListEntityInstances(r.ctx, r.entity) } func (r *basePoolManager) ListPools() ([]params.Pool, error) { - switch r.entity.EntityType { - case params.GithubEntityTypeRepository: - return r.store.ListRepoPools(r.ctx, r.entity.ID) - case params.GithubEntityTypeOrganization: - return r.store.ListOrgPools(r.ctx, r.entity.ID) - case params.GithubEntityTypeEnterprise: - return r.store.ListEnterprisePools(r.ctx, r.entity.ID) - default: - return nil, fmt.Errorf("unknown entity type: %s", r.entity.EntityType) - } + return r.store.ListEntityPools(r.ctx, r.entity) } func (r *basePoolManager) GetPoolByID(poolID string) (params.Pool, error) { diff --git a/runner/pools.go b/runner/pools.go index 16194f65..aab423ff 100644 --- a/runner/pools.go +++ b/runner/pools.go @@ -16,7 +16,6 @@ package runner import ( "context" - "fmt" "github.com/pkg/errors" @@ -108,19 +107,12 @@ func (r *Runner) UpdatePoolByID(ctx context.Context, poolID string, param params param.Tags = newTags } - var newPool params.Pool - - switch { - case pool.RepoID != "": - newPool, err = r.store.UpdateRepositoryPool(ctx, pool.RepoID, poolID, param) - case pool.OrgID != "": - newPool, err = r.store.UpdateOrganizationPool(ctx, pool.OrgID, poolID, param) - case pool.EnterpriseID != "": - newPool, err = r.store.UpdateEnterprisePool(ctx, pool.EnterpriseID, poolID, param) - default: - return params.Pool{}, fmt.Errorf("pool not found to a repo, org or enterprise") + entity, err := pool.GithubEntity() + if err != nil { + return params.Pool{}, errors.Wrap(err, "getting entity") } + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } diff --git a/runner/repositories.go b/runner/repositories.go index 68ef38f5..f7692b69 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -137,7 +137,12 @@ func (r *Runner) DeleteRepository(ctx context.Context, repoID string, keepWebhoo return errors.Wrap(err, "fetching repo") } - pools, err := r.store.ListRepoPools(ctx, repoID) + entity, err := repo.GetEntity() + if err != nil { + return errors.Wrap(err, "getting entity") + } + + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return errors.Wrap(err, "fetching repo pools") } @@ -295,8 +300,11 @@ func (r *Runner) ListRepoPools(ctx context.Context, repoID string) ([]params.Poo if !auth.IsAdmin(ctx) { return []params.Pool{}, runnerErrors.ErrUnauthorized } - - pools, err := r.store.ListRepoPools(ctx, repoID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return nil, errors.Wrap(err, "fetching pools") } @@ -343,7 +351,7 @@ func (r *Runner) UpdateRepoPool(ctx context.Context, repoID, poolID string, para return params.Pool{}, runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners") } - newPool, err := r.store.UpdateRepositoryPool(ctx, repoID, poolID, param) + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } @@ -354,8 +362,11 @@ func (r *Runner) ListRepoInstances(ctx context.Context, repoID string) ([]params if !auth.IsAdmin(ctx) { return nil, runnerErrors.ErrUnauthorized } - - instances, err := r.store.ListRepoInstances(ctx, repoID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + instances, err := r.store.ListEntityInstances(ctx, entity) if err != nil { return []params.Instance{}, errors.Wrap(err, "fetching instances") }