From 032d40f5f91c691a0f4407ccbab4d095e25206e6 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Tue, 16 Apr 2024 17:05:18 +0000 Subject: [PATCH] Fix tests Signed-off-by: Gabriel Adrian Samfira --- database/sql/enterprise.go | 115 +++++++++++++------ database/sql/enterprise_test.go | 167 ++++++++++++++++----------- database/sql/github.go | 15 ++- database/sql/instances_test.go | 65 ++++++----- database/sql/organizations.go | 108 ++++++++++++------ database/sql/organizations_test.go | 176 ++++++++++++++++------------ database/sql/pools_test.go | 27 +++-- database/sql/repositories.go | 93 +++++++++------ database/sql/repositories_test.go | 177 +++++++++++++++++------------ database/sql/util.go | 51 +++++---- internal/testing/testing.go | 60 ++++++++++ runner/enterprises.go | 18 +-- runner/enterprises_test.go | 46 ++++---- runner/organizations.go | 18 +-- runner/organizations_test.go | 40 ++++--- runner/pools_test.go | 22 +++- runner/repositories.go | 18 +-- runner/repositories_test.go | 38 ++++--- runner/runner.go | 4 +- 19 files changed, 760 insertions(+), 498 deletions(-) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index e7270faf..e8efcf8b 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -1,3 +1,17 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + package sql import ( @@ -12,7 +26,7 @@ import ( "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateEnterprise(_ context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) { +func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) { if webhookSecret == "" { return params.Enterprise{}, errors.New("creating enterprise: missing secret") } @@ -26,13 +40,27 @@ func (s *sqlDatabase) CreateEnterprise(_ context.Context, name, credentialsName, CredentialsName: credentialsName, PoolBalancerType: poolBalancerType, } + err = s.conn.Transaction(func(tx *gorm.DB) error { + creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false) + if err != nil { + return errors.Wrap(err, "creating enterprise") + } + newEnterprise.CredentialsID = &creds.ID - q := s.conn.Create(&newEnterprise) - if q.Error != nil { - return params.Enterprise{}, errors.Wrap(q.Error, "creating enterprise") + q := tx.Create(&newEnterprise) + if q.Error != nil { + return errors.Wrap(q.Error, "creating enterprise") + } + + newEnterprise.Credentials = creds + + return nil + }) + if err != nil { + return params.Enterprise{}, errors.Wrap(err, "creating enterprise") } - param, err := s.sqlToCommonEnterprise(newEnterprise) + param, err := s.sqlToCommonEnterprise(newEnterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "creating enterprise") } @@ -46,7 +74,7 @@ func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.En return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } - param, err := s.sqlToCommonEnterprise(enterprise) + param, err := s.sqlToCommonEnterprise(enterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -54,12 +82,12 @@ func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.En } func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Pools", "Credentials", "Endpoint") + enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } - param, err := s.sqlToCommonEnterprise(enterprise) + param, err := s.sqlToCommonEnterprise(enterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") } @@ -76,7 +104,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e ret := make([]params.Enterprise, len(enterprises)) for idx, val := range enterprises { var err error - ret[idx], err = s.sqlToCommonEnterprise(val) + ret[idx], err = s.sqlToCommonEnterprise(val, true) if err != nil { return nil, errors.Wrap(err, "fetching enterprises") } @@ -86,7 +114,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e } func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) + enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID) if err != nil { return errors.Wrap(err, "fetching enterprise") } @@ -100,33 +128,50 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) } func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Credentials", "Endpoint") - if err != nil { - return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") - } - - if param.CredentialsName != "" { - enterprise.CredentialsName = param.CredentialsName - } - - if param.WebhookSecret != "" { - secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase)) + var enterprise Enterprise + var creds GithubCredentials + err := s.conn.Transaction(func(tx *gorm.DB) error { + var err error + enterprise, err = s.getEnterpriseByID(ctx, tx, enterpriseID, "Credentials", "Endpoint") if err != nil { - return params.Enterprise{}, errors.Wrap(err, "encoding secret") + return errors.Wrap(err, "fetching enterprise") } - enterprise.WebhookSecret = secret + + if param.CredentialsName != "" { + creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false) + if err != nil { + return errors.Wrap(err, "fetching credentials") + } + enterprise.CredentialsID = &creds.ID + } + if param.WebhookSecret != "" { + secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase)) + if err != nil { + return errors.Wrap(err, "encoding secret") + } + enterprise.WebhookSecret = secret + } + + if param.PoolBalancerType != "" { + enterprise.PoolBalancerType = param.PoolBalancerType + } + + q := tx.Save(&enterprise) + if q.Error != nil { + return errors.Wrap(q.Error, "saving enterprise") + } + + if creds.ID != 0 { + enterprise.Credentials = creds + } + + return nil + }) + if err != nil { + return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } - if param.PoolBalancerType != "" { - enterprise.PoolBalancerType = param.PoolBalancerType - } - - q := s.conn.Save(&enterprise) - if q.Error != nil { - return params.Enterprise{}, errors.Wrap(q.Error, "saving enterprise") - } - - newParams, err := s.sqlToCommonEnterprise(enterprise) + newParams, err := s.sqlToCommonEnterprise(enterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } @@ -149,14 +194,14 @@ func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, return enterprise, nil } -func (s *sqlDatabase) getEnterpriseByID(_ context.Context, id string, preload ...string) (Enterprise, error) { +func (s *sqlDatabase) getEnterpriseByID(_ context.Context, tx *gorm.DB, id string, preload ...string) (Enterprise, error) { u, err := uuid.Parse(id) if err != nil { return Enterprise{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var enterprise Enterprise - q := s.conn + q := tx if len(preload) > 0 { for _, field := range preload { q = q.Preload(field) diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index 7f22956b..3509f946 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -49,6 +49,11 @@ type EnterpriseTestSuite struct { Store dbCommon.Store StoreSQLMocked *sqlDatabase Fixtures *EnterpriseTestFixtures + + adminCtx context.Context + testCreds params.GithubCredentials + secondaryTestCreds params.GithubCredentials + githubEndpoint params.GithubEndpoint } func (s *EnterpriseTestSuite) equalInstancesByName(expected, actual []params.Instance) { @@ -77,18 +82,25 @@ func (s *EnterpriseTestSuite) SetupTest() { } s.Store = db + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + s.adminCtx = adminCtx + + s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint) + s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint) + // create some enterprise objects in the database, for testing purposes enterprises := []params.Enterprise{} for i := 1; i <= 3; i++ { enterprise, err := db.CreateEnterprise( - context.Background(), + s.adminCtx, fmt.Sprintf("test-enterprise-%d", i), - fmt.Sprintf("test-creds-%d", i), + s.testCreds.Name, fmt.Sprintf("test-webhook-secret-%d", i), params.PoolBalancerTypeRoundRobin, ) if err != nil { - s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%d)", i)) + s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%d): %q", i, err)) } enterprises = append(enterprises, enterprise) @@ -124,7 +136,7 @@ func (s *EnterpriseTestSuite) SetupTest() { Enterprises: enterprises, CreateEnterpriseParams: params.CreateEnterpriseParams{ Name: "new-test-enterprise", - CredentialsName: "new-creds", + CredentialsName: s.testCreds.Name, WebhookSecret: "new-webhook-secret", }, CreatePoolParams: params.CreatePoolParams{ @@ -143,7 +155,7 @@ func (s *EnterpriseTestSuite) SetupTest() { OSType: "linux", }, UpdateRepoParams: params.UpdateEntityParams{ - CredentialsName: "test-update-creds", + CredentialsName: s.secondaryTestCreds.Name, WebhookSecret: "test-update-repo-webhook-secret", }, UpdatePoolParams: params.UpdatePoolParams{ @@ -160,7 +172,7 @@ func (s *EnterpriseTestSuite) SetupTest() { func (s *EnterpriseTestSuite) TestCreateEnterprise() { // call tested function enterprise, err := s.Store.CreateEnterprise( - context.Background(), + s.adminCtx, s.Fixtures.CreateEnterpriseParams.Name, s.Fixtures.CreateEnterpriseParams.CredentialsName, s.Fixtures.CreateEnterpriseParams.WebhookSecret, @@ -168,7 +180,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() { // assertions s.Require().Nil(err) - storeEnterprise, err := s.Store.GetEnterpriseByID(context.Background(), enterprise.ID) + storeEnterprise, err := s.Store.GetEnterpriseByID(s.adminCtx, enterprise.ID) if err != nil { s.FailNow(fmt.Sprintf("failed to get enterprise by id: %v", err)) } @@ -191,7 +203,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseInvalidDBPassphrase() { } _, err = sqlDB.CreateEnterprise( - context.Background(), + s.adminCtx, s.Fixtures.CreateEnterpriseParams.Name, s.Fixtures.CreateEnterpriseParams.CredentialsName, s.Fixtures.CreateEnterpriseParams.WebhookSecret, @@ -203,25 +215,29 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseInvalidDBPassphrase() { func (s *EnterpriseTestSuite) TestCreateEnterpriseDBCreateErr() { s.Fixtures.SQLMock.ExpectBegin() + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.Fixtures.Enterprises[0].CredentialsName). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID)) s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `enterprises`")). WillReturnError(fmt.Errorf("creating enterprise mock error")) s.Fixtures.SQLMock.ExpectRollback() _, err := s.StoreSQLMocked.CreateEnterprise( - context.Background(), + s.adminCtx, s.Fixtures.CreateEnterpriseParams.Name, s.Fixtures.CreateEnterpriseParams.CredentialsName, s.Fixtures.CreateEnterpriseParams.WebhookSecret, params.PoolBalancerTypeRoundRobin) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating enterprise: creating enterprise mock error", err.Error()) + s.Require().Equal("creating enterprise: creating enterprise: creating enterprise mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestGetEnterprise() { - enterprise, err := s.Store.GetEnterprise(context.Background(), s.Fixtures.Enterprises[0].Name) + enterprise, err := s.Store.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Enterprises[0].Name, enterprise.Name) @@ -229,14 +245,14 @@ func (s *EnterpriseTestSuite) TestGetEnterprise() { } func (s *EnterpriseTestSuite) TestGetEnterpriseCaseInsensitive() { - enterprise, err := s.Store.GetEnterprise(context.Background(), "TeSt-eNtErPriSe-1") + enterprise, err := s.Store.GetEnterprise(s.adminCtx, "TeSt-eNtErPriSe-1") s.Require().Nil(err) s.Require().Equal("test-enterprise-1", enterprise.Name) } func (s *EnterpriseTestSuite) TestGetEnterpriseNotFound() { - _, err := s.Store.GetEnterprise(context.Background(), "dummy-name") + _, err := s.Store.GetEnterprise(s.adminCtx, "dummy-name") s.Require().NotNil(err) s.Require().Equal("fetching enterprise: not found", err.Error()) @@ -248,7 +264,7 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseDBDecryptingErr() { WithArgs(s.Fixtures.Enterprises[0].Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Enterprises[0].Name)) - _, err := s.StoreSQLMocked.GetEnterprise(context.Background(), s.Fixtures.Enterprises[0].Name) + _, err := s.StoreSQLMocked.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -256,7 +272,7 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseDBDecryptingErr() { } func (s *EnterpriseTestSuite) TestListEnterprises() { - enterprises, err := s.Store.ListEnterprises(context.Background()) + enterprises, err := s.Store.ListEnterprises(s.adminCtx) s.Require().Nil(err) garmTesting.EqualDBEntityByName(s.T(), s.Fixtures.Enterprises, enterprises) @@ -267,7 +283,7 @@ func (s *EnterpriseTestSuite) TestListEnterprisesDBFetchErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE `enterprises`.`deleted_at` IS NULL")). WillReturnError(fmt.Errorf("fetching user from database mock error")) - _, err := s.StoreSQLMocked.ListEnterprises(context.Background()) + _, err := s.StoreSQLMocked.ListEnterprises(s.adminCtx) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -275,16 +291,16 @@ func (s *EnterpriseTestSuite) TestListEnterprisesDBFetchErr() { } func (s *EnterpriseTestSuite) TestDeleteEnterprise() { - err := s.Store.DeleteEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID) + err := s.Store.DeleteEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID) s.Require().Nil(err) - _, err = s.Store.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) + _, err = s.Store.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID) s.Require().NotNil(err) s.Require().Equal("fetching enterprise: not found", err.Error()) } func (s *EnterpriseTestSuite) TestDeleteEnterpriseInvalidEnterpriseID() { - err := s.Store.DeleteEnterprise(context.Background(), "dummy-enterprise-id") + err := s.Store.DeleteEnterprise(s.adminCtx, "dummy-enterprise-id") s.Require().NotNil(err) s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) @@ -302,15 +318,15 @@ func (s *EnterpriseTestSuite) TestDeleteEnterpriseDBDeleteErr() { WillReturnError(fmt.Errorf("mocked delete enterprise error")) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.DeleteEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID) + err := s.StoreSQLMocked.DeleteEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("deleting enterprise: mocked delete enterprise error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestUpdateEnterprise() { - enterprise, err := s.Store.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) + enterprise, err := s.Store.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, enterprise.Credentials.Name) @@ -318,70 +334,85 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprise() { } func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidEnterpriseID() { - _, err := s.Store.UpdateEnterprise(context.Background(), "dummy-enterprise-id", s.Fixtures.UpdateRepoParams) + _, err := s.Store.UpdateEnterprise(s.adminCtx, "dummy-enterprise-id", s.Fixtures.UpdateRepoParams) s.Require().NotNil(err) - s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) + s.Require().Equal("updating enterprise: fetching enterprise: parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBEncryptErr() { s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase - + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). WithArgs(s.Fixtures.Enterprises[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) + s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("encoding secret: invalid passphrase length (expected length 32 characters)", err.Error()) + s.Require().Equal("updating enterprise: encoding secret: invalid passphrase length (expected length 32 characters)", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBSaveErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). WithArgs(s.Fixtures.Enterprises[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock.ExpectBegin() + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) s.Fixtures.SQLMock. ExpectExec(("UPDATE `enterprises` SET")). WillReturnError(fmt.Errorf("saving enterprise mock error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving enterprise: saving enterprise mock error", err.Error()) + s.Require().Equal("updating enterprise: saving enterprise: saving enterprise mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBDecryptingErr() { s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). WithArgs(s.Fixtures.Enterprises[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) + s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("encoding secret: invalid passphrase length (expected length 32 characters)", err.Error()) + s.Require().Equal("updating enterprise: encoding secret: invalid passphrase length (expected length 32 characters)", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestGetEnterpriseByID() { - enterprise, err := s.Store.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) + enterprise, err := s.Store.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Enterprises[0].ID, enterprise.ID) } func (s *EnterpriseTestSuite) TestGetEnterpriseByIDInvalidEnterpriseID() { - _, err := s.Store.GetEnterpriseByID(context.Background(), "dummy-enterprise-id") + _, err := s.Store.GetEnterpriseByID(s.adminCtx, "dummy-enterprise-id") s.Require().NotNil(err) s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) @@ -397,7 +428,7 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseByIDDBDecryptingErr() { WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - _, err := s.StoreSQLMocked.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) + _, err := s.StoreSQLMocked.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -407,11 +438,11 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseByIDDBDecryptingErr() { func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) - enterprise, err := s.Store.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) + enterprise, err := s.Store.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID) if err != nil { s.FailNow(fmt.Sprintf("cannot get enterprise by ID: %v", err)) } @@ -426,7 +457,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) @@ -437,7 +468,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolInvalidEnterpriseID() { ID: "dummy-enterprise-id", EntityType: params.GithubEntityTypeEnterprise, } - _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("parsing id: invalid request", err.Error()) @@ -455,7 +486,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBCreateErr() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) @@ -484,7 +515,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBPoolAlreadyExistErr() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) @@ -511,7 +542,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchTagErr() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) @@ -546,7 +577,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBAddingPoolErr() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("creating pool: mocked adding pool error", err.Error()) @@ -585,7 +616,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBSaveTagErr() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("associating tags: mocked saving tag error", err.Error()) @@ -633,7 +664,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchPoolErr() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) @@ -646,14 +677,14 @@ func (s *EnterpriseTestSuite) TestListEnterprisePools() { s.Require().Nil(err) for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } enterprisePools = append(enterprisePools, pool) } - pools, err := s.Store.ListEntityPools(context.Background(), entity) + pools, err := s.Store.ListEntityPools(s.adminCtx, entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), enterprisePools, pools) @@ -664,7 +695,7 @@ func (s *EnterpriseTestSuite) TestListEnterprisePoolsInvalidEnterpriseID() { ID: "dummy-enterprise-id", EntityType: params.GithubEntityTypeEnterprise, } - _, err := s.Store.ListEntityPools(context.Background(), entity) + _, err := s.Store.ListEntityPools(s.adminCtx, entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) @@ -673,12 +704,12 @@ func (s *EnterpriseTestSuite) TestListEnterprisePoolsInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestGetEnterprisePool() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - enterprisePool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) + enterprisePool, err := s.Store.GetEntityPool(s.adminCtx, entity, pool.ID) s.Require().Nil(err) s.Require().Equal(enterprisePool.ID, pool.ID) @@ -689,7 +720,7 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() { ID: "dummy-enterprise-id", EntityType: params.GithubEntityTypeEnterprise, } - _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") + _, err := s.Store.GetEntityPool(s.adminCtx, entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) @@ -698,15 +729,15 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) + err = s.Store.DeleteEntityPool(s.adminCtx, entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) + _, err = s.Store.GetEntityPool(s.adminCtx, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -715,7 +746,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() { ID: "dummy-enterprise-id", EntityType: params.GithubEntityTypeEnterprise, } - err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") + err := s.Store.DeleteEntityPool(s.adminCtx, entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("parsing id: invalid request", err.Error()) @@ -724,7 +755,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } @@ -736,7 +767,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) + err = s.StoreSQLMocked.DeleteEntityPool(s.adminCtx, entity, pool.ID) s.Require().NotNil(err) s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) s.assertSQLMockExpectations() @@ -745,21 +776,21 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } poolInstances := []params.Instance{} for i := 1; i <= 3; i++ { s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-enterprise-%v", i) - instance, err := s.Store.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) + instance, err := s.Store.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) } poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListEntityInstances(context.Background(), entity) + instances, err := s.Store.ListEntityInstances(s.adminCtx, entity) s.Require().Nil(err) s.equalInstancesByName(poolInstances, instances) @@ -770,7 +801,7 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstancesInvalidEnterpriseID() { ID: "dummy-enterprise-id", EntityType: params.GithubEntityTypeEnterprise, } - _, err := s.Store.ListEntityInstances(context.Background(), entity) + _, err := s.Store.ListEntityInstances(s.adminCtx, entity) s.Require().NotNil(err) s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) @@ -779,12 +810,12 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstancesInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { entity, err := s.Fixtures.Enterprises[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) + pool, err = s.Store.UpdateEntityPool(s.adminCtx, entity, pool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -798,7 +829,7 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolInvalidEnterpriseID() { ID: "dummy-enterprise-id", EntityType: params.GithubEntityTypeEnterprise, } - _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) + _, err := s.Store.UpdateEntityPool(s.adminCtx, entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/github.go b/database/sql/github.go index 0087165f..300c0ef4 100644 --- a/database/sql/github.go +++ b/database/sql/github.go @@ -14,6 +14,9 @@ import ( ) func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (params.GithubCredentials, error) { + if len(creds.Payload) == 0 { + return params.GithubCredentials{}, errors.New("empty credentials payload") + } data, err := util.Unseal(creds.Payload, []byte(s.cfg.Passphrase)) if err != nil { return params.GithubCredentials{}, errors.Wrap(err, "unsealing credentials") @@ -33,7 +36,7 @@ func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (par } for _, repo := range creds.Repositories { - commonRepo, err := s.sqlToCommonRepository(repo) + commonRepo, err := s.sqlToCommonRepository(repo, false) if err != nil { return params.GithubCredentials{}, errors.Wrap(err, "converting github repository") } @@ -41,7 +44,7 @@ func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (par } for _, org := range creds.Organizations { - commonOrg, err := s.sqlToCommonOrganization(org) + commonOrg, err := s.sqlToCommonOrganization(org, false) if err != nil { return params.GithubCredentials{}, errors.Wrap(err, "converting github organization") } @@ -49,9 +52,9 @@ func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (par } for _, ent := range creds.Enterprises { - commonEnt, err := s.sqlToCommonEnterprise(ent) + commonEnt, err := s.sqlToCommonEnterprise(ent, false) if err != nil { - return params.GithubCredentials{}, errors.Wrap(err, "converting github enterprise") + return params.GithubCredentials{}, errors.Wrapf(err, "converting github enterprise: %s", ent.Name) } commonCreds.Enterprises = append(commonCreds.Enterprises, commonEnt) } @@ -73,12 +76,12 @@ func (s *sqlDatabase) sqlToCommonGithubEndpoint(ep GithubEndpoint) (params.Githu func getUIDFromContext(ctx context.Context) (uuid.UUID, error) { userID := auth.UserID(ctx) if userID == "" { - return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") + return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "getting UID from context") } asUUID, err := uuid.Parse(userID) if err != nil { - return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") + return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "parsing UID from context") } return asUUID, nil } diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 74fd8c65..82b5cb46 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -48,6 +48,7 @@ type InstancesTestSuite struct { Store dbCommon.Store StoreSQLMocked *sqlDatabase Fixtures *InstancesTestFixtures + adminCtx context.Context } func (s *InstancesTestSuite) equalInstancesByName(expected, actual []params.Instance) { @@ -76,8 +77,14 @@ func (s *InstancesTestSuite) SetupTest() { } s.Store = db + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + s.adminCtx = adminCtx + + githubEndpoint := garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + creds := garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), githubEndpoint) + // create an organization for testing purposes - org, err := s.Store.CreateOrganization(context.Background(), "test-org", "test-creds", "test-webhookSecret", params.PoolBalancerTypeRoundRobin) + org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds.Name, "test-webhookSecret", params.PoolBalancerTypeRoundRobin) if err != nil { s.FailNow(fmt.Sprintf("failed to create org: %s", err)) } @@ -94,7 +101,7 @@ func (s *InstancesTestSuite) SetupTest() { } entity, err := org.GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, createPoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams) if err != nil { s.FailNow(fmt.Sprintf("failed to create org pool: %s", err)) } @@ -103,7 +110,7 @@ func (s *InstancesTestSuite) SetupTest() { instances := []params.Instance{} for i := 1; i <= 3; i++ { instance, err := db.CreateInstance( - context.Background(), + s.adminCtx, pool.ID, params.CreateInstanceParams{ Name: fmt.Sprintf("test-instance-%d", i), @@ -179,11 +186,11 @@ func (s *InstancesTestSuite) SetupTest() { func (s *InstancesTestSuite) TestCreateInstance() { // call tested function - instance, err := s.Store.CreateInstance(context.Background(), s.Fixtures.Pool.ID, s.Fixtures.CreateInstanceParams) + instance, err := s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, s.Fixtures.CreateInstanceParams) // assertions s.Require().Nil(err) - storeInstance, err := s.Store.GetInstanceByName(context.Background(), s.Fixtures.CreateInstanceParams.Name) + storeInstance, err := s.Store.GetInstanceByName(s.adminCtx, s.Fixtures.CreateInstanceParams.Name) if err != nil { s.FailNow(fmt.Sprintf("failed to get instance: %v", err)) } @@ -195,7 +202,7 @@ func (s *InstancesTestSuite) TestCreateInstance() { } func (s *InstancesTestSuite) TestCreateInstanceInvalidPoolID() { - _, err := s.Store.CreateInstance(context.Background(), "dummy-pool-id", params.CreateInstanceParams{}) + _, err := s.Store.CreateInstance(s.adminCtx, "dummy-pool-id", params.CreateInstanceParams{}) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } @@ -216,7 +223,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() { WillReturnError(fmt.Errorf("mocked insert instance error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) + _, err := s.StoreSQLMocked.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -226,7 +233,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() { func (s *InstancesTestSuite) TestGetPoolInstanceByName() { storeInstance := s.Fixtures.Instances[0] // this is already created in `SetupTest()` - instance, err := s.Store.GetPoolInstanceByName(context.Background(), s.Fixtures.Pool.ID, storeInstance.Name) + instance, err := s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name) s.Require().Nil(err) s.Require().Equal(storeInstance.Name, instance.Name) @@ -237,7 +244,7 @@ func (s *InstancesTestSuite) TestGetPoolInstanceByName() { } func (s *InstancesTestSuite) TestGetPoolInstanceByNameNotFound() { - _, err := s.Store.GetPoolInstanceByName(context.Background(), s.Fixtures.Pool.ID, "not-existent-instance-name") + _, err := s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, "not-existent-instance-name") s.Require().Equal("fetching instance: fetching pool instance by name: not found", err.Error()) } @@ -245,7 +252,7 @@ func (s *InstancesTestSuite) TestGetPoolInstanceByNameNotFound() { func (s *InstancesTestSuite) TestGetInstanceByName() { storeInstance := s.Fixtures.Instances[1] - instance, err := s.Store.GetInstanceByName(context.Background(), storeInstance.Name) + instance, err := s.Store.GetInstanceByName(s.adminCtx, storeInstance.Name) s.Require().Nil(err) s.Require().Equal(storeInstance.Name, instance.Name) @@ -256,7 +263,7 @@ func (s *InstancesTestSuite) TestGetInstanceByName() { } func (s *InstancesTestSuite) TestGetInstanceByNameFetchInstanceFailed() { - _, err := s.Store.GetInstanceByName(context.Background(), "not-existent-instance-name") + _, err := s.Store.GetInstanceByName(s.adminCtx, "not-existent-instance-name") s.Require().Equal("fetching instance: fetching instance by name: not found", err.Error()) } @@ -264,16 +271,16 @@ func (s *InstancesTestSuite) TestGetInstanceByNameFetchInstanceFailed() { func (s *InstancesTestSuite) TestDeleteInstance() { storeInstance := s.Fixtures.Instances[0] - err := s.Store.DeleteInstance(context.Background(), s.Fixtures.Pool.ID, storeInstance.Name) + err := s.Store.DeleteInstance(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name) s.Require().Nil(err) - _, err = s.Store.GetPoolInstanceByName(context.Background(), s.Fixtures.Pool.ID, storeInstance.Name) + _, err = s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name) s.Require().Equal("fetching instance: fetching pool instance by name: not found", err.Error()) } func (s *InstancesTestSuite) TestDeleteInstanceInvalidPoolID() { - err := s.Store.DeleteInstance(context.Background(), "dummy-pool-id", "dummy-instance-name") + err := s.Store.DeleteInstance(s.adminCtx, "dummy-pool-id", "dummy-instance-name") s.Require().Equal("deleting instance: fetching pool: parsing id: invalid request", err.Error()) } @@ -309,7 +316,7 @@ func (s *InstancesTestSuite) TestDeleteInstanceDBRecordNotFoundErr() { WillReturnError(gorm.ErrRecordNotFound) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.DeleteInstance(context.Background(), pool.ID, instance.Name) + err := s.StoreSQLMocked.DeleteInstance(s.adminCtx, pool.ID, instance.Name) s.assertSQLMockExpectations() s.Require().Nil(err) @@ -346,7 +353,7 @@ func (s *InstancesTestSuite) TestDeleteInstanceDBDeleteErr() { WillReturnError(fmt.Errorf("mocked delete instance error")) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.DeleteInstance(context.Background(), pool.ID, instance.Name) + err := s.StoreSQLMocked.DeleteInstance(s.adminCtx, pool.ID, instance.Name) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -357,10 +364,10 @@ func (s *InstancesTestSuite) TestAddInstanceEvent() { storeInstance := s.Fixtures.Instances[0] statusMsg := "test-status-message" - err := s.Store.AddInstanceEvent(context.Background(), storeInstance.Name, params.StatusEvent, params.EventInfo, statusMsg) + err := s.Store.AddInstanceEvent(s.adminCtx, storeInstance.Name, params.StatusEvent, params.EventInfo, statusMsg) s.Require().Nil(err) - instance, err := s.Store.GetInstanceByName(context.Background(), storeInstance.Name) + instance, err := s.Store.GetInstanceByName(s.adminCtx, storeInstance.Name) if err != nil { s.FailNow(fmt.Sprintf("failed to get db instance: %s", err)) } @@ -398,7 +405,7 @@ func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() { WillReturnError(fmt.Errorf("mocked add status message error")) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.AddInstanceEvent(context.Background(), instance.Name, params.StatusEvent, params.EventInfo, statusMsg) + err := s.StoreSQLMocked.AddInstanceEvent(s.adminCtx, instance.Name, params.StatusEvent, params.EventInfo, statusMsg) s.Require().NotNil(err) s.Require().Equal("adding status message: mocked add status message error", err.Error()) @@ -406,7 +413,7 @@ func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() { } func (s *InstancesTestSuite) TestUpdateInstance() { - instance, err := s.Store.UpdateInstance(context.Background(), s.Fixtures.Instances[0].Name, s.Fixtures.UpdateInstanceParams) + instance, err := s.Store.UpdateInstance(s.adminCtx, s.Fixtures.Instances[0].Name, s.Fixtures.UpdateInstanceParams) s.Require().Nil(err) s.Require().Equal(s.Fixtures.UpdateInstanceParams.ProviderID, instance.ProviderID) @@ -443,7 +450,7 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateInstanceErr() { WillReturnError(fmt.Errorf("mocked update instance error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.Name, s.Fixtures.UpdateInstanceParams) + _, err := s.StoreSQLMocked.UpdateInstance(s.adminCtx, instance.Name, s.Fixtures.UpdateInstanceParams) s.Require().NotNil(err) s.Require().Equal("updating instance: mocked update instance error", err.Error()) @@ -489,7 +496,7 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() { WillReturnError(fmt.Errorf("update addresses mock error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.Name, s.Fixtures.UpdateInstanceParams) + _, err := s.StoreSQLMocked.UpdateInstance(s.adminCtx, instance.Name, s.Fixtures.UpdateInstanceParams) s.Require().NotNil(err) s.Require().Equal("updating addresses: update addresses mock error", err.Error()) @@ -497,20 +504,20 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() { } func (s *InstancesTestSuite) TestListPoolInstances() { - instances, err := s.Store.ListPoolInstances(context.Background(), s.Fixtures.Pool.ID) + instances, err := s.Store.ListPoolInstances(s.adminCtx, s.Fixtures.Pool.ID) s.Require().Nil(err) s.equalInstancesByName(s.Fixtures.Instances, instances) } func (s *InstancesTestSuite) TestListPoolInstancesInvalidPoolID() { - _, err := s.Store.ListPoolInstances(context.Background(), "dummy-pool-id") + _, err := s.Store.ListPoolInstances(s.adminCtx, "dummy-pool-id") s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *InstancesTestSuite) TestListAllInstances() { - instances, err := s.Store.ListAllInstances(context.Background()) + instances, err := s.Store.ListAllInstances(s.adminCtx) s.Require().Nil(err) s.equalInstancesByName(s.Fixtures.Instances, instances) @@ -521,7 +528,7 @@ func (s *InstancesTestSuite) TestListAllInstancesDBFetchErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE `instances`.`deleted_at` IS NULL")). WillReturnError(fmt.Errorf("fetch instances mock error")) - _, err := s.StoreSQLMocked.ListAllInstances(context.Background()) + _, err := s.StoreSQLMocked.ListAllInstances(s.adminCtx) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -529,14 +536,14 @@ func (s *InstancesTestSuite) TestListAllInstancesDBFetchErr() { } func (s *InstancesTestSuite) TestPoolInstanceCount() { - instancesCount, err := s.Store.PoolInstanceCount(context.Background(), s.Fixtures.Pool.ID) + instancesCount, err := s.Store.PoolInstanceCount(s.adminCtx, s.Fixtures.Pool.ID) s.Require().Nil(err) s.Require().Equal(int64(len(s.Fixtures.Instances)), instancesCount) } func (s *InstancesTestSuite) TestPoolInstanceCountInvalidPoolID() { - _, err := s.Store.PoolInstanceCount(context.Background(), "dummy-pool-id") + _, err := s.Store.PoolInstanceCount(s.adminCtx, "dummy-pool-id") s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } @@ -553,7 +560,7 @@ func (s *InstancesTestSuite) TestPoolInstanceCountDBCountErr() { WithArgs(pool.ID). WillReturnError(fmt.Errorf("count mock error")) - _, err := s.StoreSQLMocked.PoolInstanceCount(context.Background(), pool.ID) + _, err := s.StoreSQLMocked.PoolInstanceCount(s.adminCtx, pool.ID) s.assertSQLMockExpectations() s.Require().NotNil(err) diff --git a/database/sql/organizations.go b/database/sql/organizations.go index c67e85d4..0019ad39 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -27,13 +27,13 @@ import ( "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateOrganization(_ context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) { +func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) { if webhookSecret == "" { return params.Organization{}, errors.New("creating org: missing secret") } secret, err := util.Seal([]byte(webhookSecret), []byte(s.cfg.Passphrase)) if err != nil { - return params.Organization{}, fmt.Errorf("failed to encrypt string") + return params.Organization{}, errors.Wrap(err, "encoding secret") } newOrg := Organization{ Name: name, @@ -42,12 +42,27 @@ func (s *sqlDatabase) CreateOrganization(_ context.Context, name, credentialsNam PoolBalancerType: poolBalancerType, } - q := s.conn.Create(&newOrg) - if q.Error != nil { - return params.Organization{}, errors.Wrap(q.Error, "creating org") + err = s.conn.Transaction(func(tx *gorm.DB) error { + creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false) + if err != nil { + return errors.Wrap(err, "creating org") + } + newOrg.CredentialsID = &creds.ID + + q := tx.Create(&newOrg) + if q.Error != nil { + return errors.Wrap(q.Error, "creating org") + } + + newOrg.Credentials = creds + + return nil + }) + if err != nil { + return params.Organization{}, errors.Wrap(err, "creating org") } - param, err := s.sqlToCommonOrganization(newOrg) + param, err := s.sqlToCommonOrganization(newOrg, true) if err != nil { return params.Organization{}, errors.Wrap(err, "creating org") } @@ -62,7 +77,7 @@ func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params. return params.Organization{}, errors.Wrap(err, "fetching org") } - param, err := s.sqlToCommonOrganization(org) + param, err := s.sqlToCommonOrganization(org, true) if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } @@ -80,7 +95,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio ret := make([]params.Organization, len(orgs)) for idx, val := range orgs { var err error - ret[idx], err = s.sqlToCommonOrganization(val) + ret[idx], err = s.sqlToCommonOrganization(val, true) if err != nil { return nil, errors.Wrap(err, "fetching org") } @@ -90,7 +105,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio } func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) error { - org, err := s.getOrgByID(ctx, orgID) + org, err := s.getOrgByID(ctx, s.conn, orgID) if err != nil { return errors.Wrap(err, "fetching org") } @@ -104,33 +119,52 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) erro } func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) { - org, err := s.getOrgByID(ctx, orgID, "Credentials", "Endpoint") - if err != nil { - return params.Organization{}, errors.Wrap(err, "fetching org") - } - - if param.CredentialsName != "" { - org.CredentialsName = param.CredentialsName - } - - if param.WebhookSecret != "" { - secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase)) + var org Organization + var creds GithubCredentials + err := s.conn.Transaction(func(tx *gorm.DB) error { + var err error + org, err = s.getOrgByID(ctx, tx, orgID, "Credentials", "Endpoint") if err != nil { - return params.Organization{}, fmt.Errorf("saving org: failed to encrypt string: %w", err) + return errors.Wrap(err, "fetching org") } - org.WebhookSecret = secret + + if param.CredentialsName != "" { + org.CredentialsName = param.CredentialsName + creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false) + if err != nil { + return errors.Wrap(err, "fetching credentials") + } + org.CredentialsID = &creds.ID + } + + if param.WebhookSecret != "" { + secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase)) + if err != nil { + return fmt.Errorf("saving org: failed to encrypt string: %w", err) + } + org.WebhookSecret = secret + } + + if param.PoolBalancerType != "" { + org.PoolBalancerType = param.PoolBalancerType + } + + q := tx.Save(&org) + if q.Error != nil { + return errors.Wrap(q.Error, "saving org") + } + + if creds.ID != 0 { + org.Credentials = creds + } + + return nil + }) + if err != nil { + return params.Organization{}, errors.Wrap(err, "saving org") } - if param.PoolBalancerType != "" { - org.PoolBalancerType = param.PoolBalancerType - } - - q := s.conn.Save(&org) - if q.Error != nil { - return params.Organization{}, errors.Wrap(q.Error, "saving org") - } - - newParams, err := s.sqlToCommonOrganization(org) + newParams, err := s.sqlToCommonOrganization(org, true) if err != nil { return params.Organization{}, errors.Wrap(err, "saving org") } @@ -138,26 +172,26 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para } func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) { - org, err := s.getOrgByID(ctx, orgID, "Pools", "Credentials", "Endpoint") + org, err := s.getOrgByID(ctx, s.conn, orgID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Organization{}, errors.Wrap(err, "fetching org") } - param, err := s.sqlToCommonOrganization(org) + param, err := s.sqlToCommonOrganization(org, true) if err != nil { - return params.Organization{}, errors.Wrap(err, "fetching enterprise") + return params.Organization{}, errors.Wrap(err, "fetching org") } return param, nil } -func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) { +func (s *sqlDatabase) getOrgByID(_ context.Context, db *gorm.DB, id string, preload ...string) (Organization, error) { u, err := uuid.Parse(id) if err != nil { return Organization{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var org Organization - q := s.conn + q := db if len(preload) > 0 { for _, field := range preload { q = q.Preload(field) diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 28a049e5..55e791db 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -49,6 +49,11 @@ type OrgTestSuite struct { Store dbCommon.Store StoreSQLMocked *sqlDatabase Fixtures *OrgTestFixtures + + adminCtx context.Context + testCreds params.GithubCredentials + secondaryTestCreds params.GithubCredentials + githubEndpoint params.GithubEndpoint } func (s *OrgTestSuite) equalInstancesByName(expected, actual []params.Instance) { @@ -71,24 +76,32 @@ func (s *OrgTestSuite) assertSQLMockExpectations() { func (s *OrgTestSuite) SetupTest() { // create testing sqlite database - db, err := NewSQLDatabase(context.Background(), garmTesting.GetTestSqliteDBConfig(s.T())) + dbConfig := garmTesting.GetTestSqliteDBConfig(s.T()) + db, err := NewSQLDatabase(context.Background(), dbConfig) if err != nil { s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) } s.Store = db + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + s.adminCtx = adminCtx + + s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint) + s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint) + // create some organization objects in the database, for testing purposes orgs := []params.Organization{} for i := 1; i <= 3; i++ { org, err := db.CreateOrganization( - context.Background(), + s.adminCtx, fmt.Sprintf("test-org-%d", i), - fmt.Sprintf("test-creds-%d", i), + s.testCreds.Name, fmt.Sprintf("test-webhook-secret-%d", i), params.PoolBalancerTypeRoundRobin, ) if err != nil { - s.FailNow(fmt.Sprintf("failed to create database object (test-org-%d)", i)) + s.FailNow(fmt.Sprintf("failed to create database object (test-org-%d): %q", i, err)) } orgs = append(orgs, org) @@ -114,7 +127,7 @@ func (s *OrgTestSuite) SetupTest() { } s.StoreSQLMocked = &sqlDatabase{ conn: gormConn, - cfg: garmTesting.GetTestSqliteDBConfig(s.T()), + cfg: dbConfig, } // setup test fixtures @@ -123,8 +136,8 @@ func (s *OrgTestSuite) SetupTest() { fixtures := &OrgTestFixtures{ Orgs: orgs, CreateOrgParams: params.CreateOrgParams{ - Name: "new-test-org", - CredentialsName: "new-creds", + Name: s.testCreds.Name, + CredentialsName: s.testCreds.Name, WebhookSecret: "new-webhook-secret", }, CreatePoolParams: params.CreatePoolParams{ @@ -143,7 +156,7 @@ func (s *OrgTestSuite) SetupTest() { OSType: "linux", }, UpdateRepoParams: params.UpdateEntityParams{ - CredentialsName: "test-update-creds", + CredentialsName: s.secondaryTestCreds.Name, WebhookSecret: "test-update-repo-webhook-secret", }, UpdatePoolParams: params.UpdatePoolParams{ @@ -160,7 +173,7 @@ func (s *OrgTestSuite) SetupTest() { func (s *OrgTestSuite) TestCreateOrganization() { // call tested function org, err := s.Store.CreateOrganization( - context.Background(), + s.adminCtx, s.Fixtures.CreateOrgParams.Name, s.Fixtures.CreateOrgParams.CredentialsName, s.Fixtures.CreateOrgParams.WebhookSecret, @@ -168,7 +181,7 @@ func (s *OrgTestSuite) TestCreateOrganization() { // assertions s.Require().Nil(err) - storeOrg, err := s.Store.GetOrganizationByID(context.Background(), org.ID) + storeOrg, err := s.Store.GetOrganizationByID(s.adminCtx, org.ID) if err != nil { s.FailNow(fmt.Sprintf("failed to get organization by id: %v", err)) } @@ -191,37 +204,41 @@ func (s *OrgTestSuite) TestCreateOrganizationInvalidDBPassphrase() { } _, err = sqlDB.CreateOrganization( - context.Background(), + s.adminCtx, s.Fixtures.CreateOrgParams.Name, s.Fixtures.CreateOrgParams.CredentialsName, s.Fixtures.CreateOrgParams.WebhookSecret, params.PoolBalancerTypeRoundRobin) s.Require().NotNil(err) - s.Require().Equal("failed to encrypt string", err.Error()) + s.Require().Equal("encoding secret: invalid passphrase length (expected length 32 characters)", err.Error()) } func (s *OrgTestSuite) TestCreateOrganizationDBCreateErr() { s.Fixtures.SQLMock.ExpectBegin() + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.Fixtures.Orgs[0].CredentialsName). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID)) s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `organizations`")). WillReturnError(fmt.Errorf("creating org mock error")) s.Fixtures.SQLMock.ExpectRollback() _, err := s.StoreSQLMocked.CreateOrganization( - context.Background(), + s.adminCtx, s.Fixtures.CreateOrgParams.Name, s.Fixtures.CreateOrgParams.CredentialsName, s.Fixtures.CreateOrgParams.WebhookSecret, params.PoolBalancerTypeRoundRobin) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating org: creating org mock error", err.Error()) + s.Require().Equal("creating org: creating org: creating org mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestGetOrganization() { - org, err := s.Store.GetOrganization(context.Background(), s.Fixtures.Orgs[0].Name) + org, err := s.Store.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Orgs[0].Name, org.Name) @@ -229,14 +246,14 @@ func (s *OrgTestSuite) TestGetOrganization() { } func (s *OrgTestSuite) TestGetOrganizationCaseInsensitive() { - org, err := s.Store.GetOrganization(context.Background(), "TeSt-oRg-1") + org, err := s.Store.GetOrganization(s.adminCtx, "TeSt-oRg-1") s.Require().Nil(err) s.Require().Equal("test-org-1", org.Name) } func (s *OrgTestSuite) TestGetOrganizationNotFound() { - _, err := s.Store.GetOrganization(context.Background(), "dummy-name") + _, err := s.Store.GetOrganization(s.adminCtx, "dummy-name") s.Require().NotNil(err) s.Require().Equal("fetching org: not found", err.Error()) @@ -248,7 +265,7 @@ func (s *OrgTestSuite) TestGetOrganizationDBDecryptingErr() { WithArgs(s.Fixtures.Orgs[0].Name, 1). WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Orgs[0].Name)) - _, err := s.StoreSQLMocked.GetOrganization(context.Background(), s.Fixtures.Orgs[0].Name) + _, err := s.StoreSQLMocked.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -256,7 +273,7 @@ func (s *OrgTestSuite) TestGetOrganizationDBDecryptingErr() { } func (s *OrgTestSuite) TestListOrganizations() { - orgs, err := s.Store.ListOrganizations(context.Background()) + orgs, err := s.Store.ListOrganizations(s.adminCtx) s.Require().Nil(err) garmTesting.EqualDBEntityByName(s.T(), s.Fixtures.Orgs, orgs) @@ -267,7 +284,7 @@ func (s *OrgTestSuite) TestListOrganizationsDBFetchErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE `organizations`.`deleted_at` IS NULL")). WillReturnError(fmt.Errorf("fetching user from database mock error")) - _, err := s.StoreSQLMocked.ListOrganizations(context.Background()) + _, err := s.StoreSQLMocked.ListOrganizations(s.adminCtx) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -275,16 +292,16 @@ func (s *OrgTestSuite) TestListOrganizationsDBFetchErr() { } func (s *OrgTestSuite) TestDeleteOrganization() { - err := s.Store.DeleteOrganization(context.Background(), s.Fixtures.Orgs[0].ID) + err := s.Store.DeleteOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID) s.Require().Nil(err) - _, err = s.Store.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) + _, err = s.Store.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID) s.Require().NotNil(err) s.Require().Equal("fetching org: not found", err.Error()) } func (s *OrgTestSuite) TestDeleteOrganizationInvalidOrgID() { - err := s.Store.DeleteOrganization(context.Background(), "dummy-org-id") + err := s.Store.DeleteOrganization(s.adminCtx, "dummy-org-id") s.Require().NotNil(err) s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) @@ -302,7 +319,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationDBDeleteErr() { WillReturnError(fmt.Errorf("mocked delete org error")) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.DeleteOrganization(context.Background(), s.Fixtures.Orgs[0].ID) + err := s.StoreSQLMocked.DeleteOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -310,7 +327,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationDBDeleteErr() { } func (s *OrgTestSuite) TestUpdateOrganization() { - org, err := s.Store.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) + org, err := s.Store.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) @@ -318,70 +335,85 @@ func (s *OrgTestSuite) TestUpdateOrganization() { } func (s *OrgTestSuite) TestUpdateOrganizationInvalidOrgID() { - _, err := s.Store.UpdateOrganization(context.Background(), "dummy-org-id", s.Fixtures.UpdateRepoParams) + _, err := s.Store.UpdateOrganization(s.adminCtx, "dummy-org-id", s.Fixtures.UpdateRepoParams) s.Require().NotNil(err) - s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) + s.Require().Equal("saving org: fetching org: parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestUpdateOrganizationDBEncryptErr() { s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase - + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). WithArgs(s.Fixtures.Orgs[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) + s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.Require().Equal("saving org: saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestUpdateOrganizationDBSaveErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). WithArgs(s.Fixtures.Orgs[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock.ExpectBegin() + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) s.Fixtures.SQLMock. ExpectExec(("UPDATE `organizations` SET")). WillReturnError(fmt.Errorf("saving org mock error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving org: saving org mock error", err.Error()) + s.Require().Equal("saving org: saving org: saving org mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestUpdateOrganizationDBDecryptingErr() { s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). WithArgs(s.Fixtures.Orgs[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) + s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.Require().Equal("saving org: saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestGetOrganizationByID() { - org, err := s.Store.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) + org, err := s.Store.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Orgs[0].ID, org.ID) } func (s *OrgTestSuite) TestGetOrganizationByIDInvalidOrgID() { - _, err := s.Store.GetOrganizationByID(context.Background(), "dummy-org-id") + _, err := s.Store.GetOrganizationByID(s.adminCtx, "dummy-org-id") s.Require().NotNil(err) s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) @@ -397,21 +429,21 @@ func (s *OrgTestSuite) TestGetOrganizationByIDDBDecryptingErr() { WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"}).AddRow(s.Fixtures.Orgs[0].ID)) - _, err := s.StoreSQLMocked.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) + _, err := s.StoreSQLMocked.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID) s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("fetching enterprise: missing secret", err.Error()) + s.Require().Equal("fetching org: missing secret", err.Error()) } func (s *OrgTestSuite) TestCreateOrganizationPool() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) - org, err := s.Store.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) + org, err := s.Store.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID) if err != nil { s.FailNow(fmt.Sprintf("cannot get org by ID: %v", err)) } @@ -426,7 +458,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) @@ -437,7 +469,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolInvalidOrgID() { ID: "dummy-org-id", EntityType: params.GithubEntityTypeOrganization, } - _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("parsing id: invalid request", err.Error()) @@ -455,7 +487,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBCreateErr() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) @@ -484,7 +516,7 @@ func (s *OrgTestSuite) TestCreateOrganizationDBPoolAlreadyExistErr() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) @@ -511,7 +543,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchTagErr() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) @@ -547,7 +579,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBAddingPoolErr() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("creating pool: mocked adding pool error", err.Error()) @@ -586,7 +618,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBSaveTagErr() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("associating tags: mocked saving tag error", err.Error()) @@ -635,7 +667,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchPoolErr() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) @@ -648,13 +680,13 @@ func (s *OrgTestSuite) TestListOrgPools() { s.Require().Nil(err) for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } orgPools = append(orgPools, pool) } - pools, err := s.Store.ListEntityPools(context.Background(), entity) + pools, err := s.Store.ListEntityPools(s.adminCtx, entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), orgPools, pools) @@ -665,7 +697,7 @@ func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() { ID: "dummy-org-id", EntityType: params.GithubEntityTypeOrganization, } - _, err := s.Store.ListEntityPools(context.Background(), entity) + _, err := s.Store.ListEntityPools(s.adminCtx, entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) @@ -674,12 +706,12 @@ func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() { func (s *OrgTestSuite) TestGetOrganizationPool() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - orgPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) + orgPool, err := s.Store.GetEntityPool(s.adminCtx, entity, pool.ID) s.Require().Nil(err) s.Require().Equal(orgPool.ID, pool.ID) @@ -690,7 +722,7 @@ func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() { ID: "dummy-org-id", EntityType: params.GithubEntityTypeOrganization, } - _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") + _, err := s.Store.GetEntityPool(s.adminCtx, entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) @@ -699,15 +731,15 @@ func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() { func (s *OrgTestSuite) TestDeleteOrganizationPool() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) + err = s.Store.DeleteEntityPool(s.adminCtx, entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) + _, err = s.Store.GetEntityPool(s.adminCtx, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -716,7 +748,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolInvalidOrgID() { ID: "dummy-org-id", EntityType: params.GithubEntityTypeOrganization, } - err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") + err := s.Store.DeleteEntityPool(s.adminCtx, entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("parsing id: invalid request", err.Error()) @@ -726,7 +758,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -738,7 +770,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) + err = s.StoreSQLMocked.DeleteEntityPool(s.adminCtx, entity, pool.ID) s.Require().NotNil(err) s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) @@ -748,21 +780,21 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { func (s *OrgTestSuite) TestListOrgInstances() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } poolInstances := []params.Instance{} for i := 1; i <= 3; i++ { s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-org-%v", i) - instance, err := s.Store.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) + instance, err := s.Store.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) } poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListEntityInstances(context.Background(), entity) + instances, err := s.Store.ListEntityInstances(s.adminCtx, entity) s.Require().Nil(err) s.equalInstancesByName(poolInstances, instances) @@ -773,7 +805,7 @@ func (s *OrgTestSuite) TestListOrgInstancesInvalidOrgID() { ID: "dummy-org-id", EntityType: params.GithubEntityTypeOrganization, } - _, err := s.Store.ListEntityInstances(context.Background(), entity) + _, err := s.Store.ListEntityInstances(s.adminCtx, entity) s.Require().NotNil(err) s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) @@ -782,12 +814,12 @@ func (s *OrgTestSuite) TestListOrgInstancesInvalidOrgID() { func (s *OrgTestSuite) TestUpdateOrganizationPool() { entity, err := s.Fixtures.Orgs[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) + pool, err = s.Store.UpdateEntityPool(s.adminCtx, entity, pool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -801,7 +833,7 @@ func (s *OrgTestSuite) TestUpdateOrganizationPoolInvalidOrgID() { ID: "dummy-org-id", EntityType: params.GithubEntityTypeOrganization, } - _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) + _, err := s.Store.UpdateEntityPool(s.adminCtx, entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index af4fa2cf..97dbdf71 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -43,6 +43,7 @@ type PoolsTestSuite struct { Store dbCommon.Store StoreSQLMocked *sqlDatabase Fixtures *PoolsTestFixtures + adminCtx context.Context } func (s *PoolsTestSuite) assertSQLMockExpectations() { @@ -60,8 +61,14 @@ func (s *PoolsTestSuite) SetupTest() { } s.Store = db + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + s.adminCtx = adminCtx + + githubEndpoint := garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + creds := garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), githubEndpoint) + // create an organization for testing purposes - org, err := s.Store.CreateOrganization(context.Background(), "test-org", "test-creds", "test-webhookSecret", params.PoolBalancerTypeRoundRobin) + org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds.Name, "test-webhookSecret", params.PoolBalancerTypeRoundRobin) if err != nil { s.FailNow(fmt.Sprintf("failed to create org: %s", err)) } @@ -72,7 +79,7 @@ func (s *PoolsTestSuite) SetupTest() { orgPools := []params.Pool{} for i := 1; i <= 3; i++ { pool, err := db.CreateEntityPool( - context.Background(), + s.adminCtx, entity, params.CreatePoolParams{ ProviderName: "test-provider", @@ -122,7 +129,7 @@ func (s *PoolsTestSuite) SetupTest() { } func (s *PoolsTestSuite) TestListAllPools() { - pools, err := s.Store.ListAllPools(context.Background()) + pools, err := s.Store.ListAllPools(s.adminCtx) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), s.Fixtures.Pools, pools) @@ -133,7 +140,7 @@ func (s *PoolsTestSuite) TestListAllPoolsDBFetchErr() { ExpectQuery(regexp.QuoteMeta("SELECT `pools`.`id`,`pools`.`created_at`,`pools`.`updated_at`,`pools`.`deleted_at`,`pools`.`provider_name`,`pools`.`runner_prefix`,`pools`.`max_runners`,`pools`.`min_idle_runners`,`pools`.`runner_bootstrap_timeout`,`pools`.`image`,`pools`.`flavor`,`pools`.`os_type`,`pools`.`os_arch`,`pools`.`enabled`,`pools`.`git_hub_runner_group`,`pools`.`repo_id`,`pools`.`org_id`,`pools`.`enterprise_id`,`pools`.`priority` FROM `pools` WHERE `pools`.`deleted_at` IS NULL")). WillReturnError(fmt.Errorf("mocked fetching all pools error")) - _, err := s.StoreSQLMocked.ListAllPools(context.Background()) + _, err := s.StoreSQLMocked.ListAllPools(s.adminCtx) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -141,29 +148,29 @@ func (s *PoolsTestSuite) TestListAllPoolsDBFetchErr() { } func (s *PoolsTestSuite) TestGetPoolByID() { - pool, err := s.Store.GetPoolByID(context.Background(), s.Fixtures.Pools[0].ID) + pool, err := s.Store.GetPoolByID(s.adminCtx, s.Fixtures.Pools[0].ID) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Pools[0].ID, pool.ID) } func (s *PoolsTestSuite) TestGetPoolByIDInvalidPoolID() { - _, err := s.Store.GetPoolByID(context.Background(), "dummy-pool-id") + _, err := s.Store.GetPoolByID(s.adminCtx, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool by ID: parsing id: invalid request", err.Error()) } func (s *PoolsTestSuite) TestDeletePoolByID() { - err := s.Store.DeletePoolByID(context.Background(), s.Fixtures.Pools[0].ID) + err := s.Store.DeletePoolByID(s.adminCtx, s.Fixtures.Pools[0].ID) s.Require().Nil(err) - _, err = s.Store.GetPoolByID(context.Background(), s.Fixtures.Pools[0].ID) + _, err = s.Store.GetPoolByID(s.adminCtx, s.Fixtures.Pools[0].ID) s.Require().Equal("fetching pool by ID: not found", err.Error()) } func (s *PoolsTestSuite) TestDeletePoolByIDInvalidPoolID() { - err := s.Store.DeletePoolByID(context.Background(), "dummy-pool-id") + err := s.Store.DeletePoolByID(s.adminCtx, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool by ID: parsing id: invalid request", err.Error()) @@ -180,7 +187,7 @@ func (s *PoolsTestSuite) TestDeletePoolByIDDBRemoveErr() { WillReturnError(fmt.Errorf("mocked removing pool error")) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.DeletePoolByID(context.Background(), s.Fixtures.Pools[0].ID) + err := s.StoreSQLMocked.DeletePoolByID(s.adminCtx, s.Fixtures.Pools[0].ID) s.assertSQLMockExpectations() s.Require().NotNil(err) diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 396b2796..18284e1b 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -36,31 +36,32 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent return params.Repository{}, fmt.Errorf("failed to encrypt string") } - var newRepo Repository + newRepo := Repository{ + Name: name, + Owner: owner, + WebhookSecret: secret, + PoolBalancerType: poolBalancerType, + } err = s.conn.Transaction(func(tx *gorm.DB) error { creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false) if err != nil { return errors.Wrap(err, "creating repository") } - - newRepo.Name = name - newRepo.Owner = owner - newRepo.WebhookSecret = secret newRepo.CredentialsID = &creds.ID - newRepo.PoolBalancerType = poolBalancerType q := tx.Create(&newRepo) if q.Error != nil { return errors.Wrap(q.Error, "creating repository") } + newRepo.Credentials = creds return nil }) if err != nil { return params.Repository{}, errors.Wrap(err, "creating repository") } - param, err := s.sqlToCommonRepository(newRepo) + param, err := s.sqlToCommonRepository(newRepo, true) if err != nil { return params.Repository{}, errors.Wrap(err, "creating repository") } @@ -74,7 +75,7 @@ func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (pa return params.Repository{}, errors.Wrap(err, "fetching repo") } - param, err := s.sqlToCommonRepository(repo) + param, err := s.sqlToCommonRepository(repo, true) if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -92,7 +93,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, ret := make([]params.Repository, len(repos)) for idx, val := range repos { var err error - ret[idx], err = s.sqlToCommonRepository(val) + ret[idx], err = s.sqlToCommonRepository(val, true) if err != nil { return nil, errors.Wrap(err, "fetching repositories") } @@ -102,7 +103,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, } func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error { - repo, err := s.getRepoByID(ctx, repoID) + repo, err := s.getRepoByID(ctx, s.conn, repoID) if err != nil { return errors.Wrap(err, "fetching repo") } @@ -116,33 +117,51 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error } func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) { - repo, err := s.getRepoByID(ctx, repoID, "Credentials", "Endpoint") - if err != nil { - return params.Repository{}, errors.Wrap(err, "fetching repo") - } - - if param.CredentialsName != "" { - repo.CredentialsName = param.CredentialsName - } - - if param.WebhookSecret != "" { - secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase)) + var repo Repository + var creds GithubCredentials + err := s.conn.Transaction(func(tx *gorm.DB) error { + var err error + repo, err = s.getRepoByID(ctx, tx, repoID, "Credentials", "Endpoint") if err != nil { - return params.Repository{}, fmt.Errorf("saving repo: failed to encrypt string: %w", err) + return errors.Wrap(err, "fetching repo") } - repo.WebhookSecret = secret + + if param.CredentialsName != "" { + repo.CredentialsName = param.CredentialsName + creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false) + if err != nil { + return errors.Wrap(err, "fetching credentials") + } + repo.CredentialsID = &creds.ID + } + + if param.WebhookSecret != "" { + secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase)) + if err != nil { + return fmt.Errorf("saving repo: failed to encrypt string: %w", err) + } + repo.WebhookSecret = secret + } + + if param.PoolBalancerType != "" { + repo.PoolBalancerType = param.PoolBalancerType + } + + q := tx.Save(&repo) + if q.Error != nil { + return errors.Wrap(q.Error, "saving repo") + } + + if creds.ID != 0 { + repo.Credentials = creds + } + return nil + }) + if err != nil { + return params.Repository{}, errors.Wrap(err, "saving repo") } - if param.PoolBalancerType != "" { - repo.PoolBalancerType = param.PoolBalancerType - } - - q := s.conn.Save(&repo) - if q.Error != nil { - return params.Repository{}, errors.Wrap(q.Error, "saving repo") - } - - newParams, err := s.sqlToCommonRepository(repo) + newParams, err := s.sqlToCommonRepository(repo, true) if err != nil { return params.Repository{}, errors.Wrap(err, "saving repo") } @@ -150,12 +169,12 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param } func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error) { - repo, err := s.getRepoByID(ctx, repoID, "Pools", "Credentials", "Endpoint") + repo, err := s.getRepoByID(ctx, s.conn, repoID, "Pools", "Credentials", "Endpoint") if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } - param, err := s.sqlToCommonRepository(repo) + param, err := s.sqlToCommonRepository(repo, true) if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -206,14 +225,14 @@ func (s *sqlDatabase) getEntityPoolByUniqueFields(tx *gorm.DB, entity params.Git return Pool{}, nil } -func (s *sqlDatabase) getRepoByID(_ context.Context, id string, preload ...string) (Repository, error) { +func (s *sqlDatabase) getRepoByID(_ context.Context, tx *gorm.DB, id string, preload ...string) (Repository, error) { u, err := uuid.Parse(id) if err != nil { return Repository{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var repo Repository - q := s.conn + q := tx if len(preload) > 0 { for _, field := range preload { q = q.Preload(field) diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index fbd68304..09f65a7a 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -48,6 +48,11 @@ type RepoTestSuite struct { Store dbCommon.Store StoreSQLMocked *sqlDatabase Fixtures *RepoTestFixtures + + adminCtx context.Context + testCreds params.GithubCredentials + secondaryTestCreds params.GithubCredentials + githubEndpoint params.GithubEndpoint } func (s *RepoTestSuite) equalReposByName(expected, actual []params.Repository) { @@ -87,14 +92,21 @@ func (s *RepoTestSuite) SetupTest() { } s.Store = db + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + s.adminCtx = adminCtx + + s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint) + s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint) + // create some repository objects in the database, for testing purposes repos := []params.Repository{} for i := 1; i <= 3; i++ { repo, err := db.CreateRepository( - context.Background(), + adminCtx, fmt.Sprintf("test-owner-%d", i), fmt.Sprintf("test-repo-%d", i), - fmt.Sprintf("test-creds-%d", i), + s.testCreds.Name, fmt.Sprintf("test-webhook-secret-%d", i), params.PoolBalancerTypeRoundRobin, ) @@ -136,7 +148,7 @@ func (s *RepoTestSuite) SetupTest() { CreateRepoParams: params.CreateRepoParams{ Owner: "test-owner-repo", Name: "test-repo", - CredentialsName: "test-creds-repo", + CredentialsName: s.testCreds.Name, WebhookSecret: "test-webhook-secret", }, CreatePoolParams: params.CreatePoolParams{ @@ -155,7 +167,7 @@ func (s *RepoTestSuite) SetupTest() { OSType: "linux", }, UpdateRepoParams: params.UpdateEntityParams{ - CredentialsName: "test-update-creds", + CredentialsName: s.secondaryTestCreds.Name, WebhookSecret: "test-update-webhook-secret", }, UpdatePoolParams: params.UpdatePoolParams{ @@ -172,7 +184,7 @@ func (s *RepoTestSuite) SetupTest() { func (s *RepoTestSuite) TestCreateRepository() { // call tested function repo, err := s.Store.CreateRepository( - context.Background(), + s.adminCtx, s.Fixtures.CreateRepoParams.Owner, s.Fixtures.CreateRepoParams.Name, s.Fixtures.CreateRepoParams.CredentialsName, @@ -182,7 +194,7 @@ func (s *RepoTestSuite) TestCreateRepository() { // assertions s.Require().Nil(err) - storeRepo, err := s.Store.GetRepositoryByID(context.Background(), repo.ID) + storeRepo, err := s.Store.GetRepositoryByID(s.adminCtx, repo.ID) if err != nil { s.FailNow(fmt.Sprintf("failed to get repository by id: %v", err)) } @@ -206,7 +218,7 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBPassphrase() { } _, err = sqlDB.CreateRepository( - context.Background(), + s.adminCtx, s.Fixtures.CreateRepoParams.Owner, s.Fixtures.CreateRepoParams.Name, s.Fixtures.CreateRepoParams.CredentialsName, @@ -220,13 +232,17 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBPassphrase() { func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() { s.Fixtures.SQLMock.ExpectBegin() + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.Fixtures.Repos[0].CredentialsName). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID)) s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `repositories`")). WillReturnError(fmt.Errorf("creating repo mock error")) s.Fixtures.SQLMock.ExpectRollback() _, err := s.StoreSQLMocked.CreateRepository( - context.Background(), + s.adminCtx, s.Fixtures.CreateRepoParams.Owner, s.Fixtures.CreateRepoParams.Name, s.Fixtures.CreateRepoParams.CredentialsName, @@ -234,13 +250,13 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() { params.PoolBalancerTypeRoundRobin, ) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating repository: creating repo mock error", err.Error()) + s.Require().Equal("creating repository: creating repository: creating repo mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestGetRepository() { - repo, err := s.Store.GetRepository(context.Background(), s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) + repo, err := s.Store.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Repos[0].Owner, repo.Owner) @@ -249,7 +265,7 @@ func (s *RepoTestSuite) TestGetRepository() { } func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() { - repo, err := s.Store.GetRepository(context.Background(), "TeSt-oWnEr-1", "TeSt-rEpO-1") + repo, err := s.Store.GetRepository(s.adminCtx, "TeSt-oWnEr-1", "TeSt-rEpO-1") s.Require().Nil(err) s.Require().Equal("test-owner-1", repo.Owner) @@ -257,7 +273,7 @@ func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() { } func (s *RepoTestSuite) TestGetRepositoryNotFound() { - _, err := s.Store.GetRepository(context.Background(), "dummy-owner", "dummy-name") + _, err := s.Store.GetRepository(s.adminCtx, "dummy-owner", "dummy-name") s.Require().NotNil(err) s.Require().Equal("fetching repo: not found", err.Error()) @@ -273,15 +289,15 @@ func (s *RepoTestSuite) TestGetRepositoryDBDecryptingErr() { WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, 1). WillReturnRows(sqlmock.NewRows([]string{"name", "owner"}).AddRow(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner)) - _, err := s.StoreSQLMocked.GetRepository(context.Background(), s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) + _, err := s.StoreSQLMocked.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching repo: missing secret", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestListRepositories() { - repos, err := s.Store.ListRepositories((context.Background())) + repos, err := s.Store.ListRepositories(s.adminCtx) s.Require().Nil(err) s.equalReposByName(s.Fixtures.Repos, repos) @@ -292,11 +308,11 @@ func (s *RepoTestSuite) TestListRepositoriesDBFetchErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE `repositories`.`deleted_at` IS NULL")). WillReturnError(fmt.Errorf("fetching user from database mock error")) - _, err := s.StoreSQLMocked.ListRepositories(context.Background()) + _, err := s.StoreSQLMocked.ListRepositories(s.adminCtx) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching user from database: fetching user from database mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestListRepositoriesDBDecryptingErr() { @@ -306,24 +322,24 @@ func (s *RepoTestSuite) TestListRepositoriesDBDecryptingErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE `repositories`.`deleted_at` IS NULL")). WillReturnRows(sqlmock.NewRows([]string{"id", "webhook_secret"}).AddRow(s.Fixtures.Repos[0].ID, s.Fixtures.Repos[0].WebhookSecret)) - _, err := s.StoreSQLMocked.ListRepositories(context.Background()) + _, err := s.StoreSQLMocked.ListRepositories(s.adminCtx) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching repositories: decrypting secret: invalid passphrase length (expected length 32 characters)", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestDeleteRepository() { - err := s.Store.DeleteRepository(context.Background(), s.Fixtures.Repos[0].ID) + err := s.Store.DeleteRepository(s.adminCtx, s.Fixtures.Repos[0].ID) s.Require().Nil(err) - _, err = s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) + _, err = s.Store.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID) s.Require().NotNil(err) s.Require().Equal("fetching repo: not found", err.Error()) } func (s *RepoTestSuite) TestDeleteRepositoryInvalidRepoID() { - err := s.Store.DeleteRepository(context.Background(), "dummy-repo-id") + err := s.Store.DeleteRepository(s.adminCtx, "dummy-repo-id") s.Require().NotNil(err) s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) @@ -341,15 +357,15 @@ func (s *RepoTestSuite) TestDeleteRepositoryDBRemoveErr() { WillReturnError(fmt.Errorf("mocked deleting repo error")) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.DeleteRepository(context.Background(), s.Fixtures.Repos[0].ID) + err := s.StoreSQLMocked.DeleteRepository(s.adminCtx, s.Fixtures.Repos[0].ID) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("deleting repo: mocked deleting repo error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestUpdateRepository() { - repo, err := s.Store.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) + repo, err := s.Store.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) s.Require().Nil(err) s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.Credentials.Name) @@ -357,69 +373,84 @@ func (s *RepoTestSuite) TestUpdateRepository() { } func (s *RepoTestSuite) TestUpdateRepositoryInvalidRepoID() { - _, err := s.Store.UpdateRepository(context.Background(), "dummy-repo-id", s.Fixtures.UpdateRepoParams) + _, err := s.Store.UpdateRepository(s.adminCtx, "dummy-repo-id", s.Fixtures.UpdateRepoParams) s.Require().NotNil(err) - s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) + s.Require().Equal("saving repo: fetching repo: parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestUpdateRepositoryDBEncryptErr() { s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase - + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). WithArgs(s.Fixtures.Repos[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - _, err := s.StoreSQLMocked.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) + s.Fixtures.SQLMock.ExpectRollback() + + _, err := s.StoreSQLMocked.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.Require().Equal("saving repo: saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestUpdateRepositoryDBSaveErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). WithArgs(s.Fixtures.Repos[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock.ExpectBegin() + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) s.Fixtures.SQLMock. ExpectExec(("UPDATE `repositories` SET")). WillReturnError(fmt.Errorf("saving repo mock error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving repo: saving repo mock error", err.Error()) + s.Require().Equal("saving repo: saving repo: saving repo mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestUpdateRepositoryDBDecryptingErr() { s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret - + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). WithArgs(s.Fixtures.Repos[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")). + WithArgs(s.secondaryTestCreds.Name). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID)) + s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) + _, err := s.StoreSQLMocked.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.Require().Equal("saving repo: saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestGetRepositoryByID() { - repo, err := s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) + repo, err := s.Store.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID) s.Require().Nil(err) s.Require().Equal(s.Fixtures.Repos[0].ID, repo.ID) } func (s *RepoTestSuite) TestGetRepositoryByIDInvalidRepoID() { - _, err := s.Store.GetRepositoryByID(context.Background(), "dummy-repo-id") + _, err := s.Store.GetRepositoryByID(s.adminCtx, "dummy-repo-id") s.Require().NotNil(err) s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) @@ -435,20 +466,20 @@ func (s *RepoTestSuite) TestGetRepositoryByIDDBDecryptingErr() { WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"}).AddRow(s.Fixtures.Repos[0].ID)) - _, err := s.StoreSQLMocked.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) + _, err := s.StoreSQLMocked.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching repo: missing secret", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPool() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) - repo, err := s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) + repo, err := s.Store.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID) if err != nil { s.FailNow(fmt.Sprintf("cannot get repo by ID: %v", err)) } @@ -463,7 +494,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) @@ -474,7 +505,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolInvalidRepoID() { ID: "dummy-repo-id", EntityType: params.GithubEntityTypeRepository, } - _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("parsing id: invalid request", err.Error()) @@ -492,7 +523,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBCreateErr() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) @@ -522,7 +553,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBPoolAlreadyExistErr() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("pool with the same image and flavor already exists on this provider", err.Error()) @@ -550,7 +581,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchTagErr() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) @@ -587,7 +618,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBAddingPoolErr() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("creating pool: mocked adding pool error", err.Error()) @@ -627,7 +658,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBSaveTagErr() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("associating tags: mocked saving tag error", err.Error()) s.assertSQLMockExpectations() @@ -675,7 +706,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchPoolErr() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) @@ -688,14 +719,14 @@ func (s *RepoTestSuite) TestListRepoPools() { repoPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%d", i) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } repoPools = append(repoPools, pool) } - pools, err := s.Store.ListEntityPools(context.Background(), entity) + pools, err := s.Store.ListEntityPools(s.adminCtx, entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), repoPools, pools) @@ -706,7 +737,7 @@ func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() { ID: "dummy-repo-id", EntityType: params.GithubEntityTypeRepository, } - _, err := s.Store.ListEntityPools(context.Background(), entity) + _, err := s.Store.ListEntityPools(s.adminCtx, entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) @@ -715,12 +746,12 @@ func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() { func (s *RepoTestSuite) TestGetRepositoryPool() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - repoPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) + repoPool, err := s.Store.GetEntityPool(s.adminCtx, entity, pool.ID) s.Require().Nil(err) s.Require().Equal(repoPool.ID, pool.ID) @@ -731,7 +762,7 @@ func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() { ID: "dummy-repo-id", EntityType: params.GithubEntityTypeRepository, } - _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") + _, err := s.Store.GetEntityPool(s.adminCtx, entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) @@ -740,15 +771,15 @@ func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() { func (s *RepoTestSuite) TestDeleteRepositoryPool() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) + err = s.Store.DeleteEntityPool(s.adminCtx, entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) + _, err = s.Store.GetEntityPool(s.adminCtx, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -757,7 +788,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolInvalidRepoID() { ID: "dummy-repo-id", EntityType: params.GithubEntityTypeRepository, } - err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") + err := s.Store.DeleteEntityPool(s.adminCtx, entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("parsing id: invalid request", err.Error()) @@ -767,7 +798,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -779,7 +810,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) + err = s.StoreSQLMocked.DeleteEntityPool(s.adminCtx, entity, pool.ID) s.Require().NotNil(err) s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) s.assertSQLMockExpectations() @@ -788,21 +819,21 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { func (s *RepoTestSuite) TestListRepoInstances() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } poolInstances := []params.Instance{} for i := 1; i <= 3; i++ { s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-repo-%d", i) - instance, err := s.Store.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) + instance, err := s.Store.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) } poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListEntityInstances(context.Background(), entity) + instances, err := s.Store.ListEntityInstances(s.adminCtx, entity) s.Require().Nil(err) s.equalInstancesByID(poolInstances, instances) @@ -813,7 +844,7 @@ func (s *RepoTestSuite) TestListRepoInstancesInvalidRepoID() { ID: "dummy-repo-id", EntityType: params.GithubEntityTypeRepository, } - _, err := s.Store.ListEntityInstances(context.Background(), entity) + _, err := s.Store.ListEntityInstances(s.adminCtx, entity) s.Require().NotNil(err) s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) @@ -822,12 +853,12 @@ func (s *RepoTestSuite) TestListRepoInstancesInvalidRepoID() { func (s *RepoTestSuite) TestUpdateRepositoryPool() { entity, err := s.Fixtures.Repos[0].GetEntity() s.Require().Nil(err) - repoPool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) + repoPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - pool, err := s.Store.UpdateEntityPool(context.Background(), entity, repoPool.ID, s.Fixtures.UpdatePoolParams) + pool, err := s.Store.UpdateEntityPool(s.adminCtx, entity, repoPool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -841,7 +872,7 @@ func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() { ID: "dummy-repo-id", EntityType: params.GithubEntityTypeRepository, } - _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-repo-id", s.Fixtures.UpdatePoolParams) + _, err := s.Store.UpdateEntityPool(s.adminCtx, entity, "dummy-repo-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/util.go b/database/sql/util.go index 30946863..bc09142d 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -105,7 +105,7 @@ func (s *sqlDatabase) sqlAddressToParamsAddress(addr Address) commonParams.Addre } } -func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organization, error) { +func (s *sqlDatabase) sqlToCommonOrganization(org Organization, detailed bool) (params.Organization, error) { if len(org.WebhookSecret) == 0 { return params.Organization{}, errors.New("missing secret") } @@ -114,20 +114,21 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organiza return params.Organization{}, errors.Wrap(err, "decrypting secret") } - creds, err := s.sqlToCommonGithubCredentials(org.Credentials) - if err != nil { - return params.Organization{}, errors.Wrap(err, "converting credentials") - } - ret := params.Organization{ ID: org.ID.String(), Name: org.Name, - CredentialsName: creds.Name, - Credentials: creds, + CredentialsName: org.Credentials.Name, Pools: make([]params.Pool, len(org.Pools)), WebhookSecret: string(secret), PoolBalancerType: org.PoolBalancerType, } + if detailed { + creds, err := s.sqlToCommonGithubCredentials(org.Credentials) + if err != nil { + return params.Organization{}, errors.Wrap(err, "converting credentials") + } + ret.Credentials = creds + } if ret.PoolBalancerType == "" { ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin @@ -143,7 +144,7 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organiza return ret, nil } -func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise) (params.Enterprise, error) { +func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise, detailed bool) (params.Enterprise, error) { if len(enterprise.WebhookSecret) == 0 { return params.Enterprise{}, errors.New("missing secret") } @@ -152,20 +153,23 @@ func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise) (params.Enter return params.Enterprise{}, errors.Wrap(err, "decrypting secret") } - creds, err := s.sqlToCommonGithubCredentials(enterprise.Credentials) - if err != nil { - return params.Enterprise{}, errors.Wrap(err, "converting credentials") - } ret := params.Enterprise{ ID: enterprise.ID.String(), Name: enterprise.Name, - CredentialsName: creds.Name, - Credentials: creds, + CredentialsName: enterprise.Credentials.Name, Pools: make([]params.Pool, len(enterprise.Pools)), WebhookSecret: string(secret), PoolBalancerType: enterprise.PoolBalancerType, } + if detailed { + creds, err := s.sqlToCommonGithubCredentials(enterprise.Credentials) + if err != nil { + return params.Enterprise{}, errors.Wrap(err, "converting credentials") + } + ret.Credentials = creds + } + if ret.PoolBalancerType == "" { ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin } @@ -241,7 +245,7 @@ func (s *sqlDatabase) sqlToCommonTags(tag Tag) params.Tag { } } -func (s *sqlDatabase) sqlToCommonRepository(repo Repository) (params.Repository, error) { +func (s *sqlDatabase) sqlToCommonRepository(repo Repository, detailed bool) (params.Repository, error) { if len(repo.WebhookSecret) == 0 { return params.Repository{}, errors.New("missing secret") } @@ -250,21 +254,24 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository) (params.Repository, return params.Repository{}, errors.Wrap(err, "decrypting secret") } - creds, err := s.sqlToCommonGithubCredentials(repo.Credentials) - if err != nil { - return params.Repository{}, errors.Wrap(err, "converting credentials") - } ret := params.Repository{ ID: repo.ID.String(), Name: repo.Name, Owner: repo.Owner, - CredentialsName: creds.Name, - Credentials: creds, + CredentialsName: repo.Credentials.Name, Pools: make([]params.Pool, len(repo.Pools)), WebhookSecret: string(secret), PoolBalancerType: repo.PoolBalancerType, } + if detailed { + creds, err := s.sqlToCommonGithubCredentials(repo.Credentials) + if err != nil { + return params.Repository{}, errors.Wrap(err, "converting credentials") + } + ret.Credentials = creds + } + if ret.PoolBalancerType == "" { ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin } diff --git a/internal/testing/testing.go b/internal/testing/testing.go index 5f8624e6..d599aca4 100644 --- a/internal/testing/testing.go +++ b/internal/testing/testing.go @@ -18,19 +18,79 @@ package testing import ( + "context" "os" "path/filepath" "sort" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/require" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/auth" "github.com/cloudbase/garm/config" + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" + "github.com/cloudbase/garm/util/appdefaults" ) //nolint:golangci-lint,gosec var encryptionPassphrase = "bocyasicgatEtenOubwonIbsudNutDom" +func ImpersonateAdminContext(ctx context.Context, db common.Store, s *testing.T) context.Context { + adminUser, err := db.GetAdminUser(ctx) + if err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + s.Fatalf("failed to get admin user: %v", err) + } + newUserParams := params.NewUserParams{ + Email: "admin@localhost", + Username: "admin", + Password: "superSecretAdminPassword@123", + IsAdmin: true, + Enabled: true, + } + adminUser, err = db.CreateUser(ctx, newUserParams) + if err != nil { + s.Fatalf("failed to create admin user: %v", err) + } + } + ctx = auth.PopulateContext(ctx, adminUser) + return ctx +} + +func CreateDefaultGithubEndpoint(ctx context.Context, db common.Store, s *testing.T) params.GithubEndpoint { + endpointParams := params.CreateGithubEndpointParams{ + Name: "github.com", + Description: "github endpoint", + APIBaseURL: appdefaults.GithubDefaultBaseURL, + UploadBaseURL: appdefaults.GithubDefaultUploadBaseURL, + BaseURL: appdefaults.DefaultGithubURL, + } + endpoint, err := db.CreateGithubEndpoint(ctx, endpointParams) + if err != nil { + s.Fatalf("failed to create database object (github.com): %v", err) + } + return endpoint +} + +func CreateTestGithubCredentials(ctx context.Context, credsName string, db common.Store, s *testing.T, endpoint params.GithubEndpoint) params.GithubCredentials { + newCredsParams := params.CreateGithubCredentialsParams{ + Name: credsName, + Description: "Test creds", + AuthType: params.GithubAuthTypePAT, + PAT: params.GithubPAT{ + OAuth2Token: "test-token", + }, + } + newCreds, err := db.CreateGithubCredentials(ctx, endpoint.Name, newCredsParams) + if err != nil { + s.Fatalf("failed to create database object (new-creds): %v", err) + } + return newCreds +} + func GetTestSqliteDBConfig(t *testing.T) config.Database { dir, err := os.MkdirTemp("", "garm-config-test") if err != nil { diff --git a/runner/enterprises.go b/runner/enterprises.go index c5274e09..8f8dbf31 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -25,8 +25,8 @@ func (r *Runner) CreateEnterprise(ctx context.Context, param params.CreateEnterp return params.Enterprise{}, errors.Wrap(err, "validating params") } - creds, ok := r.credentials[param.CredentialsName] - if !ok { + creds, err := r.store.GetGithubCredentialsByName(ctx, param.CredentialsName, true) + if err != nil { return params.Enterprise{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } @@ -161,25 +161,13 @@ func (r *Runner) UpdateEnterprise(ctx context.Context, enterpriseID string, para r.mux.Lock() defer r.mux.Unlock() - enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID) - if err != nil { - return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") - } - - if param.CredentialsName != "" { - // Check that credentials are set before saving to db - if _, ok := r.credentials[param.CredentialsName]; !ok { - return params.Enterprise{}, runnerErrors.NewBadRequestError("invalid credentials (%s) for enterprise %s", param.CredentialsName, enterprise.Name) - } - } - switch param.PoolBalancerType { case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone: default: return params.Enterprise{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType) } - enterprise, err = r.store.UpdateEnterprise(ctx, enterpriseID, param) + enterprise, err := r.store.UpdateEnterprise(ctx, enterpriseID, param) if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index 2ad54e5d..501d96a8 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -19,12 +19,11 @@ import ( "fmt" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" runnerErrors "github.com/cloudbase/garm-provider-common/errors" - "github.com/cloudbase/garm/auth" - "github.com/cloudbase/garm/config" "github.com/cloudbase/garm/database" dbCommon "github.com/cloudbase/garm/database/common" garmTesting "github.com/cloudbase/garm/internal/testing" //nolint:typecheck @@ -40,7 +39,7 @@ type EnterpriseTestFixtures struct { Store dbCommon.Store StoreEnterprises map[string]params.Enterprise Providers map[string]common.Provider - Credentials map[string]config.Github + Credentials map[string]params.GithubCredentials CreateEnterpriseParams params.CreateEnterpriseParams CreatePoolParams params.CreatePoolParams CreateInstanceParams params.CreateInstanceParams @@ -57,18 +56,25 @@ type EnterpriseTestSuite struct { suite.Suite Fixtures *EnterpriseTestFixtures Runner *Runner + + testCreds params.GithubCredentials + secondaryTestCreds params.GithubCredentials + githubEndpoint params.GithubEndpoint } func (s *EnterpriseTestSuite) SetupTest() { - adminCtx := auth.GetAdminContext(context.Background()) - // create testing sqlite database dbCfg := garmTesting.GetTestSqliteDBConfig(s.T()) - db, err := database.NewDatabase(adminCtx, dbCfg) + db, err := database.NewDatabase(context.Background(), dbCfg) if err != nil { s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) } + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint) + s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint) + // create some organization objects in the database, for testing purposes enterprises := map[string]params.Enterprise{} for i := 1; i <= 3; i++ { @@ -76,12 +82,12 @@ func (s *EnterpriseTestSuite) SetupTest() { enterprise, err := db.CreateEnterprise( adminCtx, name, - fmt.Sprintf("test-creds-%v", i), + s.testCreds.Name, fmt.Sprintf("test-webhook-secret-%v", i), params.PoolBalancerTypeRoundRobin, ) if err != nil { - s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%v)", i)) + s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%v): %+v", i, err)) } enterprises[name] = enterprise } @@ -98,16 +104,13 @@ func (s *EnterpriseTestSuite) SetupTest() { Providers: map[string]common.Provider{ "test-provider": providerMock, }, - Credentials: map[string]config.Github{ - "test-creds": { - Name: "test-creds-name", - Description: "test-creds-description", - OAuth2Token: "test-creds-oauth2-token", - }, + Credentials: map[string]params.GithubCredentials{ + s.testCreds.Name: s.testCreds, + s.secondaryTestCreds.Name: s.secondaryTestCreds, }, CreateEnterpriseParams: params.CreateEnterpriseParams{ Name: "test-enterprise-create", - CredentialsName: "test-creds", + CredentialsName: s.testCreds.Name, WebhookSecret: "test-create-enterprise-webhook-secret", }, CreatePoolParams: params.CreatePoolParams{ @@ -126,7 +129,7 @@ func (s *EnterpriseTestSuite) SetupTest() { OSType: "linux", }, UpdateRepoParams: params.UpdateEntityParams{ - CredentialsName: "test-creds", + CredentialsName: s.testCreds.Name, WebhookSecret: "test-update-repo-webhook-secret", }, UpdatePoolParams: params.UpdatePoolParams{ @@ -148,7 +151,6 @@ func (s *EnterpriseTestSuite) SetupTest() { // setup test runner runner := &Runner{ providers: fixtures.Providers, - credentials: fixtures.Credentials, ctx: fixtures.AdminContext, store: fixtures.Store, poolManagerCtrl: fixtures.PoolMgrCtrlMock, @@ -164,13 +166,13 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() { // call tested function enterprise, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams) - // assertions - s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) - s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateEnterpriseParams.Name, enterprise.Name) s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.Credentials.Name) s.Require().Equal(params.PoolBalancerTypeRoundRobin, enterprise.PoolBalancerType) + // assertions + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) } func (s *EnterpriseTestSuite) TestCreateEnterpriseErrUnauthorized() { @@ -322,7 +324,9 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidCreds() { _, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) - s.Require().Equal(runnerErrors.NewBadRequestError("invalid credentials (%s) for enterprise %s", s.Fixtures.UpdateRepoParams.CredentialsName, s.Fixtures.StoreEnterprises["test-enterprise-1"].Name), err) + if !errors.Is(err, runnerErrors.ErrNotFound) { + s.FailNow(fmt.Sprintf("expected error: %v", runnerErrors.ErrNotFound)) + } } func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() { diff --git a/runner/organizations.go b/runner/organizations.go index 40847ccf..c0663505 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -38,8 +38,8 @@ func (r *Runner) CreateOrganization(ctx context.Context, param params.CreateOrgP return params.Organization{}, errors.Wrap(err, "validating params") } - creds, ok := r.credentials[param.CredentialsName] - if !ok { + creds, err := r.store.GetGithubCredentialsByName(ctx, param.CredentialsName, true) + if err != nil { return params.Organization{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } @@ -190,25 +190,13 @@ func (r *Runner) UpdateOrganization(ctx context.Context, orgID string, param par r.mux.Lock() defer r.mux.Unlock() - org, err := r.store.GetOrganizationByID(ctx, orgID) - if err != nil { - return params.Organization{}, errors.Wrap(err, "fetching org") - } - - if param.CredentialsName != "" { - // Check that credentials are set before saving to db - if _, ok := r.credentials[param.CredentialsName]; !ok { - return params.Organization{}, runnerErrors.NewBadRequestError("invalid credentials (%s) for org %s", param.CredentialsName, org.Name) - } - } - switch param.PoolBalancerType { case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone: default: return params.Organization{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType) } - org, err = r.store.UpdateOrganization(ctx, orgID, param) + org, err := r.store.UpdateOrganization(ctx, orgID, param) if err != nil { return params.Organization{}, errors.Wrap(err, "updating org") } diff --git a/runner/organizations_test.go b/runner/organizations_test.go index 30a58882..7954d2e7 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -19,12 +19,11 @@ import ( "fmt" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" runnerErrors "github.com/cloudbase/garm-provider-common/errors" - "github.com/cloudbase/garm/auth" - "github.com/cloudbase/garm/config" "github.com/cloudbase/garm/database" dbCommon "github.com/cloudbase/garm/database/common" garmTesting "github.com/cloudbase/garm/internal/testing" @@ -40,7 +39,7 @@ type OrgTestFixtures struct { Store dbCommon.Store StoreOrgs map[string]params.Organization Providers map[string]common.Provider - Credentials map[string]config.Github + Credentials map[string]params.GithubCredentials CreateOrgParams params.CreateOrgParams CreatePoolParams params.CreatePoolParams CreateInstanceParams params.CreateInstanceParams @@ -57,18 +56,26 @@ type OrgTestSuite struct { suite.Suite Fixtures *OrgTestFixtures Runner *Runner + + testCreds params.GithubCredentials + secondaryTestCreds params.GithubCredentials + githubEndpoint params.GithubEndpoint } func (s *OrgTestSuite) SetupTest() { - adminCtx := auth.GetAdminContext(context.Background()) - // create testing sqlite database dbCfg := garmTesting.GetTestSqliteDBConfig(s.T()) - db, err := database.NewDatabase(adminCtx, dbCfg) + db, err := database.NewDatabase(context.Background(), dbCfg) if err != nil { s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) } + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + + s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint) + s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint) + // create some organization objects in the database, for testing purposes orgs := map[string]params.Organization{} for i := 1; i <= 3; i++ { @@ -76,7 +83,7 @@ func (s *OrgTestSuite) SetupTest() { org, err := db.CreateOrganization( adminCtx, name, - fmt.Sprintf("test-creds-%v", i), + s.testCreds.Name, fmt.Sprintf("test-webhook-secret-%v", i), params.PoolBalancerTypeRoundRobin, ) @@ -98,16 +105,13 @@ func (s *OrgTestSuite) SetupTest() { Providers: map[string]common.Provider{ "test-provider": providerMock, }, - Credentials: map[string]config.Github{ - "test-creds": { - Name: "test-creds-name", - Description: "test-creds-description", - OAuth2Token: "test-creds-oauth2-token", - }, + Credentials: map[string]params.GithubCredentials{ + s.testCreds.Name: s.testCreds, + s.secondaryTestCreds.Name: s.secondaryTestCreds, }, CreateOrgParams: params.CreateOrgParams{ Name: "test-org-create", - CredentialsName: "test-creds", + CredentialsName: s.testCreds.Name, WebhookSecret: "test-create-org-webhook-secret", }, CreatePoolParams: params.CreatePoolParams{ @@ -126,7 +130,7 @@ func (s *OrgTestSuite) SetupTest() { OSType: "linux", }, UpdateRepoParams: params.UpdateEntityParams{ - CredentialsName: "test-creds", + CredentialsName: s.testCreds.Name, WebhookSecret: "test-update-repo-webhook-secret", }, UpdatePoolParams: params.UpdatePoolParams{ @@ -148,7 +152,6 @@ func (s *OrgTestSuite) SetupTest() { // setup test runner runner := &Runner{ providers: fixtures.Providers, - credentials: fixtures.Credentials, ctx: fixtures.AdminContext, store: fixtures.Store, poolManagerCtrl: fixtures.PoolMgrCtrlMock, @@ -346,8 +349,9 @@ func (s *OrgTestSuite) TestUpdateOrganizationInvalidCreds() { s.Fixtures.UpdateRepoParams.CredentialsName = invalidCredentialsName _, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams) - - s.Require().Equal(runnerErrors.NewBadRequestError("invalid credentials (%s) for org %s", s.Fixtures.UpdateRepoParams.CredentialsName, s.Fixtures.StoreOrgs["test-org-1"].Name), err) + if !errors.Is(err, runnerErrors.ErrNotFound) { + s.FailNow(fmt.Sprintf("expected error: %v", runnerErrors.ErrNotFound)) + } } func (s *OrgTestSuite) TestUpdateOrganizationPoolMgrFailed() { diff --git a/runner/pools_test.go b/runner/pools_test.go index e2b269a0..49ca5a5c 100644 --- a/runner/pools_test.go +++ b/runner/pools_test.go @@ -45,6 +45,11 @@ type PoolTestSuite struct { suite.Suite Fixtures *PoolTestFixtures Runner *Runner + + adminCtx context.Context + testCreds params.GithubCredentials + secondaryTestCreds params.GithubCredentials + githubEndpoint params.GithubEndpoint } func (s *PoolTestSuite) SetupTest() { @@ -57,8 +62,14 @@ func (s *PoolTestSuite) SetupTest() { s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) } + s.adminCtx = garmTesting.ImpersonateAdminContext(adminCtx, db, s.T()) + + s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(s.adminCtx, db, s.T()) + s.testCreds = garmTesting.CreateTestGithubCredentials(s.adminCtx, "new-creds", db, s.T(), s.githubEndpoint) + s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(s.adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint) + // create an organization for testing purposes - org, err := db.CreateOrganization(context.Background(), "test-org", "test-creds", "test-webhookSecret", params.PoolBalancerTypeRoundRobin) + org, err := db.CreateOrganization(s.adminCtx, "test-org", s.testCreds.Name, "test-webhookSecret", params.PoolBalancerTypeRoundRobin) if err != nil { s.FailNow(fmt.Sprintf("failed to create org: %s", err)) } @@ -71,7 +82,7 @@ func (s *PoolTestSuite) SetupTest() { orgPools := []params.Pool{} for i := 1; i <= 3; i++ { pool, err := db.CreateEntityPool( - context.Background(), + adminCtx, entity, params.CreatePoolParams{ ProviderName: "test-provider", @@ -112,10 +123,9 @@ func (s *PoolTestSuite) SetupTest() { // setup test runner runner := &Runner{ - providers: fixtures.Providers, - credentials: fixtures.Credentials, - store: fixtures.Store, - ctx: fixtures.AdminContext, + providers: fixtures.Providers, + store: fixtures.Store, + ctx: fixtures.AdminContext, } s.Runner = runner } diff --git a/runner/repositories.go b/runner/repositories.go index f7692b69..b2a6ef54 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -38,8 +38,8 @@ func (r *Runner) CreateRepository(ctx context.Context, param params.CreateRepoPa return params.Repository{}, errors.Wrap(err, "validating params") } - creds, ok := r.credentials[param.CredentialsName] - if !ok { + creds, err := r.store.GetGithubCredentialsByName(ctx, param.CredentialsName, true) + if err != nil { return params.Repository{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) } @@ -189,25 +189,13 @@ func (r *Runner) UpdateRepository(ctx context.Context, repoID string, param para r.mux.Lock() defer r.mux.Unlock() - repo, err := r.store.GetRepositoryByID(ctx, repoID) - if err != nil { - return params.Repository{}, errors.Wrap(err, "fetching repo") - } - - if param.CredentialsName != "" { - // Check that credentials are set before saving to db - if _, ok := r.credentials[param.CredentialsName]; !ok { - return params.Repository{}, runnerErrors.NewBadRequestError("invalid credentials (%s) for repo %s/%s", param.CredentialsName, repo.Owner, repo.Name) - } - } - switch param.PoolBalancerType { case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone: default: return params.Repository{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType) } - repo, err = r.store.UpdateRepository(ctx, repoID, param) + repo, err := r.store.UpdateRepository(ctx, repoID, param) if err != nil { return params.Repository{}, errors.Wrap(err, "updating repo") } diff --git a/runner/repositories_test.go b/runner/repositories_test.go index 74bc8a76..aa8da725 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -19,12 +19,12 @@ import ( "fmt" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm/auth" - "github.com/cloudbase/garm/config" "github.com/cloudbase/garm/database" dbCommon "github.com/cloudbase/garm/database/common" garmTesting "github.com/cloudbase/garm/internal/testing" @@ -39,7 +39,7 @@ type RepoTestFixtures struct { Store dbCommon.Store StoreRepos map[string]params.Repository Providers map[string]common.Provider - Credentials map[string]config.Github + Credentials map[string]params.GithubCredentials CreateRepoParams params.CreateRepoParams CreatePoolParams params.CreatePoolParams CreateInstanceParams params.CreateInstanceParams @@ -56,18 +56,25 @@ type RepoTestSuite struct { suite.Suite Fixtures *RepoTestFixtures Runner *Runner + + testCreds params.GithubCredentials + secondaryTestCreds params.GithubCredentials + githubEndpoint params.GithubEndpoint } func (s *RepoTestSuite) SetupTest() { - adminCtx := auth.GetAdminContext(context.Background()) - // create testing sqlite database dbCfg := garmTesting.GetTestSqliteDBConfig(s.T()) - db, err := database.NewDatabase(adminCtx, dbCfg) + db, err := database.NewDatabase(context.Background(), dbCfg) if err != nil { s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) } + adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) + s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T()) + s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint) + s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint) + // create some repository objects in the database, for testing purposes repos := map[string]params.Repository{} for i := 1; i <= 3; i++ { @@ -76,7 +83,7 @@ func (s *RepoTestSuite) SetupTest() { adminCtx, fmt.Sprintf("test-owner-%v", i), name, - fmt.Sprintf("test-creds-%v", i), + s.testCreds.Name, fmt.Sprintf("test-webhook-secret-%v", i), params.PoolBalancerTypeRoundRobin, ) @@ -97,17 +104,14 @@ func (s *RepoTestSuite) SetupTest() { Providers: map[string]common.Provider{ "test-provider": providerMock, }, - Credentials: map[string]config.Github{ - "test-creds": { - Name: "test-creds-name", - Description: "test-creds-description", - OAuth2Token: "test-creds-oauth2-token", - }, + Credentials: map[string]params.GithubCredentials{ + s.testCreds.Name: s.testCreds, + s.secondaryTestCreds.Name: s.secondaryTestCreds, }, CreateRepoParams: params.CreateRepoParams{ Owner: "test-owner-create", Name: "test-repo-create", - CredentialsName: "test-creds", + CredentialsName: s.testCreds.Name, WebhookSecret: "test-create-repo-webhook-secret", }, CreatePoolParams: params.CreatePoolParams{ @@ -126,7 +130,7 @@ func (s *RepoTestSuite) SetupTest() { OSType: "linux", }, UpdateRepoParams: params.UpdateEntityParams{ - CredentialsName: "test-creds", + CredentialsName: s.testCreds.Name, WebhookSecret: "test-update-repo-webhook-secret", }, UpdatePoolParams: params.UpdatePoolParams{ @@ -148,7 +152,6 @@ func (s *RepoTestSuite) SetupTest() { // setup test runner runner := &Runner{ providers: fixtures.Providers, - credentials: fixtures.Credentials, ctx: fixtures.AdminContext, store: fixtures.Store, poolManagerCtrl: fixtures.PoolMgrCtrlMock, @@ -167,6 +170,7 @@ func (s *RepoTestSuite) TestCreateRepository() { // assertions s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Nil(err) s.Require().Equal(s.Fixtures.CreateRepoParams.Owner, repo.Owner) s.Require().Equal(s.Fixtures.CreateRepoParams.Name, repo.Name) @@ -358,7 +362,9 @@ func (s *RepoTestSuite) TestUpdateRepositoryInvalidCreds() { _, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams) - s.Require().Equal(runnerErrors.NewBadRequestError("invalid credentials (%s) for repo %s/%s", s.Fixtures.UpdateRepoParams.CredentialsName, s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name), err) + if !errors.Is(err, runnerErrors.ErrNotFound) { + s.FailNow(fmt.Sprintf("expected error: %v", runnerErrors.ErrNotFound)) + } } func (s *RepoTestSuite) TestUpdateRepositoryPoolMgrFailed() { diff --git a/runner/runner.go b/runner/runner.go index fd10ad78..78e298d3 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -77,7 +77,6 @@ func NewRunner(ctx context.Context, cfg config.Config, db dbCommon.Store) (*Runn store: db, poolManagerCtrl: poolManagerCtrl, providers: providers, - credentials: creds, controllerID: ctrlID.ControllerID, } @@ -355,8 +354,7 @@ type Runner struct { poolManagerCtrl PoolManagerController - providers map[string]common.Provider - credentials map[string]config.Github + providers map[string]common.Provider controllerInfo params.ControllerInfo controllerID uuid.UUID