diff --git a/cmd/garm-cli/cmd/enterprise.go b/cmd/garm-cli/cmd/enterprise.go index 98457aef..dc9000d8 100644 --- a/cmd/garm-cli/cmd/enterprise.go +++ b/cmd/garm-cli/cmd/enterprise.go @@ -201,10 +201,10 @@ func init() { func formatEnterprises(enterprises []params.Enterprise) { t := table.NewWriter() - header := table.Row{"ID", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} + header := table.Row{"ID", "Name", "Endpoint", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range enterprises { - t.AppendRow(table.Row{val.ID, val.Name, val.Credentials.Name, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Name, val.Endpoint.Name, val.Credentials.Name, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -217,6 +217,7 @@ func formatOneEnterprise(enterprise params.Enterprise) { t.AppendHeader(header) t.AppendRow(table.Row{"ID", enterprise.ID}) t.AppendRow(table.Row{"Name", enterprise.Name}) + t.AppendRow(table.Row{"Endpoint", enterprise.Endpoint.Name}) t.AppendRow(table.Row{"Pool balancer type", enterprise.GetBalancerType()}) t.AppendRow(table.Row{"Credentials", enterprise.Credentials.Name}) t.AppendRow(table.Row{"Pool manager running", enterprise.PoolManagerStatus.IsRunning}) diff --git a/cmd/garm-cli/cmd/organization.go b/cmd/garm-cli/cmd/organization.go index 020452a1..4e5e5360 100644 --- a/cmd/garm-cli/cmd/organization.go +++ b/cmd/garm-cli/cmd/organization.go @@ -341,10 +341,10 @@ func init() { func formatOrganizations(orgs []params.Organization) { t := table.NewWriter() - header := table.Row{"ID", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} + header := table.Row{"ID", "Name", "Endpoint", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range orgs { - t.AppendRow(table.Row{val.ID, val.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Name, val.Endpoint.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -357,6 +357,7 @@ func formatOneOrganization(org params.Organization) { t.AppendHeader(header) t.AppendRow(table.Row{"ID", org.ID}) t.AppendRow(table.Row{"Name", org.Name}) + t.AppendRow(table.Row{"Endpoint", org.Endpoint.Name}) t.AppendRow(table.Row{"Pool balancer type", org.GetBalancerType()}) t.AppendRow(table.Row{"Credentials", org.CredentialsName}) t.AppendRow(table.Row{"Pool manager running", org.PoolManagerStatus.IsRunning}) diff --git a/cmd/garm-cli/cmd/repository.go b/cmd/garm-cli/cmd/repository.go index 845252a5..3fa75560 100644 --- a/cmd/garm-cli/cmd/repository.go +++ b/cmd/garm-cli/cmd/repository.go @@ -347,10 +347,10 @@ func init() { func formatRepositories(repos []params.Repository) { t := table.NewWriter() - header := table.Row{"ID", "Owner", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} + header := table.Row{"ID", "Owner", "Name", "Endpoint", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range repos { - t.AppendRow(table.Row{val.ID, val.Owner, val.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Owner, val.Name, val.Endpoint.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -364,6 +364,7 @@ func formatOneRepository(repo params.Repository) { t.AppendRow(table.Row{"ID", repo.ID}) t.AppendRow(table.Row{"Owner", repo.Owner}) t.AppendRow(table.Row{"Name", repo.Name}) + t.AppendRow(table.Row{"Endpoint", repo.Endpoint.Name}) t.AppendRow(table.Row{"Pool balancer type", repo.GetBalancerType()}) t.AppendRow(table.Row{"Credentials", repo.CredentialsName}) t.AppendRow(table.Row{"Pool manager running", repo.PoolManagerStatus.IsRunning}) diff --git a/database/common/store.go b/database/common/store.go index 18075c1d..4d91e6cd 100644 --- a/database/common/store.go +++ b/database/common/store.go @@ -39,7 +39,7 @@ type GithubCredentialsStore interface { type RepoStore interface { CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) - GetRepository(ctx context.Context, owner, name string) (params.Repository, error) + GetRepository(ctx context.Context, owner, name, endpointName string) (params.Repository, error) GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error) ListRepositories(ctx context.Context) ([]params.Repository, error) DeleteRepository(ctx context.Context, repoID string) error @@ -48,7 +48,7 @@ type RepoStore interface { type OrgStore interface { CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) - GetOrganization(ctx context.Context, name string) (params.Organization, error) + GetOrganization(ctx context.Context, name, endpointName string) (params.Organization, error) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) ListOrganizations(ctx context.Context) ([]params.Organization, error) DeleteOrganization(ctx context.Context, orgID string) error @@ -57,7 +57,7 @@ type OrgStore interface { type EnterpriseStore interface { CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) - GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) + GetEnterprise(ctx context.Context, name, endpointName string) (params.Enterprise, error) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) DeleteEnterprise(ctx context.Context, enterpriseID string) error diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index 30b42137..dfcb10a2 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -82,8 +82,8 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam return paramEnt, nil } -func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { - enterprise, err := s.getEnterprise(ctx, name) +func (s *sqlDatabase) GetEnterprise(ctx context.Context, name, endpointName string) (params.Enterprise, error) { + enterprise, err := s.getEnterprise(ctx, name, endpointName) if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -223,10 +223,10 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, return newParams, nil } -func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, error) { +func (s *sqlDatabase) getEnterprise(_ context.Context, name, endpointName string) (Enterprise, error) { var enterprise Enterprise - q := s.conn.Where("name = ? COLLATE NOCASE", name). + q := s.conn.Where("name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE", name, endpointName). Preload("Credentials"). Preload("Credentials.Endpoint"). Preload("Endpoint"). diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index 9e1a86dd..405cfa85 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -245,7 +245,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBCreateErr() { } func (s *EnterpriseTestSuite) TestGetEnterprise() { - enterprise, err := s.Store.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name) + enterprise, err := s.Store.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name, s.Fixtures.Enterprises[0].Endpoint.Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Enterprises[0].Name, enterprise.Name) @@ -253,14 +253,14 @@ func (s *EnterpriseTestSuite) TestGetEnterprise() { } func (s *EnterpriseTestSuite) TestGetEnterpriseCaseInsensitive() { - enterprise, err := s.Store.GetEnterprise(s.adminCtx, "TeSt-eNtErPriSe-1") + enterprise, err := s.Store.GetEnterprise(s.adminCtx, "TeSt-eNtErPriSe-1", "github.com") s.Require().Nil(err) s.Require().Equal("test-enterprise-1", enterprise.Name) } func (s *EnterpriseTestSuite) TestGetEnterpriseNotFound() { - _, err := s.Store.GetEnterprise(s.adminCtx, "dummy-name") + _, err := s.Store.GetEnterprise(s.adminCtx, "dummy-name", "github.com") s.Require().NotNil(err) s.Require().Equal("fetching enterprise: not found", err.Error()) @@ -268,15 +268,15 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseNotFound() { func (s *EnterpriseTestSuite) TestGetEnterpriseDBDecryptingErr() { s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE name = ? COLLATE NOCASE AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Enterprises[0].Name, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE (name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Enterprises[0].Name, s.Fixtures.Enterprises[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Enterprises[0].Name)) - _, err := s.StoreSQLMocked.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name) + _, err := s.StoreSQLMocked.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name, s.Fixtures.Enterprises[0].Endpoint.Name) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching enterprise: missing secret", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestListEnterprises() { diff --git a/database/sql/models.go b/database/sql/models.go index 7c62ea97..ac7a056a 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -119,7 +119,7 @@ type Organization struct { Jobs []WorkflowJob `gorm:"foreignKey:OrgID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` - EndpointName *string `gorm:"index"` + EndpointName *string `gorm:"index:idx_org_name_nocase,collate:nocase"` Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } @@ -137,7 +137,7 @@ type Enterprise struct { Jobs []WorkflowJob `gorm:"foreignKey:EnterpriseID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` - EndpointName *string `gorm:"index"` + EndpointName *string `gorm:"index:idx_ent_name_nocase,collate:nocase"` Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 02ae5e62..c41b9269 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -85,8 +85,8 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN return org, nil } -func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.Organization, error) { - org, err := s.getOrg(ctx, name) +func (s *sqlDatabase) GetOrganization(ctx context.Context, name, endpointName string) (params.Organization, error) { + org, err := s.getOrg(ctx, name, endpointName) if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } @@ -252,10 +252,10 @@ func (s *sqlDatabase) getOrgByID(_ context.Context, db *gorm.DB, id string, prel return org, nil } -func (s *sqlDatabase) getOrg(_ context.Context, name string) (Organization, error) { +func (s *sqlDatabase) getOrg(_ context.Context, name, endpointName string) (Organization, error) { var org Organization - q := s.conn.Where("name = ? COLLATE NOCASE", name). + q := s.conn.Where("name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE", name, endpointName). Preload("Credentials"). Preload("Credentials.Endpoint"). Preload("Endpoint"). diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 11be72d0..f7aa5a84 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -247,7 +247,7 @@ func (s *OrgTestSuite) TestCreateOrganizationDBCreateErr() { } func (s *OrgTestSuite) TestGetOrganization() { - org, err := s.Store.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name) + org, err := s.Store.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name, s.Fixtures.Orgs[0].Endpoint.Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Orgs[0].Name, org.Name) @@ -255,14 +255,14 @@ func (s *OrgTestSuite) TestGetOrganization() { } func (s *OrgTestSuite) TestGetOrganizationCaseInsensitive() { - org, err := s.Store.GetOrganization(s.adminCtx, "TeSt-oRg-1") + org, err := s.Store.GetOrganization(s.adminCtx, "TeSt-oRg-1", "github.com") s.Require().Nil(err) s.Require().Equal("test-org-1", org.Name) } func (s *OrgTestSuite) TestGetOrganizationNotFound() { - _, err := s.Store.GetOrganization(s.adminCtx, "dummy-name") + _, err := s.Store.GetOrganization(s.adminCtx, "dummy-name", "github.com") s.Require().NotNil(err) s.Require().Equal("fetching org: not found", err.Error()) @@ -270,15 +270,15 @@ func (s *OrgTestSuite) TestGetOrganizationNotFound() { func (s *OrgTestSuite) TestGetOrganizationDBDecryptingErr() { s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE name = ? COLLATE NOCASE AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Orgs[0].Name, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE (name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Orgs[0].Name, s.Fixtures.Orgs[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Orgs[0].Name)) - _, err := s.StoreSQLMocked.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name) + _, err := s.StoreSQLMocked.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name, s.Fixtures.Orgs[0].Endpoint.Name) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching org: missing secret", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestListOrganizations() { diff --git a/database/sql/repositories.go b/database/sql/repositories.go index a08e815b..c1eaef3b 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -84,8 +84,8 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent return param, nil } -func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (params.Repository, error) { - repo, err := s.getRepo(ctx, owner, name) +func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name, endpointName string) (params.Repository, error) { + repo, err := s.getRepo(ctx, owner, name, endpointName) if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -228,10 +228,10 @@ func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (par return param, nil } -func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository, error) { +func (s *sqlDatabase) getRepo(_ context.Context, owner, name, endpointName string) (Repository, error) { var repo Repository - q := s.conn.Where("name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE", name, owner). + q := s.conn.Where("name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE", name, owner, endpointName). Preload("Credentials"). Preload("Credentials.Endpoint"). Preload("Endpoint"). diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 313396f8..826623db 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -271,7 +271,7 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() { } func (s *RepoTestSuite) TestGetRepository() { - repo, err := s.Store.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) + repo, err := s.Store.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Endpoint.Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Repos[0].Owner, repo.Owner) @@ -280,7 +280,7 @@ func (s *RepoTestSuite) TestGetRepository() { } func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() { - repo, err := s.Store.GetRepository(s.adminCtx, "TeSt-oWnEr-1", "TeSt-rEpO-1") + repo, err := s.Store.GetRepository(s.adminCtx, "TeSt-oWnEr-1", "TeSt-rEpO-1", "github.com") s.Require().Nil(err) s.Require().Equal("test-owner-1", repo.Owner) @@ -288,7 +288,7 @@ func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() { } func (s *RepoTestSuite) TestGetRepositoryNotFound() { - _, err := s.Store.GetRepository(s.adminCtx, "dummy-owner", "dummy-name") + _, err := s.Store.GetRepository(s.adminCtx, "dummy-owner", "dummy-name", "github.com") s.Require().NotNil(err) s.Require().Equal("fetching repo: not found", err.Error()) @@ -296,15 +296,15 @@ func (s *RepoTestSuite) TestGetRepositoryNotFound() { func (s *RepoTestSuite) TestGetRepositoryDBDecryptingErr() { s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name", "owner"}).AddRow(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id`,`repositories`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id`,`repositories`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name", "owner"}).AddRow(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner)) - _, err := s.StoreSQLMocked.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) + _, err := s.StoreSQLMocked.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Endpoint.Name) s.Require().NotNil(err) s.Require().Equal("fetching repo: missing secret", err.Error()) diff --git a/runner/enterprises.go b/runner/enterprises.go index 6fb86f96..3e9e3b8c 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -30,7 +30,7 @@ func (r *Runner) CreateEnterprise(ctx context.Context, param params.CreateEnterp return params.Enterprise{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } - _, err = r.store.GetEnterprise(ctx, param.Name) + _, err = r.store.GetEnterprise(ctx, param.Name, creds.Endpoint.Name) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") @@ -322,11 +322,11 @@ func (r *Runner) ListEnterpriseInstances(ctx context.Context, enterpriseID strin return instances, nil } -func (r *Runner) findEnterprisePoolManager(name string) (common.PoolManager, error) { +func (r *Runner) findEnterprisePoolManager(name, endpointName string) (common.PoolManager, error) { r.mux.Lock() defer r.mux.Unlock() - enterprise, err := r.store.GetEnterprise(r.ctx, name) + enterprise, err := r.store.GetEnterprise(r.ctx, name, endpointName) if err != nil { return nil, errors.Wrap(err, "fetching enterprise") } diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index f912c8ef..94bc4807 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -544,7 +544,7 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstancesErrUnauthorized() { func (s *EnterpriseTestSuite) TestFindEnterprisePoolManager() { s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) - poolManager, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name) + poolManager, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name, s.Fixtures.StoreEnterprises["test-enterprise-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) @@ -555,7 +555,7 @@ func (s *EnterpriseTestSuite) TestFindEnterprisePoolManager() { func (s *EnterpriseTestSuite) TestFindEnterprisePoolManagerFetchPoolMgrFailed() { s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - _, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name) + _, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name, s.Fixtures.StoreEnterprises["test-enterprise-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) diff --git a/runner/organizations.go b/runner/organizations.go index ac55de54..39aa788b 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -43,7 +43,7 @@ func (r *Runner) CreateOrganization(ctx context.Context, param params.CreateOrgP return params.Organization{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } - _, err = r.store.GetOrganization(ctx, param.Name) + _, err = r.store.GetOrganization(ctx, param.Name, creds.Endpoint.Name) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return params.Organization{}, errors.Wrap(err, "fetching org") @@ -359,11 +359,11 @@ func (r *Runner) ListOrgInstances(ctx context.Context, orgID string) ([]params.I return instances, nil } -func (r *Runner) findOrgPoolManager(name string) (common.PoolManager, error) { +func (r *Runner) findOrgPoolManager(name, endpointName string) (common.PoolManager, error) { r.mux.Lock() defer r.mux.Unlock() - org, err := r.store.GetOrganization(r.ctx, name) + org, err := r.store.GetOrganization(r.ctx, name, endpointName) if err != nil { return nil, errors.Wrap(err, "fetching org") } diff --git a/runner/organizations_test.go b/runner/organizations_test.go index f7513234..ae0af3cf 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -569,7 +569,7 @@ func (s *OrgTestSuite) TestListOrgInstancesErrUnauthorized() { func (s *OrgTestSuite) TestFindOrgPoolManager() { s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) - poolManager, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name) + poolManager, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name, s.Fixtures.StoreOrgs["test-org-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) @@ -580,7 +580,7 @@ func (s *OrgTestSuite) TestFindOrgPoolManager() { func (s *OrgTestSuite) TestFindOrgPoolManagerFetchPoolMgrFailed() { s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - _, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name) + _, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name, s.Fixtures.StoreOrgs["test-org-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) diff --git a/runner/pool/pool.go b/runner/pool/pool.go index d3dbd96b..09383e34 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -78,7 +78,8 @@ func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, insta return nil, errors.Wrap(err, "getting controller info") } - consumerID := fmt.Sprintf("pool-manager-%s", entity.String()) + consumerID := fmt.Sprintf("pool-manager-%s-%s", entity.String(), entity.Credentials.Endpoint.Name) + slog.InfoContext(ctx, "registering consumer", "consumer_id", consumerID) consumer, err := watcher.RegisterConsumer( ctx, consumerID, composeWatcherFilters(entity), diff --git a/runner/pool/watcher.go b/runner/pool/watcher.go index b50a85b2..b17494d5 100644 --- a/runner/pool/watcher.go +++ b/runner/pool/watcher.go @@ -38,8 +38,24 @@ func (r *basePoolManager) getClientOrStub() runnerCommon.GithubClient { return ghc } -func (r *basePoolManager) handleEntityUpdate(entity params.GithubEntity) { - slog.DebugContext(r.ctx, "received entity update", "entity", entity.ID) +func (r *basePoolManager) handleEntityUpdate(entity params.GithubEntity, operation common.OperationType) { + slog.DebugContext(r.ctx, "received entity operation", "entity", entity.ID, "operation", operation) + if r.entity.ID != entity.ID { + slog.WarnContext(r.ctx, "entity ID mismatch; stale event? refusing to update", "entity", entity.ID) + return + } + + if operation == common.DeleteOperation { + slog.InfoContext(r.ctx, "entity deleted; closing db consumer", "entity", entity.ID) + r.consumer.Close() + return + } + + if operation != common.UpdateOperation { + slog.DebugContext(r.ctx, "operation not update; ignoring", "entity", entity.ID, "operation", operation) + return + } + credentialsUpdate := r.entity.Credentials.ID != entity.Credentials.ID defer func() { slog.DebugContext(r.ctx, "deferred tools update", "credentials_update", credentialsUpdate) @@ -133,11 +149,12 @@ func (r *basePoolManager) handleWatcherEvent(event common.ChangePayload) { slog.ErrorContext(r.ctx, "failed to get entity", "error", err) return } - r.handleEntityUpdate(entityInfo) + r.handleEntityUpdate(entityInfo, event.Operation) } } func (r *basePoolManager) runWatcher() { + defer r.consumer.Close() for { select { case <-r.quit: diff --git a/runner/repositories.go b/runner/repositories.go index ce0bbc73..4a76e570 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -43,7 +43,7 @@ func (r *Runner) CreateRepository(ctx context.Context, param params.CreateRepoPa return params.Repository{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } - _, err = r.store.GetRepository(ctx, param.Owner, param.Name) + _, err = r.store.GetRepository(ctx, param.Owner, param.Name, creds.Endpoint.Name) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return params.Repository{}, errors.Wrap(err, "fetching repo") @@ -364,11 +364,11 @@ func (r *Runner) ListRepoInstances(ctx context.Context, repoID string) ([]params return instances, nil } -func (r *Runner) findRepoPoolManager(owner, name string) (common.PoolManager, error) { +func (r *Runner) findRepoPoolManager(owner, name, endpointName string) (common.PoolManager, error) { r.mux.Lock() defer r.mux.Unlock() - repo, err := r.store.GetRepository(r.ctx, owner, name) + repo, err := r.store.GetRepository(r.ctx, owner, name, endpointName) if err != nil { return nil, errors.Wrap(err, "fetching repo") } diff --git a/runner/repositories_test.go b/runner/repositories_test.go index a13b6112..9e55cbda 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -617,7 +617,7 @@ func (s *RepoTestSuite) TestListRepoInstancesErrUnauthorized() { func (s *RepoTestSuite) TestFindRepoPoolManager() { s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) - poolManager, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name) + poolManager, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name, s.Fixtures.StoreRepos["test-repo-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) @@ -628,7 +628,7 @@ func (s *RepoTestSuite) TestFindRepoPoolManager() { func (s *RepoTestSuite) TestFindRepoPoolManagerFetchPoolMgrFailed() { s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - _, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name) + _, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name, s.Fixtures.StoreRepos["test-repo-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) diff --git a/runner/runner.go b/runner/runner.go index 532412dd..5c0883aa 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -24,6 +24,7 @@ import ( "fmt" "hash" "log/slog" + "net/url" "os" "strings" "sync" @@ -599,6 +600,31 @@ func (r *Runner) validateHookBody(signature, secret string, body []byte) error { return nil } +func (r *Runner) findEndpointForJob(job params.WorkflowJob) (params.GithubEndpoint, error) { + uri, err := url.ParseRequestURI(job.WorkflowJob.HTMLURL) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "parsing job URL") + } + baseURI := fmt.Sprintf("%s://%s", uri.Scheme, uri.Host) + + // Note(gabriel-samfira): Endpoints should be cached. We don't expect to have a large number + // of endpoints. In most cases there will be just one (github.com). In cases where there is + // a GHES involved, those users will have just one extra endpoint or 2 (if they also have a + // test env). But there should be a relatively small number, regardless. So we don't really care + // that much about the performance of this function. + endpoints, err := r.store.ListGithubEndpoints(r.ctx) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "fetching github endpoints") + } + for _, ep := range endpoints { + if ep.BaseURL == baseURI { + return ep, nil + } + } + + return params.GithubEndpoint{}, runnerErrors.NewNotFoundError("no endpoint found for job") +} + func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData []byte) error { if len(jobData) == 0 { return runnerErrors.NewBadRequestError("missing job data") @@ -609,8 +635,12 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ return errors.Wrapf(runnerErrors.ErrBadRequest, "invalid job data: %s", err) } + endpoint, err := r.findEndpointForJob(job) + if err != nil { + return errors.Wrap(err, "finding endpoint for job") + } + var poolManager common.PoolManager - var err error switch HookTargetType(hookTargetType) { case RepoHook: @@ -618,17 +648,17 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ r.ctx, "got hook for repo", "repo_owner", util.SanitizeLogEntry(job.Repository.Owner.Login), "repo_name", util.SanitizeLogEntry(job.Repository.Name)) - poolManager, err = r.findRepoPoolManager(job.Repository.Owner.Login, job.Repository.Name) + poolManager, err = r.findRepoPoolManager(job.Repository.Owner.Login, job.Repository.Name, endpoint.Name) case OrganizationHook: slog.DebugContext( r.ctx, "got hook for organization", "organization", util.SanitizeLogEntry(job.Organization.Login)) - poolManager, err = r.findOrgPoolManager(job.Organization.Login) + poolManager, err = r.findOrgPoolManager(job.Organization.Login, endpoint.Name) case EnterpriseHook: slog.DebugContext( r.ctx, "got hook for enterprise", "enterprise", util.SanitizeLogEntry(job.Enterprise.Slug)) - poolManager, err = r.findEnterprisePoolManager(job.Enterprise.Slug) + poolManager, err = r.findEnterprisePoolManager(job.Enterprise.Slug, endpoint.Name) default: return runnerErrors.NewBadRequestError("cannot handle hook target type %s", hookTargetType) } @@ -766,7 +796,7 @@ func (r *Runner) getPoolManagerFromInstance(ctx context.Context, instance params if err != nil { return nil, errors.Wrap(err, "fetching repo") } - poolMgr, err = r.findRepoPoolManager(repo.Owner, repo.Name) + poolMgr, err = r.findRepoPoolManager(repo.Owner, repo.Name, repo.Endpoint.Name) if err != nil { return nil, errors.Wrapf(err, "fetching pool manager for repo %s", pool.RepoName) } @@ -775,7 +805,7 @@ func (r *Runner) getPoolManagerFromInstance(ctx context.Context, instance params if err != nil { return nil, errors.Wrap(err, "fetching org") } - poolMgr, err = r.findOrgPoolManager(org.Name) + poolMgr, err = r.findOrgPoolManager(org.Name, org.Endpoint.Name) if err != nil { return nil, errors.Wrapf(err, "fetching pool manager for org %s", pool.OrgName) } @@ -784,7 +814,7 @@ func (r *Runner) getPoolManagerFromInstance(ctx context.Context, instance params if err != nil { return nil, errors.Wrap(err, "fetching enterprise") } - poolMgr, err = r.findEnterprisePoolManager(enterprise.Name) + poolMgr, err = r.findEnterprisePoolManager(enterprise.Name, enterprise.Endpoint.Name) if err != nil { return nil, errors.Wrapf(err, "fetching pool manager for enterprise %s", pool.EnterpriseName) }