Implement some common logic for pool creation

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2024-03-28 10:08:19 +00:00
parent 72501aee0f
commit 0152b21529
15 changed files with 201 additions and 254 deletions

View file

@ -29,11 +29,9 @@ type RepoStore interface {
UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error)
CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error)
GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error)
DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error
UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error)
ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error)
ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error)
@ -52,11 +50,20 @@ type OrgStore interface {
DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error
UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error)
ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error)
ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error)
}
type EntityPools interface {
CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error)
GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error)
DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error
UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error)
ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error)
ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error)
}
type EnterpriseStore interface {
CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error)
GetEnterprise(ctx context.Context, name string) (params.Enterprise, error)
@ -70,7 +77,6 @@ type EnterpriseStore interface {
DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error
UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error)
ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error)
ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error)
}
@ -139,6 +145,7 @@ type Store interface {
UserStore
InstanceStore
JobsStore
EntityPools
ControllerInfo() (params.ControllerInfo, error)
InitController() (params.ControllerInfo, error)

View file

@ -510,62 +510,6 @@ func (_m *Store) DeleteRepositoryPool(ctx context.Context, repoID string, poolID
return r0
}
// FindEnterprisePoolByTags provides a mock function with given fields: ctx, enterpriseID, tags
func (_m *Store) FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) {
ret := _m.Called(ctx, enterpriseID, tags)
if len(ret) == 0 {
panic("no return value specified for FindEnterprisePoolByTags")
}
var r0 params.Pool
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok {
return rf(ctx, enterpriseID, tags)
}
if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok {
r0 = rf(ctx, enterpriseID, tags)
} else {
r0 = ret.Get(0).(params.Pool)
}
if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok {
r1 = rf(ctx, enterpriseID, tags)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindOrganizationPoolByTags provides a mock function with given fields: ctx, orgID, tags
func (_m *Store) FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) {
ret := _m.Called(ctx, orgID, tags)
if len(ret) == 0 {
panic("no return value specified for FindOrganizationPoolByTags")
}
var r0 params.Pool
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok {
return rf(ctx, orgID, tags)
}
if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok {
r0 = rf(ctx, orgID, tags)
} else {
r0 = ret.Get(0).(params.Pool)
}
if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok {
r1 = rf(ctx, orgID, tags)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindPoolsMatchingAllTags provides a mock function with given fields: ctx, entityType, entityID, tags
func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params.GithubEntityType, entityID string, tags []string) ([]params.Pool, error) {
ret := _m.Called(ctx, entityType, entityID, tags)
@ -596,34 +540,6 @@ func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params
return r0, r1
}
// FindRepositoryPoolByTags provides a mock function with given fields: ctx, repoID, tags
func (_m *Store) FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error) {
ret := _m.Called(ctx, repoID, tags)
if len(ret) == 0 {
panic("no return value specified for FindRepositoryPoolByTags")
}
var r0 params.Pool
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok {
return rf(ctx, repoID, tags)
}
if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok {
r0 = rf(ctx, repoID, tags)
} else {
r0 = ret.Get(0).(params.Pool)
}
if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok {
r1 = rf(ctx, repoID, tags)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetEnterprise provides a mock function with given fields: ctx, name
func (_m *Store) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) {
ret := _m.Called(ctx, name)

View file

@ -175,7 +175,7 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str
tags := []Tag{}
for _, val := range param.Tags {
t, err := s.getOrCreateTag(val)
t, err := s.getOrCreateTag(s.conn, val)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching tag")
}
@ -193,7 +193,7 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str
}
}
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
@ -230,14 +230,6 @@ func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, po
return s.updatePool(pool, param)
}
func (s *sqlDatabase) FindEnterprisePoolByTags(_ context.Context, enterpriseID string, tags []string) (params.Pool, error) {
pool, err := s.findPoolByTags(enterpriseID, params.GithubEntityTypeEnterprise, tags)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return pool[0], nil
}
func (s *sqlDatabase) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) {
pools, err := s.listEntityPools(ctx, params.GithubEntityTypeEnterprise, enterpriseID, "Tags", "Instances", "Enterprise")
if err != nil {

View file

@ -740,29 +740,6 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() {
s.Require().Equal("deleting pool: mocked deleting pool error", err.Error())
}
func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTags() {
enterprisePool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams)
if err != nil {
s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err))
}
pool, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.Tags)
s.Require().Nil(err)
s.Require().Equal(enterprisePool.ID, pool.ID)
s.Require().Equal(enterprisePool.Image, pool.Image)
s.Require().Equal(enterprisePool.Flavor, pool.Flavor)
}
func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTagsMissingTags() {
tags := []string{}
_, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, tags)
s.Require().NotNil(err)
s.Require().Equal("fetching pool: missing tags", err.Error())
}
func (s *EnterpriseTestSuite) TestListEnterpriseInstances() {
pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams)
if err != nil {

View file

@ -49,7 +49,7 @@ func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error
}
func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) {
pool, err := s.getPoolByID(ctx, poolID)
pool, err := s.getPoolByID(s.conn, poolID)
if err != nil {
return params.Instance{}, errors.Wrap(err, "fetching pool")
}
@ -108,8 +108,8 @@ func (s *sqlDatabase) getInstanceByID(_ context.Context, instanceID string) (Ins
return instance, nil
}
func (s *sqlDatabase) getPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (Instance, error) {
pool, err := s.getPoolByID(ctx, poolID)
func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) {
pool, err := s.getPoolByID(s.conn, poolID)
if err != nil {
return Instance{}, errors.Wrap(err, "fetching pool")
}
@ -153,7 +153,7 @@ func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string,
}
func (s *sqlDatabase) GetPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error) {
instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName)
instance, err := s.getPoolInstanceByName(poolID, instanceName)
if err != nil {
return params.Instance{}, errors.Wrap(err, "fetching instance")
}
@ -171,7 +171,7 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string
}
func (s *sqlDatabase) DeleteInstance(ctx context.Context, poolID string, instanceName string) error {
instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName)
instance, err := s.getPoolInstanceByName(poolID, instanceName)
if err != nil {
return errors.Wrap(err, "deleting instance")
}
@ -338,7 +338,7 @@ func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, er
}
func (s *sqlDatabase) PoolInstanceCount(ctx context.Context, poolID string) (int64, error) {
pool, err := s.getPoolByID(ctx, poolID)
pool, err := s.getPoolByID(s.conn, poolID)
if err != nil {
return 0, errors.Wrap(err, "fetching pool")
}

View file

@ -192,7 +192,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string,
tags := []Tag{}
for _, val := range param.Tags {
t, err := s.getOrCreateTag(val)
t, err := s.getOrCreateTag(s.conn, val)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching tag")
}
@ -210,7 +210,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string,
}
}
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
@ -255,14 +255,6 @@ func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID
return nil
}
func (s *sqlDatabase) FindOrganizationPoolByTags(_ context.Context, orgID string, tags []string) (params.Pool, error) {
pool, err := s.findPoolByTags(orgID, params.GithubEntityTypeOrganization, tags)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return pool[0], nil
}
func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) {
pools, err := s.listEntityPools(ctx, params.GithubEntityTypeOrganization, orgID, "Tags", "Instances", "Instances.Job")
if err != nil {
@ -290,30 +282,6 @@ func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID
return s.updatePool(pool, param)
}
func (s *sqlDatabase) getPoolByID(_ context.Context, poolID string, preload ...string) (Pool, error) {
u, err := uuid.Parse(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var pool Pool
q := s.conn.Model(&Pool{})
if len(preload) > 0 {
for _, item := range preload {
q = q.Preload(item)
}
}
q = q.Where("id = ?", u).First(&pool)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return Pool{}, runnerErrors.ErrNotFound
}
return Pool{}, errors.Wrap(q.Error, "fetching org from database")
}
return pool, nil
}
func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) {
u, err := uuid.Parse(id)
if err != nil {

View file

@ -740,29 +740,6 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() {
s.Require().Equal("deleting pool: mocked deleting pool error", err.Error())
}
func (s *OrgTestSuite) TestFindOrganizationPoolByTags() {
orgPool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams)
if err != nil {
s.FailNow(fmt.Sprintf("cannot create org pool: %v", err))
}
pool, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.Tags)
s.Require().Nil(err)
s.Require().Equal(orgPool.ID, pool.ID)
s.Require().Equal(orgPool.Image, pool.Image)
s.Require().Equal(orgPool.Flavor, pool.Flavor)
}
func (s *OrgTestSuite) TestFindOrganizationPoolByTagsMissingTags() {
tags := []string{}
_, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, tags)
s.Require().NotNil(err)
s.Require().Equal("fetching pool: missing tags", err.Error())
}
func (s *OrgTestSuite) TestListOrgInstances() {
pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams)
if err != nil {

View file

@ -20,6 +20,7 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
"gorm.io/datatypes"
"gorm.io/gorm"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
@ -58,7 +59,7 @@ func (s *sqlDatabase) ListAllPools(_ context.Context) ([]params.Pool, error) {
}
func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) {
pool, err := s.getPoolByID(ctx, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
pool, err := s.getPoolByID(s.conn, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool by ID")
}
@ -66,7 +67,7 @@ func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Po
}
func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error {
pool, err := s.getPoolByID(ctx, poolID)
pool, err := s.getPoolByID(s.conn, poolID)
if err != nil {
return errors.Wrap(err, "fetching pool by ID")
}
@ -231,3 +232,101 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par
return pools, nil
}
func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) {
if len(param.Tags) == 0 {
return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified")
}
newPool := Pool{
ProviderName: param.ProviderName,
MaxRunners: param.MaxRunners,
MinIdleRunners: param.MinIdleRunners,
RunnerPrefix: param.GetRunnerPrefix(),
Image: param.Image,
Flavor: param.Flavor,
OSType: param.OSType,
OSArch: param.OSArch,
Enabled: param.Enabled,
RunnerBootstrapTimeout: param.RunnerBootstrapTimeout,
GitHubRunnerGroup: param.GitHubRunnerGroup,
Priority: param.Priority,
}
if len(param.ExtraSpecs) > 0 {
newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs)
}
entityID, err := uuid.Parse(entity.ID)
if err != nil {
return params.Pool{}, fmt.Errorf("parsing entity ID: %w", err)
}
switch entity.EntityType {
case params.GithubEntityTypeRepository:
newPool.RepoID = &entityID
case params.GithubEntityTypeOrganization:
newPool.OrgID = &entityID
case params.GithubEntityTypeEnterprise:
newPool.EnterpriseID = &entityID
}
err = s.conn.Transaction(func(tx *gorm.DB) error {
if _, err := s.getEntityPoolByUniqueFields(tx, entity, newPool.ProviderName, newPool.Image, newPool.Flavor); err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
return fmt.Errorf("checking for existing pool: %w", err)
}
} else {
return runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider")
}
tags := []Tag{}
for _, val := range param.Tags {
t, err := s.getOrCreateTag(tx, val)
if err != nil {
return fmt.Errorf("creating tag: %w", err)
}
tags = append(tags, t)
}
q := tx.Create(&newPool)
if q.Error != nil {
return fmt.Errorf("creating pool: %w", q.Error)
}
for i := range tags {
if err := tx.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil {
return fmt.Errorf("associating tags: %w", err)
}
}
return nil
})
if err != nil {
return params.Pool{}, fmt.Errorf("creating pool: %w", err)
}
pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return s.sqlToCommonPool(pool)
}
func (s *sqlDatabase) GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) {
return params.Pool{}, nil
}
func (s *sqlDatabase) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error {
return nil
}
func (s *sqlDatabase) UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
return params.Pool{}, nil
}
func (s *sqlDatabase) ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) {
return nil, nil
}
func (s *sqlDatabase) ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) {
return nil, nil
}

