From cc6e98562935cda2eeff43b1ab70f04d17f32577 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 29 Jul 2024 17:35:57 +0000 Subject: [PATCH] Fix: Scope entities to endpoint This change scopes all github entities to a github endpoint, allowing users to have the same repo/org/enterprise created for each endpoint. This way, if your username is the same on github.com and on your GHES server, and you have the same repository name or org in both places, GARM can now handle that situation. This change also fixes a leaky watcher in the pool manager. Signed-off-by: Gabriel Adrian Samfira --- cmd/garm-cli/cmd/enterprise.go | 5 ++-- cmd/garm-cli/cmd/organization.go | 5 ++-- cmd/garm-cli/cmd/repository.go | 5 ++-- database/common/store.go | 6 ++-- database/sql/enterprise.go | 8 +++--- database/sql/enterprise_test.go | 14 +++++----- database/sql/models.go | 4 +-- database/sql/organizations.go | 8 +++--- database/sql/organizations_test.go | 14 +++++----- database/sql/repositories.go | 8 +++--- database/sql/repositories_test.go | 16 +++++------ runner/enterprises.go | 6 ++-- runner/enterprises_test.go | 4 +-- runner/organizations.go | 6 ++-- runner/organizations_test.go | 4 +-- runner/pool/pool.go | 3 +- runner/pool/watcher.go | 23 ++++++++++++++-- runner/repositories.go | 6 ++-- runner/repositories_test.go | 4 +-- runner/runner.go | 44 +++++++++++++++++++++++++----- 20 files changed, 122 insertions(+), 71 deletions(-) diff --git a/cmd/garm-cli/cmd/enterprise.go b/cmd/garm-cli/cmd/enterprise.go index 98457aef..dc9000d8 100644 --- a/cmd/garm-cli/cmd/enterprise.go +++ b/cmd/garm-cli/cmd/enterprise.go @@ -201,10 +201,10 @@ func init() { func formatEnterprises(enterprises []params.Enterprise) { t := table.NewWriter() - header := table.Row{"ID", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} + header := table.Row{"ID", "Name", "Endpoint", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range enterprises { - t.AppendRow(table.Row{val.ID, val.Name, val.Credentials.Name, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Name, val.Endpoint.Name, val.Credentials.Name, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -217,6 +217,7 @@ func formatOneEnterprise(enterprise params.Enterprise) { t.AppendHeader(header) t.AppendRow(table.Row{"ID", enterprise.ID}) t.AppendRow(table.Row{"Name", enterprise.Name}) + t.AppendRow(table.Row{"Endpoint", enterprise.Endpoint.Name}) t.AppendRow(table.Row{"Pool balancer type", enterprise.GetBalancerType()}) t.AppendRow(table.Row{"Credentials", enterprise.Credentials.Name}) t.AppendRow(table.Row{"Pool manager running", enterprise.PoolManagerStatus.IsRunning}) diff --git a/cmd/garm-cli/cmd/organization.go b/cmd/garm-cli/cmd/organization.go index 020452a1..4e5e5360 100644 --- a/cmd/garm-cli/cmd/organization.go +++ b/cmd/garm-cli/cmd/organization.go @@ -341,10 +341,10 @@ func init() { func formatOrganizations(orgs []params.Organization) { t := table.NewWriter() - header := table.Row{"ID", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} + header := table.Row{"ID", "Name", "Endpoint", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range orgs { - t.AppendRow(table.Row{val.ID, val.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Name, val.Endpoint.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -357,6 +357,7 @@ func formatOneOrganization(org params.Organization) { t.AppendHeader(header) t.AppendRow(table.Row{"ID", org.ID}) t.AppendRow(table.Row{"Name", org.Name}) + t.AppendRow(table.Row{"Endpoint", org.Endpoint.Name}) t.AppendRow(table.Row{"Pool balancer type", org.GetBalancerType()}) t.AppendRow(table.Row{"Credentials", org.CredentialsName}) t.AppendRow(table.Row{"Pool manager running", org.PoolManagerStatus.IsRunning}) diff --git a/cmd/garm-cli/cmd/repository.go b/cmd/garm-cli/cmd/repository.go index 845252a5..3fa75560 100644 --- a/cmd/garm-cli/cmd/repository.go +++ b/cmd/garm-cli/cmd/repository.go @@ -347,10 +347,10 @@ func init() { func formatRepositories(repos []params.Repository) { t := table.NewWriter() - header := table.Row{"ID", "Owner", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} + header := table.Row{"ID", "Owner", "Name", "Endpoint", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range repos { - t.AppendRow(table.Row{val.ID, val.Owner, val.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Owner, val.Name, val.Endpoint.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -364,6 +364,7 @@ func formatOneRepository(repo params.Repository) { t.AppendRow(table.Row{"ID", repo.ID}) t.AppendRow(table.Row{"Owner", repo.Owner}) t.AppendRow(table.Row{"Name", repo.Name}) + t.AppendRow(table.Row{"Endpoint", repo.Endpoint.Name}) t.AppendRow(table.Row{"Pool balancer type", repo.GetBalancerType()}) t.AppendRow(table.Row{"Credentials", repo.CredentialsName}) t.AppendRow(table.Row{"Pool manager running", repo.PoolManagerStatus.IsRunning}) diff --git a/database/common/store.go b/database/common/store.go index 18075c1d..4d91e6cd 100644 --- a/database/common/store.go +++ b/database/common/store.go @@ -39,7 +39,7 @@ type GithubCredentialsStore interface { type RepoStore interface { CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) - GetRepository(ctx context.Context, owner, name string) (params.Repository, error) + GetRepository(ctx context.Context, owner, name, endpointName string) (params.Repository, error) GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error) ListRepositories(ctx context.Context) ([]params.Repository, error) DeleteRepository(ctx context.Context, repoID string) error @@ -48,7 +48,7 @@ type RepoStore interface { type OrgStore interface { CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) - GetOrganization(ctx context.Context, name string) (params.Organization, error) + GetOrganization(ctx context.Context, name, endpointName string) (params.Organization, error) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) ListOrganizations(ctx context.Context) ([]params.Organization, error) DeleteOrganization(ctx context.Context, orgID string) error @@ -57,7 +57,7 @@ type OrgStore interface { type EnterpriseStore interface { CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) - GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) + GetEnterprise(ctx context.Context, name, endpointName string) (params.Enterprise, error) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) DeleteEnterprise(ctx context.Context, enterpriseID string) error diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index 30b42137..dfcb10a2 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -82,8 +82,8 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam return paramEnt, nil } -func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { - enterprise, err := s.getEnterprise(ctx, name) +func (s *sqlDatabase) GetEnterprise(ctx context.Context, name, endpointName string) (params.Enterprise, error) { + enterprise, err := s.getEnterprise(ctx, name, endpointName) if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -223,10 +223,10 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, return newParams, nil } -func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, error) { +func (s *sqlDatabase) getEnterprise(_ context.Context, name, endpointName string) (Enterprise, error) { var enterprise Enterprise - q := s.conn.Where("name = ? COLLATE NOCASE", name). + q := s.conn.Where("name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE", name, endpointName). Preload("Credentials"). Preload("Credentials.Endpoint"). Preload("Endpoint"). diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index 9e1a86dd..405cfa85 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -245,7 +245,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBCreateErr() { } func (s *EnterpriseTestSuite) TestGetEnterprise() { - enterprise, err := s.Store.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name) + enterprise, err := s.Store.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name, s.Fixtures.Enterprises[0].Endpoint.Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Enterprises[0].Name, enterprise.Name) @@ -253,14 +253,14 @@ func (s *EnterpriseTestSuite) TestGetEnterprise() { } func (s *EnterpriseTestSuite) TestGetEnterpriseCaseInsensitive() { - enterprise, err := s.Store.GetEnterprise(s.adminCtx, "TeSt-eNtErPriSe-1") + enterprise, err := s.Store.GetEnterprise(s.adminCtx, "TeSt-eNtErPriSe-1", "github.com") s.Require().Nil(err) s.Require().Equal("test-enterprise-1", enterprise.Name) } func (s *EnterpriseTestSuite) TestGetEnterpriseNotFound() { - _, err := s.Store.GetEnterprise(s.adminCtx, "dummy-name") + _, err := s.Store.GetEnterprise(s.adminCtx, "dummy-name", "github.com") s.Require().NotNil(err) s.Require().Equal("fetching enterprise: not found", err.Error()) @@ -268,15 +268,15 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseNotFound() { func (s *EnterpriseTestSuite) TestGetEnterpriseDBDecryptingErr() { s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE name = ? COLLATE NOCASE AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Enterprises[0].Name, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE (name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Enterprises[0].Name, s.Fixtures.Enterprises[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Enterprises[0].Name)) - _, err := s.StoreSQLMocked.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name) + _, err := s.StoreSQLMocked.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name, s.Fixtures.Enterprises[0].Endpoint.Name) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching enterprise: missing secret", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestListEnterprises() { diff --git a/database/sql/models.go b/database/sql/models.go index 7c62ea97..ac7a056a 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -119,7 +119,7 @@ type Organization struct { Jobs []WorkflowJob `gorm:"foreignKey:OrgID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` - EndpointName *string `gorm:"index"` + EndpointName *string `gorm:"index:idx_org_name_nocase,collate:nocase"` Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } @@ -137,7 +137,7 @@ type Enterprise struct { Jobs []WorkflowJob `gorm:"foreignKey:EnterpriseID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` - EndpointName *string `gorm:"index"` + EndpointName *string `gorm:"index:idx_ent_name_nocase,collate:nocase"` Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 02ae5e62..c41b9269 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -85,8 +85,8 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN return org, nil } -func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.Organization, error) { - org, err := s.getOrg(ctx, name) +func (s *sqlDatabase) GetOrganization(ctx context.Context, name, endpointName string) (params.Organization, error) { + org, err := s.getOrg(ctx, name, endpointName) if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } @@ -252,10 +252,10 @@ func (s *sqlDatabase) getOrgByID(_ context.Context, db *gorm.DB, id string, prel return org, nil } -func (s *sqlDatabase) getOrg(_ context.Context, name string) (Organization, error) { +func (s *sqlDatabase) getOrg(_ context.Context, name, endpointName string) (Organization, error) { var org Organization - q := s.conn.Where("name = ? COLLATE NOCASE", name). + q := s.conn.Where("name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE", name, endpointName). Preload("Credentials"). Preload("Credentials.Endpoint"). Preload("Endpoint"). diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 11be72d0..f7aa5a84 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -247,7 +247,7 @@ func (s *OrgTestSuite) TestCreateOrganizationDBCreateErr() { } func (s *OrgTestSuite) TestGetOrganization() { - org, err := s.Store.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name) + org, err := s.Store.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name, s.Fixtures.Orgs[0].Endpoint.Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Orgs[0].Name, org.Name) @@ -255,14 +255,14 @@ func (s *OrgTestSuite) TestGetOrganization() { } func (s *OrgTestSuite) TestGetOrganizationCaseInsensitive() { - org, err := s.Store.GetOrganization(s.adminCtx, "TeSt-oRg-1") + org, err := s.Store.GetOrganization(s.adminCtx, "TeSt-oRg-1", "github.com") s.Require().Nil(err) s.Require().Equal("test-org-1", org.Name) } func (s *OrgTestSuite) TestGetOrganizationNotFound() { - _, err := s.Store.GetOrganization(s.adminCtx, "dummy-name") + _, err := s.Store.GetOrganization(s.adminCtx, "dummy-name", "github.com") s.Require().NotNil(err) s.Require().Equal("fetching org: not found", err.Error()) @@ -270,15 +270,15 @@ func (s *OrgTestSuite) TestGetOrganizationNotFound() { func (s *OrgTestSuite) TestGetOrganizationDBDecryptingErr() { s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE name = ? COLLATE NOCASE AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Orgs[0].Name, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE (name = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Orgs[0].Name, s.Fixtures.Orgs[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Orgs[0].Name)) - _, err := s.StoreSQLMocked.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name) + _, err := s.StoreSQLMocked.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name, s.Fixtures.Orgs[0].Endpoint.Name) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching org: missing secret", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestListOrganizations() { diff --git a/database/sql/repositories.go b/database/sql/repositories.go index a08e815b..c1eaef3b 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -84,8 +84,8 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent return param, nil } -func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (params.Repository, error) { - repo, err := s.getRepo(ctx, owner, name) +func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name, endpointName string) (params.Repository, error) { + repo, err := s.getRepo(ctx, owner, name, endpointName) if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -228,10 +228,10 @@ func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (par return param, nil } -func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository, error) { +func (s *sqlDatabase) getRepo(_ context.Context, owner, name, endpointName string) (Repository, error) { var repo Repository - q := s.conn.Where("name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE", name, owner). + q := s.conn.Where("name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE", name, owner, endpointName). Preload("Credentials"). Preload("Credentials.Endpoint"). Preload("Endpoint"). diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 313396f8..826623db 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -271,7 +271,7 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() { } func (s *RepoTestSuite) TestGetRepository() { - repo, err := s.Store.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) + repo, err := s.Store.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Endpoint.Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Repos[0].Owner, repo.Owner) @@ -280,7 +280,7 @@ func (s *RepoTestSuite) TestGetRepository() { } func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() { - repo, err := s.Store.GetRepository(s.adminCtx, "TeSt-oWnEr-1", "TeSt-rEpO-1") + repo, err := s.Store.GetRepository(s.adminCtx, "TeSt-oWnEr-1", "TeSt-rEpO-1", "github.com") s.Require().Nil(err) s.Require().Equal("test-owner-1", repo.Owner) @@ -288,7 +288,7 @@ func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() { } func (s *RepoTestSuite) TestGetRepositoryNotFound() { - _, err := s.Store.GetRepository(s.adminCtx, "dummy-owner", "dummy-name") + _, err := s.Store.GetRepository(s.adminCtx, "dummy-owner", "dummy-name", "github.com") s.Require().NotNil(err) s.Require().Equal("fetching repo: not found", err.Error()) @@ -296,15 +296,15 @@ func (s *RepoTestSuite) TestGetRepositoryNotFound() { func (s *RepoTestSuite) TestGetRepositoryDBDecryptingErr() { s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name", "owner"}).AddRow(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id`,`repositories`.`id` LIMIT ?")). - WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, 1). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE (name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE and endpoint_name = ? COLLATE NOCASE) AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id`,`repositories`.`id` LIMIT ?")). + WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Endpoint.Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name", "owner"}).AddRow(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner)) - _, err := s.StoreSQLMocked.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) + _, err := s.StoreSQLMocked.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Endpoint.Name) s.Require().NotNil(err) s.Require().Equal("fetching repo: missing secret", err.Error()) diff --git a/runner/enterprises.go b/runner/enterprises.go index 6fb86f96..3e9e3b8c 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -30,7 +30,7 @@ func (r *Runner) CreateEnterprise(ctx context.Context, param params.CreateEnterp return params.Enterprise{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } - _, err = r.store.GetEnterprise(ctx, param.Name) + _, err = r.store.GetEnterprise(ctx, param.Name, creds.Endpoint.Name) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") @@ -322,11 +322,11 @@ func (r *Runner) ListEnterpriseInstances(ctx context.Context, enterpriseID strin return instances, nil } -func (r *Runner) findEnterprisePoolManager(name string) (common.PoolManager, error) { +func (r *Runner) findEnterprisePoolManager(name, endpointName string) (common.PoolManager, error) { r.mux.Lock() defer r.mux.Unlock() - enterprise, err := r.store.GetEnterprise(r.ctx, name) + enterprise, err := r.store.GetEnterprise(r.ctx, name, endpointName) if err != nil { return nil, errors.Wrap(err, "fetching enterprise") } diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index f912c8ef..94bc4807 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -544,7 +544,7 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstancesErrUnauthorized() { func (s *EnterpriseTestSuite) TestFindEnterprisePoolManager() { s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) - poolManager, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name) + poolManager, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name, s.Fixtures.StoreEnterprises["test-enterprise-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) @@ -555,7 +555,7 @@ func (s *EnterpriseTestSuite) TestFindEnterprisePoolManager() { func (s *EnterpriseTestSuite) TestFindEnterprisePoolManagerFetchPoolMgrFailed() { s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - _, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name) + _, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name, s.Fixtures.StoreEnterprises["test-enterprise-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) diff --git a/runner/organizations.go b/runner/organizations.go index ac55de54..39aa788b 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -43,7 +43,7 @@ func (r *Runner) CreateOrganization(ctx context.Context, param params.CreateOrgP return params.Organization{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } - _, err = r.store.GetOrganization(ctx, param.Name) + _, err = r.store.GetOrganization(ctx, param.Name, creds.Endpoint.Name) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return params.Organization{}, errors.Wrap(err, "fetching org") @@ -359,11 +359,11 @@ func (r *Runner) ListOrgInstances(ctx context.Context, orgID string) ([]params.I return instances, nil } -func (r *Runner) findOrgPoolManager(name string) (common.PoolManager, error) { +func (r *Runner) findOrgPoolManager(name, endpointName string) (common.PoolManager, error) { r.mux.Lock() defer r.mux.Unlock() - org, err := r.store.GetOrganization(r.ctx, name) + org, err := r.store.GetOrganization(r.ctx, name, endpointName) if err != nil { return nil, errors.Wrap(err, "fetching org") } diff --git a/runner/organizations_test.go b/runner/organizations_test.go index f7513234..ae0af3cf 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -569,7 +569,7 @@ func (s *OrgTestSuite) TestListOrgInstancesErrUnauthorized() { func (s *OrgTestSuite) TestFindOrgPoolManager() { s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) - poolManager, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name) + poolManager, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name, s.Fixtures.StoreOrgs["test-org-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) @@ -580,7 +580,7 @@ func (s *OrgTestSuite) TestFindOrgPoolManager() { func (s *OrgTestSuite) TestFindOrgPoolManagerFetchPoolMgrFailed() { s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - _, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name) + _, err := s.Runner.findOrgPoolManager(s.Fixtures.StoreOrgs["test-org-1"].Name, s.Fixtures.StoreOrgs["test-org-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) diff --git a/runner/pool/pool.go b/runner/pool/pool.go index d3dbd96b..09383e34 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -78,7 +78,8 @@ func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, insta return nil, errors.Wrap(err, "getting controller info") } - consumerID := fmt.Sprintf("pool-manager-%s", entity.String()) + consumerID := fmt.Sprintf("pool-manager-%s-%s", entity.String(), entity.Credentials.Endpoint.Name) + slog.InfoContext(ctx, "registering consumer", "consumer_id", consumerID) consumer, err := watcher.RegisterConsumer( ctx, consumerID, composeWatcherFilters(entity), diff --git a/runner/pool/watcher.go b/runner/pool/watcher.go index b50a85b2..b17494d5 100644 --- a/runner/pool/watcher.go +++ b/runner/pool/watcher.go @@ -38,8 +38,24 @@ func (r *basePoolManager) getClientOrStub() runnerCommon.GithubClient { return ghc } -func (r *basePoolManager) handleEntityUpdate(entity params.GithubEntity) { - slog.DebugContext(r.ctx, "received entity update", "entity", entity.ID) +func (r *basePoolManager) handleEntityUpdate(entity params.GithubEntity, operation common.OperationType) { + slog.DebugContext(r.ctx, "received entity operation", "entity", entity.ID, "operation", operation) + if r.entity.ID != entity.ID { + slog.WarnContext(r.ctx, "entity ID mismatch; stale event? refusing to update", "entity", entity.ID) + return + } + + if operation == common.DeleteOperation { + slog.InfoContext(r.ctx, "entity deleted; closing db consumer", "entity", entity.ID) + r.consumer.Close() + return + } + + if operation != common.UpdateOperation { + slog.DebugContext(r.ctx, "operation not update; ignoring", "entity", entity.ID, "operation", operation) + return + } + credentialsUpdate := r.entity.Credentials.ID != entity.Credentials.ID defer func() { slog.DebugContext(r.ctx, "deferred tools update", "credentials_update", credentialsUpdate) @@ -133,11 +149,12 @@ func (r *basePoolManager) handleWatcherEvent(event common.ChangePayload) { slog.ErrorContext(r.ctx, "failed to get entity", "error", err) return } - r.handleEntityUpdate(entityInfo) + r.handleEntityUpdate(entityInfo, event.Operation) } } func (r *basePoolManager) runWatcher() { + defer r.consumer.Close() for { select { case <-r.quit: diff --git a/runner/repositories.go b/runner/repositories.go index ce0bbc73..4a76e570 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -43,7 +43,7 @@ func (r *Runner) CreateRepository(ctx context.Context, param params.CreateRepoPa return params.Repository{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } - _, err = r.store.GetRepository(ctx, param.Owner, param.Name) + _, err = r.store.GetRepository(ctx, param.Owner, param.Name, creds.Endpoint.Name) if err != nil { if !errors.Is(err, runnerErrors.ErrNotFound) { return params.Repository{}, errors.Wrap(err, "fetching repo") @@ -364,11 +364,11 @@ func (r *Runner) ListRepoInstances(ctx context.Context, repoID string) ([]params return instances, nil } -func (r *Runner) findRepoPoolManager(owner, name string) (common.PoolManager, error) { +func (r *Runner) findRepoPoolManager(owner, name, endpointName string) (common.PoolManager, error) { r.mux.Lock() defer r.mux.Unlock() - repo, err := r.store.GetRepository(r.ctx, owner, name) + repo, err := r.store.GetRepository(r.ctx, owner, name, endpointName) if err != nil { return nil, errors.Wrap(err, "fetching repo") } diff --git a/runner/repositories_test.go b/runner/repositories_test.go index a13b6112..9e55cbda 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -617,7 +617,7 @@ func (s *RepoTestSuite) TestListRepoInstancesErrUnauthorized() { func (s *RepoTestSuite) TestFindRepoPoolManager() { s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) - poolManager, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name) + poolManager, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name, s.Fixtures.StoreRepos["test-repo-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) @@ -628,7 +628,7 @@ func (s *RepoTestSuite) TestFindRepoPoolManager() { func (s *RepoTestSuite) TestFindRepoPoolManagerFetchPoolMgrFailed() { s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - _, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name) + _, err := s.Runner.findRepoPoolManager(s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name, s.Fixtures.StoreRepos["test-repo-1"].Endpoint.Name) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) diff --git a/runner/runner.go b/runner/runner.go index 532412dd..5c0883aa 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -24,6 +24,7 @@ import ( "fmt" "hash" "log/slog" + "net/url" "os" "strings" "sync" @@ -599,6 +600,31 @@ func (r *Runner) validateHookBody(signature, secret string, body []byte) error { return nil } +func (r *Runner) findEndpointForJob(job params.WorkflowJob) (params.GithubEndpoint, error) { + uri, err := url.ParseRequestURI(job.WorkflowJob.HTMLURL) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "parsing job URL") + } + baseURI := fmt.Sprintf("%s://%s", uri.Scheme, uri.Host) + + // Note(gabriel-samfira): Endpoints should be cached. We don't expect to have a large number + // of endpoints. In most cases there will be just one (github.com). In cases where there is + // a GHES involved, those users will have just one extra endpoint or 2 (if they also have a + // test env). But there should be a relatively small number, regardless. So we don't really care + // that much about the performance of this function. + endpoints, err := r.store.ListGithubEndpoints(r.ctx) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "fetching github endpoints") + } + for _, ep := range endpoints { + if ep.BaseURL == baseURI { + return ep, nil + } + } + + return params.GithubEndpoint{}, runnerErrors.NewNotFoundError("no endpoint found for job") +} + func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData []byte) error { if len(jobData) == 0 { return runnerErrors.NewBadRequestError("missing job data") @@ -609,8 +635,12 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ return errors.Wrapf(runnerErrors.ErrBadRequest, "invalid job data: %s", err) } + endpoint, err := r.findEndpointForJob(job) + if err != nil { + return errors.Wrap(err, "finding endpoint for job") + } + var poolManager common.PoolManager - var err error switch HookTargetType(hookTargetType) { case RepoHook: @@ -618,17 +648,17 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ r.ctx, "got hook for repo", "repo_owner", util.SanitizeLogEntry(job.Repository.Owner.Login), "repo_name", util.SanitizeLogEntry(job.Repository.Name)) - poolManager, err = r.findRepoPoolManager(job.Repository.Owner.Login, job.Repository.Name) + poolManager, err = r.findRepoPoolManager(job.Repository.Owner.Login, job.Repository.Name, endpoint.Name) case OrganizationHook: slog.DebugContext( r.ctx, "got hook for organization", "organization", util.SanitizeLogEntry(job.Organization.Login)) - poolManager, err = r.findOrgPoolManager(job.Organization.Login) + poolManager, err = r.findOrgPoolManager(job.Organization.Login, endpoint.Name) case EnterpriseHook: slog.DebugContext( r.ctx, "got hook for enterprise", "enterprise", util.SanitizeLogEntry(job.Enterprise.Slug)) - poolManager, err = r.findEnterprisePoolManager(job.Enterprise.Slug) + poolManager, err = r.findEnterprisePoolManager(job.Enterprise.Slug, endpoint.Name) default: return runnerErrors.NewBadRequestError("cannot handle hook target type %s", hookTargetType) } @@ -766,7 +796,7 @@ func (r *Runner) getPoolManagerFromInstance(ctx context.Context, instance params if err != nil { return nil, errors.Wrap(err, "fetching repo") } - poolMgr, err = r.findRepoPoolManager(repo.Owner, repo.Name) + poolMgr, err = r.findRepoPoolManager(repo.Owner, repo.Name, repo.Endpoint.Name) if err != nil { return nil, errors.Wrapf(err, "fetching pool manager for repo %s", pool.RepoName) } @@ -775,7 +805,7 @@ func (r *Runner) getPoolManagerFromInstance(ctx context.Context, instance params if err != nil { return nil, errors.Wrap(err, "fetching org") } - poolMgr, err = r.findOrgPoolManager(org.Name) + poolMgr, err = r.findOrgPoolManager(org.Name, org.Endpoint.Name) if err != nil { return nil, errors.Wrapf(err, "fetching pool manager for org %s", pool.OrgName) } @@ -784,7 +814,7 @@ func (r *Runner) getPoolManagerFromInstance(ctx context.Context, instance params if err != nil { return nil, errors.Wrap(err, "fetching enterprise") } - poolMgr, err = r.findEnterprisePoolManager(enterprise.Name) + poolMgr, err = r.findEnterprisePoolManager(enterprise.Name, enterprise.Endpoint.Name) if err != nil { return nil, errors.Wrapf(err, "fetching pool manager for enterprise %s", pool.EnterpriseName) }