diff --git a/cmd/garm-cli/cmd/enterprise.go b/cmd/garm-cli/cmd/enterprise.go index 27ca662c..98457aef 100644 --- a/cmd/garm-cli/cmd/enterprise.go +++ b/cmd/garm-cli/cmd/enterprise.go @@ -204,7 +204,7 @@ func formatEnterprises(enterprises []params.Enterprise) { header := table.Row{"ID", "Name", "Credentials name", "Pool Balancer Type", "Pool mgr running"} t.AppendHeader(header) for _, val := range enterprises { - t.AppendRow(table.Row{val.ID, val.Name, val.CredentialsName, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) + t.AppendRow(table.Row{val.ID, val.Name, val.Credentials.Name, val.GetBalancerType(), val.PoolManagerStatus.IsRunning}) t.AppendSeparator() } fmt.Println(t.Render()) @@ -218,7 +218,7 @@ func formatOneEnterprise(enterprise params.Enterprise) { t.AppendRow(table.Row{"ID", enterprise.ID}) t.AppendRow(table.Row{"Name", enterprise.Name}) t.AppendRow(table.Row{"Pool balancer type", enterprise.GetBalancerType()}) - t.AppendRow(table.Row{"Credentials", enterprise.CredentialsName}) + t.AppendRow(table.Row{"Credentials", enterprise.Credentials.Name}) t.AppendRow(table.Row{"Pool manager running", enterprise.PoolManagerStatus.IsRunning}) if !enterprise.PoolManagerStatus.IsRunning { t.AppendRow(table.Row{"Failure reason", enterprise.PoolManagerStatus.FailureReason}) diff --git a/cmd/garm/main.go b/cmd/garm/main.go index ad80c521..454d766f 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -167,6 +167,9 @@ func main() { } setupLogging(ctx, logCfg, hub) + // Migrate credentials to the new format. This field will be read + // by the DB migration logic. + cfg.Database.MigrateCredentials = cfg.Github db, err := database.NewDatabase(ctx, cfg.Database) if err != nil { log.Fatal(err) diff --git a/config/config.go b/config/config.go index d777de7e..baafcb8e 100644 --- a/config/config.go +++ b/config/config.go @@ -241,6 +241,14 @@ type GithubApp struct { InstallationID int64 `toml:"installation_id" json:"installation-id"` } +func (a *GithubApp) PrivateKeyBytes() ([]byte, error) { + keyBytes, err := os.ReadFile(a.PrivateKeyPath) + if err != nil { + return nil, fmt.Errorf("reading private_key_path: %w", err) + } + return keyBytes, nil +} + func (a *GithubApp) Validate() error { if a.AppID == 0 { return fmt.Errorf("missing app_id") @@ -472,6 +480,11 @@ type Database struct { // Don't lose or change this. It will invalidate all encrypted data // in the DB. This field must be set and must be exactly 32 characters. Passphrase string `toml:"passphrase"` + + // MigrateCredentials is a list of github credentials that need to be migrated + // from the config file to the database. This field will be removed once GARM + // reaches version 0.2.x. It's only meant to be used for the migration process. + MigrateCredentials []Github `toml:"-"` } // GormParams returns the database type and connection URI diff --git a/database/common/common.go b/database/common/common.go index c270051f..8f901ab7 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -20,6 +20,23 @@ import ( "github.com/cloudbase/garm/params" ) +type GithubEndpointStore interface { + CreateGithubEndpoint(ctx context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) + GetGithubEndpoint(ctx context.Context, name string) (params.GithubEndpoint, error) + ListGithubEndpoints(ctx context.Context) ([]params.GithubEndpoint, error) + UpdateGithubEndpoint(ctx context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) + DeleteGithubEndpoint(ctx context.Context, name string) error +} + +type GithubCredentialsStore interface { + CreateGithubCredentials(ctx context.Context, endpointName string, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) + GetGithubCredentials(ctx context.Context, id uint, detailed bool) (params.GithubCredentials, error) + GetGithubCredentialsByName(ctx context.Context, name string, detailed bool) (params.GithubCredentials, error) + ListGithubCredentials(ctx context.Context) ([]params.GithubCredentials, error) + UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) + DeleteGithubCredentials(ctx context.Context, id uint) error +} + type RepoStore interface { CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) GetRepository(ctx context.Context, owner, name string) (params.Repository, error) @@ -65,6 +82,7 @@ type PoolStore interface { type UserStore interface { GetUser(ctx context.Context, user string) (params.User, error) GetUserByID(ctx context.Context, userID string) (params.User, error) + GetAdminUser(ctx context.Context) (params.User, error) CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error) UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error) @@ -121,6 +139,8 @@ type Store interface { InstanceStore JobsStore EntityPools + GithubEndpointStore + GithubCredentialsStore ControllerInfo() (params.ControllerInfo, error) InitController() (params.ControllerInfo, error) diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index 5b24fdc8..f8877ef7 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -134,6 +134,62 @@ func (_m *Store) CreateEntityPool(ctx context.Context, entity params.GithubEntit return r0, r1 } +// CreateGithubCredentials provides a mock function with given fields: ctx, endpointName, param +func (_m *Store) CreateGithubCredentials(ctx context.Context, endpointName string, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) { + ret := _m.Called(ctx, endpointName, param) + + if len(ret) == 0 { + panic("no return value specified for CreateGithubCredentials") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, params.CreateGithubCredentialsParams) (params.GithubCredentials, error)); ok { + return rf(ctx, endpointName, param) + } + if rf, ok := ret.Get(0).(func(context.Context, string, params.CreateGithubCredentialsParams) params.GithubCredentials); ok { + r0 = rf(ctx, endpointName, param) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, params.CreateGithubCredentialsParams) error); ok { + r1 = rf(ctx, endpointName, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateGithubEndpoint provides a mock function with given fields: ctx, param +func (_m *Store) CreateGithubEndpoint(ctx context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) { + ret := _m.Called(ctx, param) + + if len(ret) == 0 { + panic("no return value specified for CreateGithubEndpoint") + } + + var r0 params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, params.CreateGithubEndpointParams) (params.GithubEndpoint, error)); ok { + return rf(ctx, param) + } + if rf, ok := ret.Get(0).(func(context.Context, params.CreateGithubEndpointParams) params.GithubEndpoint); ok { + r0 = rf(ctx, param) + } else { + r0 = ret.Get(0).(params.GithubEndpoint) + } + + if rf, ok := ret.Get(1).(func(context.Context, params.CreateGithubEndpointParams) error); ok { + r1 = rf(ctx, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // CreateInstance provides a mock function with given fields: ctx, poolID, param func (_m *Store) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { ret := _m.Called(ctx, poolID, param) @@ -328,6 +384,42 @@ func (_m *Store) DeleteEntityPool(ctx context.Context, entity params.GithubEntit return r0 } +// DeleteGithubCredentials provides a mock function with given fields: ctx, id +func (_m *Store) DeleteGithubCredentials(ctx context.Context, id uint) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteGithubCredentials") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, uint) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteGithubEndpoint provides a mock function with given fields: ctx, name +func (_m *Store) DeleteGithubEndpoint(ctx context.Context, name string) error { + ret := _m.Called(ctx, name) + + if len(ret) == 0 { + panic("no return value specified for DeleteGithubEndpoint") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, name) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeleteInstance provides a mock function with given fields: ctx, poolID, instanceName func (_m *Store) DeleteInstance(ctx context.Context, poolID string, instanceName string) error { ret := _m.Called(ctx, poolID, instanceName) @@ -448,6 +540,34 @@ func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params return r0, r1 } +// GetAdminUser provides a mock function with given fields: ctx +func (_m *Store) GetAdminUser(ctx context.Context) (params.User, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetAdminUser") + } + + var r0 params.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (params.User, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) params.User); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(params.User) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetEnterprise provides a mock function with given fields: ctx, name func (_m *Store) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { ret := _m.Called(ctx, name) @@ -532,6 +652,90 @@ func (_m *Store) GetEntityPool(ctx context.Context, entity params.GithubEntity, return r0, r1 } +// GetGithubCredentials provides a mock function with given fields: ctx, id, detailed +func (_m *Store) GetGithubCredentials(ctx context.Context, id uint, detailed bool) (params.GithubCredentials, error) { + ret := _m.Called(ctx, id, detailed) + + if len(ret) == 0 { + panic("no return value specified for GetGithubCredentials") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint, bool) (params.GithubCredentials, error)); ok { + return rf(ctx, id, detailed) + } + if rf, ok := ret.Get(0).(func(context.Context, uint, bool) params.GithubCredentials); ok { + r0 = rf(ctx, id, detailed) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, uint, bool) error); ok { + r1 = rf(ctx, id, detailed) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetGithubCredentialsByName provides a mock function with given fields: ctx, name, detailed +func (_m *Store) GetGithubCredentialsByName(ctx context.Context, name string, detailed bool) (params.GithubCredentials, error) { + ret := _m.Called(ctx, name, detailed) + + if len(ret) == 0 { + panic("no return value specified for GetGithubCredentialsByName") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) (params.GithubCredentials, error)); ok { + return rf(ctx, name, detailed) + } + if rf, ok := ret.Get(0).(func(context.Context, string, bool) params.GithubCredentials); ok { + r0 = rf(ctx, name, detailed) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok { + r1 = rf(ctx, name, detailed) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetGithubEndpoint provides a mock function with given fields: ctx, name +func (_m *Store) GetGithubEndpoint(ctx context.Context, name string) (params.GithubEndpoint, error) { + ret := _m.Called(ctx, name) + + if len(ret) == 0 { + panic("no return value specified for GetGithubEndpoint") + } + + var r0 params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (params.GithubEndpoint, error)); ok { + return rf(ctx, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string) params.GithubEndpoint); ok { + r0 = rf(ctx, name) + } else { + r0 = ret.Get(0).(params.GithubEndpoint) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetInstanceByName provides a mock function with given fields: ctx, instanceName func (_m *Store) GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error) { ret := _m.Called(ctx, instanceName) @@ -1068,6 +1272,66 @@ func (_m *Store) ListEntityPools(ctx context.Context, entity params.GithubEntity return r0, r1 } +// ListGithubCredentials provides a mock function with given fields: ctx +func (_m *Store) ListGithubCredentials(ctx context.Context) ([]params.GithubCredentials, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListGithubCredentials") + } + + var r0 []params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]params.GithubCredentials, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []params.GithubCredentials); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]params.GithubCredentials) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListGithubEndpoints provides a mock function with given fields: ctx +func (_m *Store) ListGithubEndpoints(ctx context.Context) ([]params.GithubEndpoint, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListGithubEndpoints") + } + + var r0 []params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]params.GithubEndpoint, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []params.GithubEndpoint); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]params.GithubEndpoint) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListJobsByStatus provides a mock function with given fields: ctx, status func (_m *Store) ListJobsByStatus(ctx context.Context, status params.JobStatus) ([]params.Job, error) { ret := _m.Called(ctx, status) @@ -1308,6 +1572,62 @@ func (_m *Store) UpdateEntityPool(ctx context.Context, entity params.GithubEntit return r0, r1 } +// UpdateGithubCredentials provides a mock function with given fields: ctx, id, param +func (_m *Store) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) { + ret := _m.Called(ctx, id, param) + + if len(ret) == 0 { + panic("no return value specified for UpdateGithubCredentials") + } + + var r0 params.GithubCredentials + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint, params.UpdateGithubCredentialsParams) (params.GithubCredentials, error)); ok { + return rf(ctx, id, param) + } + if rf, ok := ret.Get(0).(func(context.Context, uint, params.UpdateGithubCredentialsParams) params.GithubCredentials); ok { + r0 = rf(ctx, id, param) + } else { + r0 = ret.Get(0).(params.GithubCredentials) + } + + if rf, ok := ret.Get(1).(func(context.Context, uint, params.UpdateGithubCredentialsParams) error); ok { + r1 = rf(ctx, id, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateGithubEndpoint provides a mock function with given fields: ctx, name, param +func (_m *Store) UpdateGithubEndpoint(ctx context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) { + ret := _m.Called(ctx, name, param) + + if len(ret) == 0 { + panic("no return value specified for UpdateGithubEndpoint") + } + + var r0 params.GithubEndpoint + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, params.UpdateGithubEndpointParams) (params.GithubEndpoint, error)); ok { + return rf(ctx, name, param) + } + if rf, ok := ret.Get(0).(func(context.Context, string, params.UpdateGithubEndpointParams) params.GithubEndpoint); ok { + r0 = rf(ctx, name, param) + } else { + r0 = ret.Get(0).(params.GithubEndpoint) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, params.UpdateGithubEndpointParams) error); ok { + r1 = rf(ctx, name, param) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // UpdateInstance provides a mock function with given fields: ctx, instanceName, param func (_m *Store) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) { ret := _m.Called(ctx, instanceName, param) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index 3eb53b9e..e7270faf 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -54,7 +54,7 @@ func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.En } func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Pools") + enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -68,7 +68,7 @@ func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, error) { var enterprises []Enterprise - q := s.conn.Find(&enterprises) + q := s.conn.Preload("Credentials").Find(&enterprises) if q.Error != nil { return []params.Enterprise{}, errors.Wrap(q.Error, "fetching enterprises") } @@ -100,7 +100,7 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) } func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) + enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Credentials", "Endpoint") if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -136,8 +136,10 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, error) { var enterprise Enterprise - q := s.conn.Where("name = ? COLLATE NOCASE", name) - q = q.First(&enterprise) + q := s.conn.Where("name = ? COLLATE NOCASE", name). + Preload("Credentials"). + Preload("Endpoint"). + First(&enterprise) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return Enterprise{}, runnerErrors.ErrNotFound diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index b20a1f20..7f22956b 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -173,7 +173,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() { s.FailNow(fmt.Sprintf("failed to get enterprise by id: %v", err)) } s.Require().Equal(storeEnterprise.Name, enterprise.Name) - s.Require().Equal(storeEnterprise.CredentialsName, enterprise.CredentialsName) + s.Require().Equal(storeEnterprise.Credentials.Name, enterprise.Credentials.Name) s.Require().Equal(storeEnterprise.WebhookSecret, enterprise.WebhookSecret) } @@ -313,7 +313,7 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprise() { enterprise, err := s.Store.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, enterprise.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, enterprise.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, enterprise.WebhookSecret) } diff --git a/database/sql/github.go b/database/sql/github.go new file mode 100644 index 00000000..0087165f --- /dev/null +++ b/database/sql/github.go @@ -0,0 +1,473 @@ +package sql + +import ( + "context" + + "github.com/google/uuid" + "github.com/pkg/errors" + "gorm.io/gorm" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/auth" + "github.com/cloudbase/garm/params" +) + +func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (params.GithubCredentials, error) { + data, err := util.Unseal(creds.Payload, []byte(s.cfg.Passphrase)) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "unsealing credentials") + } + + commonCreds := params.GithubCredentials{ + ID: creds.ID, + Name: creds.Name, + Description: creds.Description, + APIBaseURL: creds.Endpoint.APIBaseURL, + BaseURL: creds.Endpoint.BaseURL, + UploadBaseURL: creds.Endpoint.UploadBaseURL, + CABundle: creds.Endpoint.CACertBundle, + AuthType: creds.AuthType, + Endpoint: creds.Endpoint.Name, + CredentialsPayload: data, + } + + for _, repo := range creds.Repositories { + commonRepo, err := s.sqlToCommonRepository(repo) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github repository") + } + commonCreds.Repositories = append(commonCreds.Repositories, commonRepo) + } + + for _, org := range creds.Organizations { + commonOrg, err := s.sqlToCommonOrganization(org) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github organization") + } + commonCreds.Organizations = append(commonCreds.Organizations, commonOrg) + } + + for _, ent := range creds.Enterprises { + commonEnt, err := s.sqlToCommonEnterprise(ent) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "converting github enterprise") + } + commonCreds.Enterprises = append(commonCreds.Enterprises, commonEnt) + } + + return commonCreds, nil +} + +func (s *sqlDatabase) sqlToCommonGithubEndpoint(ep GithubEndpoint) (params.GithubEndpoint, error) { + return params.GithubEndpoint{ + Name: ep.Name, + Description: ep.Description, + APIBaseURL: ep.APIBaseURL, + BaseURL: ep.BaseURL, + UploadBaseURL: ep.UploadBaseURL, + CACertBundle: ep.CACertBundle, + }, nil +} + +func getUIDFromContext(ctx context.Context) (uuid.UUID, error) { + userID := auth.UserID(ctx) + if userID == "" { + return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") + } + + asUUID, err := uuid.Parse(userID) + if err != nil { + return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") + } + return asUUID, nil +} + +func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) { + var endpoint GithubEndpoint + err := s.conn.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("name = ?", param.Name).First(&endpoint).Error; err == nil { + return errors.Wrap(runnerErrors.ErrDuplicateEntity, "github endpoint already exists") + } + endpoint = GithubEndpoint{ + Name: param.Name, + Description: param.Description, + APIBaseURL: param.APIBaseURL, + BaseURL: param.BaseURL, + UploadBaseURL: param.UploadBaseURL, + CACertBundle: param.CACertBundle, + } + + if err := tx.Create(&endpoint).Error; err != nil { + return errors.Wrap(err, "creating github endpoint") + } + return nil + }) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "creating github endpoint") + } + return s.sqlToCommonGithubEndpoint(endpoint) +} + +func (s *sqlDatabase) ListGithubEndpoints(_ context.Context) ([]params.GithubEndpoint, error) { + var endpoints []GithubEndpoint + err := s.conn.Find(&endpoints).Error + if err != nil { + return nil, errors.Wrap(err, "fetching github endpoints") + } + + var ret []params.GithubEndpoint + for _, ep := range endpoints { + commonEp, err := s.sqlToCommonGithubEndpoint(ep) + if err != nil { + return nil, errors.Wrap(err, "converting github endpoint") + } + ret = append(ret, commonEp) + } + return ret, nil +} + +func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) { + var endpoint GithubEndpoint + err := s.conn.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "github endpoint not found") + } + return errors.Wrap(err, "fetching github endpoint") + } + if param.APIBaseURL != nil { + endpoint.APIBaseURL = *param.APIBaseURL + } + + if param.BaseURL != nil { + endpoint.BaseURL = *param.BaseURL + } + + if param.UploadBaseURL != nil { + endpoint.UploadBaseURL = *param.UploadBaseURL + } + + if param.CACertBundle != nil { + endpoint.CACertBundle = param.CACertBundle + } + + if param.Description != nil { + endpoint.Description = *param.Description + } + + if err := tx.Save(&endpoint).Error; err != nil { + return errors.Wrap(err, "updating github endpoint") + } + + return nil + }) + if err != nil { + return params.GithubEndpoint{}, errors.Wrap(err, "updating github endpoint") + } + return s.sqlToCommonGithubEndpoint(endpoint) +} + +func (s *sqlDatabase) GetGithubEndpoint(_ context.Context, name string) (params.GithubEndpoint, error) { + var endpoint GithubEndpoint + + err := s.conn.Where("name = ?", name).First(&endpoint).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return params.GithubEndpoint{}, errors.Wrap(err, "github endpoint not found") + } + return params.GithubEndpoint{}, errors.Wrap(err, "fetching github endpoint") + } + + return s.sqlToCommonGithubEndpoint(endpoint) +} + +func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) error { + err := s.conn.Transaction(func(tx *gorm.DB) error { + var endpoint GithubEndpoint + if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return errors.Wrap(err, "fetching github endpoint") + } + + var credsCount int64 + if err := tx.Model(&GithubCredentials{}).Where("endpoint_name = ?", endpoint.Name).Count(&credsCount).Error; err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(err, "fetching github credentials") + } + } + + if credsCount > 0 { + return errors.New("cannot delete endpoint with credentials") + } + + if err := tx.Unscoped().Delete(&endpoint).Error; err != nil { + return errors.Wrap(err, "deleting github endpoint") + } + return nil + }) + if err != nil { + return errors.Wrap(err, "deleting github endpoint") + } + return nil +} + +func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, endpointName string, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials") + } + var creds GithubCredentials + err = s.conn.Transaction(func(tx *gorm.DB) error { + var endpoint GithubEndpoint + if err := tx.Where("name = ?", endpointName).First(&endpoint).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "github endpoint not found") + } + return errors.Wrap(err, "fetching github endpoint") + } + + if err := tx.Where("name = ?", param.Name).First(&creds).Error; err == nil { + return errors.New("github credentials already exists") + } + + var data []byte + var err error + switch param.AuthType { + case params.GithubAuthTypePAT: + data, err = s.marshalAndSeal(param.PAT) + case params.GithubAuthTypeApp: + data, err = s.marshalAndSeal(param.App) + } + if err != nil { + return errors.Wrap(err, "marshaling and sealing credentials") + } + + creds = GithubCredentials{ + Name: param.Name, + Description: param.Description, + EndpointName: &endpoint.Name, + AuthType: param.AuthType, + Payload: data, + UserID: &userID, + } + + if err := tx.Create(&creds).Error; err != nil { + return errors.Wrap(err, "creating github credentials") + } + // Skip making an extra query. + creds.Endpoint = endpoint + + return nil + }) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials") + } + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) getGithubCredentialsByName(ctx context.Context, tx *gorm.DB, name string, detailed bool) (GithubCredentials, error) { + var creds GithubCredentials + q := tx.Preload("Endpoint") + + if detailed { + q = q. + Preload("Repositories"). + Preload("Organizations"). + Preload("Enterprises") + } + + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + q = q.Where("user_id = ?", userID) + } + + err := q.Where("name = ?", name).First(&creds).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return GithubCredentials{}, errors.Wrap(runnerErrors.ErrNotFound, "github credentials not found") + } + return GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + + return creds, nil +} + +func (s *sqlDatabase) GetGithubCredentialsByName(ctx context.Context, name string, detailed bool) (params.GithubCredentials, error) { + creds, err := s.getGithubCredentialsByName(ctx, s.conn, name, detailed) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) GetGithubCredentials(ctx context.Context, id uint, detailed bool) (params.GithubCredentials, error) { + var creds GithubCredentials + q := s.conn.Preload("Endpoint") + + if detailed { + q = q. + Preload("Repositories"). + Preload("Organizations"). + Preload("Enterprises") + } + + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + q = q.Where("user_id = ?", userID) + } + + err := q.Where("id = ?", id).First(&creds).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return params.GithubCredentials{}, errors.Wrap(runnerErrors.ErrNotFound, "github credentials not found") + } + return params.GithubCredentials{}, errors.Wrap(err, "fetching github credentials") + } + + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) ListGithubCredentials(ctx context.Context) ([]params.GithubCredentials, error) { + q := s.conn.Preload("Endpoint") + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return nil, errors.Wrap(err, "fetching github credentials") + } + q = q.Where("user_id = ?", userID) + } + + var creds []GithubCredentials + err := q.Preload("Endpoint").Find(&creds).Error + if err != nil { + return nil, errors.Wrap(err, "fetching github credentials") + } + + var ret []params.GithubCredentials + for _, c := range creds { + commonCreds, err := s.sqlToCommonGithubCredentials(c) + if err != nil { + return nil, errors.Wrap(err, "converting github credentials") + } + ret = append(ret, commonCreds) + } + return ret, nil +} + +func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) { + var creds GithubCredentials + err := s.conn.Transaction(func(tx *gorm.DB) error { + q := tx.Preload("Endpoint") + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return errors.Wrap(err, "updating github credentials") + } + q = q.Where("user_id = ?", userID) + } + + if err := q.Where("id = ?", id).First(&creds).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "github credentials not found") + } + return errors.Wrap(err, "fetching github credentials") + } + + if param.Name != nil { + creds.Name = *param.Name + } + if param.Description != nil { + creds.Description = *param.Description + } + + var data []byte + var err error + switch creds.AuthType { + case params.GithubAuthTypePAT: + if param.PAT != nil { + data, err = s.marshalAndSeal(param.PAT) + } + + if param.App != nil { + return errors.New("cannot update app credentials for PAT") + } + case params.GithubAuthTypeApp: + if param.App != nil { + data, err = s.marshalAndSeal(param.App) + } + + if param.PAT != nil { + return errors.New("cannot update PAT credentials for app") + } + } + + if err != nil { + return errors.Wrap(err, "marshaling and sealing credentials") + } + if len(data) > 0 { + creds.Payload = data + } + + if err := tx.Save(&creds).Error; err != nil { + return errors.Wrap(err, "updating github credentials") + } + return nil + }) + if err != nil { + return params.GithubCredentials{}, errors.Wrap(err, "updating github credentials") + } + return s.sqlToCommonGithubCredentials(creds) +} + +func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) error { + err := s.conn.Transaction(func(tx *gorm.DB) error { + q := tx.Where("id = ?", id). + Preload("Repositories"). + Preload("Organizations"). + Preload("Enterprises") + if !auth.IsAdmin(ctx) { + userID, err := getUIDFromContext(ctx) + if err != nil { + return errors.Wrap(err, "deleting github credentials") + } + q = q.Where("user_id = ?", userID) + } + + var creds GithubCredentials + err := q.First(&creds).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return errors.Wrap(err, "fetching github credentials") + } + if len(creds.Repositories) > 0 { + return errors.New("cannot delete credentials with repositories") + } + if len(creds.Organizations) > 0 { + return errors.New("cannot delete credentials with organizations") + } + if len(creds.Enterprises) > 0 { + return errors.New("cannot delete credentials with enterprises") + } + + if err := tx.Unscoped().Delete(&creds).Error; err != nil { + return errors.Wrap(err, "deleting github credentials") + } + return nil + }) + if err != nil { + return errors.Wrap(err, "deleting github credentials") + } + return nil +} diff --git a/database/sql/models.go b/database/sql/models.go index 874a375d..633a1b51 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -89,35 +89,56 @@ type Pool struct { type Repository struct { Base - CredentialsName string + CredentialsName string + + CredentialsID *uint `gorm:"index"` + Credentials GithubCredentials `gorm:"foreignKey:CredentialsID;constraint:OnDelete:SET NULL"` + Owner string `gorm:"index:idx_owner_nocase,unique,collate:nocase"` Name string `gorm:"index:idx_owner_nocase,unique,collate:nocase"` WebhookSecret []byte Pools []Pool `gorm:"foreignKey:RepoID"` Jobs []WorkflowJob `gorm:"foreignKey:RepoID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` + + EndpointName *string `gorm:"index:idx_owner_nocase,unique,collate:nocase"` + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } type Organization struct { Base - CredentialsName string + CredentialsName string + + CredentialsID *uint `gorm:"index"` + Credentials GithubCredentials `gorm:"foreignKey:CredentialsID;constraint:OnDelete:SET NULL"` + Name string `gorm:"index:idx_org_name_nocase,collate:nocase"` WebhookSecret []byte Pools []Pool `gorm:"foreignKey:OrgID"` Jobs []WorkflowJob `gorm:"foreignKey:OrgID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` + + EndpointName *string `gorm:"index"` + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } type Enterprise struct { Base - CredentialsName string + CredentialsName string + + CredentialsID *uint `gorm:"index"` + Credentials GithubCredentials `gorm:"foreignKey:CredentialsID;constraint:OnDelete:SET NULL"` + Name string `gorm:"index:idx_ent_name_nocase,collate:nocase"` WebhookSecret []byte Pools []Pool `gorm:"foreignKey:EnterpriseID"` Jobs []WorkflowJob `gorm:"foreignKey:EnterpriseID;constraint:OnDelete:SET NULL"` PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"` + + EndpointName *string `gorm:"index"` + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"` } type Address struct { @@ -246,3 +267,35 @@ type WorkflowJob struct { UpdatedAt time.Time DeletedAt gorm.DeletedAt `gorm:"index"` } + +type GithubEndpoint struct { + Name string `gorm:"type:varchar(64) collate nocase;primary_key;"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` + + Description string `gorm:"type:text"` + APIBaseURL string `gorm:"type:text collate nocase"` + UploadBaseURL string `gorm:"type:text collate nocase"` + BaseURL string `gorm:"type:text collate nocase"` + CACertBundle []byte `gorm:"type:longblob"` +} + +type GithubCredentials struct { + gorm.Model + + Name string `gorm:"index:idx_github_credentials,unique;type:varchar(64) collate nocase"` + UserID *uuid.UUID `gorm:"index:idx_github_credentials,unique"` + User User `gorm:"foreignKey:UserID"` + + Description string `gorm:"type:text"` + AuthType params.GithubAuthType `gorm:"index"` + Payload []byte `gorm:"type:longblob"` + + Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName"` + EndpointName *string `gorm:"index"` + + Repositories []Repository `gorm:"foreignKey:CredentialsID"` + Organizations []Organization `gorm:"foreignKey:CredentialsID"` + Enterprises []Enterprise `gorm:"foreignKey:CredentialsID"` +} diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 24704fd9..c67e85d4 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -72,7 +72,7 @@ func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params. func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organization, error) { var orgs []Organization - q := s.conn.Find(&orgs) + q := s.conn.Preload("Credentials").Find(&orgs) if q.Error != nil { return []params.Organization{}, errors.Wrap(q.Error, "fetching org from database") } @@ -104,7 +104,7 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) erro } func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) { - org, err := s.getOrgByID(ctx, orgID) + org, err := s.getOrgByID(ctx, orgID, "Credentials", "Endpoint") if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } @@ -138,7 +138,7 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para } func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) { - org, err := s.getOrgByID(ctx, orgID, "Pools") + org, err := s.getOrgByID(ctx, orgID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } @@ -177,8 +177,10 @@ func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string func (s *sqlDatabase) getOrg(_ context.Context, name string) (Organization, error) { var org Organization - q := s.conn.Where("name = ? COLLATE NOCASE", name) - q = q.First(&org) + q := s.conn.Where("name = ? COLLATE NOCASE", name). + Preload("Credentials"). + Preload("Endpoint"). + First(&org) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return Organization{}, runnerErrors.ErrNotFound diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 77ebab90..28a049e5 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -173,7 +173,7 @@ func (s *OrgTestSuite) TestCreateOrganization() { s.FailNow(fmt.Sprintf("failed to get organization by id: %v", err)) } s.Require().Equal(storeOrg.Name, org.Name) - s.Require().Equal(storeOrg.CredentialsName, org.CredentialsName) + s.Require().Equal(storeOrg.Credentials.Name, org.Credentials.Name) s.Require().Equal(storeOrg.WebhookSecret, org.WebhookSecret) } @@ -313,7 +313,7 @@ func (s *OrgTestSuite) TestUpdateOrganization() { org, err := s.Store.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, org.WebhookSecret) } diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 164c0197..396b2796 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -27,7 +27,7 @@ import ( "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateRepository(_ context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) { +func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) { if webhookSecret == "" { return params.Repository{}, errors.New("creating repo: missing secret") } @@ -35,17 +35,29 @@ func (s *sqlDatabase) CreateRepository(_ context.Context, owner, name, credentia if err != nil { return params.Repository{}, fmt.Errorf("failed to encrypt string") } - newRepo := Repository{ - Name: name, - Owner: owner, - WebhookSecret: secret, - CredentialsName: credentialsName, - PoolBalancerType: poolBalancerType, - } - q := s.conn.Create(&newRepo) - if q.Error != nil { - return params.Repository{}, errors.Wrap(q.Error, "creating repository") + var newRepo Repository + err = s.conn.Transaction(func(tx *gorm.DB) error { + creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false) + if err != nil { + return errors.Wrap(err, "creating repository") + } + + newRepo.Name = name + newRepo.Owner = owner + newRepo.WebhookSecret = secret + newRepo.CredentialsID = &creds.ID + newRepo.PoolBalancerType = poolBalancerType + + q := tx.Create(&newRepo) + if q.Error != nil { + return errors.Wrap(q.Error, "creating repository") + } + + return nil + }) + if err != nil { + return params.Repository{}, errors.Wrap(err, "creating repository") } param, err := s.sqlToCommonRepository(newRepo) @@ -72,7 +84,7 @@ func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (pa func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, error) { var repos []Repository - q := s.conn.Find(&repos) + q := s.conn.Preload("Credentials").Find(&repos) if q.Error != nil { return []params.Repository{}, errors.Wrap(q.Error, "fetching user from database") } @@ -104,7 +116,7 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error } func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) { - repo, err := s.getRepoByID(ctx, repoID) + repo, err := s.getRepoByID(ctx, repoID, "Credentials", "Endpoint") if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -138,7 +150,7 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param } func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error) { - repo, err := s.getRepoByID(ctx, repoID, "Pools") + repo, err := s.getRepoByID(ctx, repoID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -154,6 +166,8 @@ func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository var repo Repository q := s.conn.Where("name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE", name, owner). + Preload("Credentials"). + Preload("Endpoint"). First(&repo) q = q.First(&repo) diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 3d335b10..fbd68304 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -188,7 +188,7 @@ func (s *RepoTestSuite) TestCreateRepository() { } s.Require().Equal(storeRepo.Owner, repo.Owner) s.Require().Equal(storeRepo.Name, repo.Name) - s.Require().Equal(storeRepo.CredentialsName, repo.CredentialsName) + s.Require().Equal(storeRepo.Credentials.Name, repo.Credentials.Name) s.Require().Equal(storeRepo.WebhookSecret, repo.WebhookSecret) } @@ -352,7 +352,7 @@ func (s *RepoTestSuite) TestUpdateRepository() { repo, err := s.Store.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, repo.WebhookSecret) } diff --git a/database/sql/sql.go b/database/sql/sql.go index e513be42..1bc16e08 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "log/slog" + "net/url" "strings" "github.com/pkg/errors" @@ -26,8 +27,12 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/auth" "github.com/cloudbase/garm/config" "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" + "github.com/cloudbase/garm/util/appdefaults" ) // newDBConn returns a new gorm db connection, given the config @@ -190,6 +195,154 @@ func (s *sqlDatabase) cascadeMigration() error { return nil } +func (s *sqlDatabase) migrateCredentialsToDB() (err error) { + s.conn.Exec("PRAGMA foreign_keys = OFF") + defer s.conn.Exec("PRAGMA foreign_keys = ON") + + adminUser, err := s.GetAdminUser(s.ctx) + if err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + // Admin user doesn't exist. This is a new deploy. Nothing to migrate. + return nil + } + return errors.Wrap(err, "getting admin user") + } + + // Impersonate the admin user. We're migrating from config credentials to + // database credentials. At this point, there is no other user than the admin + // user. GARM is not yet multi-user, so it's safe to assume we only have this + // one user. + adminCtx := context.Background() + adminCtx = auth.PopulateContext(adminCtx, adminUser) + + slog.Info("migrating credentials to DB") + slog.Info("creating github endpoints table") + if err := s.conn.AutoMigrate(&GithubEndpoint{}); err != nil { + return errors.Wrap(err, "migrating github endpoints") + } + + defer func() { + if err != nil { + slog.With(slog.Any("error", err)).Error("rolling back github github endpoints table") + s.conn.Migrator().DropTable(&GithubEndpoint{}) + } + }() + + slog.Info("creating github credentials table") + if err := s.conn.AutoMigrate(&GithubCredentials{}); err != nil { + return errors.Wrap(err, "migrating github credentials") + } + + defer func() { + if err != nil { + slog.With(slog.Any("error", err)).Error("rolling back github github credentials table") + s.conn.Migrator().DropTable(&GithubCredentials{}) + } + }() + + // Create the default Github endpoint. + createEndpointParams := params.CreateGithubEndpointParams{ + Name: "github.com", + Description: "The github.com endpoint", + APIBaseURL: appdefaults.GithubDefaultBaseURL, + BaseURL: appdefaults.DefaultGithubURL, + UploadBaseURL: appdefaults.GithubDefaultUploadBaseURL, + } + + if _, err := s.CreateGithubEndpoint(adminCtx, createEndpointParams); err != nil { + if !errors.Is(err, runnerErrors.ErrDuplicateEntity) { + return errors.Wrap(err, "creating default github endpoint") + } + } + + // Nothing to migrate. + if len(s.cfg.MigrateCredentials) == 0 { + return nil + } + + slog.Info("importing credentials from config") + for _, cred := range s.cfg.MigrateCredentials { + slog.Info("importing credential", "name", cred.Name) + parsed, err := url.Parse(cred.BaseEndpoint()) + if err != nil { + return errors.Wrap(err, "parsing base URL") + } + + certBundle, err := cred.CACertBundle() + if err != nil { + return errors.Wrap(err, "getting CA cert bundle") + } + hostname := parsed.Hostname() + createParams := params.CreateGithubEndpointParams{ + Name: hostname, + Description: fmt.Sprintf("Endpoint for %s", hostname), + APIBaseURL: cred.APIEndpoint(), + BaseURL: cred.BaseEndpoint(), + UploadBaseURL: cred.UploadEndpoint(), + CACertBundle: certBundle, + } + + var endpoint params.GithubEndpoint + endpoint, err = s.GetGithubEndpoint(adminCtx, hostname) + if err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return errors.Wrap(err, "getting github endpoint") + } + endpoint, err = s.CreateGithubEndpoint(adminCtx, createParams) + if err != nil { + return errors.Wrap(err, "creating default github endpoint") + } + } + + credParams := params.CreateGithubCredentialsParams{ + Name: cred.Name, + Description: cred.Description, + AuthType: params.GithubAuthType(cred.AuthType), + } + switch credParams.AuthType { + case params.GithubAuthTypeApp: + keyBytes, err := cred.App.PrivateKeyBytes() + if err != nil { + return errors.Wrap(err, "getting private key bytes") + } + credParams.App = params.GithubApp{ + AppID: cred.App.AppID, + InstallationID: cred.App.InstallationID, + PrivateKeyBytes: keyBytes, + } + + if err := credParams.App.Validate(); err != nil { + return errors.Wrap(err, "validating app credentials") + } + case params.GithubAuthTypePAT: + if cred.PAT.OAuth2Token == "" { + return errors.New("missing OAuth2 token") + } + credParams.PAT = params.GithubPAT{ + OAuth2Token: cred.PAT.OAuth2Token, + } + } + + creds, err := s.CreateGithubCredentials(adminCtx, endpoint.Name, credParams) + if err != nil { + return errors.Wrap(err, "creating github credentials") + } + + if err := s.conn.Exec("update repositories set credentials_id = ?,endpoint_name = ? where credentials_name = ?", creds.ID, creds.Endpoint, creds.Name).Error; err != nil { + return errors.Wrap(err, "updating repositories") + } + + if err := s.conn.Exec("update organizations set credentials_id = ?,endpoint_name = ? where credentials_name = ?", creds.ID, creds.Endpoint, creds.Name).Error; err != nil { + return errors.Wrap(err, "updating organizations") + } + + if err := s.conn.Exec("update enterprises set credentials_id = ?,endpoint_name = ? where credentials_name = ?", creds.ID, creds.Endpoint, creds.Name).Error; err != nil { + return errors.Wrap(err, "updating enterprises") + } + } + return nil +} + func (s *sqlDatabase) migrateDB() error { if s.conn.Migrator().HasIndex(&Organization{}, "idx_organizations_name") { if err := s.conn.Migrator().DropIndex(&Organization{}, "idx_organizations_name"); err != nil { @@ -234,7 +387,15 @@ func (s *sqlDatabase) migrateDB() error { } } + var needsCredentialMigration bool + if !s.conn.Migrator().HasTable(&GithubCredentials{}) || !s.conn.Migrator().HasTable(&GithubEndpoint{}) { + needsCredentialMigration = true + } + s.conn.Exec("PRAGMA foreign_keys = OFF") if err := s.conn.AutoMigrate( + &User{}, + &GithubEndpoint{}, + &GithubCredentials{}, &Tag{}, &Pool{}, &Repository{}, @@ -244,11 +405,16 @@ func (s *sqlDatabase) migrateDB() error { &InstanceStatusUpdate{}, &Instance{}, &ControllerInfo{}, - &User{}, &WorkflowJob{}, ); err != nil { return errors.Wrap(err, "running auto migrate") } + s.conn.Exec("PRAGMA foreign_keys = ON") + if needsCredentialMigration { + if err := s.migrateCredentialsToDB(); err != nil { + return errors.Wrap(err, "migrating credentials") + } + } return nil } diff --git a/database/sql/users.go b/database/sql/users.go index 039d86fe..5fc47564 100644 --- a/database/sql/users.go +++ b/database/sql/users.go @@ -67,6 +67,10 @@ func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) ( return params.User{}, runnerErrors.NewConflictError("email already exists") } + if s.HasAdminUser(context.Background()) && user.IsAdmin { + return params.User{}, runnerErrors.NewBadRequestError("admin user already exists") + } + newUser := User{ Username: user.Username, Password: user.Password, @@ -129,3 +133,16 @@ func (s *sqlDatabase) UpdateUser(_ context.Context, user string, param params.Up return s.sqlToParamsUser(dbUser), nil } + +// GetAdminUser returns the system admin user. This is only for internal use. +func (s *sqlDatabase) GetAdminUser(_ context.Context) (params.User, error) { + var user User + q := s.conn.Model(&User{}).Where("is_admin = ?", true).First(&user) + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return params.User{}, runnerErrors.ErrNotFound + } + return params.User{}, errors.Wrap(q.Error, "fetching admin user") + } + return s.sqlToParamsUser(user), nil +} diff --git a/database/sql/util.go b/database/sql/util.go index aaea31fe..30946863 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -114,10 +114,16 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organiza return params.Organization{}, errors.Wrap(err, "decrypting secret") } + creds, err := s.sqlToCommonGithubCredentials(org.Credentials) + if err != nil { + return params.Organization{}, errors.Wrap(err, "converting credentials") + } + ret := params.Organization{ ID: org.ID.String(), Name: org.Name, - CredentialsName: org.CredentialsName, + CredentialsName: creds.Name, + Credentials: creds, Pools: make([]params.Pool, len(org.Pools)), WebhookSecret: string(secret), PoolBalancerType: org.PoolBalancerType, @@ -146,10 +152,15 @@ func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise) (params.Enter return params.Enterprise{}, errors.Wrap(err, "decrypting secret") } + creds, err := s.sqlToCommonGithubCredentials(enterprise.Credentials) + if err != nil { + return params.Enterprise{}, errors.Wrap(err, "converting credentials") + } ret := params.Enterprise{ ID: enterprise.ID.String(), Name: enterprise.Name, - CredentialsName: enterprise.CredentialsName, + CredentialsName: creds.Name, + Credentials: creds, Pools: make([]params.Pool, len(enterprise.Pools)), WebhookSecret: string(secret), PoolBalancerType: enterprise.PoolBalancerType, @@ -239,11 +250,16 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository) (params.Repository, return params.Repository{}, errors.Wrap(err, "decrypting secret") } + creds, err := s.sqlToCommonGithubCredentials(repo.Credentials) + if err != nil { + return params.Repository{}, errors.Wrap(err, "converting credentials") + } ret := params.Repository{ ID: repo.ID.String(), Name: repo.Name, Owner: repo.Owner, - CredentialsName: repo.CredentialsName, + CredentialsName: creds.Name, + Credentials: creds, Pools: make([]params.Pool, len(repo.Pools)), WebhookSecret: string(secret), PoolBalancerType: repo.PoolBalancerType, diff --git a/params/params.go b/params/params.go index a2a44222..e6bf8fb6 100644 --- a/params/params.go +++ b/params/params.go @@ -16,6 +16,8 @@ package params import ( "bytes" + "context" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -23,8 +25,10 @@ import ( "net/http" "time" + "github.com/bradleyfalzon/ghinstallation/v2" "github.com/google/go-github/v57/github" "github.com/google/uuid" + "golang.org/x/oauth2" commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm/util/appdefaults" @@ -398,6 +402,7 @@ type Repository struct { Name string `json:"name"` Pools []Pool `json:"pool,omitempty"` CredentialsName string `json:"credentials_name"` + Credentials GithubCredentials `json:"credentials"` PoolManagerStatus PoolManagerStatus `json:"pool_manager_status,omitempty"` PoolBalancerType PoolBalancerType `json:"pool_balancing_type"` // Do not serialize sensitive info. @@ -439,6 +444,7 @@ type Organization struct { Name string `json:"name"` Pools []Pool `json:"pool,omitempty"` CredentialsName string `json:"credentials_name"` + Credentials GithubCredentials `json:"credentials"` PoolManagerStatus PoolManagerStatus `json:"pool_manager_status,omitempty"` PoolBalancerType PoolBalancerType `json:"pool_balancing_type"` // Do not serialize sensitive info. @@ -480,6 +486,7 @@ type Enterprise struct { Name string `json:"name"` Pools []Pool `json:"pool,omitempty"` CredentialsName string `json:"credentials_name"` + Credentials GithubCredentials `json:"credentials"` PoolManagerStatus PoolManagerStatus `json:"pool_manager_status,omitempty"` PoolBalancerType PoolBalancerType `json:"pool_balancing_type"` // Do not serialize sensitive info. @@ -545,6 +552,7 @@ type ControllerInfo struct { } type GithubCredentials struct { + ID uint `json:"id"` Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` APIBaseURL string `json:"api_base_url"` @@ -552,7 +560,68 @@ type GithubCredentials struct { BaseURL string `json:"base_url"` CABundle []byte `json:"ca_bundle,omitempty"` AuthType GithubAuthType `toml:"auth_type" json:"auth-type"` - HTTPClient *http.Client `json:"-"` + + Repositories []Repository `json:"repositories,omitempty"` + Organizations []Organization `json:"organizations,omitempty"` + Enterprises []Enterprise `json:"enterprises,omitempty"` + Endpoint string `json:"endpoint"` + + CredentialsPayload []byte `json:"-"` + HTTPClient *http.Client `json:"-"` +} + +func (g GithubCredentials) GetHTTPClient(ctx context.Context) (*http.Client, error) { + var roots *x509.CertPool + if g.CABundle != nil { + roots = x509.NewCertPool() + ok := roots.AppendCertsFromPEM(g.CABundle) + if !ok { + return nil, fmt.Errorf("failed to parse CA cert") + } + } + // nolint:golangci-lint,gosec,godox + // TODO: set TLS MinVersion + httpTransport := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: roots, + }, + } + + var tc *http.Client + switch g.AuthType { + case GithubAuthTypeApp: + var app GithubApp + if err := json.Unmarshal(g.CredentialsPayload, &app); err != nil { + return nil, fmt.Errorf("failed to unmarshal github app credentials: %w", err) + } + if app.AppID == 0 || app.InstallationID == 0 || len(app.PrivateKeyBytes) == 0 { + return nil, fmt.Errorf("github app credentials are missing required fields") + } + itr, err := ghinstallation.New(httpTransport, app.AppID, app.InstallationID, app.PrivateKeyBytes) + if err != nil { + return nil, fmt.Errorf("failed to create github app installation transport: %w", err) + } + + tc = &http.Client{Transport: itr} + default: + var pat GithubPAT + if err := json.Unmarshal(g.CredentialsPayload, &pat); err != nil { + return nil, fmt.Errorf("failed to unmarshal github app credentials: %w", err) + } + httpClient := &http.Client{Transport: httpTransport} + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + if pat.OAuth2Token == "" { + return nil, fmt.Errorf("github credentials are missing the OAuth2 token") + } + + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: pat.OAuth2Token}, + ) + tc = oauth2.NewClient(ctx, ts) + } + + return tc, nil } func (g GithubCredentials) RootCertificateBundle() (CertificateBundle, error) { @@ -700,10 +769,11 @@ type UpdateSystemInfoParams struct { } type GithubEntity struct { - Owner string `json:"owner"` - Name string `json:"name"` - ID string `json:"id"` - EntityType GithubEntityType `json:"entity_type"` + Owner string `json:"owner"` + Name string `json:"name"` + ID string `json:"id"` + EntityType GithubEntityType `json:"entity_type"` + Credentials GithubCredentials `json:"credentials"` WebhookSecret string `json:"-"` } @@ -729,3 +799,14 @@ func (g GithubEntity) String() string { } return "" } + +type GithubEndpoint struct { + Name string `json:"name"` + Description string `json:"description"` + APIBaseURL string `json:"api_base_url"` + UploadBaseURL string `json:"upload_base_url"` + BaseURL string `json:"base_url"` + CACertBundle []byte `json:"ca_cert_bundle"` + + Credentials []GithubCredentials `json:"credentials"` +} diff --git a/params/requests.go b/params/requests.go index 885ed678..5da66d53 100644 --- a/params/requests.go +++ b/params/requests.go @@ -15,7 +15,9 @@ package params import ( + "crypto/x509" "encoding/json" + "encoding/pem" "fmt" "github.com/cloudbase/garm-provider-common/errors" @@ -265,3 +267,68 @@ type InstanceUpdateMessage struct { Message string `json:"message"` AgentID *int64 `json:"agent_id,omitempty"` } + +type CreateGithubEndpointParams struct { + Name string `json:"name"` + Description string `json:"description"` + APIBaseURL string `json:"api_base_url"` + UploadBaseURL string `json:"upload_base_url"` + BaseURL string `json:"base_url"` + CACertBundle []byte `json:"ca_cert_bundle"` +} + +type UpdateGithubEndpointParams struct { + Description *string `json:"description"` + APIBaseURL *string `json:"api_base_url"` + UploadBaseURL *string `json:"upload_base_url"` + BaseURL *string `json:"base_url"` + CACertBundle []byte `json:"ca_cert_bundle"` +} + +type GithubPAT struct { + OAuth2Token string `json:"oauth2_token"` +} + +type GithubApp struct { + AppID int64 `json:"app_id"` + InstallationID int64 `json:"installation_id"` + PrivateKeyBytes []byte `json:"private_key_bytes"` +} + +func (g GithubApp) Validate() error { + if g.AppID == 0 { + return errors.NewBadRequestError("missing app_id") + } + + if g.InstallationID == 0 { + return errors.NewBadRequestError("missing installation_id") + } + + if len(g.PrivateKeyBytes) == 0 { + return errors.NewBadRequestError("missing private_key_bytes") + } + + block, _ := pem.Decode(g.PrivateKeyBytes) + // Parse the private key as PCKS1 + _, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("parsing private_key_path: %w", err) + } + + return nil +} + +type CreateGithubCredentialsParams struct { + Name string `json:"name"` + Description string `json:"description"` + AuthType GithubAuthType `json:"auth_type"` + PAT GithubPAT `json:"pat,omitempty"` + App GithubApp `json:"app,omitempty"` +} + +type UpdateGithubCredentialsParams struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + PAT *GithubPAT `json:"pat,omitempty"` + App *GithubApp `json:"app,omitempty"` +} diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index dc81da5e..2ad54e5d 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -169,7 +169,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateEnterpriseParams.Name, enterprise.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.Credentials.Name) s.Require().Equal(params.PoolBalancerTypeRoundRobin, enterprise.PoolBalancerType) } @@ -306,7 +306,7 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprise() { s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, org.WebhookSecret) s.Require().Equal(params.PoolBalancerTypePack, org.PoolBalancerType) } diff --git a/runner/organizations_test.go b/runner/organizations_test.go index d0113756..30a58882 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -169,7 +169,7 @@ func (s *OrgTestSuite) TestCreateOrganization() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateOrgParams.Name, org.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateOrgParams.CredentialsName].Name, org.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateOrgParams.CredentialsName].Name, org.Credentials.Name) s.Require().Equal(params.PoolBalancerTypeRoundRobin, org.PoolBalancerType) } @@ -317,7 +317,7 @@ func (s *OrgTestSuite) TestUpdateOrganization() { s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, org.WebhookSecret) } diff --git a/runner/repositories_test.go b/runner/repositories_test.go index 20814a86..74bc8a76 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -81,7 +81,7 @@ func (s *RepoTestSuite) SetupTest() { params.PoolBalancerTypeRoundRobin, ) if err != nil { - s.FailNow(fmt.Sprintf("failed to create database object (test-repo-%v)", i)) + s.FailNow(fmt.Sprintf("failed to create database object (test-repo-%v): %q", i, err)) } repos[name] = repo } @@ -170,7 +170,7 @@ func (s *RepoTestSuite) TestCreateRepository() { s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateRepoParams.Owner, repo.Owner) s.Require().Equal(s.Fixtures.CreateRepoParams.Name, repo.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.Credentials.Name) s.Require().Equal(params.PoolBalancerTypeRoundRobin, repo.PoolBalancerType) } @@ -190,7 +190,7 @@ func (s *RepoTestSuite) TestCreareRepositoryPoolBalancerTypePack() { s.Require().Nil(err) s.Require().Equal(param.Owner, repo.Owner) s.Require().Equal(param.Name, repo.Name) - s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.CredentialsName) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateRepoParams.CredentialsName].Name, repo.Credentials.Name) s.Require().Equal(params.PoolBalancerTypePack, repo.PoolBalancerType) } @@ -327,7 +327,7 @@ func (s *RepoTestSuite) TestUpdateRepository() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, repo.WebhookSecret) s.Require().Equal(params.PoolBalancerTypeRoundRobin, repo.PoolBalancerType) } @@ -343,7 +343,7 @@ func (s *RepoTestSuite) TestUpdateRepositoryBalancingType() { s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Require().Nil(err) - s.Require().Equal(updateRepoParams.CredentialsName, repo.CredentialsName) + s.Require().Equal(updateRepoParams.CredentialsName, repo.Credentials.Name) s.Require().Equal(updateRepoParams.WebhookSecret, repo.WebhookSecret) s.Require().Equal(params.PoolBalancerTypePack, repo.PoolBalancerType) } diff --git a/runner/runner.go b/runner/runner.go index 1ffa508a..fd10ad78 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -67,7 +67,6 @@ func NewRunner(ctx context.Context, cfg config.Config, db dbCommon.Store) (*Runn poolManagerCtrl := &poolManagerCtrl{ controllerID: ctrlID.ControllerID.String(), config: cfg, - credentials: creds, repositories: map[string]common.PoolManager{}, organizations: map[string]common.PoolManager{}, enterprises: map[string]common.PoolManager{}, @@ -94,7 +93,6 @@ type poolManagerCtrl struct { controllerID string config config.Config - credentials map[string]config.Github repositories map[string]common.PoolManager organizations map[string]common.PoolManager @@ -105,7 +103,7 @@ func (p *poolManagerCtrl) CreateRepoPoolManager(ctx context.Context, repo params p.mux.Lock() defer p.mux.Unlock() - cfgInternal, err := p.getInternalConfig(ctx, repo.CredentialsName, repo.GetBalancerType()) + cfgInternal, err := p.getInternalConfig(ctx, repo.Credentials, repo.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -133,7 +131,7 @@ func (p *poolManagerCtrl) UpdateRepoPoolManager(ctx context.Context, repo params return nil, errors.Wrapf(runnerErrors.ErrNotFound, "repository %s/%s pool manager not loaded", repo.Owner, repo.Name) } - internalCfg, err := p.getInternalConfig(ctx, repo.CredentialsName, repo.GetBalancerType()) + internalCfg, err := p.getInternalConfig(ctx, repo.Credentials, repo.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -178,7 +176,7 @@ func (p *poolManagerCtrl) CreateOrgPoolManager(ctx context.Context, org params.O p.mux.Lock() defer p.mux.Unlock() - cfgInternal, err := p.getInternalConfig(ctx, org.CredentialsName, org.GetBalancerType()) + cfgInternal, err := p.getInternalConfig(ctx, org.Credentials, org.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -205,7 +203,7 @@ func (p *poolManagerCtrl) UpdateOrgPoolManager(ctx context.Context, org params.O return nil, errors.Wrapf(runnerErrors.ErrNotFound, "org %s pool manager not loaded", org.Name) } - internalCfg, err := p.getInternalConfig(ctx, org.CredentialsName, org.GetBalancerType()) + internalCfg, err := p.getInternalConfig(ctx, org.Credentials, org.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -250,7 +248,7 @@ func (p *poolManagerCtrl) CreateEnterprisePoolManager(ctx context.Context, enter p.mux.Lock() defer p.mux.Unlock() - cfgInternal, err := p.getInternalConfig(ctx, enterprise.CredentialsName, enterprise.GetBalancerType()) + cfgInternal, err := p.getInternalConfig(ctx, enterprise.Credentials, enterprise.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -278,7 +276,7 @@ func (p *poolManagerCtrl) UpdateEnterprisePoolManager(ctx context.Context, enter return nil, errors.Wrapf(runnerErrors.ErrNotFound, "enterprise %s pool manager not loaded", enterprise.Name) } - internalCfg, err := p.getInternalConfig(ctx, enterprise.CredentialsName, enterprise.GetBalancerType()) + internalCfg, err := p.getInternalConfig(ctx, enterprise.Credentials, enterprise.GetBalancerType()) if err != nil { return nil, errors.Wrap(err, "fetching internal config") } @@ -319,22 +317,12 @@ func (p *poolManagerCtrl) GetEnterprisePoolManagers() (map[string]common.PoolMan return p.enterprises, nil } -func (p *poolManagerCtrl) getInternalConfig(ctx context.Context, credsName string, poolBalancerType params.PoolBalancerType) (params.Internal, error) { - creds, ok := p.credentials[credsName] - if !ok { - return params.Internal{}, runnerErrors.NewBadRequestError("invalid credential name (%s)", credsName) - } - - caBundle, err := creds.CACertBundle() - if err != nil { - return params.Internal{}, fmt.Errorf("fetching CA bundle for creds: %w", err) - } - +func (p *poolManagerCtrl) getInternalConfig(ctx context.Context, creds params.GithubCredentials, poolBalancerType params.PoolBalancerType) (params.Internal, error) { var controllerWebhookURL string if p.config.Default.WebhookURL != "" { controllerWebhookURL = fmt.Sprintf("%s/%s", p.config.Default.WebhookURL, p.controllerID) } - httpClient, err := creds.HTTPClient(ctx) + httpClient, err := creds.GetHTTPClient(ctx) if err != nil { return params.Internal{}, fmt.Errorf("fetching http client for creds: %w", err) } @@ -349,10 +337,10 @@ func (p *poolManagerCtrl) getInternalConfig(ctx context.Context, credsName strin GithubCredentialsDetails: params.GithubCredentials{ Name: creds.Name, Description: creds.Description, - BaseURL: creds.BaseEndpoint(), - APIBaseURL: creds.APIEndpoint(), - UploadBaseURL: creds.UploadEndpoint(), - CABundle: caBundle, + BaseURL: creds.BaseURL, + APIBaseURL: creds.APIBaseURL, + UploadBaseURL: creds.UploadBaseURL, + CABundle: creds.CABundle, HTTPClient: httpClient, }, }, nil