View file

@ -192,7 +192,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, p
tags := []Tag{}
for _, val := range param.Tags {
t, err := s.getOrCreateTag(val)
t, err := s.getOrCreateTag(s.conn, val)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching tag")
}
@ -210,7 +210,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, p
}
}
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
@ -255,14 +255,6 @@ func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID s
return nil
}
func (s *sqlDatabase) FindRepositoryPoolByTags(_ context.Context, repoID string, tags []string) (params.Pool, error) {
pool, err := s.findPoolByTags(repoID, params.GithubEntityTypeRepository, tags)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return pool[0], nil
}
func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) {
pools, err := s.listEntityPools(ctx, params.GithubEntityTypeRepository, repoID, "Tags", "Instances", "Instances.Job")
if err != nil {
@ -308,6 +300,31 @@ func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository
return repo, nil
}
func (s *sqlDatabase) getEntityPoolByUniqueFields(tx *gorm.DB, entity params.GithubEntity, provider, image, flavor string) (pool Pool, err error) {
var entityField string
switch entity.EntityType {
case params.GithubEntityTypeRepository:
entityField = entityTypeRepoName
case params.GithubEntityTypeOrganization:
entityField = entityTypeOrgName
case params.GithubEntityTypeEnterprise:
entityField = entityTypeEnterpriseName
}
entityID, err := uuid.Parse(entity.ID)
if err != nil {
return pool, fmt.Errorf("parsing entity ID: %w", err)
}
poolQueryString := fmt.Sprintf("provider_name = ? and image = ? and flavor = ? and %s = ?", entityField)
err = tx.Where(poolQueryString, provider, image, flavor, entityID).First(&pool).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return pool, runnerErrors.ErrNotFound
}
return
}
return Pool{}, nil
}
func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID string, provider, image, flavor string) (Pool, error) {
repo, err := s.getRepoByID(ctx, repoID)
if err != nil {

View file

@ -777,28 +777,6 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() {
s.Require().Equal("deleting pool: mocked deleting pool error", err.Error())
}
func (s *RepoTestSuite) TestFindRepositoryPoolByTags() {
repoPool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams)
if err != nil {
s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err))
}
pool, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.Tags)
s.Require().Nil(err)
s.Require().Equal(repoPool.ID, pool.ID)
s.Require().Equal(repoPool.Image, pool.Image)
s.Require().Equal(repoPool.Flavor, pool.Flavor)
}
func (s *RepoTestSuite) TestFindRepositoryPoolByTagsMissingTags() {
tags := []string{}
_, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, tags)
s.Require().NotNil(err)
s.Require().Equal("fetching pool: missing tags", err.Error())
}
func (s *RepoTestSuite) TestListRepoInstances() {
pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams)
if err != nil {

View file

@ -18,10 +18,12 @@ import (
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/pkg/errors"
"gorm.io/datatypes"
"gorm.io/gorm"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
commonParams "github.com/cloudbase/garm-provider-common/params"
"github.com/cloudbase/garm-provider-common/util"
"github.com/cloudbase/garm/params"
@ -275,9 +277,9 @@ func (s *sqlDatabase) sqlToParamsUser(user User) params.User {
}
}
func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) {
func (s *sqlDatabase) getOrCreateTag(tx *gorm.DB, tagName string) (Tag, error) {
var tag Tag
q := s.conn.Where("name = ?", tagName).First(&tag)
q := tx.Where("name = ?", tagName).First(&tag)
if q.Error == nil {
return tag, nil
}
@ -288,7 +290,7 @@ func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) {
Name: tagName,
}
q = s.conn.Create(&newTag)
q = tx.Create(&newTag)
if q.Error != nil {
return Tag{}, errors.Wrap(q.Error, "creating tag")
}
@ -351,7 +353,7 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para
tags := []Tag{}
if param.Tags != nil && len(param.Tags) > 0 {
for _, val := range param.Tags {
t, err := s.getOrCreateTag(val)
t, err := s.getOrCreateTag(s.conn, val)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching tag")
}
@ -365,3 +367,27 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para
return s.sqlToCommonPool(pool)
}
func (s *sqlDatabase) getPoolByID(tx *gorm.DB, poolID string, preload ...string) (Pool, error) {
u, err := uuid.Parse(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var pool Pool
q := tx.Model(&Pool{})
if len(preload) > 0 {
for _, item := range preload {
q = q.Preload(item)
}
}
q = q.Where("id = ?", u).First(&pool)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return Pool{}, runnerErrors.ErrNotFound
}
return Pool{}, errors.Wrap(q.Error, "fetching org from database")
}
return pool, nil
}

