From 90870c11bebcc34c8bc8900e2b2e7b3fd30a080d Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 15 Apr 2024 08:32:19 +0000 Subject: [PATCH] Use database for github creds Add database models that deal with github credentials. This change adds models for github endpoints (github.com, GHES, etc). This change also adds code to migrate config credntials to the DB. Tests need to be fixed and new tests need to be written. This will come in a later commit. Signed-off-by: Gabriel Adrian Samfira --- cmd/garm-cli/cmd/enterprise.go | 4 +- cmd/garm/main.go | 3 + config/config.go | 13 + database/common/common.go | 20 ++ database/common/mocks/Store.go | 320 +++++++++++++++++++ database/sql/enterprise.go | 12 +- database/sql/enterprise_test.go | 4 +- database/sql/github.go | 473 +++++++++++++++++++++++++++++ database/sql/models.go | 59 +++- database/sql/organizations.go | 12 +- database/sql/organizations_test.go | 4 +- database/sql/repositories.go | 42 ++- database/sql/repositories_test.go | 4 +- database/sql/sql.go | 168 +++++++++- database/sql/users.go | 17 ++ database/sql/util.go | 22 +- params/params.go | 91 +++++- params/requests.go | 67 ++++ runner/enterprises_test.go | 4 +- runner/organizations_test.go | 4 +- runner/repositories_test.go | 10 +- runner/runner.go | 36 +-- 22 files changed, 1312 insertions(+), 77 deletions(-) create mode 100644 database/sql/github.go diff --git a/cmd/garm-cli/cmd/enterprise.go b/cmd/garm-cli/cmd/enterprise.go index 27ca662c..98457aef 100644 --- a/cmd/garm-cli/cmd/enterprise.go +++ b/cmd/garm-cli/cmd/enterprise.go @@ -204,7 +204,7 @@ func formatEnterprises(enterprises []params.Enterprise) { header := table.Row{"ID", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range enterprises { - t.AppendRow(table.Row{val.ID, val.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Name, val.Credentials.Name, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -218,7 +218,7 @@ func formatOneEnterprise(enterprise params.Enterprise) { t.AppendRow(table.Row{"ID", enterprise.ID}) t.AppendRow(table.Row{"Name", enterprise.Name}) t.AppendRow(table.Row{"Pool balancer type", enterprise.GetBalancerType()}) - t.AppendRow(table.Row{"Credentials", enterprise.CredentialsName}) + t.AppendRow(table.Row{"Credentials", enterprise.Credentials.Name}) t.AppendRow(table.Row{"Pool manager running", enterprise.PoolManagerStatus.IsRunning}) if !enterprise.PoolManagerStatus.IsRunning { t.AppendRow(table.Row{"Failure reason", enterprise.PoolManagerStatus.FailureReason}) diff --git a/cmd/garm/main.go b/cmd/garm/main.go index ad80c521..454d766f 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -167,6 +167,9 @@ func main() { } setupLogging(ctx, logCfg, hub) + // Migrate credentials to the new format. This field will be read + // by the DB migration logic. + cfg.Database.MigrateCredentials = cfg.Github db, err := database.NewDatabase(ctx, cfg.Database) if err != nil { log.Fatal(err) diff --git a/config/config.go b/config/config.go index d777de7e..baafcb8e 100644 --- a/config/config.go +++ b/config/config.go @@ -241,6 +241,14 @@ type GithubApp struct { InstallationID int64 `toml:"installation_id" json:"installation-id"` } +func (a *GithubApp) PrivateKeyBytes() ([]byte, error) { + keyBytes, err := os.ReadFile(a.PrivateKeyPath) + if err != nil { + return nil, fmt.Errorf("reading private_key_path: %w", err) + } + return keyBytes, nil +} + func (a *GithubApp) Validate() error { if a.AppID == 0 { return fmt.Errorf("missing app_id") @@ -472,6 +480,11 @@ type Database struct { // Don't lose or change this. It will invalidate all encrypted data // in the DB. This field must be set and must be exactly 32 characters. Passphrase string `toml:"passphrase"` + + // MigrateCredentials is a list of github credentials that need to be migrated + // from the config file to the database. This field will be removed once GARM + // reaches version 0.2.x. It's only meant to be used for the migration process. + MigrateCredentials []Github `toml:"-"` } // GormParams returns the database type and connection URI diff --git a/database/common/common.go b/database/common/common.go index c270051f..8f901ab7 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -20,6 +20,23 @@ import ( "github.com/cloudbase/garm/params" ) +type GithubEndpointStore interface { + CreateGithubEndpoint(ctx context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) + GetGithubEndpoint(ctx context.Context, name string) (params.GithubEndpoint, error) + ListGithubEndpoints(ctx context.Context) ([]params.GithubEndpoint, error) + UpdateGithubEndpoint(ctx context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) + DeleteGithubEndpoint(ctx context.Context, name string) error +} + +type GithubCredentialsStore interface { + CreateGithubCredentials(ctx context.Context, endpointName string, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) + GetGithubCredentials(ctx context.Context, id uint, detailed bool) (params.GithubCredentials, error) + GetGithubCredentialsByName(ctx context.Context, name string, detailed bool) (params.GithubCredentials, error) + ListGithubCredentials(ctx context.Context) ([]params.GithubCredentials, error) + UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) + DeleteGithubCredentials(ctx context.Context, id uint) error +} + type RepoStore interface { CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) GetRepository(ctx context.Context, owner, name string) (params.Repository, error) @@ -65,6 +82,7 @@ type PoolStore interface { type UserStore interface { GetUser(ctx context.Context, user string) (params.User, error) GetUserByID(ctx context.Context, userID string) (params.User, error) + GetAdminUser(ctx context.Context) (params.User, error) CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error) UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error) @@ -121,6 +139,8 @@ type Store interface { InstanceStore JobsStore EntityPools + GithubEndpointStore + GithubCredentialsStore ControllerInfo() (params.ControllerInfo, error) InitController() (params.ControllerInfo, error) diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index 5b24fdc8..f8877ef7 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -134,6 +134,62 @@ func (_m *Store) CreateEntityPool(ctx context.Context, entity params.GithubEntit return r0, r1 } +// CreateGithubCredentials provides a mock function with given fields: ctx, endpointName, param +func (_m *Store) CreateGithubCredentials(ctx context.Context, endpointName string, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) { + ret := _m.Called(ctx, endpointName, param) + + if len(ret) == 0 { + panic("no return value specified for CreateGithubCredentials") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, params.CreateGithubCredentialsParams) (params.GithubCredentials, error)); ok { + return rf(ctx, endpointName, param) + } + if rf, ok := ret.Get(0).(func(context.Context, string, params.CreateGithubCredentialsParams) params.GithubCredentials); ok { + r0 = rf(ctx, endpointName, param) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, params.CreateGithubCredentialsParams) error); ok { + r1 = rf(ctx, endpointName, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateGithubEndpoint provides a mock function with given fields: ctx, param +func (_m *Store) CreateGithubEndpoint(ctx context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) { + ret := _m.Called(ctx, param) + + if len(ret) == 0 { + panic("no return value specified for CreateGithubEndpoint") + } + + var r0 params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, params.CreateGithubEndpointParams) (params.GithubEndpoint, error)); ok { + return rf(ctx, param) + } + if rf, ok := ret.Get(0).(func(context.Context, params.CreateGithubEndpointParams) params.GithubEndpoint); ok { + r0 = rf(ctx, param) + } else { + r0 = ret.Get(0).(params.GithubEndpoint) + } + + if rf, ok := ret.Get(1).(func(context.Context, params.CreateGithubEndpointParams) error); ok { + r1 = rf(ctx, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // CreateInstance provides a mock function with given fields: ctx, poolID, param func (_m *Store) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { ret := _m.Called(ctx, poolID, param) @@ -328,6 +384,42 @@ func (_m *Store) DeleteEntityPool(ctx context.Context, entity params.GithubEntit return r0 } +// DeleteGithubCredentials provides a mock function with given fields: ctx, id +func (_m *Store) DeleteGithubCredentials(ctx context.Context, id uint) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteGithubCredentials") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, uint) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteGithubEndpoint provides a mock function with given fields: ctx, name +func (_m *Store) DeleteGithubEndpoint(ctx context.Context, name string) error { + ret := _m.Called(ctx, name) + + if len(ret) == 0 { + panic("no return value specified for DeleteGithubEndpoint") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, name) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeleteInstance provides a mock function with given fields: ctx, poolID, instanceName func (_m *Store) DeleteInstance(ctx context.Context, poolID string, instanceName string) error { ret := _m.Called(ctx, poolID, instanceName) @@ -448,6 +540,34 @@ func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params return r0, r1 } +// GetAdminUser provides a mock function with given fields: ctx +func (_m *Store) GetAdminUser(ctx context.Context) (params.User, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetAdminUser") + } + + var r0 params.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (params.User, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) params.User); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(params.User) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetEnterprise provides a mock function with given fields: ctx, name func (_m *Store) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { ret := _m.Called(ctx, name) @@ -532,6 +652,90 @@ func (_m *Store) GetEntityPool(ctx context.Context, entity params.GithubEntity, return r0, r1 } +// GetGithubCredentials provides a mock function with given fields: ctx, id, detailed +func (_m *Store) GetGithubCredentials(ctx context.Context, id uint, detailed bool) (params.GithubCredentials, error) { + ret := _m.Called(ctx, id, detailed) + + if len(ret) == 0 { + panic("no return value specified for GetGithubCredentials") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint, bool) (params.GithubCredentials, error)); ok { + return rf(ctx, id, detailed) + } + if rf, ok := ret.Get(0).(func(context.Context, uint, bool) params.GithubCredentials); ok { + r0 = rf(ctx, id, detailed) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, uint, bool) error); ok { + r1 = rf(ctx, id, detailed) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetGithubCredentialsByName provides a mock function with given fields: ctx, name, detailed +func (_m *Store) GetGithubCredentialsByName(ctx context.Context, name string, detailed bool) (params.GithubCredentials, error) { + ret := _m.Called(ctx, name, detailed) + + if len(ret) == 0 { + panic("no return value specified for GetGithubCredentialsByName") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) (params.GithubCredentials, error)); ok { + return rf(ctx, name, detailed) + } + if rf, ok := ret.Get(0).(func(context.Context, string, bool) params.GithubCredentials); ok { + r0 = rf(ctx, name, detailed) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok { + r1 = rf(ctx, name, detailed) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetGithubEndpoint provides a mock function with given fields: ctx, name +func (_m *Store) GetGithubEndpoint(ctx context.Context, name string) (params.GithubEndpoint, error) { + ret := _m.Called(ctx, name) + + if len(ret) == 0 { + panic("no return value specified for GetGithubEndpoint") + } + + var r0 params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (params.GithubEndpoint, error)); ok { + return rf(ctx, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string) params.GithubEndpoint); ok { + r0 = rf(ctx, name) + } else { + r0 = ret.Get(0).(params.GithubEndpoint) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetInstanceByName provides a mock function with given fields: ctx, instanceName func (_m *Store) GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error) { ret := _m.Called(ctx, instanceName) @@ -1068,6 +1272,66 @@ func (_m *Store) ListEntityPools(ctx context.Context, entity params.GithubEntity return r0, r1 } +// ListGithubCredentials provides a mock function with given fields: ctx +func (_m *Store) ListGithubCredentials(ctx context.Context) ([]params.GithubCredentials, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListGithubCredentials") + } + + var r0 []params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]params.GithubCredentials, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []params.GithubCredentials); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]params.GithubCredentials) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListGithubEndpoints provides a mock function with given fields: ctx +func (_m *Store) ListGithubEndpoints(ctx context.Context) ([]params.GithubEndpoint, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListGithubEndpoints") + } + + var r0 []params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]params.GithubEndpoint, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []params.GithubEndpoint); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]params.GithubEndpoint) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListJobsByStatus provides a mock function with given fields: ctx, status func (_m *Store) ListJobsByStatus(ctx context.Context, status params.JobStatus) ([]params.Job, error) { ret := _m.Called(ctx, status) @@ -1308,6 +1572,62 @@ func (_m *Store) UpdateEntityPool(ctx context.Context, entity params.GithubEntit return r0, r1 } +// UpdateGithubCredentials provides a mock function with given fields: ctx, id, param +func (_m *Store) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) { + ret := _m.Called(ctx, id, param) + + if len(ret) == 0 { + panic("no return value specified for UpdateGithubCredentials") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint, params.UpdateGithubCredentialsParams) (params.GithubCredentials, error)); ok { + return rf(ctx, id, param) + } + if rf, ok := ret.Get(0).(func(context.Context, uint, params.UpdateGithubCredentialsParams) params.GithubCredentials); ok { + r0 = rf(ctx, id, param) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, uint, params.UpdateGithubCredentialsParams) error); ok { + r1 = rf(ctx, id, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateGithubEndpoint provides a mock function with given fields: ctx, name, param +func (_m *Store) UpdateGithubEndpoint(ctx context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) { + ret := _m.Called(ctx, name, param) + + if len(ret) == 0 { + panic("no return value specified for UpdateGithubEndpoint") + } + + var r0 params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, params.UpdateGithubEndpointParams) (params.GithubEndpoint, error)); ok { + return rf(ctx, name, param) + } + if rf, ok := ret.Get(0).(func(context.Context, string, params.UpdateGithubEndpointParams) params.GithubEndpoint); ok { + r0 = rf(ctx, name, param) + } else { + r0 = ret.Get(0).(params.GithubEndpoint) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, params.UpdateGithubEndpointParams) error); ok { + r1 = rf(ctx, name, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // UpdateInstance provides a mock function with given fields: ctx, instanceName, param func (_m *Store) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) { ret := _m.Called(ctx, instanceName, param) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index 3eb53b9e..e7270faf 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -54,7 +54,7 @@ func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.En } func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Pools") + enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -68,7 +68,7 @@ func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, error) { var enterprises []Enterprise - q := s.conn.Find(&enterprises) + q := s.conn.Preload("Credentials").Find(&enterprises) if q.Error != nil { return []params.Enterprise{}, errors.Wrap(q.Error, "fetching enterprises") } @@ -100,7 +100,7 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) } func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) + enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Credentials", "Endpoint") if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -136,8 +136,10 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, error) { var enterprise Enterprise - q := s.conn.Where("name = ? COLLATE NOCASE", name) - q = q.First(&enterprise) + q := s.conn.Where("name = ? COLLATE NOCASE", name). + Preload("Credentials"). + Preload("Endpoint"). + First(&enterprise) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return Enterprise{}, runnerErrors.ErrNotFound diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index b20a1f20..7f22956b 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -173,7 +173,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() { s.FailNow(fmt.Sprintf("failed to get enterprise by id: %v", err)) } s.Require().Equal(storeEnterprise.Name, enterprise.Name) - s.Require().Equal(storeEnterprise.CredentialsName, enterprise.CredentialsName) + s.Require().Equal(storeEnterprise.Credentials.Name, enterprise.Credentials.Name) s.Require().Equal(storeEnterprise.WebhookSecret, enterprise.WebhookSecret) } @@ -313,7 +313,7 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprise() { enterprise, err := s.Store.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, enterprise.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, enterprise.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, enterprise.WebhookSecret) } diff --git a/database/sql/github.go b/database/sql/github.go new file mode 100644 index 00000000..0087165f --- /dev/null +++ b/database/sql/github.go @@ -0,0 +1,473 @@ +package sql + +import ( + "context" + + "github.com/google/uuid" + "github.com/pkg/errors" + "gorm.io/gorm" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/auth" + "github.com/cloudbase/garm/params" +) + +func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (params.GithubCredentials, error) { + data, err := util.Unseal(creds.Payload, []byte(s.cfg.Passphrase)) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "unsealing credentials") + } + + commonCreds := params.GithubCredentials{ + ID: creds.ID, + Name: creds.Name, + Description: creds.Description, + APIBaseURL: creds.Endpoint.APIBaseURL, + BaseURL: creds.Endpoint.BaseURL, + UploadBaseURL: creds.Endpoint.UploadBaseURL, + CABundle: creds.Endpoint.CACertBundle, + AuthType: creds.AuthType, + Endpoint: creds.Endpoint.Name, + CredentialsPayload: data, + } + + for _, repo := range creds.Repositories { + commonRepo, err := s.sqlToCommonRepository(repo) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github repository") + } + commonCreds.Repositories = append(commonCreds.Repositories, commonRepo) + } + + for _, org := range creds.Organizations { + commonOrg, err := s.sqlToCommonOrganization(org) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github organization") + } + commonCreds.Organizations = append(commonCreds.Organizations, commonOrg) + } + + for _, ent := range creds.Enterprises { + commonEnt, err := s.sqlToCommonEnterprise(ent) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github enterprise") + } + commonCreds.Enterprises = append(commonCreds.Enterprises, commonEnt) + } + + return commonCreds, nil +} + +func (s *sqlDatabase) sqlToCommonGithubEndpoint(ep GithubEndpoint) (params.GithubEndpoint, error) { + return params.GithubEndpoint{ + Name: ep.Name, + Description: ep.Description, + APIBaseURL: ep.APIBaseURL, + BaseURL: ep.BaseURL, + UploadBaseURL: ep.UploadBaseURL, + CACertBundle: ep.CACertBundle, + }, nil +} + +func getUIDFromContext(ctx context.Context) (uuid.UUID, error) { + userID := auth.UserID(ctx) + if userID == "" { + return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") + } + + asUUID, err := uuid.Parse(userID) + if err != nil { + return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") + } + return asUUID, nil +} + +func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) { + var endpoint GithubEndpoint + err := s.conn.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("name = ?", param.Name).First(&endpoint).Error; err == nil { + return errors.Wrap(runnerErrors.ErrDuplicateEntity, "github endpoint already exists") + } + endpoint = GithubEndpoint{ + Name: param.Name, + Description: param.Description, + APIBaseURL: param.APIBaseURL, + BaseURL: param.BaseURL, + UploadBaseURL: param.UploadBaseURL, + CACertBundle: param.CACertBundle, + } + + if err := tx.Create(&endpoint).Error; err != nil { + return errors.Wrap(err, "creating github endpoint") + } + return nil + }) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "creating github endpoint") + } + return s.sqlToCommonGithubEndpoint(endpoint) +} + +func (s *sqlDatabase) ListGithubEndpoints(_ context.Context) ([]params.GithubEndpoint, error) { + var endpoints []GithubEndpoint + err := s.conn.Find(&endpoints).Error + if err != nil { + return nil, errors.Wrap(err, "fetching github endpoints") + } + + var ret []params.GithubEndpoint + for _, ep := range endpoints { + commonEp, err := s.sqlToCommonGithubEndpoint(ep) + if err != nil { + return nil, errors.Wrap(err, "converting github endpoint") + } + ret = append(ret, commonEp) + } + return ret, nil +} + +func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) { + var endpoint GithubEndpoint + err := s.conn.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "github endpoint not found") + } + return errors.Wrap(err, "fetching github endpoint") + } + if param.APIBaseURL != nil { + endpoint.APIBaseURL = *param.APIBaseURL + } + + if param.BaseURL != nil { + endpoint.BaseURL = *param.BaseURL + } + + if param.UploadBaseURL != nil { + endpoint.UploadBaseURL = *param.UploadBaseURL + } + + if param.CACertBundle != nil { + endpoint.CACertBundle = param.CACertBundle + } + + if param.Description != nil { + endpoint.Description = *param.Description + } + + if err := tx.Save(&endpoint).Error; err != nil { + return errors.Wrap(err, "updating github endpoint") + } + + return nil + }) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "updating github endpoint") + } + return s.sqlToCommonGithubEndpoint(endpoint) +} + +func (s *sqlDatabase) GetGithubEndpoint(_ context.Context, name string) (params.GithubEndpoint, error) { + var endpoint GithubEndpoint + + err := s.conn.Where("name = ?", name).First(&endpoint).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return params.GithubEndpoint{}, errors.Wrap(err, "github endpoint not found") + } + return params.GithubEndpoint{}, errors.Wrap(err, "fetching github endpoint") + } + + return s.sqlToCommonGithubEndpoint(endpoint) +} + +func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) error { + err := s.conn.Transaction(func(tx *gorm.DB) error { + var endpoint GithubEndpoint + if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return errors.Wrap(err, "fetching github endpoint") + } + + var credsCount int64 + if err := tx.Model(&GithubCredentials{}).Where("endpoint_name = ?", endpoint.Name).Count(&credsCount).Error; err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(err, "fetching github credentials") + } + } + + if credsCount > 0 { + return errors.New("cannot delete endpoint with credentials") + } + + if err := tx.Unscoped().Delete(&endpoint).Error; err != nil { + return errors.Wrap(err, "deleting github endpoint") + } + return nil + }) + if err != nil { + return errors.Wrap(err, "deleting github endpoint") + } + return nil +} + +func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, endpointName string, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials") + } + var creds GithubCredentials + err = s.conn.Transaction(func(tx *gorm.DB) error { + var endpoint GithubEndpoint + if err := tx.Where("name = ?", endpointName).First(&endpoint).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "github endpoint not found") + } + return errors.Wrap(err, "fetching github endpoint") + } + + if err := tx.Where("name = ?", param.Name).First(&creds).Error; err == nil { + return errors.New("github credentials already exists") + } + + var data []byte + var err error + switch param.AuthType { + case params.GithubAuthTypePAT: + data, err = s.marshalAndSeal(param.PAT) + case params.GithubAuthTypeApp: + data, err = s.marshalAndSeal(param.App) + } + if err != nil { + return errors.Wrap(err, "marshaling and sealing credentials") + } + + creds = GithubCredentials{ + Name: param.Name, + Description: param.Description, + EndpointName: &endpoint.Name, + AuthType: param.AuthType, + Payload: data, + UserID: &userID, + } + + if err := tx.Create(&creds).Error; err != nil { + return errors.Wrap(err, "creating github credentials") + } + // Skip making an extra query. + creds.Endpoint = endpoint + + return nil + }) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials") + } + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) getGithubCredentialsByName(ctx context.Context, tx *gorm.DB, name string, detailed bool) (GithubCredentials, error) { + var creds GithubCredentials + q := tx.Preload("Endpoint") + + if detailed { + q = q. + Preload("Repositories"). + Preload("Organizations"). + Preload("Enterprises") + } + + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + q = q.Where("user_id = ?", userID) + } + + err := q.Where("name = ?", name).First(&creds).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return GithubCredentials{}, errors.Wrap(runnerErrors.ErrNotFound, "github credentials not found") + } + return GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + + return creds, nil +} + +func (s *sqlDatabase) GetGithubCredentialsByName(ctx context.Context, name string, detailed bool) (params.GithubCredentials, error) { + creds, err := s.getGithubCredentialsByName(ctx, s.conn, name, detailed) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) GetGithubCredentials(ctx context.Context, id uint, detailed bool) (params.GithubCredentials, error) { + var creds GithubCredentials + q := s.conn.Preload("Endpoint") + + if detailed { + q = q. + Preload("Repositories"). + Preload("Organizations"). + Preload("Enterprises") + } + + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + q = q.Where("user_id = ?", userID) + } + + err := q.Where("id = ?", id).First(&creds).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return params.GithubCredentials{}, errors.Wrap(runnerErrors.ErrNotFound, "github credentials not found") + } + return params.GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) ListGithubCredentials(ctx context.Context) ([]params.GithubCredentials, error) { + q := s.conn.Preload("Endpoint") + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return nil, errors.Wrap(err, "fetching github credentials") + } + q = q.Where("user_id = ?", userID) + } + + var creds []GithubCredentials + err := q.Preload("Endpoint").Find(&creds).Error + if err != nil { + return nil, errors.Wrap(err, "fetching github credentials") + } + + var ret []params.GithubCredentials + for _, c := range creds { + commonCreds, err := s.sqlToCommonGithubCredentials(c) + if err != nil { + return nil, errors.Wrap(err, "converting github credentials") + } + ret = append(ret, commonCreds) + } + return ret, nil +} + +func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) { + var creds GithubCredentials + err := s.conn.Transaction(func(tx *gorm.DB) error { + q := tx.Preload("Endpoint") + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return errors.Wrap(err, "updating github credentials") + } + q = q.Where("user_id = ?", userID) + } + + if err := q.Where("id = ?", id).First(&creds).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "github credentials not found") + } + return errors.Wrap(err, "fetching github credentials") + } + + if param.Name != nil { + creds.Name = *param.Name + } + if param.Description != nil { + creds.Description = *param.Description + } + + var data []byte + var err error + switch creds.AuthType { + case params.GithubAuthTypePAT: + if param.PAT != nil { + data, err = s.marshalAndSeal(param.PAT) + } + + if param.App != nil { + return errors.New("cannot update app credentials for PAT") + } + case params.GithubAuthTypeApp: + if param.App != nil { + data, err = s.marshalAndSeal(param.App) + } + + if param.PAT != nil { + return errors.New("cannot update PAT credentials for app") + } + } + + if err != nil { + return errors.Wrap(err, "marshaling and sealing credentials") + } + if len(data) > 0 { + creds.Payload = data + } + + if err := tx.Save(&creds).Error; err != nil { + return errors.Wrap(err, "updating github credentials") + } + return nil + }) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "updating github credentials") + } + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) error { + err := s.conn.Transaction(func(tx *gorm.DB) error { + q := tx.Where("id = ?", id). + Preload("Repositories"). + Preload("Organizations"). + Preload("Enterprises") + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return errors.Wrap(err, "deleting github credentials") + } + q = q.Where("user_id = ?", userID) + } + + var creds GithubCredentials + err := q.First(&creds).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return errors.Wrap(err, "fetching github credentials") + } + if len(creds.Repositories) > 0 { + return errors.New("cannot delete credentials with repositories") + } + if len(creds.Organizations) > 0 { + return errors.New("cannot delete credentials with organizations") + } + if len(creds.Enterprises) > 0 { + return errors.New("cannot delete credentials with enterprises") + } + + if err := tx.Unscoped().Delete(&creds).Error; err != nil { + return errors.Wrap(err, "deleting github credentials") + } + return nil + }) + if err != nil { + return errors.Wrap(err, "deleting github credentials") + } + return nil +} diff --git a/database/sql/models.go b/database/sql/models.go index 874a375d..633a1b51 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -89,35 +89,56 @@ type Pool struct { type Repository struct { Base - CredentialsName string + CredentialsName string + + CredentialsID *uint `gorm:"index"` + Credentials GithubCredentials `gorm:"foreignKey:CredentialsID;constraint:OnDelete:SET NULL"` + Owner string `gorm:"index:idx_owner_nocase,unique,collate:nocase"` Name string `gorm:"index:idx_owner_nocase,unique,collate:nocase"` WebhookSecret []byte Pools []Pool `gorm:"foreignKey:RepoID"` Jobs []WorkflowJob `gorm:"foreignKey:RepoID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` + + EndpointName *string `gorm:"index:idx_owner_nocase,unique,collate:nocase"` + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } type Organization struct { Base - CredentialsName string + CredentialsName string + + CredentialsID *uint `gorm:"index"` + Credentials GithubCredentials `gorm:"foreignKey:CredentialsID;constraint:OnDelete:SET NULL"` + Name string `gorm:"index:idx_org_name_nocase,collate:nocase"` WebhookSecret []byte Pools []Pool `gorm:"foreignKey:OrgID"` Jobs []WorkflowJob `gorm:"foreignKey:OrgID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` + + EndpointName *string `gorm:"index"` + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } type Enterprise struct { Base - CredentialsName string + CredentialsName string + + CredentialsID *uint `gorm:"index"` + Credentials GithubCredentials `gorm:"foreignKey:CredentialsID;constraint:OnDelete:SET NULL"` + Name string `gorm:"index:idx_ent_name_nocase,collate:nocase"` WebhookSecret []byte Pools []Pool `gorm:"foreignKey:EnterpriseID"` Jobs []WorkflowJob `gorm:"foreignKey:EnterpriseID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` + + EndpointName *string `gorm:"index"` + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } type Address struct { @@ -246,3 +267,35 @@ type WorkflowJob struct { UpdatedAt time.Time DeletedAt gorm.DeletedAt `gorm:"index"` } + +type GithubEndpoint struct { + Name string `gorm:"type:varchar(64) collate nocase;primary_key;"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` + + Description string `gorm:"type:text"` + APIBaseURL string `gorm:"type:text collate nocase"` + UploadBaseURL string `gorm:"type:text collate nocase"` + BaseURL string `gorm:"type:text collate nocase"` + CACertBundle []byte `gorm:"type:longblob"` +} + +type GithubCredentials struct { + gorm.Model + + Name string `gorm:"index:idx_github_credentials,unique;type:varchar(64) collate nocase"` + UserID *uuid.UUID `gorm:"index:idx_github_credentials,unique"` + User User `gorm:"foreignKey:UserID"` + + Description string `gorm:"type:text"` + AuthType params.GithubAuthType `gorm:"index"` + Payload []byte `gorm:"type:longblob"` + + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName"` + EndpointName *string `gorm:"index"` + + Repositories []Repository `gorm:"foreignKey:CredentialsID"` + Organizations []Organization `gorm:"foreignKey:CredentialsID"` + Enterprises []Enterprise `gorm:"foreignKey:CredentialsID"` +} diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 24704fd9..c67e85d4 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -72,7 +72,7 @@ func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params. func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organization, error) { var orgs []Organization - q := s.conn.Find(&orgs) + q := s.conn.Preload("Credentials").Find(&orgs) if q.Error != nil { return []params.Organization{}, errors.Wrap(q.Error, "fetching org from database") } @@ -104,7 +104,7 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) erro } func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) { - org, err := s.getOrgByID(ctx, orgID) + org, err := s.getOrgByID(ctx, orgID, "Credentials", "Endpoint") if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } @@ -138,7 +138,7 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para } func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) { - org, err := s.getOrgByID(ctx, orgID, "Pools") + org, err := s.getOrgByID(ctx, orgID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } @@ -177,8 +177,10 @@ func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string func (s *sqlDatabase) getOrg(_ context.Context, name string) (Organization, error) { var org Organization - q := s.conn.Where("name = ? COLLATE NOCASE", name) - q = q.First(&org) + q := s.conn.Where("name = ? COLLATE NOCASE", name). + Preload("Credentials"). + Preload("Endpoint"). + First(&org) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return Organization{}, runnerErrors.ErrNotFound diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 77ebab90..28a049e5 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -173,7 +173,7 @@ func (s *OrgTestSuite) TestCreateOrganization() { s.FailNow(fmt.Sprintf("failed to get organization by id: %v", err)) } s.Require().Equal(storeOrg.Name, org.Name) - s.Require().Equal(storeOrg.CredentialsName, org.CredentialsName) + s.Require().Equal(storeOrg.Credentials.Name, org.Credentials.Name) s.Require().Equal(storeOrg.WebhookSecret, org.WebhookSecret) } @@ -313,7 +313,7 @@ func (s *OrgTestSuite) TestUpdateOrganization() { org, err := s.Store.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, org.WebhookSecret) } diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 164c0197..396b2796 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -27,7 +27,7 @@ import ( "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateRepository(_ context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) { +func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) { if webhookSecret == "" { return params.Repository{}, errors.New("creating repo: missing secret") } @@ -35,17 +35,29 @@ func (s *sqlDatabase) CreateRepository(_ context.Context, owner, name, credentia if err != nil { return params.Repository{}, fmt.Errorf("failed to encrypt string") } - newRepo := Repository{ - Name: name, - Owner: owner, - WebhookSecret: secret, - CredentialsName: credentialsName, - PoolBalancerType: poolBalancerType, - } - q := s.conn.Create(&newRepo) - if q.Error != nil { - return params.Repository{}, errors.Wrap(q.Error, "creating repository") + var newRepo Repository + err = s.conn.Transaction(func(tx *gorm.DB) error { + creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false) + if err != nil { + return errors.Wrap(err, "creating repository") + } + + newRepo.Name = name + newRepo.Owner = owner + newRepo.WebhookSecret = secret + newRepo.CredentialsID = &creds.ID + newRepo.PoolBalancerType = poolBalancerType + + q := tx.Create(&newRepo) + if q.Error != nil { + return errors.Wrap(q.Error, "creating repository") + } + + return nil + }) + if err != nil { + return params.Repository{}, errors.Wrap(err, "creating repository") } param, err := s.sqlToCommonRepository(newRepo) @@ -72,7 +84,7 @@ func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (pa func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, error) { var repos []Repository - q := s.conn.Find(&repos) + q := s.conn.Preload("Credentials").Find(&repos) if q.Error != nil { return []params.Repository{}, errors.Wrap(q.Error, "fetching user from database") } @@ -104,7 +116,7 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error } func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) { - repo, err := s.getRepoByID(ctx, repoID) + repo, err := s.getRepoByID(ctx, repoID, "Credentials", "Endpoint") if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -138,7 +150,7 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param } func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error) { - repo, err := s.getRepoByID(ctx, repoID, "Pools") + repo, err := s.getRepoByID(ctx, repoID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -154,6 +166,8 @@ func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository var repo Repository q := s.conn.Where("name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE", name, owner). + Preload("Credentials"). + Preload("Endpoint"). First(&repo) q = q.First(&repo) diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 3d335b10..fbd68304 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -188,7 +188,7 @@ func (s *RepoTestSuite) TestCreateRepository() { } s.Require().Equal(storeRepo.Owner, repo.Owner) s.Require().Equal(storeRepo.Name, repo.Name) - s.Require().Equal(storeRepo.CredentialsName, repo.CredentialsName) + s.Require().Equal(storeRepo.Credentials.Name, repo.Credentials.Name) s.Require().Equal(storeRepo.WebhookSecret, repo.WebhookSecret) } @@ -352,7 +352,7 @@ func (s *RepoTestSuite) TestUpdateRepository() { repo, err := s.Store.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, repo.WebhookSecret) } diff --git a/database/sql/sql.go b/database/sql/sql.go index e513be42..1bc16e08 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "log/slog" + "net/url" "strings" "github.com/pkg/errors" @@ -26,8 +27,12 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/auth" "github.com/cloudbase/garm/config" "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" + "github.com/cloudbase/garm/util/appdefaults" ) // newDBConn returns a new gorm db connection, given the config @@ -190,6 +195,154 @@ func (s *sqlDatabase) cascadeMigration() error { return nil } +func (s *sqlDatabase) migrateCredentialsToDB() (err error) { + s.conn.Exec("PRAGMA foreign_keys = OFF") + defer s.conn.Exec("PRAGMA foreign_keys = ON") + + adminUser, err := s.GetAdminUser(s.ctx) + if err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + // Admin user doesn't exist. This is a new deploy. Nothing to migrate. + return nil + } + return errors.Wrap(err, "getting admin user") + } + + // Impersonate the admin user. We're migrating from config credentials to + // database credentials. At this point, there is no other user than the admin + // user. GARM is not yet multi-user, so it's safe to assume we only have this + // one user. + adminCtx := context.Background() + adminCtx = auth.PopulateContext(adminCtx, adminUser) + + slog.Info("migrating credentials to DB") + slog.Info("creating github endpoints table") + if err := s.conn.AutoMigrate(&GithubEndpoint{}); err != nil { + return errors.Wrap(err, "migrating github endpoints") + } + + defer func() { + if err != nil { + slog.With(slog.Any("error", err)).Error("rolling back github github endpoints table") + s.conn.Migrator().DropTable(&GithubEndpoint{}) + } + }() + + slog.Info("creating github credentials table") + if err := s.conn.AutoMigrate(&GithubCredentials{}); err != nil { + return errors.Wrap(err, "migrating github credentials") + } + + defer func() { + if err != nil { + slog.With(slog.Any("error", err)).Error("rolling back github github credentials table") + s.conn.Migrator().DropTable(&GithubCredentials{}) + } + }() + + // Create the default Github endpoint. + createEndpointParams := params.CreateGithubEndpointParams{ + Name: "github.com", + Description: "The github.com endpoint", + APIBaseURL: appdefaults.GithubDefaultBaseURL, + BaseURL: appdefaults.DefaultGithubURL, + UploadBaseURL: appdefaults.GithubDefaultUploadBaseURL, + } + + if _, err := s.CreateGithubEndpoint(adminCtx, createEndpointParams); err != nil { + if !errors.Is(err, runnerErrors.ErrDuplicateEntity) { + return errors.Wrap(err, "creating default github endpoint") + } + } + + // Nothing to migrate. + if len(s.cfg.MigrateCredentials) == 0 { + return nil + } + + slog.Info("importing credentials from config") + for _, cred := range s.cfg.MigrateCredentials { + slog.Info("importing credential", "name", cred.Name) + parsed, err := url.Parse(cred.BaseEndpoint()) + if err != nil { + return errors.Wrap(err, "parsing base URL") + } + + certBundle, err := cred.CACertBundle() + if err != nil { + return errors.Wrap(err, "getting CA cert bundle") + } + hostname := parsed.Hostname() + createParams := params.CreateGithubEndpointParams{ + Name: hostname, + Description: fmt.Sprintf("Endpoint for %s", hostname), + APIBaseURL: cred.APIEndpoint(), + BaseURL: cred.BaseEndpoint(), + UploadBaseURL: cred.UploadEndpoint(), + CACertBundle: certBundle, + } + + var endpoint params.GithubEndpoint + endpoint, err = s.GetGithubEndpoint(adminCtx, hostname) + if err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return errors.Wrap(err, "getting github endpoint") + } + endpoint, err = s.CreateGithubEndpoint(adminCtx, createParams) + if err != nil { + return errors.Wrap(err, "creating default github endpoint") + } + } + + credParams := params.CreateGithubCredentialsParams{ + Name: cred.Name, + Description: cred.Description, + AuthType: params.GithubAuthType(cred.AuthType), + } + switch credParams.AuthType { + case params.GithubAuthTypeApp: + keyBytes, err := cred.App.PrivateKeyBytes() + if err != nil { + return errors.Wrap(err, "getting private key bytes") + } + credParams.App = params.GithubApp{ + AppID: cred.App.AppID, + InstallationID: cred.App.InstallationID, + PrivateKeyBytes: keyBytes, + } + + if err := credParams.App.Validate(); err != nil { + return errors.Wrap(err, "validating app credentials") + } + case params.GithubAuthTypePAT: + if cred.PAT.OAuth2Token == "" { + return errors.New("missing OAuth2 token") + } + credParams.PAT = params.GithubPAT{ + OAuth2Token: cred.PAT.OAuth2Token, + } + } + + creds, err := s.CreateGithubCredentials(adminCtx, endpoint.Name, credParams) + if err != nil { + return errors.Wrap(err, "creating github credentials") + } + + if err := s.conn.Exec("update repositories set credentials_id = ?,endpoint_name = ? where credentials_name = ?", creds.ID, creds.Endpoint, creds.Name).Error; err != nil { + return errors.Wrap(err, "updating repositories") + } + + if err := s.conn.Exec("update organizations set credentials_id = ?,endpoint_name = ? where credentials_name = ?", creds.ID, creds.Endpoint, creds.Name).Error; err != nil { + return errors.Wrap(err, "updating organizations") + } + + if err := s.conn.Exec("update enterprises set credentials_id = ?,endpoint_name = ? where credentials_name = ?", creds.ID, creds.Endpoint, creds.Name).Error; err != nil { + return errors.Wrap(err, "updating enterprises") + } + } + return nil +} + func (s *sqlDatabase) migrateDB() error { if s.conn.Migrator().HasIndex(&Organization{}, "idx_organizations_name") { if err := s.conn.Migrator().DropIndex(&Organization{}, "idx_organizations_name"); err != nil { @@ -234,7 +387,15 @@ func (s *sqlDatabase) migrateDB() error { } } + var needsCredentialMigration bool + if !s.conn.Migrator().HasTable(&GithubCredentials{}) || !s.conn.Migrator().HasTable(&GithubEndpoint{}) { + needsCredentialMigration = true + } + s.conn.Exec("PRAGMA foreign_keys = OFF") if err := s.conn.AutoMigrate( + &User{}, + &GithubEndpoint{}, + &GithubCredentials{}, &Tag{}, &Pool{}, &Repository{}, @@ -244,11 +405,16 @@ func (s *sqlDatabase) migrateDB() error { &InstanceStatusUpdate{}, &Instance{}, &ControllerInfo{}, - &User{}, &WorkflowJob{}, ); err != nil { return errors.Wrap(err, "running auto migrate") } + s.conn.Exec("PRAGMA foreign_keys = ON") + if needsCredentialMigration { + if err := s.migrateCredentialsToDB(); err != nil { + return errors.Wrap(err, "migrating credentials") + } + } return nil } diff --git a/database/sql/users.go b/database/sql/users.go index 039d86fe..5fc47564 100644 --- a/database/sql/users.go +++ b/database/sql/users.go @@ -67,6 +67,10 @@ func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) ( return params.User{}, runnerErrors.NewConflictError("email already exists") } + if s.HasAdminUser(context.Background()) && user.IsAdmin { + return params.User{}, runnerErrors.NewBadRequestError("admin user already exists") + } + newUser := User{ Username: user.Username, Password: user.Password, @@ -129,3 +133,16 @@ func (s *sqlDatabase) UpdateUser(_ context.Context, user string, param params.Up return s.sqlToParamsUser(dbUser), nil } + +// GetAdminUser returns the system admin user. This is only for internal use. +func (s *sqlDatabase) GetAdminUser(_ context.Context) (params.User, error) { + var user User + q := s.conn.Model(&User{}).Where("is_admin = ?", true).First(&user) + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return params.User{}, runnerErrors.ErrNotFound + } + return params.User{}, errors.Wrap(q.Error, "fetching admin user") + } + return s.sqlToParamsUser(user), nil +} diff --git a/database/sql/util.go b/database/sql/util.go index aaea31fe..30946863 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -114,10 +114,16 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organiza return params.Organization{}, errors.Wrap(err, "decrypting secret") } + creds, err := s.sqlToCommonGithubCredentials(org.Credentials) + if err != nil { + return params.Organization{}, errors.Wrap(err, "converting credentials") + } + ret := params.Organization{ ID: org.ID.String(), Name: org.Name, - CredentialsName: org.CredentialsName, + CredentialsName: creds.Name, + Credentials: creds, Pools: make([]params.Pool, len(org.Pools)), WebhookSecret: string(secret), PoolBalancerType: org.PoolBalancerType, @@ -146,10 +152,15 @@ func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise) (params.Enter return params.Enterprise{}, errors.Wrap(err, "decrypting secret") } + creds, err := s.sqlToCommonGithubCredentials(enterprise.Credentials) + if err != nil { + return params.Enterprise{}, errors.Wrap(err, "converting credentials") + } ret := params.Enterprise{ ID: enterprise.ID.String(), Name: enterprise.Name, - CredentialsName: enterprise.CredentialsName, + CredentialsName: creds.Name, + Credentials: creds, Pools: make([]params.Pool, len(enterprise.Pools)), WebhookSecret: string(secret), PoolBalancerType: enterprise.PoolBalancerType, @@ -239,11 +250,16 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository) (params.Repository, return params.Repository{}, errors.Wrap(err, "decrypting secret") } + creds, err := s.sqlToCommonGithubCredentials(repo.Credentials) + if err != nil { + return params.Repository{}, errors.Wrap(err, "converting credentials") + } ret := params.Repository{ ID: repo.ID.String(), Name: repo.Name, Owner: repo.Owner, - CredentialsName: repo.CredentialsName, + CredentialsName: creds.Name, + Credentials: creds, Pools: make([]params.Pool, len(repo.Pools)), WebhookSecret: string(secret), PoolBalancerType: repo.PoolBalancerType, diff --git a/params/params.go b/params/params.go index a2a44222..e6bf8fb6 100644 --- a/params/params.go +++ b/params/params.go @@ -16,6 +16,8 @@ package params import ( "bytes" + "context" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -23,8 +25,10 @@ import ( "net/http" "time" + "github.com/bradleyfalzon/ghinstallation/v2" "github.com/google/go-github/v57/github" "github.com/google/uuid" + "golang.org/x/oauth2" commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm/util/appdefaults" @@ -398,6 +402,7 @@ type Repository struct { Name string `json:"name"` Pools []Pool `json:"pool,omitempty"` CredentialsName string `json:"credentials_name"` + Credentials GithubCredentials `json:"credentials"` PoolManagerStatus PoolManagerStatus `json:"pool_manager_status,omitempty"` PoolBalancerType PoolBalancerType `json:"pool_balancing_type"` // Do not serialize sensitive info. @@ -439,6 +444,7 @@ type Organization struct { Name string `json:"name"` Pools []Pool `json:"pool,omitempty"` CredentialsName string `json:"credentials_name"` + Credentials GithubCredentials `json:"credentials"` PoolManagerStatus PoolManagerStatus `json:"pool_manager_status,omitempty"` PoolBalancerType PoolBalancerType `json:"pool_balancing_type"` // Do not serialize sensitive info. @@ -480,6 +486,7 @@ type Enterprise struct { Name string `json:"name"` Pools []Pool `json:"pool,omitempty"` CredentialsName string `json:"credentials_name"` + Credentials GithubCredentials `json:"credentials"` PoolManagerStatus PoolManagerStatus `json:"pool_manager_status,omitempty"` PoolBalancerType PoolBalancerType `json:"pool_balancing_type"` // Do not serialize sensitive info. @@ -545,6 +552,7 @@ type ControllerInfo struct { } type GithubCredentials struct { + ID uint `json:"id"` Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` APIBaseURL string `json:"api_base_url"` @@ -552,7 +560,68 @@ type GithubCredentials struct { BaseURL string `json:"base_url"` CABundle []byte `json:"ca_bundle,omitempty"` AuthType GithubAuthType `toml:"auth_type" json:"auth-type"` - HTTPClient *http.Client `json:"-"` + + Repositories []Repository `json:"repositories,omitempty"` + Organizations []Organization `json:"organizations,omitempty"` + Enterprises []Enterprise `json:"enterprises,omitempty"` + Endpoint string `json:"endpoint"` + + CredentialsPayload []byte `json:"-"` + HTTPClient *http.Client `json:"-"` +} + +func (g GithubCredentials) GetHTTPClient(ctx context.Context) (*http.Client, error) { + var roots *x509.CertPool + if g.CABundle != nil { + roots = x509.NewCertPool() + ok := roots.AppendCertsFromPEM(g.CABundle) + if !ok { + return nil, fmt.Errorf("failed to parse CA cert") + } + } + // nolint:golangci-lint,gosec,godox + // TODO: set TLS MinVersion + httpTransport := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: roots, + }, + } + + var tc *http.Client + switch g.AuthType { + case GithubAuthTypeApp: + var app GithubApp + if err := json.Unmarshal(g.CredentialsPayload, &app); err != nil { + return nil, fmt.Errorf("failed to unmarshal github app credentials: %w", err) + } + if app.AppID == 0 || app.InstallationID == 0 || len(app.PrivateKeyBytes) == 0 { + return nil, fmt.Errorf("github app credentials are missing required fields") + } + itr, err := ghinstallation.New(httpTransport, app.AppID, app.InstallationID, app.PrivateKeyBytes) + if err != nil { + return nil, fmt.Errorf("failed to create github app installation transport: %w", err) + } + + tc = &http.Client{Transport: itr} + default: + var pat GithubPAT + if err := json.Unmarshal(g.CredentialsPayload, &pat); err != nil { + return nil, fmt.Errorf("failed to unmarshal github app credentials: %w", err) + } + httpClient := &http.Client{Transport: httpTransport} + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + if pat.OAuth2Token == "" { + return nil, fmt.Errorf("github credentials are missing the OAuth2 token") + } + + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: pat.OAuth2Token}, + ) + tc = oauth2.NewClient(ctx, ts) + } + + return tc, nil } func (g GithubCredentials) RootCertificateBundle() (CertificateBundle, error) { @@ -700,10 +769,11 @@ type UpdateSystemInfoParams struct { } type GithubEntity struct { - Owner string `json:"owner"` - Name string `json:"name"` - ID string `json:"id"` - EntityType GithubEntityType `json:"entity_type"` + Owner string `json:"owner"` + Name string `json:"name"` + ID string `json:"id"` + EntityType GithubEntityType `json:"entity_type"` + Credentials GithubCredentials `json:"credentials"` WebhookSecret string `json:"-"` } @@ -729,3 +799,14 @@ func (g GithubEntity) String() string { } return "" } + +type GithubEndpoint struct { + Name string `json:"name"` + Description string `json:"description"` + APIBaseURL string `json:"api_base_url"` + UploadBaseURL string `json:"upload_base_url"` + BaseURL string `json:"base_url"` + CACertBundle []byte `json:"ca_cert_bundle"` + + Credentials []GithubCredentials `json:"credentials"` +} diff --git a/params/requests.go b/params/requests.go index 885ed678..5da66d53 100644 --- a/params/requests.go +++ b/params/requests.go @@ -15,7 +15,9 @@ package params import ( + "crypto/x509" "encoding/json" + "encoding/pem" "fmt" "github.com/cloudbase/garm-provider-common/errors" @@ -265,3 +267,68 @@ type InstanceUpdateMessage struct { Message string `json:"message"` AgentID *int64 `json:"agent_id,omitempty"` } + +type CreateGithubEndpointParams struct { + Name string `json:"name"` + Description string `json:"description"` + APIBaseURL string `json:"api_base_url"` + UploadBaseURL string `json:"upload_base_url"` + BaseURL string `json:"base_url"` + CACertBundle []byte `json:"ca_cert_bundle"` +} + +type UpdateGithubEndpointParams struct { + Description *string `json:"description"` + APIBaseURL *string `json:"api_base_url"` + UploadBaseURL *string `json:"upload_base_url"` + BaseURL *string `json:"base_url"` + CACertBundle []byte `json:"ca_cert_bundle"` +} + +type GithubPAT struct { + OAuth2Token string `json:"oauth2_token"` +} + +type GithubApp struct { + AppID int64 `json:"app_id"` + InstallationID int64 `json:"installation_id"` + PrivateKeyBytes []byte `json:"private_key_bytes"` +} + +func (g GithubApp) Validate() error { + if g.AppID == 0 { + return errors.NewBadRequestError("missing app_id") + } + + if g.InstallationID == 0 { + return errors.NewBadRequestError("missing installation_id") + } + + if len(g.PrivateKeyBytes) == 0 { + return errors.NewBadRequestError("missing private_key_bytes") + } + + block, _ := pem.Decode(g.PrivateKeyBytes) + // Parse the private key as PCKS1 + _, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("parsing private_key_path: %w", err) + } + + return nil +} + +type CreateGithubCredentialsParams struct { + Name string `json:"name"` + Description string `json:"description"` + AuthType GithubAuthType `json:"auth_type"` + PAT GithubPAT `json:"pat,omitempty"` + App GithubApp `json:"app,omitempty"` +} + +type UpdateGithubCredentialsParams struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + PAT *GithubPAT `json:"pat,omitempty"` + App *GithubApp `json:"app,omitempty"` +} diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index dc81da5e..2ad54e5d 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -169,7 +169,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateEnterpriseParams.Name, enterprise.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.Credentials.Name) s.Require().Equal(params.PoolBalancerTypeRoundRobin, enterprise.PoolBalancerType) } @@ -306,7 +306,7 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprise() { s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, org.WebhookSecret) s.Require().Equal(params.PoolBalancerTypePack, org.PoolBalancerType) } diff --git a/runner/organizations_test.go b/runner/organizations_test.go index d0113756..30a58882 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -169,7 +169,7 @@ func (s *OrgTestSuite) TestCreateOrganization() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateOrgParams.Name, org.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateOrgParams.CredentialsName].Name, org.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateOrgParams.CredentialsName].Name, org.Credentials.Name) s.Require().Equal(params.PoolBalancerTypeRoundRobin, org.PoolBalancerType) } @@ -317,7 +317,7 @@ func (s *OrgTestSuite) TestUpdateOrganization() { s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, org.WebhookSecret) } diff --git a/runner/repositories_test.go b/runner/repositories_test.go index 20814a86..74bc8a76 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -81,7 +81,7 @@ func (s *RepoTestSuite) SetupTest() { params.PoolBalancerTypeRoundRobin, ) if err != nil { - s.FailNow(fmt.Sprintf("failed to create database object (test-repo-%v)", i)) + s.FailNow(fmt.Sprintf("failed to create database object (test-repo-%v): %q", i, err)) } repos[name] = repo } @@ -170,7 +170,7 @@ func (s *RepoTestSuite) TestCreateRepository() { s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateRepoParams.Owner, repo.Owner) s.Require().Equal(s.Fixtures.CreateRepoParams.Name, repo.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.Credentials.Name) s.Require().Equal(params.PoolBalancerTypeRoundRobin, repo.PoolBalancerType) } @@ -190,7 +190,7 @@ func (s *RepoTestSuite) TestCreareRepositoryPoolBalancerTypePack() { s.Require().Nil(err) s.Require().Equal(param.Owner, repo.Owner) s.Require().Equal(param.Name, repo.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.Credentials.Name) s.Require().Equal(params.PoolBalancerTypePack, repo.PoolBalancerType) } @@ -327,7 +327,7 @@ func (s *RepoTestSuite) TestUpdateRepository() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, repo.WebhookSecret) s.Require().Equal(params.PoolBalancerTypeRoundRobin, repo.PoolBalancerType) } @@ -343,7 +343,7 @@ func (s *RepoTestSuite) TestUpdateRepositoryBalancingType() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(updateRepoParams.CredentialsName, repo.CredentialsName) + s.Require().Equal(updateRepoParams.CredentialsName, repo.Credentials.Name) s.Require().Equal(updateRepoParams.WebhookSecret, repo.WebhookSecret) s.Require().Equal(params.PoolBalancerTypePack, repo.PoolBalancerType) } diff --git a/runner/runner.go b/runner/runner.go index 1ffa508a..fd10ad78 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -67,7 +67,6 @@ func NewRunner(ctx context.Context, cfg config.Config, db dbCommon.Store) (*Runn poolManagerCtrl := &poolManagerCtrl{ controllerID: ctrlID.ControllerID.String(), config: cfg, - credentials: creds, repositories: map[string]common.PoolManager{}, organizations: map[string]common.PoolManager{}, enterprises: map[string]common.PoolManager{}, @@ -94,7 +93,6 @@ type poolManagerCtrl struct { controllerID string config config.Config - credentials map[string]config.Github repositories map[string]common.PoolManager organizations map[string]common.PoolManager @@ -105,7 +103,7 @@ func (p *poolManagerCtrl) CreateRepoPoolManager(ctx context.Context, repo params p.mux.Lock() defer p.mux.Unlock() - cfgInternal, err := p.getInternalConfig(ctx, repo.CredentialsName, repo.GetBalancerType()) + cfgInternal, err := p.getInternalConfig(ctx, repo.Credentials, repo.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -133,7 +131,7 @@ func (p *poolManagerCtrl) UpdateRepoPoolManager(ctx context.Context, repo params return nil, errors.Wrapf(runnerErrors.ErrNotFound, "repository %s/%s pool manager not loaded", repo.Owner, repo.Name) } - internalCfg, err := p.getInternalConfig(ctx, repo.CredentialsName, repo.GetBalancerType()) + internalCfg, err := p.getInternalConfig(ctx, repo.Credentials, repo.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -178,7 +176,7 @@ func (p *poolManagerCtrl) CreateOrgPoolManager(ctx context.Context, org params.O p.mux.Lock() defer p.mux.Unlock() - cfgInternal, err := p.getInternalConfig(ctx, org.CredentialsName, org.GetBalancerType()) + cfgInternal, err := p.getInternalConfig(ctx, org.Credentials, org.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -205,7 +203,7 @@ func (p *poolManagerCtrl) UpdateOrgPoolManager(ctx context.Context, org params.O return nil, errors.Wrapf(runnerErrors.ErrNotFound, "org %s pool manager not loaded", org.Name) } - internalCfg, err := p.getInternalConfig(ctx, org.CredentialsName, org.GetBalancerType()) + internalCfg, err := p.getInternalConfig(ctx, org.Credentials, org.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -250,7 +248,7 @@ func (p *poolManagerCtrl) CreateEnterprisePoolManager(ctx context.Context, enter p.mux.Lock() defer p.mux.Unlock() - cfgInternal, err := p.getInternalConfig(ctx, enterprise.CredentialsName, enterprise.GetBalancerType()) + cfgInternal, err := p.getInternalConfig(ctx, enterprise.Credentials, enterprise.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -278,7 +276,7 @@ func (p *poolManagerCtrl) UpdateEnterprisePoolManager(ctx context.Context, enter return nil, errors.Wrapf(runnerErrors.ErrNotFound, "enterprise %s pool manager not loaded", enterprise.Name) } - internalCfg, err := p.getInternalConfig(ctx, enterprise.CredentialsName, enterprise.GetBalancerType()) + internalCfg, err := p.getInternalConfig(ctx, enterprise.Credentials, enterprise.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -319,22 +317,12 @@ func (p *poolManagerCtrl) GetEnterprisePoolManagers() (map[string]common.PoolMan return p.enterprises, nil } -func (p *poolManagerCtrl) getInternalConfig(ctx context.Context, credsName string, poolBalancerType params.PoolBalancerType) (params.Internal, error) { - creds, ok := p.credentials[credsName] - if !ok { - return params.Internal{}, runnerErrors.NewBadRequestError("invalid credential name (%s)", credsName) - } - - caBundle, err := creds.CACertBundle() - if err != nil { - return params.Internal{}, fmt.Errorf("fetching CA bundle for creds: %w", err) - } - +func (p *poolManagerCtrl) getInternalConfig(ctx context.Context, creds params.GithubCredentials, poolBalancerType params.PoolBalancerType) (params.Internal, error) { var controllerWebhookURL string if p.config.Default.WebhookURL != "" { controllerWebhookURL = fmt.Sprintf("%s/%s", p.config.Default.WebhookURL, p.controllerID) } - httpClient, err := creds.HTTPClient(ctx) + httpClient, err := creds.GetHTTPClient(ctx) if err != nil { return params.Internal{}, fmt.Errorf("fetching http client for creds: %w", err) } @@ -349,10 +337,10 @@ func (p *poolManagerCtrl) getInternalConfig(ctx context.Context, credsName strin GithubCredentialsDetails: params.GithubCredentials{ Name: creds.Name, Description: creds.Description, - BaseURL: creds.BaseEndpoint(), - APIBaseURL: creds.APIEndpoint(), - UploadBaseURL: creds.UploadEndpoint(), - CABundle: caBundle, + BaseURL: creds.BaseURL, + APIBaseURL: creds.APIBaseURL, + UploadBaseURL: creds.UploadBaseURL, + CABundle: creds.CABundle, HTTPClient: httpClient, }, }, nil