View file

@ -193,18 +193,11 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string,
return params.Pool{}, runnerErrors.ErrUnauthorized
}
r.mux.Lock()
defer r.mux.Unlock()
enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID)
_, err := r.store.GetEnterpriseByID(ctx, enterpriseID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching enterprise")
}
if _, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise); err != nil {
return params.Pool{}, runnerErrors.ErrNotFound
}
createPoolParams, err := r.appendTagsToCreatePoolParams(param)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool params")
@ -214,7 +207,12 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string,
param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout
}
pool, err := r.store.CreateEnterprisePool(ctx, enterpriseID, createPoolParams)
entity := params.GithubEntity{
ID: enterpriseID,
EntityType: params.GithubEntityTypeEnterprise,
}
pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams)
if err != nil {
return params.Pool{}, errors.Wrap(err, "creating pool")
}

View file

@ -222,18 +222,11 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C
return params.Pool{}, runnerErrors.ErrUnauthorized
}
r.mux.Lock()
defer r.mux.Unlock()
org, err := r.store.GetOrganizationByID(ctx, orgID)
_, err := r.store.GetOrganizationByID(ctx, orgID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching org")
}
if _, err := r.poolManagerCtrl.GetOrgPoolManager(org); err != nil {
return params.Pool{}, runnerErrors.ErrNotFound
}
createPoolParams, err := r.appendTagsToCreatePoolParams(param)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool params")
@ -243,7 +236,12 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C
param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout
}
pool, err := r.store.CreateOrganizationPool(ctx, orgID, createPoolParams)
entity := params.GithubEntity{
ID: orgID,
EntityType: params.GithubEntityTypeOrganization,
}
pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams)
if err != nil {
return params.Pool{}, errors.Wrap(err, "creating pool")
}

View file

@ -1725,10 +1725,6 @@ func (r *basePoolManager) WebhookSecret() string {
return r.entity.WebhookSecret
}
func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) {
return r.GetGithubRegistrationToken()
}
func (r *basePoolManager) ID() string {
return r.entity.ID
}
@ -2095,7 +2091,7 @@ func (r *basePoolManager) GetRunnerInfoFromWorkflow(job params.WorkflowJob) (par
return params.RunnerInfo{}, fmt.Errorf("failed to find runner name from workflow")
}
func (r *basePoolManager) GetGithubRegistrationToken() (string, error) {
func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) {
tk, ghResp, err := r.ghcli.CreateEntityRegistrationToken(r.ctx)
if err != nil {
if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized {

View file

@ -221,18 +221,11 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params
return params.Pool{}, runnerErrors.ErrUnauthorized
}
r.mux.Lock()
defer r.mux.Unlock()
repo, err := r.store.GetRepositoryByID(ctx, repoID)
_, err := r.store.GetRepositoryByID(ctx, repoID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching repo")
}
if _, err := r.poolManagerCtrl.GetRepoPoolManager(repo); err != nil {
return params.Pool{}, runnerErrors.ErrNotFound
}
createPoolParams, err := r.appendTagsToCreatePoolParams(param)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool params")
@ -242,7 +235,12 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params
createPoolParams.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout
}
pool, err := r.store.CreateRepositoryPool(ctx, repoID, createPoolParams)
entity := params.GithubEntity{
ID: repoID,
EntityType: params.GithubEntityTypeRepository,
}
pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams)
if err != nil {
return params.Pool{}, errors.Wrap(err, "creating pool")
}