Fix tests

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2024-04-16 17:05:18 +00:00
parent 90870c11be
commit 032d40f5f9
19 changed files with 760 additions and 498 deletions

View file

@ -1,3 +1,17 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package sql package sql
import ( import (
@ -12,7 +26,7 @@ import (
"github.com/cloudbase/garm/params" "github.com/cloudbase/garm/params"
) )
func (s *sqlDatabase) CreateEnterprise(_ context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) { func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) {
if webhookSecret == "" { if webhookSecret == "" {
return params.Enterprise{}, errors.New("creating enterprise: missing secret") return params.Enterprise{}, errors.New("creating enterprise: missing secret")
} }
@ -26,13 +40,27 @@ func (s *sqlDatabase) CreateEnterprise(_ context.Context, name, credentialsName,
CredentialsName: credentialsName, CredentialsName: credentialsName,
PoolBalancerType: poolBalancerType, PoolBalancerType: poolBalancerType,
} }
err = s.conn.Transaction(func(tx *gorm.DB) error {
creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false)
if err != nil {
return errors.Wrap(err, "creating enterprise")
}
newEnterprise.CredentialsID = &creds.ID
q := s.conn.Create(&newEnterprise) q := tx.Create(&newEnterprise)
if q.Error != nil { if q.Error != nil {
return params.Enterprise{}, errors.Wrap(q.Error, "creating enterprise") return errors.Wrap(q.Error, "creating enterprise")
}
newEnterprise.Credentials = creds
return nil
})
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "creating enterprise")
} }
param, err := s.sqlToCommonEnterprise(newEnterprise) param, err := s.sqlToCommonEnterprise(newEnterprise, true)
if err != nil { if err != nil {
return params.Enterprise{}, errors.Wrap(err, "creating enterprise") return params.Enterprise{}, errors.Wrap(err, "creating enterprise")
} }
@ -46,7 +74,7 @@ func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.En
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
} }
param, err := s.sqlToCommonEnterprise(enterprise) param, err := s.sqlToCommonEnterprise(enterprise, true)
if err != nil { if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
} }
@ -54,12 +82,12 @@ func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.En
} }
func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) { func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Pools", "Credentials", "Endpoint") enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Pools", "Credentials", "Endpoint")
if err != nil { if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
} }
param, err := s.sqlToCommonEnterprise(enterprise) param, err := s.sqlToCommonEnterprise(enterprise, true)
if err != nil { if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
} }
@ -76,7 +104,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e
ret := make([]params.Enterprise, len(enterprises)) ret := make([]params.Enterprise, len(enterprises))
for idx, val := range enterprises { for idx, val := range enterprises {
var err error var err error
ret[idx], err = s.sqlToCommonEnterprise(val) ret[idx], err = s.sqlToCommonEnterprise(val, true)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching enterprises") return nil, errors.Wrap(err, "fetching enterprises")
} }
@ -86,7 +114,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e
} }
func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error { func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID)
if err != nil { if err != nil {
return errors.Wrap(err, "fetching enterprise") return errors.Wrap(err, "fetching enterprise")
} }
@ -100,33 +128,50 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string)
} }
func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) { func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Credentials", "Endpoint") var enterprise Enterprise
if err != nil { var creds GithubCredentials
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise") err := s.conn.Transaction(func(tx *gorm.DB) error {
} var err error
enterprise, err = s.getEnterpriseByID(ctx, tx, enterpriseID, "Credentials", "Endpoint")
if param.CredentialsName != "" {
enterprise.CredentialsName = param.CredentialsName
}
if param.WebhookSecret != "" {
secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase))
if err != nil { if err != nil {
return params.Enterprise{}, errors.Wrap(err, "encoding secret") return errors.Wrap(err, "fetching enterprise")
} }
enterprise.WebhookSecret = secret
if param.CredentialsName != "" {
creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false)
if err != nil {
return errors.Wrap(err, "fetching credentials")
}
enterprise.CredentialsID = &creds.ID
}
if param.WebhookSecret != "" {
secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase))
if err != nil {
return errors.Wrap(err, "encoding secret")
}
enterprise.WebhookSecret = secret
}
if param.PoolBalancerType != "" {
enterprise.PoolBalancerType = param.PoolBalancerType
}
q := tx.Save(&enterprise)
if q.Error != nil {
return errors.Wrap(q.Error, "saving enterprise")
}
if creds.ID != 0 {
enterprise.Credentials = creds
}
return nil
})
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
} }
if param.PoolBalancerType != "" { newParams, err := s.sqlToCommonEnterprise(enterprise, true)
enterprise.PoolBalancerType = param.PoolBalancerType
}
q := s.conn.Save(&enterprise)
if q.Error != nil {
return params.Enterprise{}, errors.Wrap(q.Error, "saving enterprise")
}
newParams, err := s.sqlToCommonEnterprise(enterprise)
if err != nil { if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise") return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
} }
@ -149,14 +194,14 @@ func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise,
return enterprise, nil return enterprise, nil
} }
func (s *sqlDatabase) getEnterpriseByID(_ context.Context, id string, preload ...string) (Enterprise, error) { func (s *sqlDatabase) getEnterpriseByID(_ context.Context, tx *gorm.DB, id string, preload ...string) (Enterprise, error) {
u, err := uuid.Parse(id) u, err := uuid.Parse(id)
if err != nil { if err != nil {
return Enterprise{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") return Enterprise{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
} }
var enterprise Enterprise var enterprise Enterprise
q := s.conn q := tx
if len(preload) > 0 { if len(preload) > 0 {
for _, field := range preload { for _, field := range preload {
q = q.Preload(field) q = q.Preload(field)

View file

@ -49,6 +49,11 @@ type EnterpriseTestSuite struct {
Store dbCommon.Store Store dbCommon.Store
StoreSQLMocked *sqlDatabase StoreSQLMocked *sqlDatabase
Fixtures *EnterpriseTestFixtures Fixtures *EnterpriseTestFixtures
adminCtx context.Context
testCreds params.GithubCredentials
secondaryTestCreds params.GithubCredentials
githubEndpoint params.GithubEndpoint
} }
func (s *EnterpriseTestSuite) equalInstancesByName(expected, actual []params.Instance) { func (s *EnterpriseTestSuite) equalInstancesByName(expected, actual []params.Instance) {
@ -77,18 +82,25 @@ func (s *EnterpriseTestSuite) SetupTest() {
} }
s.Store = db s.Store = db
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.adminCtx = adminCtx
s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint)
s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint)
// create some enterprise objects in the database, for testing purposes // create some enterprise objects in the database, for testing purposes
enterprises := []params.Enterprise{} enterprises := []params.Enterprise{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
enterprise, err := db.CreateEnterprise( enterprise, err := db.CreateEnterprise(
context.Background(), s.adminCtx,
fmt.Sprintf("test-enterprise-%d", i), fmt.Sprintf("test-enterprise-%d", i),
fmt.Sprintf("test-creds-%d", i), s.testCreds.Name,
fmt.Sprintf("test-webhook-secret-%d", i), fmt.Sprintf("test-webhook-secret-%d", i),
params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypeRoundRobin,
) )
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%d)", i)) s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%d): %q", i, err))
} }
enterprises = append(enterprises, enterprise) enterprises = append(enterprises, enterprise)
@ -124,7 +136,7 @@ func (s *EnterpriseTestSuite) SetupTest() {
Enterprises: enterprises, Enterprises: enterprises,
CreateEnterpriseParams: params.CreateEnterpriseParams{ CreateEnterpriseParams: params.CreateEnterpriseParams{
Name: "new-test-enterprise", Name: "new-test-enterprise",
CredentialsName: "new-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "new-webhook-secret", WebhookSecret: "new-webhook-secret",
}, },
CreatePoolParams: params.CreatePoolParams{ CreatePoolParams: params.CreatePoolParams{
@ -143,7 +155,7 @@ func (s *EnterpriseTestSuite) SetupTest() {
OSType: "linux", OSType: "linux",
}, },
UpdateRepoParams: params.UpdateEntityParams{ UpdateRepoParams: params.UpdateEntityParams{
CredentialsName: "test-update-creds", CredentialsName: s.secondaryTestCreds.Name,
WebhookSecret: "test-update-repo-webhook-secret", WebhookSecret: "test-update-repo-webhook-secret",
}, },
UpdatePoolParams: params.UpdatePoolParams{ UpdatePoolParams: params.UpdatePoolParams{
@ -160,7 +172,7 @@ func (s *EnterpriseTestSuite) SetupTest() {
func (s *EnterpriseTestSuite) TestCreateEnterprise() { func (s *EnterpriseTestSuite) TestCreateEnterprise() {
// call tested function // call tested function
enterprise, err := s.Store.CreateEnterprise( enterprise, err := s.Store.CreateEnterprise(
context.Background(), s.adminCtx,
s.Fixtures.CreateEnterpriseParams.Name, s.Fixtures.CreateEnterpriseParams.Name,
s.Fixtures.CreateEnterpriseParams.CredentialsName, s.Fixtures.CreateEnterpriseParams.CredentialsName,
s.Fixtures.CreateEnterpriseParams.WebhookSecret, s.Fixtures.CreateEnterpriseParams.WebhookSecret,
@ -168,7 +180,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() {
// assertions // assertions
s.Require().Nil(err) s.Require().Nil(err)
storeEnterprise, err := s.Store.GetEnterpriseByID(context.Background(), enterprise.ID) storeEnterprise, err := s.Store.GetEnterpriseByID(s.adminCtx, enterprise.ID)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to get enterprise by id: %v", err)) s.FailNow(fmt.Sprintf("failed to get enterprise by id: %v", err))
} }
@ -191,7 +203,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseInvalidDBPassphrase() {
} }
_, err = sqlDB.CreateEnterprise( _, err = sqlDB.CreateEnterprise(
context.Background(), s.adminCtx,
s.Fixtures.CreateEnterpriseParams.Name, s.Fixtures.CreateEnterpriseParams.Name,
s.Fixtures.CreateEnterpriseParams.CredentialsName, s.Fixtures.CreateEnterpriseParams.CredentialsName,
s.Fixtures.CreateEnterpriseParams.WebhookSecret, s.Fixtures.CreateEnterpriseParams.WebhookSecret,
@ -203,25 +215,29 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseInvalidDBPassphrase() {
func (s *EnterpriseTestSuite) TestCreateEnterpriseDBCreateErr() { func (s *EnterpriseTestSuite) TestCreateEnterpriseDBCreateErr() {
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Enterprises[0].CredentialsName).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID))
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `enterprises`")). ExpectExec(regexp.QuoteMeta("INSERT INTO `enterprises`")).
WillReturnError(fmt.Errorf("creating enterprise mock error")) WillReturnError(fmt.Errorf("creating enterprise mock error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.CreateEnterprise( _, err := s.StoreSQLMocked.CreateEnterprise(
context.Background(), s.adminCtx,
s.Fixtures.CreateEnterpriseParams.Name, s.Fixtures.CreateEnterpriseParams.Name,
s.Fixtures.CreateEnterpriseParams.CredentialsName, s.Fixtures.CreateEnterpriseParams.CredentialsName,
s.Fixtures.CreateEnterpriseParams.WebhookSecret, s.Fixtures.CreateEnterpriseParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin) params.PoolBalancerTypeRoundRobin)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating enterprise: creating enterprise mock error", err.Error()) s.Require().Equal("creating enterprise: creating enterprise: creating enterprise mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *EnterpriseTestSuite) TestGetEnterprise() { func (s *EnterpriseTestSuite) TestGetEnterprise() {
enterprise, err := s.Store.GetEnterprise(context.Background(), s.Fixtures.Enterprises[0].Name) enterprise, err := s.Store.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.Enterprises[0].Name, enterprise.Name) s.Require().Equal(s.Fixtures.Enterprises[0].Name, enterprise.Name)
@ -229,14 +245,14 @@ func (s *EnterpriseTestSuite) TestGetEnterprise() {
} }
func (s *EnterpriseTestSuite) TestGetEnterpriseCaseInsensitive() { func (s *EnterpriseTestSuite) TestGetEnterpriseCaseInsensitive() {
enterprise, err := s.Store.GetEnterprise(context.Background(), "TeSt-eNtErPriSe-1") enterprise, err := s.Store.GetEnterprise(s.adminCtx, "TeSt-eNtErPriSe-1")
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal("test-enterprise-1", enterprise.Name) s.Require().Equal("test-enterprise-1", enterprise.Name)
} }
func (s *EnterpriseTestSuite) TestGetEnterpriseNotFound() { func (s *EnterpriseTestSuite) TestGetEnterpriseNotFound() {
_, err := s.Store.GetEnterprise(context.Background(), "dummy-name") _, err := s.Store.GetEnterprise(s.adminCtx, "dummy-name")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching enterprise: not found", err.Error()) s.Require().Equal("fetching enterprise: not found", err.Error())
@ -248,7 +264,7 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseDBDecryptingErr() {
WithArgs(s.Fixtures.Enterprises[0].Name, 1). WithArgs(s.Fixtures.Enterprises[0].Name, 1).
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Enterprises[0].Name)) WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Enterprises[0].Name))
_, err := s.StoreSQLMocked.GetEnterprise(context.Background(), s.Fixtures.Enterprises[0].Name) _, err := s.StoreSQLMocked.GetEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].Name)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -256,7 +272,7 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseDBDecryptingErr() {
} }
func (s *EnterpriseTestSuite) TestListEnterprises() { func (s *EnterpriseTestSuite) TestListEnterprises() {
enterprises, err := s.Store.ListEnterprises(context.Background()) enterprises, err := s.Store.ListEnterprises(s.adminCtx)
s.Require().Nil(err) s.Require().Nil(err)
garmTesting.EqualDBEntityByName(s.T(), s.Fixtures.Enterprises, enterprises) garmTesting.EqualDBEntityByName(s.T(), s.Fixtures.Enterprises, enterprises)
@ -267,7 +283,7 @@ func (s *EnterpriseTestSuite) TestListEnterprisesDBFetchErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE `enterprises`.`deleted_at` IS NULL")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE `enterprises`.`deleted_at` IS NULL")).
WillReturnError(fmt.Errorf("fetching user from database mock error")) WillReturnError(fmt.Errorf("fetching user from database mock error"))
_, err := s.StoreSQLMocked.ListEnterprises(context.Background()) _, err := s.StoreSQLMocked.ListEnterprises(s.adminCtx)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -275,16 +291,16 @@ func (s *EnterpriseTestSuite) TestListEnterprisesDBFetchErr() {
} }
func (s *EnterpriseTestSuite) TestDeleteEnterprise() { func (s *EnterpriseTestSuite) TestDeleteEnterprise() {
err := s.Store.DeleteEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID) err := s.Store.DeleteEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) _, err = s.Store.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching enterprise: not found", err.Error()) s.Require().Equal("fetching enterprise: not found", err.Error())
} }
func (s *EnterpriseTestSuite) TestDeleteEnterpriseInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestDeleteEnterpriseInvalidEnterpriseID() {
err := s.Store.DeleteEnterprise(context.Background(), "dummy-enterprise-id") err := s.Store.DeleteEnterprise(s.adminCtx, "dummy-enterprise-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error())
@ -302,15 +318,15 @@ func (s *EnterpriseTestSuite) TestDeleteEnterpriseDBDeleteErr() {
WillReturnError(fmt.Errorf("mocked delete enterprise error")) WillReturnError(fmt.Errorf("mocked delete enterprise error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.DeleteEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID) err := s.StoreSQLMocked.DeleteEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("deleting enterprise: mocked delete enterprise error", err.Error()) s.Require().Equal("deleting enterprise: mocked delete enterprise error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *EnterpriseTestSuite) TestUpdateEnterprise() { func (s *EnterpriseTestSuite) TestUpdateEnterprise() {
enterprise, err := s.Store.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) enterprise, err := s.Store.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, enterprise.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, enterprise.Credentials.Name)
@ -318,70 +334,85 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprise() {
} }
func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidEnterpriseID() {
_, err := s.Store.UpdateEnterprise(context.Background(), "dummy-enterprise-id", s.Fixtures.UpdateRepoParams) _, err := s.Store.UpdateEnterprise(s.adminCtx, "dummy-enterprise-id", s.Fixtures.UpdateRepoParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) s.Require().Equal("updating enterprise: fetching enterprise: parsing id: invalid request", err.Error())
} }
func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBEncryptErr() { func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBEncryptErr() {
s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Enterprises[0].ID, 1). WithArgs(s.Fixtures.Enterprises[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("encoding secret: invalid passphrase length (expected length 32 characters)", err.Error()) s.Require().Equal("updating enterprise: encoding secret: invalid passphrase length (expected length 32 characters)", err.Error())
s.assertSQLMockExpectations()
} }
func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBSaveErr() { func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBSaveErr() {
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Enterprises[0].ID, 1). WithArgs(s.Fixtures.Enterprises[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID))
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec(("UPDATE `enterprises` SET")). ExpectExec(("UPDATE `enterprises` SET")).
WillReturnError(fmt.Errorf("saving enterprise mock error")) WillReturnError(fmt.Errorf("saving enterprise mock error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving enterprise: saving enterprise mock error", err.Error()) s.Require().Equal("updating enterprise: saving enterprise: saving enterprise mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBDecryptingErr() { func (s *EnterpriseTestSuite) TestUpdateEnterpriseDBDecryptingErr() {
s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase
s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Enterprises[0].ID, 1). WithArgs(s.Fixtures.Enterprises[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateEnterprise(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateEnterprise(s.adminCtx, s.Fixtures.Enterprises[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("encoding secret: invalid passphrase length (expected length 32 characters)", err.Error()) s.Require().Equal("updating enterprise: encoding secret: invalid passphrase length (expected length 32 characters)", err.Error())
s.assertSQLMockExpectations()
} }
func (s *EnterpriseTestSuite) TestGetEnterpriseByID() { func (s *EnterpriseTestSuite) TestGetEnterpriseByID() {
enterprise, err := s.Store.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) enterprise, err := s.Store.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.Enterprises[0].ID, enterprise.ID) s.Require().Equal(s.Fixtures.Enterprises[0].ID, enterprise.ID)
} }
func (s *EnterpriseTestSuite) TestGetEnterpriseByIDInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestGetEnterpriseByIDInvalidEnterpriseID() {
_, err := s.Store.GetEnterpriseByID(context.Background(), "dummy-enterprise-id") _, err := s.Store.GetEnterpriseByID(s.adminCtx, "dummy-enterprise-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error())
@ -397,7 +428,7 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseByIDDBDecryptingErr() {
WithArgs(s.Fixtures.Enterprises[0].ID). WithArgs(s.Fixtures.Enterprises[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"}).AddRow(s.Fixtures.Enterprises[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"}).AddRow(s.Fixtures.Enterprises[0].ID))
_, err := s.StoreSQLMocked.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) _, err := s.StoreSQLMocked.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -407,11 +438,11 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseByIDDBDecryptingErr() {
func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { func (s *EnterpriseTestSuite) TestCreateEnterprisePool() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().Nil(err) s.Require().Nil(err)
enterprise, err := s.Store.GetEnterpriseByID(context.Background(), s.Fixtures.Enterprises[0].ID) enterprise, err := s.Store.GetEnterpriseByID(s.adminCtx, s.Fixtures.Enterprises[0].ID)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot get enterprise by ID: %v", err)) s.FailNow(fmt.Sprintf("cannot get enterprise by ID: %v", err))
} }
@ -426,7 +457,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolMissingTags() {
s.Fixtures.CreatePoolParams.Tags = []string{} s.Fixtures.CreatePoolParams.Tags = []string{}
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("no tags specified", err.Error()) s.Require().Equal("no tags specified", err.Error())
@ -437,7 +468,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolInvalidEnterpriseID() {
ID: "dummy-enterprise-id", ID: "dummy-enterprise-id",
EntityType: params.GithubEntityTypeEnterprise, EntityType: params.GithubEntityTypeEnterprise,
} }
_, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("parsing id: invalid request", err.Error()) s.Require().Equal("parsing id: invalid request", err.Error())
@ -455,7 +486,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBCreateErr() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) s.Require().Equal("checking pool existence: mocked creating pool error", err.Error())
@ -484,7 +515,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBPoolAlreadyExistErr() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err)
@ -511,7 +542,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchTagErr() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error())
@ -546,7 +577,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBAddingPoolErr() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating pool: mocked adding pool error", err.Error()) s.Require().Equal("creating pool: mocked adding pool error", err.Error())
@ -585,7 +616,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBSaveTagErr() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("associating tags: mocked saving tag error", err.Error()) s.Require().Equal("associating tags: mocked saving tag error", err.Error())
@ -633,7 +664,7 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchPoolErr() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: not found", err.Error())
@ -646,14 +677,14 @@ func (s *EnterpriseTestSuite) TestListEnterprisePools() {
s.Require().Nil(err) s.Require().Nil(err)
for i := 1; i <= 2; i++ { for i := 1; i <= 2; i++ {
s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err))
} }
enterprisePools = append(enterprisePools, pool) enterprisePools = append(enterprisePools, pool)
} }
pools, err := s.Store.ListEntityPools(context.Background(), entity) pools, err := s.Store.ListEntityPools(s.adminCtx, entity)
s.Require().Nil(err) s.Require().Nil(err)
garmTesting.EqualDBEntityID(s.T(), enterprisePools, pools) garmTesting.EqualDBEntityID(s.T(), enterprisePools, pools)
@ -664,7 +695,7 @@ func (s *EnterpriseTestSuite) TestListEnterprisePoolsInvalidEnterpriseID() {
ID: "dummy-enterprise-id", ID: "dummy-enterprise-id",
EntityType: params.GithubEntityTypeEnterprise, EntityType: params.GithubEntityTypeEnterprise,
} }
_, err := s.Store.ListEntityPools(context.Background(), entity) _, err := s.Store.ListEntityPools(s.adminCtx, entity)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error())
@ -673,12 +704,12 @@ func (s *EnterpriseTestSuite) TestListEnterprisePoolsInvalidEnterpriseID() {
func (s *EnterpriseTestSuite) TestGetEnterprisePool() { func (s *EnterpriseTestSuite) TestGetEnterprisePool() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err))
} }
enterprisePool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) enterprisePool, err := s.Store.GetEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(enterprisePool.ID, pool.ID) s.Require().Equal(enterprisePool.ID, pool.ID)
@ -689,7 +720,7 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() {
ID: "dummy-enterprise-id", ID: "dummy-enterprise-id",
EntityType: params.GithubEntityTypeEnterprise, EntityType: params.GithubEntityTypeEnterprise,
} }
_, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") _, err := s.Store.GetEntityPool(s.adminCtx, entity, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
@ -698,15 +729,15 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() {
func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err))
} }
err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) err = s.Store.DeleteEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) _, err = s.Store.GetEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Equal("fetching pool: finding pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
@ -715,7 +746,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() {
ID: "dummy-enterprise-id", ID: "dummy-enterprise-id",
EntityType: params.GithubEntityTypeEnterprise, EntityType: params.GithubEntityTypeEnterprise,
} }
err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") err := s.Store.DeleteEntityPool(s.adminCtx, entity, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("parsing id: invalid request", err.Error()) s.Require().Equal("parsing id: invalid request", err.Error())
@ -724,7 +755,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() {
func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err))
} }
@ -736,7 +767,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() {
WillReturnError(fmt.Errorf("mocked deleting pool error")) WillReturnError(fmt.Errorf("mocked deleting pool error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) err = s.StoreSQLMocked.DeleteEntityPool(s.adminCtx, entity, pool.ID)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) s.Require().Equal("removing pool: mocked deleting pool error", err.Error())
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
@ -745,21 +776,21 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() {
func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { func (s *EnterpriseTestSuite) TestListEnterpriseInstances() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err))
} }
poolInstances := []params.Instance{} poolInstances := []params.Instance{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-enterprise-%v", i) s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-enterprise-%v", i)
instance, err := s.Store.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) instance, err := s.Store.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) s.FailNow(fmt.Sprintf("cannot create instance: %s", err))
} }
poolInstances = append(poolInstances, instance) poolInstances = append(poolInstances, instance)
} }
instances, err := s.Store.ListEntityInstances(context.Background(), entity) instances, err := s.Store.ListEntityInstances(s.adminCtx, entity)
s.Require().Nil(err) s.Require().Nil(err)
s.equalInstancesByName(poolInstances, instances) s.equalInstancesByName(poolInstances, instances)
@ -770,7 +801,7 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstancesInvalidEnterpriseID() {
ID: "dummy-enterprise-id", ID: "dummy-enterprise-id",
EntityType: params.GithubEntityTypeEnterprise, EntityType: params.GithubEntityTypeEnterprise,
} }
_, err := s.Store.ListEntityInstances(context.Background(), entity) _, err := s.Store.ListEntityInstances(s.adminCtx, entity)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) s.Require().Equal("fetching entity: parsing id: invalid request", err.Error())
@ -779,12 +810,12 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstancesInvalidEnterpriseID() {
func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() {
entity, err := s.Fixtures.Enterprises[0].GetEntity() entity, err := s.Fixtures.Enterprises[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err))
} }
pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) pool, err = s.Store.UpdateEntityPool(s.adminCtx, entity, pool.ID, s.Fixtures.UpdatePoolParams)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners)
@ -798,7 +829,7 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolInvalidEnterpriseID() {
ID: "dummy-enterprise-id", ID: "dummy-enterprise-id",
EntityType: params.GithubEntityTypeEnterprise, EntityType: params.GithubEntityTypeEnterprise,
} }
_, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) _, err := s.Store.UpdateEntityPool(s.adminCtx, entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())

View file

@ -14,6 +14,9 @@ import (
) )
func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (params.GithubCredentials, error) { func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (params.GithubCredentials, error) {
if len(creds.Payload) == 0 {
return params.GithubCredentials{}, errors.New("empty credentials payload")
}
data, err := util.Unseal(creds.Payload, []byte(s.cfg.Passphrase)) data, err := util.Unseal(creds.Payload, []byte(s.cfg.Passphrase))
if err != nil { if err != nil {
return params.GithubCredentials{}, errors.Wrap(err, "unsealing credentials") return params.GithubCredentials{}, errors.Wrap(err, "unsealing credentials")
@ -33,7 +36,7 @@ func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (par
} }
for _, repo := range creds.Repositories { for _, repo := range creds.Repositories {
commonRepo, err := s.sqlToCommonRepository(repo) commonRepo, err := s.sqlToCommonRepository(repo, false)
if err != nil { if err != nil {
return params.GithubCredentials{}, errors.Wrap(err, "converting github repository") return params.GithubCredentials{}, errors.Wrap(err, "converting github repository")
} }
@ -41,7 +44,7 @@ func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (par
} }
for _, org := range creds.Organizations { for _, org := range creds.Organizations {
commonOrg, err := s.sqlToCommonOrganization(org) commonOrg, err := s.sqlToCommonOrganization(org, false)
if err != nil { if err != nil {
return params.GithubCredentials{}, errors.Wrap(err, "converting github organization") return params.GithubCredentials{}, errors.Wrap(err, "converting github organization")
} }
@ -49,9 +52,9 @@ func (s *sqlDatabase) sqlToCommonGithubCredentials(creds GithubCredentials) (par
} }
for _, ent := range creds.Enterprises { for _, ent := range creds.Enterprises {
commonEnt, err := s.sqlToCommonEnterprise(ent) commonEnt, err := s.sqlToCommonEnterprise(ent, false)
if err != nil { if err != nil {
return params.GithubCredentials{}, errors.Wrap(err, "converting github enterprise") return params.GithubCredentials{}, errors.Wrapf(err, "converting github enterprise: %s", ent.Name)
} }
commonCreds.Enterprises = append(commonCreds.Enterprises, commonEnt) commonCreds.Enterprises = append(commonCreds.Enterprises, commonEnt)
} }
@ -73,12 +76,12 @@ func (s *sqlDatabase) sqlToCommonGithubEndpoint(ep GithubEndpoint) (params.Githu
func getUIDFromContext(ctx context.Context) (uuid.UUID, error) { func getUIDFromContext(ctx context.Context) (uuid.UUID, error) {
userID := auth.UserID(ctx) userID := auth.UserID(ctx)
if userID == "" { if userID == "" {
return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "getting UID from context")
} }
asUUID, err := uuid.Parse(userID) asUUID, err := uuid.Parse(userID)
if err != nil { if err != nil {
return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "creating github endpoint") return uuid.Nil, errors.Wrap(runnerErrors.ErrUnauthorized, "parsing UID from context")
} }
return asUUID, nil return asUUID, nil
} }

View file

@ -48,6 +48,7 @@ type InstancesTestSuite struct {
Store dbCommon.Store Store dbCommon.Store
StoreSQLMocked *sqlDatabase StoreSQLMocked *sqlDatabase
Fixtures *InstancesTestFixtures Fixtures *InstancesTestFixtures
adminCtx context.Context
} }
func (s *InstancesTestSuite) equalInstancesByName(expected, actual []params.Instance) { func (s *InstancesTestSuite) equalInstancesByName(expected, actual []params.Instance) {
@ -76,8 +77,14 @@ func (s *InstancesTestSuite) SetupTest() {
} }
s.Store = db s.Store = db
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.adminCtx = adminCtx
githubEndpoint := garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
creds := garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), githubEndpoint)
// create an organization for testing purposes // create an organization for testing purposes
org, err := s.Store.CreateOrganization(context.Background(), "test-org", "test-creds", "test-webhookSecret", params.PoolBalancerTypeRoundRobin) org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds.Name, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create org: %s", err)) s.FailNow(fmt.Sprintf("failed to create org: %s", err))
} }
@ -94,7 +101,7 @@ func (s *InstancesTestSuite) SetupTest() {
} }
entity, err := org.GetEntity() entity, err := org.GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, createPoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create org pool: %s", err)) s.FailNow(fmt.Sprintf("failed to create org pool: %s", err))
} }
@ -103,7 +110,7 @@ func (s *InstancesTestSuite) SetupTest() {
instances := []params.Instance{} instances := []params.Instance{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
instance, err := db.CreateInstance( instance, err := db.CreateInstance(
context.Background(), s.adminCtx,
pool.ID, pool.ID,
params.CreateInstanceParams{ params.CreateInstanceParams{
Name: fmt.Sprintf("test-instance-%d", i), Name: fmt.Sprintf("test-instance-%d", i),
@ -179,11 +186,11 @@ func (s *InstancesTestSuite) SetupTest() {
func (s *InstancesTestSuite) TestCreateInstance() { func (s *InstancesTestSuite) TestCreateInstance() {
// call tested function // call tested function
instance, err := s.Store.CreateInstance(context.Background(), s.Fixtures.Pool.ID, s.Fixtures.CreateInstanceParams) instance, err := s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, s.Fixtures.CreateInstanceParams)
// assertions // assertions
s.Require().Nil(err) s.Require().Nil(err)
storeInstance, err := s.Store.GetInstanceByName(context.Background(), s.Fixtures.CreateInstanceParams.Name) storeInstance, err := s.Store.GetInstanceByName(s.adminCtx, s.Fixtures.CreateInstanceParams.Name)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to get instance: %v", err)) s.FailNow(fmt.Sprintf("failed to get instance: %v", err))
} }
@ -195,7 +202,7 @@ func (s *InstancesTestSuite) TestCreateInstance() {
} }
func (s *InstancesTestSuite) TestCreateInstanceInvalidPoolID() { func (s *InstancesTestSuite) TestCreateInstanceInvalidPoolID() {
_, err := s.Store.CreateInstance(context.Background(), "dummy-pool-id", params.CreateInstanceParams{}) _, err := s.Store.CreateInstance(s.adminCtx, "dummy-pool-id", params.CreateInstanceParams{})
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
@ -216,7 +223,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() {
WillReturnError(fmt.Errorf("mocked insert instance error")) WillReturnError(fmt.Errorf("mocked insert instance error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) _, err := s.StoreSQLMocked.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -226,7 +233,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() {
func (s *InstancesTestSuite) TestGetPoolInstanceByName() { func (s *InstancesTestSuite) TestGetPoolInstanceByName() {
storeInstance := s.Fixtures.Instances[0] // this is already created in `SetupTest()` storeInstance := s.Fixtures.Instances[0] // this is already created in `SetupTest()`
instance, err := s.Store.GetPoolInstanceByName(context.Background(), s.Fixtures.Pool.ID, storeInstance.Name) instance, err := s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(storeInstance.Name, instance.Name) s.Require().Equal(storeInstance.Name, instance.Name)
@ -237,7 +244,7 @@ func (s *InstancesTestSuite) TestGetPoolInstanceByName() {
} }
func (s *InstancesTestSuite) TestGetPoolInstanceByNameNotFound() { func (s *InstancesTestSuite) TestGetPoolInstanceByNameNotFound() {
_, err := s.Store.GetPoolInstanceByName(context.Background(), s.Fixtures.Pool.ID, "not-existent-instance-name") _, err := s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, "not-existent-instance-name")
s.Require().Equal("fetching instance: fetching pool instance by name: not found", err.Error()) s.Require().Equal("fetching instance: fetching pool instance by name: not found", err.Error())
} }
@ -245,7 +252,7 @@ func (s *InstancesTestSuite) TestGetPoolInstanceByNameNotFound() {
func (s *InstancesTestSuite) TestGetInstanceByName() { func (s *InstancesTestSuite) TestGetInstanceByName() {
storeInstance := s.Fixtures.Instances[1] storeInstance := s.Fixtures.Instances[1]
instance, err := s.Store.GetInstanceByName(context.Background(), storeInstance.Name) instance, err := s.Store.GetInstanceByName(s.adminCtx, storeInstance.Name)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(storeInstance.Name, instance.Name) s.Require().Equal(storeInstance.Name, instance.Name)
@ -256,7 +263,7 @@ func (s *InstancesTestSuite) TestGetInstanceByName() {
} }
func (s *InstancesTestSuite) TestGetInstanceByNameFetchInstanceFailed() { func (s *InstancesTestSuite) TestGetInstanceByNameFetchInstanceFailed() {
_, err := s.Store.GetInstanceByName(context.Background(), "not-existent-instance-name") _, err := s.Store.GetInstanceByName(s.adminCtx, "not-existent-instance-name")
s.Require().Equal("fetching instance: fetching instance by name: not found", err.Error()) s.Require().Equal("fetching instance: fetching instance by name: not found", err.Error())
} }
@ -264,16 +271,16 @@ func (s *InstancesTestSuite) TestGetInstanceByNameFetchInstanceFailed() {
func (s *InstancesTestSuite) TestDeleteInstance() { func (s *InstancesTestSuite) TestDeleteInstance() {
storeInstance := s.Fixtures.Instances[0] storeInstance := s.Fixtures.Instances[0]
err := s.Store.DeleteInstance(context.Background(), s.Fixtures.Pool.ID, storeInstance.Name) err := s.Store.DeleteInstance(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetPoolInstanceByName(context.Background(), s.Fixtures.Pool.ID, storeInstance.Name) _, err = s.Store.GetPoolInstanceByName(s.adminCtx, s.Fixtures.Pool.ID, storeInstance.Name)
s.Require().Equal("fetching instance: fetching pool instance by name: not found", err.Error()) s.Require().Equal("fetching instance: fetching pool instance by name: not found", err.Error())
} }
func (s *InstancesTestSuite) TestDeleteInstanceInvalidPoolID() { func (s *InstancesTestSuite) TestDeleteInstanceInvalidPoolID() {
err := s.Store.DeleteInstance(context.Background(), "dummy-pool-id", "dummy-instance-name") err := s.Store.DeleteInstance(s.adminCtx, "dummy-pool-id", "dummy-instance-name")
s.Require().Equal("deleting instance: fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("deleting instance: fetching pool: parsing id: invalid request", err.Error())
} }
@ -309,7 +316,7 @@ func (s *InstancesTestSuite) TestDeleteInstanceDBRecordNotFoundErr() {
WillReturnError(gorm.ErrRecordNotFound) WillReturnError(gorm.ErrRecordNotFound)
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.DeleteInstance(context.Background(), pool.ID, instance.Name) err := s.StoreSQLMocked.DeleteInstance(s.adminCtx, pool.ID, instance.Name)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().Nil(err) s.Require().Nil(err)
@ -346,7 +353,7 @@ func (s *InstancesTestSuite) TestDeleteInstanceDBDeleteErr() {
WillReturnError(fmt.Errorf("mocked delete instance error")) WillReturnError(fmt.Errorf("mocked delete instance error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.DeleteInstance(context.Background(), pool.ID, instance.Name) err := s.StoreSQLMocked.DeleteInstance(s.adminCtx, pool.ID, instance.Name)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -357,10 +364,10 @@ func (s *InstancesTestSuite) TestAddInstanceEvent() {
storeInstance := s.Fixtures.Instances[0] storeInstance := s.Fixtures.Instances[0]
statusMsg := "test-status-message" statusMsg := "test-status-message"
err := s.Store.AddInstanceEvent(context.Background(), storeInstance.Name, params.StatusEvent, params.EventInfo, statusMsg) err := s.Store.AddInstanceEvent(s.adminCtx, storeInstance.Name, params.StatusEvent, params.EventInfo, statusMsg)
s.Require().Nil(err) s.Require().Nil(err)
instance, err := s.Store.GetInstanceByName(context.Background(), storeInstance.Name) instance, err := s.Store.GetInstanceByName(s.adminCtx, storeInstance.Name)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to get db instance: %s", err)) s.FailNow(fmt.Sprintf("failed to get db instance: %s", err))
} }
@ -398,7 +405,7 @@ func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() {
WillReturnError(fmt.Errorf("mocked add status message error")) WillReturnError(fmt.Errorf("mocked add status message error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.AddInstanceEvent(context.Background(), instance.Name, params.StatusEvent, params.EventInfo, statusMsg) err := s.StoreSQLMocked.AddInstanceEvent(s.adminCtx, instance.Name, params.StatusEvent, params.EventInfo, statusMsg)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("adding status message: mocked add status message error", err.Error()) s.Require().Equal("adding status message: mocked add status message error", err.Error())
@ -406,7 +413,7 @@ func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() {
} }
func (s *InstancesTestSuite) TestUpdateInstance() { func (s *InstancesTestSuite) TestUpdateInstance() {
instance, err := s.Store.UpdateInstance(context.Background(), s.Fixtures.Instances[0].Name, s.Fixtures.UpdateInstanceParams) instance, err := s.Store.UpdateInstance(s.adminCtx, s.Fixtures.Instances[0].Name, s.Fixtures.UpdateInstanceParams)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.UpdateInstanceParams.ProviderID, instance.ProviderID) s.Require().Equal(s.Fixtures.UpdateInstanceParams.ProviderID, instance.ProviderID)
@ -443,7 +450,7 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateInstanceErr() {
WillReturnError(fmt.Errorf("mocked update instance error")) WillReturnError(fmt.Errorf("mocked update instance error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.Name, s.Fixtures.UpdateInstanceParams) _, err := s.StoreSQLMocked.UpdateInstance(s.adminCtx, instance.Name, s.Fixtures.UpdateInstanceParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("updating instance: mocked update instance error", err.Error()) s.Require().Equal("updating instance: mocked update instance error", err.Error())
@ -489,7 +496,7 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() {
WillReturnError(fmt.Errorf("update addresses mock error")) WillReturnError(fmt.Errorf("update addresses mock error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.Name, s.Fixtures.UpdateInstanceParams) _, err := s.StoreSQLMocked.UpdateInstance(s.adminCtx, instance.Name, s.Fixtures.UpdateInstanceParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("updating addresses: update addresses mock error", err.Error()) s.Require().Equal("updating addresses: update addresses mock error", err.Error())
@ -497,20 +504,20 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() {
} }
func (s *InstancesTestSuite) TestListPoolInstances() { func (s *InstancesTestSuite) TestListPoolInstances() {
instances, err := s.Store.ListPoolInstances(context.Background(), s.Fixtures.Pool.ID) instances, err := s.Store.ListPoolInstances(s.adminCtx, s.Fixtures.Pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
s.equalInstancesByName(s.Fixtures.Instances, instances) s.equalInstancesByName(s.Fixtures.Instances, instances)
} }
func (s *InstancesTestSuite) TestListPoolInstancesInvalidPoolID() { func (s *InstancesTestSuite) TestListPoolInstancesInvalidPoolID() {
_, err := s.Store.ListPoolInstances(context.Background(), "dummy-pool-id") _, err := s.Store.ListPoolInstances(s.adminCtx, "dummy-pool-id")
s.Require().Equal("parsing id: invalid request", err.Error()) s.Require().Equal("parsing id: invalid request", err.Error())
} }
func (s *InstancesTestSuite) TestListAllInstances() { func (s *InstancesTestSuite) TestListAllInstances() {
instances, err := s.Store.ListAllInstances(context.Background()) instances, err := s.Store.ListAllInstances(s.adminCtx)
s.Require().Nil(err) s.Require().Nil(err)
s.equalInstancesByName(s.Fixtures.Instances, instances) s.equalInstancesByName(s.Fixtures.Instances, instances)
@ -521,7 +528,7 @@ func (s *InstancesTestSuite) TestListAllInstancesDBFetchErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE `instances`.`deleted_at` IS NULL")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE `instances`.`deleted_at` IS NULL")).
WillReturnError(fmt.Errorf("fetch instances mock error")) WillReturnError(fmt.Errorf("fetch instances mock error"))
_, err := s.StoreSQLMocked.ListAllInstances(context.Background()) _, err := s.StoreSQLMocked.ListAllInstances(s.adminCtx)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -529,14 +536,14 @@ func (s *InstancesTestSuite) TestListAllInstancesDBFetchErr() {
} }
func (s *InstancesTestSuite) TestPoolInstanceCount() { func (s *InstancesTestSuite) TestPoolInstanceCount() {
instancesCount, err := s.Store.PoolInstanceCount(context.Background(), s.Fixtures.Pool.ID) instancesCount, err := s.Store.PoolInstanceCount(s.adminCtx, s.Fixtures.Pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(int64(len(s.Fixtures.Instances)), instancesCount) s.Require().Equal(int64(len(s.Fixtures.Instances)), instancesCount)
} }
func (s *InstancesTestSuite) TestPoolInstanceCountInvalidPoolID() { func (s *InstancesTestSuite) TestPoolInstanceCountInvalidPoolID() {
_, err := s.Store.PoolInstanceCount(context.Background(), "dummy-pool-id") _, err := s.Store.PoolInstanceCount(s.adminCtx, "dummy-pool-id")
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
@ -553,7 +560,7 @@ func (s *InstancesTestSuite) TestPoolInstanceCountDBCountErr() {
WithArgs(pool.ID). WithArgs(pool.ID).
WillReturnError(fmt.Errorf("count mock error")) WillReturnError(fmt.Errorf("count mock error"))
_, err := s.StoreSQLMocked.PoolInstanceCount(context.Background(), pool.ID) _, err := s.StoreSQLMocked.PoolInstanceCount(s.adminCtx, pool.ID)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)

View file

@ -27,13 +27,13 @@ import (
"github.com/cloudbase/garm/params" "github.com/cloudbase/garm/params"
) )
func (s *sqlDatabase) CreateOrganization(_ context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) { func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) {
if webhookSecret == "" { if webhookSecret == "" {
return params.Organization{}, errors.New("creating org: missing secret") return params.Organization{}, errors.New("creating org: missing secret")
} }
secret, err := util.Seal([]byte(webhookSecret), []byte(s.cfg.Passphrase)) secret, err := util.Seal([]byte(webhookSecret), []byte(s.cfg.Passphrase))
if err != nil { if err != nil {
return params.Organization{}, fmt.Errorf("failed to encrypt string") return params.Organization{}, errors.Wrap(err, "encoding secret")
} }
newOrg := Organization{ newOrg := Organization{
Name: name, Name: name,
@ -42,12 +42,27 @@ func (s *sqlDatabase) CreateOrganization(_ context.Context, name, credentialsNam
PoolBalancerType: poolBalancerType, PoolBalancerType: poolBalancerType,
} }
q := s.conn.Create(&newOrg) err = s.conn.Transaction(func(tx *gorm.DB) error {
if q.Error != nil { creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false)
return params.Organization{}, errors.Wrap(q.Error, "creating org") if err != nil {
return errors.Wrap(err, "creating org")
}
newOrg.CredentialsID = &creds.ID
q := tx.Create(&newOrg)
if q.Error != nil {
return errors.Wrap(q.Error, "creating org")
}
newOrg.Credentials = creds
return nil
})
if err != nil {
return params.Organization{}, errors.Wrap(err, "creating org")
} }
param, err := s.sqlToCommonOrganization(newOrg) param, err := s.sqlToCommonOrganization(newOrg, true)
if err != nil { if err != nil {
return params.Organization{}, errors.Wrap(err, "creating org") return params.Organization{}, errors.Wrap(err, "creating org")
} }
@ -62,7 +77,7 @@ func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.
return params.Organization{}, errors.Wrap(err, "fetching org") return params.Organization{}, errors.Wrap(err, "fetching org")
} }
param, err := s.sqlToCommonOrganization(org) param, err := s.sqlToCommonOrganization(org, true)
if err != nil { if err != nil {
return params.Organization{}, errors.Wrap(err, "fetching org") return params.Organization{}, errors.Wrap(err, "fetching org")
} }
@ -80,7 +95,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio
ret := make([]params.Organization, len(orgs)) ret := make([]params.Organization, len(orgs))
for idx, val := range orgs { for idx, val := range orgs {
var err error var err error
ret[idx], err = s.sqlToCommonOrganization(val) ret[idx], err = s.sqlToCommonOrganization(val, true)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching org") return nil, errors.Wrap(err, "fetching org")
} }
@ -90,7 +105,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio
} }
func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) error { func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) error {
org, err := s.getOrgByID(ctx, orgID) org, err := s.getOrgByID(ctx, s.conn, orgID)
if err != nil { if err != nil {
return errors.Wrap(err, "fetching org") return errors.Wrap(err, "fetching org")
} }
@ -104,33 +119,52 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) erro
} }
func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) { func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) {
org, err := s.getOrgByID(ctx, orgID, "Credentials", "Endpoint") var org Organization
if err != nil { var creds GithubCredentials
return params.Organization{}, errors.Wrap(err, "fetching org") err := s.conn.Transaction(func(tx *gorm.DB) error {
} var err error
org, err = s.getOrgByID(ctx, tx, orgID, "Credentials", "Endpoint")
if param.CredentialsName != "" {
org.CredentialsName = param.CredentialsName
}
if param.WebhookSecret != "" {
secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase))
if err != nil { if err != nil {
return params.Organization{}, fmt.Errorf("saving org: failed to encrypt string: %w", err) return errors.Wrap(err, "fetching org")
} }
org.WebhookSecret = secret
if param.CredentialsName != "" {
org.CredentialsName = param.CredentialsName
creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false)
if err != nil {
return errors.Wrap(err, "fetching credentials")
}
org.CredentialsID = &creds.ID
}
if param.WebhookSecret != "" {
secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase))
if err != nil {
return fmt.Errorf("saving org: failed to encrypt string: %w", err)
}
org.WebhookSecret = secret
}
if param.PoolBalancerType != "" {
org.PoolBalancerType = param.PoolBalancerType
}
q := tx.Save(&org)
if q.Error != nil {
return errors.Wrap(q.Error, "saving org")
}
if creds.ID != 0 {
org.Credentials = creds
}
return nil
})
if err != nil {
return params.Organization{}, errors.Wrap(err, "saving org")
} }
if param.PoolBalancerType != "" { newParams, err := s.sqlToCommonOrganization(org, true)
org.PoolBalancerType = param.PoolBalancerType
}
q := s.conn.Save(&org)
if q.Error != nil {
return params.Organization{}, errors.Wrap(q.Error, "saving org")
}
newParams, err := s.sqlToCommonOrganization(org)
if err != nil { if err != nil {
return params.Organization{}, errors.Wrap(err, "saving org") return params.Organization{}, errors.Wrap(err, "saving org")
} }
@ -138,26 +172,26 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
} }
func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) { func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) {
org, err := s.getOrgByID(ctx, orgID, "Pools", "Credentials", "Endpoint") org, err := s.getOrgByID(ctx, s.conn, orgID, "Pools", "Credentials", "Endpoint")
if err != nil { if err != nil {
return params.Organization{}, errors.Wrap(err, "fetching org") return params.Organization{}, errors.Wrap(err, "fetching org")
} }
param, err := s.sqlToCommonOrganization(org) param, err := s.sqlToCommonOrganization(org, true)
if err != nil { if err != nil {
return params.Organization{}, errors.Wrap(err, "fetching enterprise") return params.Organization{}, errors.Wrap(err, "fetching org")
} }
return param, nil return param, nil
} }
func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) { func (s *sqlDatabase) getOrgByID(_ context.Context, db *gorm.DB, id string, preload ...string) (Organization, error) {
u, err := uuid.Parse(id) u, err := uuid.Parse(id)
if err != nil { if err != nil {
return Organization{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") return Organization{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
} }
var org Organization var org Organization
q := s.conn q := db
if len(preload) > 0 { if len(preload) > 0 {
for _, field := range preload { for _, field := range preload {
q = q.Preload(field) q = q.Preload(field)

View file

@ -49,6 +49,11 @@ type OrgTestSuite struct {
Store dbCommon.Store Store dbCommon.Store
StoreSQLMocked *sqlDatabase StoreSQLMocked *sqlDatabase
Fixtures *OrgTestFixtures Fixtures *OrgTestFixtures
adminCtx context.Context
testCreds params.GithubCredentials
secondaryTestCreds params.GithubCredentials
githubEndpoint params.GithubEndpoint
} }
func (s *OrgTestSuite) equalInstancesByName(expected, actual []params.Instance) { func (s *OrgTestSuite) equalInstancesByName(expected, actual []params.Instance) {
@ -71,24 +76,32 @@ func (s *OrgTestSuite) assertSQLMockExpectations() {
func (s *OrgTestSuite) SetupTest() { func (s *OrgTestSuite) SetupTest() {
// create testing sqlite database // create testing sqlite database
db, err := NewSQLDatabase(context.Background(), garmTesting.GetTestSqliteDBConfig(s.T())) dbConfig := garmTesting.GetTestSqliteDBConfig(s.T())
db, err := NewSQLDatabase(context.Background(), dbConfig)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) s.FailNow(fmt.Sprintf("failed to create db connection: %s", err))
} }
s.Store = db s.Store = db
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.adminCtx = adminCtx
s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint)
s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint)
// create some organization objects in the database, for testing purposes // create some organization objects in the database, for testing purposes
orgs := []params.Organization{} orgs := []params.Organization{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
org, err := db.CreateOrganization( org, err := db.CreateOrganization(
context.Background(), s.adminCtx,
fmt.Sprintf("test-org-%d", i), fmt.Sprintf("test-org-%d", i),
fmt.Sprintf("test-creds-%d", i), s.testCreds.Name,
fmt.Sprintf("test-webhook-secret-%d", i), fmt.Sprintf("test-webhook-secret-%d", i),
params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypeRoundRobin,
) )
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create database object (test-org-%d)", i)) s.FailNow(fmt.Sprintf("failed to create database object (test-org-%d): %q", i, err))
} }
orgs = append(orgs, org) orgs = append(orgs, org)
@ -114,7 +127,7 @@ func (s *OrgTestSuite) SetupTest() {
} }
s.StoreSQLMocked = &sqlDatabase{ s.StoreSQLMocked = &sqlDatabase{
conn: gormConn, conn: gormConn,
cfg: garmTesting.GetTestSqliteDBConfig(s.T()), cfg: dbConfig,
} }
// setup test fixtures // setup test fixtures
@ -123,8 +136,8 @@ func (s *OrgTestSuite) SetupTest() {
fixtures := &OrgTestFixtures{ fixtures := &OrgTestFixtures{
Orgs: orgs, Orgs: orgs,
CreateOrgParams: params.CreateOrgParams{ CreateOrgParams: params.CreateOrgParams{
Name: "new-test-org", Name: s.testCreds.Name,
CredentialsName: "new-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "new-webhook-secret", WebhookSecret: "new-webhook-secret",
}, },
CreatePoolParams: params.CreatePoolParams{ CreatePoolParams: params.CreatePoolParams{
@ -143,7 +156,7 @@ func (s *OrgTestSuite) SetupTest() {
OSType: "linux", OSType: "linux",
}, },
UpdateRepoParams: params.UpdateEntityParams{ UpdateRepoParams: params.UpdateEntityParams{
CredentialsName: "test-update-creds", CredentialsName: s.secondaryTestCreds.Name,
WebhookSecret: "test-update-repo-webhook-secret", WebhookSecret: "test-update-repo-webhook-secret",
}, },
UpdatePoolParams: params.UpdatePoolParams{ UpdatePoolParams: params.UpdatePoolParams{
@ -160,7 +173,7 @@ func (s *OrgTestSuite) SetupTest() {
func (s *OrgTestSuite) TestCreateOrganization() { func (s *OrgTestSuite) TestCreateOrganization() {
// call tested function // call tested function
org, err := s.Store.CreateOrganization( org, err := s.Store.CreateOrganization(
context.Background(), s.adminCtx,
s.Fixtures.CreateOrgParams.Name, s.Fixtures.CreateOrgParams.Name,
s.Fixtures.CreateOrgParams.CredentialsName, s.Fixtures.CreateOrgParams.CredentialsName,
s.Fixtures.CreateOrgParams.WebhookSecret, s.Fixtures.CreateOrgParams.WebhookSecret,
@ -168,7 +181,7 @@ func (s *OrgTestSuite) TestCreateOrganization() {
// assertions // assertions
s.Require().Nil(err) s.Require().Nil(err)
storeOrg, err := s.Store.GetOrganizationByID(context.Background(), org.ID) storeOrg, err := s.Store.GetOrganizationByID(s.adminCtx, org.ID)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to get organization by id: %v", err)) s.FailNow(fmt.Sprintf("failed to get organization by id: %v", err))
} }
@ -191,37 +204,41 @@ func (s *OrgTestSuite) TestCreateOrganizationInvalidDBPassphrase() {
} }
_, err = sqlDB.CreateOrganization( _, err = sqlDB.CreateOrganization(
context.Background(), s.adminCtx,
s.Fixtures.CreateOrgParams.Name, s.Fixtures.CreateOrgParams.Name,
s.Fixtures.CreateOrgParams.CredentialsName, s.Fixtures.CreateOrgParams.CredentialsName,
s.Fixtures.CreateOrgParams.WebhookSecret, s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin) params.PoolBalancerTypeRoundRobin)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("failed to encrypt string", err.Error()) s.Require().Equal("encoding secret: invalid passphrase length (expected length 32 characters)", err.Error())
} }
func (s *OrgTestSuite) TestCreateOrganizationDBCreateErr() { func (s *OrgTestSuite) TestCreateOrganizationDBCreateErr() {
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Orgs[0].CredentialsName).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID))
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `organizations`")). ExpectExec(regexp.QuoteMeta("INSERT INTO `organizations`")).
WillReturnError(fmt.Errorf("creating org mock error")) WillReturnError(fmt.Errorf("creating org mock error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.CreateOrganization( _, err := s.StoreSQLMocked.CreateOrganization(
context.Background(), s.adminCtx,
s.Fixtures.CreateOrgParams.Name, s.Fixtures.CreateOrgParams.Name,
s.Fixtures.CreateOrgParams.CredentialsName, s.Fixtures.CreateOrgParams.CredentialsName,
s.Fixtures.CreateOrgParams.WebhookSecret, s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin) params.PoolBalancerTypeRoundRobin)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating org: creating org mock error", err.Error()) s.Require().Equal("creating org: creating org: creating org mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *OrgTestSuite) TestGetOrganization() { func (s *OrgTestSuite) TestGetOrganization() {
org, err := s.Store.GetOrganization(context.Background(), s.Fixtures.Orgs[0].Name) org, err := s.Store.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.Orgs[0].Name, org.Name) s.Require().Equal(s.Fixtures.Orgs[0].Name, org.Name)
@ -229,14 +246,14 @@ func (s *OrgTestSuite) TestGetOrganization() {
} }
func (s *OrgTestSuite) TestGetOrganizationCaseInsensitive() { func (s *OrgTestSuite) TestGetOrganizationCaseInsensitive() {
org, err := s.Store.GetOrganization(context.Background(), "TeSt-oRg-1") org, err := s.Store.GetOrganization(s.adminCtx, "TeSt-oRg-1")
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal("test-org-1", org.Name) s.Require().Equal("test-org-1", org.Name)
} }
func (s *OrgTestSuite) TestGetOrganizationNotFound() { func (s *OrgTestSuite) TestGetOrganizationNotFound() {
_, err := s.Store.GetOrganization(context.Background(), "dummy-name") _, err := s.Store.GetOrganization(s.adminCtx, "dummy-name")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching org: not found", err.Error()) s.Require().Equal("fetching org: not found", err.Error())
@ -248,7 +265,7 @@ func (s *OrgTestSuite) TestGetOrganizationDBDecryptingErr() {
WithArgs(s.Fixtures.Orgs[0].Name, 1). WithArgs(s.Fixtures.Orgs[0].Name, 1).
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Orgs[0].Name)) WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(s.Fixtures.Orgs[0].Name))
_, err := s.StoreSQLMocked.GetOrganization(context.Background(), s.Fixtures.Orgs[0].Name) _, err := s.StoreSQLMocked.GetOrganization(s.adminCtx, s.Fixtures.Orgs[0].Name)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -256,7 +273,7 @@ func (s *OrgTestSuite) TestGetOrganizationDBDecryptingErr() {
} }
func (s *OrgTestSuite) TestListOrganizations() { func (s *OrgTestSuite) TestListOrganizations() {
orgs, err := s.Store.ListOrganizations(context.Background()) orgs, err := s.Store.ListOrganizations(s.adminCtx)
s.Require().Nil(err) s.Require().Nil(err)
garmTesting.EqualDBEntityByName(s.T(), s.Fixtures.Orgs, orgs) garmTesting.EqualDBEntityByName(s.T(), s.Fixtures.Orgs, orgs)
@ -267,7 +284,7 @@ func (s *OrgTestSuite) TestListOrganizationsDBFetchErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE `organizations`.`deleted_at` IS NULL")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE `organizations`.`deleted_at` IS NULL")).
WillReturnError(fmt.Errorf("fetching user from database mock error")) WillReturnError(fmt.Errorf("fetching user from database mock error"))
_, err := s.StoreSQLMocked.ListOrganizations(context.Background()) _, err := s.StoreSQLMocked.ListOrganizations(s.adminCtx)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -275,16 +292,16 @@ func (s *OrgTestSuite) TestListOrganizationsDBFetchErr() {
} }
func (s *OrgTestSuite) TestDeleteOrganization() { func (s *OrgTestSuite) TestDeleteOrganization() {
err := s.Store.DeleteOrganization(context.Background(), s.Fixtures.Orgs[0].ID) err := s.Store.DeleteOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) _, err = s.Store.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching org: not found", err.Error()) s.Require().Equal("fetching org: not found", err.Error())
} }
func (s *OrgTestSuite) TestDeleteOrganizationInvalidOrgID() { func (s *OrgTestSuite) TestDeleteOrganizationInvalidOrgID() {
err := s.Store.DeleteOrganization(context.Background(), "dummy-org-id") err := s.Store.DeleteOrganization(s.adminCtx, "dummy-org-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) s.Require().Equal("fetching org: parsing id: invalid request", err.Error())
@ -302,7 +319,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationDBDeleteErr() {
WillReturnError(fmt.Errorf("mocked delete org error")) WillReturnError(fmt.Errorf("mocked delete org error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.DeleteOrganization(context.Background(), s.Fixtures.Orgs[0].ID) err := s.StoreSQLMocked.DeleteOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -310,7 +327,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationDBDeleteErr() {
} }
func (s *OrgTestSuite) TestUpdateOrganization() { func (s *OrgTestSuite) TestUpdateOrganization() {
org, err := s.Store.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) org, err := s.Store.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.Credentials.Name)
@ -318,70 +335,85 @@ func (s *OrgTestSuite) TestUpdateOrganization() {
} }
func (s *OrgTestSuite) TestUpdateOrganizationInvalidOrgID() { func (s *OrgTestSuite) TestUpdateOrganizationInvalidOrgID() {
_, err := s.Store.UpdateOrganization(context.Background(), "dummy-org-id", s.Fixtures.UpdateRepoParams) _, err := s.Store.UpdateOrganization(s.adminCtx, "dummy-org-id", s.Fixtures.UpdateRepoParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) s.Require().Equal("saving org: fetching org: parsing id: invalid request", err.Error())
} }
func (s *OrgTestSuite) TestUpdateOrganizationDBEncryptErr() { func (s *OrgTestSuite) TestUpdateOrganizationDBEncryptErr() {
s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Orgs[0].ID, 1). WithArgs(s.Fixtures.Orgs[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) s.Require().Equal("saving org: saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error())
s.assertSQLMockExpectations()
} }
func (s *OrgTestSuite) TestUpdateOrganizationDBSaveErr() { func (s *OrgTestSuite) TestUpdateOrganizationDBSaveErr() {
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Orgs[0].ID, 1). WithArgs(s.Fixtures.Orgs[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID))
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec(("UPDATE `organizations` SET")). ExpectExec(("UPDATE `organizations` SET")).
WillReturnError(fmt.Errorf("saving org mock error")) WillReturnError(fmt.Errorf("saving org mock error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving org: saving org mock error", err.Error()) s.Require().Equal("saving org: saving org: saving org mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *OrgTestSuite) TestUpdateOrganizationDBDecryptingErr() { func (s *OrgTestSuite) TestUpdateOrganizationDBDecryptingErr() {
s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase
s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Orgs[0].ID, 1). WithArgs(s.Fixtures.Orgs[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateOrganization(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateOrganization(s.adminCtx, s.Fixtures.Orgs[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) s.Require().Equal("saving org: saving org: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error())
s.assertSQLMockExpectations()
} }
func (s *OrgTestSuite) TestGetOrganizationByID() { func (s *OrgTestSuite) TestGetOrganizationByID() {
org, err := s.Store.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) org, err := s.Store.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.Orgs[0].ID, org.ID) s.Require().Equal(s.Fixtures.Orgs[0].ID, org.ID)
} }
func (s *OrgTestSuite) TestGetOrganizationByIDInvalidOrgID() { func (s *OrgTestSuite) TestGetOrganizationByIDInvalidOrgID() {
_, err := s.Store.GetOrganizationByID(context.Background(), "dummy-org-id") _, err := s.Store.GetOrganizationByID(s.adminCtx, "dummy-org-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) s.Require().Equal("fetching org: parsing id: invalid request", err.Error())
@ -397,21 +429,21 @@ func (s *OrgTestSuite) TestGetOrganizationByIDDBDecryptingErr() {
WithArgs(s.Fixtures.Orgs[0].ID). WithArgs(s.Fixtures.Orgs[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"org_id"}).AddRow(s.Fixtures.Orgs[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"org_id"}).AddRow(s.Fixtures.Orgs[0].ID))
_, err := s.StoreSQLMocked.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) _, err := s.StoreSQLMocked.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching enterprise: missing secret", err.Error()) s.Require().Equal("fetching org: missing secret", err.Error())
} }
func (s *OrgTestSuite) TestCreateOrganizationPool() { func (s *OrgTestSuite) TestCreateOrganizationPool() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().Nil(err) s.Require().Nil(err)
org, err := s.Store.GetOrganizationByID(context.Background(), s.Fixtures.Orgs[0].ID) org, err := s.Store.GetOrganizationByID(s.adminCtx, s.Fixtures.Orgs[0].ID)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot get org by ID: %v", err)) s.FailNow(fmt.Sprintf("cannot get org by ID: %v", err))
} }
@ -426,7 +458,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolMissingTags() {
s.Fixtures.CreatePoolParams.Tags = []string{} s.Fixtures.CreatePoolParams.Tags = []string{}
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("no tags specified", err.Error()) s.Require().Equal("no tags specified", err.Error())
@ -437,7 +469,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolInvalidOrgID() {
ID: "dummy-org-id", ID: "dummy-org-id",
EntityType: params.GithubEntityTypeOrganization, EntityType: params.GithubEntityTypeOrganization,
} }
_, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("parsing id: invalid request", err.Error()) s.Require().Equal("parsing id: invalid request", err.Error())
@ -455,7 +487,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBCreateErr() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) s.Require().Equal("checking pool existence: mocked creating pool error", err.Error())
@ -484,7 +516,7 @@ func (s *OrgTestSuite) TestCreateOrganizationDBPoolAlreadyExistErr() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err)
@ -511,7 +543,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchTagErr() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error())
@ -547,7 +579,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBAddingPoolErr() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating pool: mocked adding pool error", err.Error()) s.Require().Equal("creating pool: mocked adding pool error", err.Error())
@ -586,7 +618,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBSaveTagErr() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("associating tags: mocked saving tag error", err.Error()) s.Require().Equal("associating tags: mocked saving tag error", err.Error())
@ -635,7 +667,7 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchPoolErr() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: not found", err.Error())
@ -648,13 +680,13 @@ func (s *OrgTestSuite) TestListOrgPools() {
s.Require().Nil(err) s.Require().Nil(err)
for i := 1; i <= 2; i++ { for i := 1; i <= 2; i++ {
s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create org pool: %v", err))
} }
orgPools = append(orgPools, pool) orgPools = append(orgPools, pool)
} }
pools, err := s.Store.ListEntityPools(context.Background(), entity) pools, err := s.Store.ListEntityPools(s.adminCtx, entity)
s.Require().Nil(err) s.Require().Nil(err)
garmTesting.EqualDBEntityID(s.T(), orgPools, pools) garmTesting.EqualDBEntityID(s.T(), orgPools, pools)
@ -665,7 +697,7 @@ func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() {
ID: "dummy-org-id", ID: "dummy-org-id",
EntityType: params.GithubEntityTypeOrganization, EntityType: params.GithubEntityTypeOrganization,
} }
_, err := s.Store.ListEntityPools(context.Background(), entity) _, err := s.Store.ListEntityPools(s.adminCtx, entity)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error())
@ -674,12 +706,12 @@ func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() {
func (s *OrgTestSuite) TestGetOrganizationPool() { func (s *OrgTestSuite) TestGetOrganizationPool() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create org pool: %v", err))
} }
orgPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) orgPool, err := s.Store.GetEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(orgPool.ID, pool.ID) s.Require().Equal(orgPool.ID, pool.ID)
@ -690,7 +722,7 @@ func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() {
ID: "dummy-org-id", ID: "dummy-org-id",
EntityType: params.GithubEntityTypeOrganization, EntityType: params.GithubEntityTypeOrganization,
} }
_, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") _, err := s.Store.GetEntityPool(s.adminCtx, entity, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
@ -699,15 +731,15 @@ func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() {
func (s *OrgTestSuite) TestDeleteOrganizationPool() { func (s *OrgTestSuite) TestDeleteOrganizationPool() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create org pool: %v", err))
} }
err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) err = s.Store.DeleteEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) _, err = s.Store.GetEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Equal("fetching pool: finding pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
@ -716,7 +748,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolInvalidOrgID() {
ID: "dummy-org-id", ID: "dummy-org-id",
EntityType: params.GithubEntityTypeOrganization, EntityType: params.GithubEntityTypeOrganization,
} }
err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") err := s.Store.DeleteEntityPool(s.adminCtx, entity, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("parsing id: invalid request", err.Error()) s.Require().Equal("parsing id: invalid request", err.Error())
@ -726,7 +758,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create org pool: %v", err))
} }
@ -738,7 +770,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() {
WillReturnError(fmt.Errorf("mocked deleting pool error")) WillReturnError(fmt.Errorf("mocked deleting pool error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) err = s.StoreSQLMocked.DeleteEntityPool(s.adminCtx, entity, pool.ID)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) s.Require().Equal("removing pool: mocked deleting pool error", err.Error())
@ -748,21 +780,21 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() {
func (s *OrgTestSuite) TestListOrgInstances() { func (s *OrgTestSuite) TestListOrgInstances() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create org pool: %v", err))
} }
poolInstances := []params.Instance{} poolInstances := []params.Instance{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-org-%v", i) s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-org-%v", i)
instance, err := s.Store.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) instance, err := s.Store.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) s.FailNow(fmt.Sprintf("cannot create instance: %s", err))
} }
poolInstances = append(poolInstances, instance) poolInstances = append(poolInstances, instance)
} }
instances, err := s.Store.ListEntityInstances(context.Background(), entity) instances, err := s.Store.ListEntityInstances(s.adminCtx, entity)
s.Require().Nil(err) s.Require().Nil(err)
s.equalInstancesByName(poolInstances, instances) s.equalInstancesByName(poolInstances, instances)
@ -773,7 +805,7 @@ func (s *OrgTestSuite) TestListOrgInstancesInvalidOrgID() {
ID: "dummy-org-id", ID: "dummy-org-id",
EntityType: params.GithubEntityTypeOrganization, EntityType: params.GithubEntityTypeOrganization,
} }
_, err := s.Store.ListEntityInstances(context.Background(), entity) _, err := s.Store.ListEntityInstances(s.adminCtx, entity)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) s.Require().Equal("fetching entity: parsing id: invalid request", err.Error())
@ -782,12 +814,12 @@ func (s *OrgTestSuite) TestListOrgInstancesInvalidOrgID() {
func (s *OrgTestSuite) TestUpdateOrganizationPool() { func (s *OrgTestSuite) TestUpdateOrganizationPool() {
entity, err := s.Fixtures.Orgs[0].GetEntity() entity, err := s.Fixtures.Orgs[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create org pool: %v", err))
} }
pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) pool, err = s.Store.UpdateEntityPool(s.adminCtx, entity, pool.ID, s.Fixtures.UpdatePoolParams)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners)
@ -801,7 +833,7 @@ func (s *OrgTestSuite) TestUpdateOrganizationPoolInvalidOrgID() {
ID: "dummy-org-id", ID: "dummy-org-id",
EntityType: params.GithubEntityTypeOrganization, EntityType: params.GithubEntityTypeOrganization,
} }
_, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) _, err := s.Store.UpdateEntityPool(s.adminCtx, entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())

View file

@ -43,6 +43,7 @@ type PoolsTestSuite struct {
Store dbCommon.Store Store dbCommon.Store
StoreSQLMocked *sqlDatabase StoreSQLMocked *sqlDatabase
Fixtures *PoolsTestFixtures Fixtures *PoolsTestFixtures
adminCtx context.Context
} }
func (s *PoolsTestSuite) assertSQLMockExpectations() { func (s *PoolsTestSuite) assertSQLMockExpectations() {
@ -60,8 +61,14 @@ func (s *PoolsTestSuite) SetupTest() {
} }
s.Store = db s.Store = db
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.adminCtx = adminCtx
githubEndpoint := garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
creds := garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), githubEndpoint)
// create an organization for testing purposes // create an organization for testing purposes
org, err := s.Store.CreateOrganization(context.Background(), "test-org", "test-creds", "test-webhookSecret", params.PoolBalancerTypeRoundRobin) org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds.Name, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create org: %s", err)) s.FailNow(fmt.Sprintf("failed to create org: %s", err))
} }
@ -72,7 +79,7 @@ func (s *PoolsTestSuite) SetupTest() {
orgPools := []params.Pool{} orgPools := []params.Pool{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
pool, err := db.CreateEntityPool( pool, err := db.CreateEntityPool(
context.Background(), s.adminCtx,
entity, entity,
params.CreatePoolParams{ params.CreatePoolParams{
ProviderName: "test-provider", ProviderName: "test-provider",
@ -122,7 +129,7 @@ func (s *PoolsTestSuite) SetupTest() {
} }
func (s *PoolsTestSuite) TestListAllPools() { func (s *PoolsTestSuite) TestListAllPools() {
pools, err := s.Store.ListAllPools(context.Background()) pools, err := s.Store.ListAllPools(s.adminCtx)
s.Require().Nil(err) s.Require().Nil(err)
garmTesting.EqualDBEntityID(s.T(), s.Fixtures.Pools, pools) garmTesting.EqualDBEntityID(s.T(), s.Fixtures.Pools, pools)
@ -133,7 +140,7 @@ func (s *PoolsTestSuite) TestListAllPoolsDBFetchErr() {
ExpectQuery(regexp.QuoteMeta("SELECT `pools`.`id`,`pools`.`created_at`,`pools`.`updated_at`,`pools`.`deleted_at`,`pools`.`provider_name`,`pools`.`runner_prefix`,`pools`.`max_runners`,`pools`.`min_idle_runners`,`pools`.`runner_bootstrap_timeout`,`pools`.`image`,`pools`.`flavor`,`pools`.`os_type`,`pools`.`os_arch`,`pools`.`enabled`,`pools`.`git_hub_runner_group`,`pools`.`repo_id`,`pools`.`org_id`,`pools`.`enterprise_id`,`pools`.`priority` FROM `pools` WHERE `pools`.`deleted_at` IS NULL")). ExpectQuery(regexp.QuoteMeta("SELECT `pools`.`id`,`pools`.`created_at`,`pools`.`updated_at`,`pools`.`deleted_at`,`pools`.`provider_name`,`pools`.`runner_prefix`,`pools`.`max_runners`,`pools`.`min_idle_runners`,`pools`.`runner_bootstrap_timeout`,`pools`.`image`,`pools`.`flavor`,`pools`.`os_type`,`pools`.`os_arch`,`pools`.`enabled`,`pools`.`git_hub_runner_group`,`pools`.`repo_id`,`pools`.`org_id`,`pools`.`enterprise_id`,`pools`.`priority` FROM `pools` WHERE `pools`.`deleted_at` IS NULL")).
WillReturnError(fmt.Errorf("mocked fetching all pools error")) WillReturnError(fmt.Errorf("mocked fetching all pools error"))
_, err := s.StoreSQLMocked.ListAllPools(context.Background()) _, err := s.StoreSQLMocked.ListAllPools(s.adminCtx)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
@ -141,29 +148,29 @@ func (s *PoolsTestSuite) TestListAllPoolsDBFetchErr() {
} }
func (s *PoolsTestSuite) TestGetPoolByID() { func (s *PoolsTestSuite) TestGetPoolByID() {
pool, err := s.Store.GetPoolByID(context.Background(), s.Fixtures.Pools[0].ID) pool, err := s.Store.GetPoolByID(s.adminCtx, s.Fixtures.Pools[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.Pools[0].ID, pool.ID) s.Require().Equal(s.Fixtures.Pools[0].ID, pool.ID)
} }
func (s *PoolsTestSuite) TestGetPoolByIDInvalidPoolID() { func (s *PoolsTestSuite) TestGetPoolByIDInvalidPoolID() {
_, err := s.Store.GetPoolByID(context.Background(), "dummy-pool-id") _, err := s.Store.GetPoolByID(s.adminCtx, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool by ID: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool by ID: parsing id: invalid request", err.Error())
} }
func (s *PoolsTestSuite) TestDeletePoolByID() { func (s *PoolsTestSuite) TestDeletePoolByID() {
err := s.Store.DeletePoolByID(context.Background(), s.Fixtures.Pools[0].ID) err := s.Store.DeletePoolByID(s.adminCtx, s.Fixtures.Pools[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetPoolByID(context.Background(), s.Fixtures.Pools[0].ID) _, err = s.Store.GetPoolByID(s.adminCtx, s.Fixtures.Pools[0].ID)
s.Require().Equal("fetching pool by ID: not found", err.Error()) s.Require().Equal("fetching pool by ID: not found", err.Error())
} }
func (s *PoolsTestSuite) TestDeletePoolByIDInvalidPoolID() { func (s *PoolsTestSuite) TestDeletePoolByIDInvalidPoolID() {
err := s.Store.DeletePoolByID(context.Background(), "dummy-pool-id") err := s.Store.DeletePoolByID(s.adminCtx, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool by ID: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool by ID: parsing id: invalid request", err.Error())
@ -180,7 +187,7 @@ func (s *PoolsTestSuite) TestDeletePoolByIDDBRemoveErr() {
WillReturnError(fmt.Errorf("mocked removing pool error")) WillReturnError(fmt.Errorf("mocked removing pool error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.DeletePoolByID(context.Background(), s.Fixtures.Pools[0].ID) err := s.StoreSQLMocked.DeletePoolByID(s.adminCtx, s.Fixtures.Pools[0].ID)
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)

View file

@ -36,31 +36,32 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent
return params.Repository{}, fmt.Errorf("failed to encrypt string") return params.Repository{}, fmt.Errorf("failed to encrypt string")
} }
var newRepo Repository newRepo := Repository{
Name: name,
Owner: owner,
WebhookSecret: secret,
PoolBalancerType: poolBalancerType,
}
err = s.conn.Transaction(func(tx *gorm.DB) error { err = s.conn.Transaction(func(tx *gorm.DB) error {
creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false) creds, err := s.getGithubCredentialsByName(ctx, tx, credentialsName, false)
if err != nil { if err != nil {
return errors.Wrap(err, "creating repository") return errors.Wrap(err, "creating repository")
} }
newRepo.Name = name
newRepo.Owner = owner
newRepo.WebhookSecret = secret
newRepo.CredentialsID = &creds.ID newRepo.CredentialsID = &creds.ID
newRepo.PoolBalancerType = poolBalancerType
q := tx.Create(&newRepo) q := tx.Create(&newRepo)
if q.Error != nil { if q.Error != nil {
return errors.Wrap(q.Error, "creating repository") return errors.Wrap(q.Error, "creating repository")
} }
newRepo.Credentials = creds
return nil return nil
}) })
if err != nil { if err != nil {
return params.Repository{}, errors.Wrap(err, "creating repository") return params.Repository{}, errors.Wrap(err, "creating repository")
} }
param, err := s.sqlToCommonRepository(newRepo) param, err := s.sqlToCommonRepository(newRepo, true)
if err != nil { if err != nil {
return params.Repository{}, errors.Wrap(err, "creating repository") return params.Repository{}, errors.Wrap(err, "creating repository")
} }
@ -74,7 +75,7 @@ func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (pa
return params.Repository{}, errors.Wrap(err, "fetching repo") return params.Repository{}, errors.Wrap(err, "fetching repo")
} }
param, err := s.sqlToCommonRepository(repo) param, err := s.sqlToCommonRepository(repo, true)
if err != nil { if err != nil {
return params.Repository{}, errors.Wrap(err, "fetching repo") return params.Repository{}, errors.Wrap(err, "fetching repo")
} }
@ -92,7 +93,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository,
ret := make([]params.Repository, len(repos)) ret := make([]params.Repository, len(repos))
for idx, val := range repos { for idx, val := range repos {
var err error var err error
ret[idx], err = s.sqlToCommonRepository(val) ret[idx], err = s.sqlToCommonRepository(val, true)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching repositories") return nil, errors.Wrap(err, "fetching repositories")
} }
@ -102,7 +103,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository,
} }
func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error { func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error {
repo, err := s.getRepoByID(ctx, repoID) repo, err := s.getRepoByID(ctx, s.conn, repoID)
if err != nil { if err != nil {
return errors.Wrap(err, "fetching repo") return errors.Wrap(err, "fetching repo")
} }
@ -116,33 +117,51 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error
} }
func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) { func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) {
repo, err := s.getRepoByID(ctx, repoID, "Credentials", "Endpoint") var repo Repository
if err != nil { var creds GithubCredentials
return params.Repository{}, errors.Wrap(err, "fetching repo") err := s.conn.Transaction(func(tx *gorm.DB) error {
} var err error
repo, err = s.getRepoByID(ctx, tx, repoID, "Credentials", "Endpoint")
if param.CredentialsName != "" {
repo.CredentialsName = param.CredentialsName
}
if param.WebhookSecret != "" {
secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase))
if err != nil { if err != nil {
return params.Repository{}, fmt.Errorf("saving repo: failed to encrypt string: %w", err) return errors.Wrap(err, "fetching repo")
} }
repo.WebhookSecret = secret
if param.CredentialsName != "" {
repo.CredentialsName = param.CredentialsName
creds, err = s.getGithubCredentialsByName(ctx, tx, param.CredentialsName, false)
if err != nil {
return errors.Wrap(err, "fetching credentials")
}
repo.CredentialsID = &creds.ID
}
if param.WebhookSecret != "" {
secret, err := util.Seal([]byte(param.WebhookSecret), []byte(s.cfg.Passphrase))
if err != nil {
return fmt.Errorf("saving repo: failed to encrypt string: %w", err)
}
repo.WebhookSecret = secret
}
if param.PoolBalancerType != "" {
repo.PoolBalancerType = param.PoolBalancerType
}
q := tx.Save(&repo)
if q.Error != nil {
return errors.Wrap(q.Error, "saving repo")
}
if creds.ID != 0 {
repo.Credentials = creds
}
return nil
})
if err != nil {
return params.Repository{}, errors.Wrap(err, "saving repo")
} }
if param.PoolBalancerType != "" { newParams, err := s.sqlToCommonRepository(repo, true)
repo.PoolBalancerType = param.PoolBalancerType
}
q := s.conn.Save(&repo)
if q.Error != nil {
return params.Repository{}, errors.Wrap(q.Error, "saving repo")
}
newParams, err := s.sqlToCommonRepository(repo)
if err != nil { if err != nil {
return params.Repository{}, errors.Wrap(err, "saving repo") return params.Repository{}, errors.Wrap(err, "saving repo")
} }
@ -150,12 +169,12 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
} }
func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error) { func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error) {
repo, err := s.getRepoByID(ctx, repoID, "Pools", "Credentials", "Endpoint") repo, err := s.getRepoByID(ctx, s.conn, repoID, "Pools", "Credentials", "Endpoint")
if err != nil { if err != nil {
return params.Repository{}, errors.Wrap(err, "fetching repo") return params.Repository{}, errors.Wrap(err, "fetching repo")
} }
param, err := s.sqlToCommonRepository(repo) param, err := s.sqlToCommonRepository(repo, true)
if err != nil { if err != nil {
return params.Repository{}, errors.Wrap(err, "fetching repo") return params.Repository{}, errors.Wrap(err, "fetching repo")
} }
@ -206,14 +225,14 @@ func (s *sqlDatabase) getEntityPoolByUniqueFields(tx *gorm.DB, entity params.Git
return Pool{}, nil return Pool{}, nil
} }
func (s *sqlDatabase) getRepoByID(_ context.Context, id string, preload ...string) (Repository, error) { func (s *sqlDatabase) getRepoByID(_ context.Context, tx *gorm.DB, id string, preload ...string) (Repository, error) {
u, err := uuid.Parse(id) u, err := uuid.Parse(id)
if err != nil { if err != nil {
return Repository{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") return Repository{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
} }
var repo Repository var repo Repository
q := s.conn q := tx
if len(preload) > 0 { if len(preload) > 0 {
for _, field := range preload { for _, field := range preload {
q = q.Preload(field) q = q.Preload(field)

View file

@ -48,6 +48,11 @@ type RepoTestSuite struct {
Store dbCommon.Store Store dbCommon.Store
StoreSQLMocked *sqlDatabase StoreSQLMocked *sqlDatabase
Fixtures *RepoTestFixtures Fixtures *RepoTestFixtures
adminCtx context.Context
testCreds params.GithubCredentials
secondaryTestCreds params.GithubCredentials
githubEndpoint params.GithubEndpoint
} }
func (s *RepoTestSuite) equalReposByName(expected, actual []params.Repository) { func (s *RepoTestSuite) equalReposByName(expected, actual []params.Repository) {
@ -87,14 +92,21 @@ func (s *RepoTestSuite) SetupTest() {
} }
s.Store = db s.Store = db
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.adminCtx = adminCtx
s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint)
s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint)
// create some repository objects in the database, for testing purposes // create some repository objects in the database, for testing purposes
repos := []params.Repository{} repos := []params.Repository{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
repo, err := db.CreateRepository( repo, err := db.CreateRepository(
context.Background(), adminCtx,
fmt.Sprintf("test-owner-%d", i), fmt.Sprintf("test-owner-%d", i),
fmt.Sprintf("test-repo-%d", i), fmt.Sprintf("test-repo-%d", i),
fmt.Sprintf("test-creds-%d", i), s.testCreds.Name,
fmt.Sprintf("test-webhook-secret-%d", i), fmt.Sprintf("test-webhook-secret-%d", i),
params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypeRoundRobin,
) )
@ -136,7 +148,7 @@ func (s *RepoTestSuite) SetupTest() {
CreateRepoParams: params.CreateRepoParams{ CreateRepoParams: params.CreateRepoParams{
Owner: "test-owner-repo", Owner: "test-owner-repo",
Name: "test-repo", Name: "test-repo",
CredentialsName: "test-creds-repo", CredentialsName: s.testCreds.Name,
WebhookSecret: "test-webhook-secret", WebhookSecret: "test-webhook-secret",
}, },
CreatePoolParams: params.CreatePoolParams{ CreatePoolParams: params.CreatePoolParams{
@ -155,7 +167,7 @@ func (s *RepoTestSuite) SetupTest() {
OSType: "linux", OSType: "linux",
}, },
UpdateRepoParams: params.UpdateEntityParams{ UpdateRepoParams: params.UpdateEntityParams{
CredentialsName: "test-update-creds", CredentialsName: s.secondaryTestCreds.Name,
WebhookSecret: "test-update-webhook-secret", WebhookSecret: "test-update-webhook-secret",
}, },
UpdatePoolParams: params.UpdatePoolParams{ UpdatePoolParams: params.UpdatePoolParams{
@ -172,7 +184,7 @@ func (s *RepoTestSuite) SetupTest() {
func (s *RepoTestSuite) TestCreateRepository() { func (s *RepoTestSuite) TestCreateRepository() {
// call tested function // call tested function
repo, err := s.Store.CreateRepository( repo, err := s.Store.CreateRepository(
context.Background(), s.adminCtx,
s.Fixtures.CreateRepoParams.Owner, s.Fixtures.CreateRepoParams.Owner,
s.Fixtures.CreateRepoParams.Name, s.Fixtures.CreateRepoParams.Name,
s.Fixtures.CreateRepoParams.CredentialsName, s.Fixtures.CreateRepoParams.CredentialsName,
@ -182,7 +194,7 @@ func (s *RepoTestSuite) TestCreateRepository() {
// assertions // assertions
s.Require().Nil(err) s.Require().Nil(err)
storeRepo, err := s.Store.GetRepositoryByID(context.Background(), repo.ID) storeRepo, err := s.Store.GetRepositoryByID(s.adminCtx, repo.ID)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to get repository by id: %v", err)) s.FailNow(fmt.Sprintf("failed to get repository by id: %v", err))
} }
@ -206,7 +218,7 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBPassphrase() {
} }
_, err = sqlDB.CreateRepository( _, err = sqlDB.CreateRepository(
context.Background(), s.adminCtx,
s.Fixtures.CreateRepoParams.Owner, s.Fixtures.CreateRepoParams.Owner,
s.Fixtures.CreateRepoParams.Name, s.Fixtures.CreateRepoParams.Name,
s.Fixtures.CreateRepoParams.CredentialsName, s.Fixtures.CreateRepoParams.CredentialsName,
@ -220,13 +232,17 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBPassphrase() {
func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() { func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() {
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Repos[0].CredentialsName).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.testCreds.ID))
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `repositories`")). ExpectExec(regexp.QuoteMeta("INSERT INTO `repositories`")).
WillReturnError(fmt.Errorf("creating repo mock error")) WillReturnError(fmt.Errorf("creating repo mock error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.CreateRepository( _, err := s.StoreSQLMocked.CreateRepository(
context.Background(), s.adminCtx,
s.Fixtures.CreateRepoParams.Owner, s.Fixtures.CreateRepoParams.Owner,
s.Fixtures.CreateRepoParams.Name, s.Fixtures.CreateRepoParams.Name,
s.Fixtures.CreateRepoParams.CredentialsName, s.Fixtures.CreateRepoParams.CredentialsName,
@ -234,13 +250,13 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() {
params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypeRoundRobin,
) )
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating repository: creating repo mock error", err.Error()) s.Require().Equal("creating repository: creating repository: creating repo mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestGetRepository() { func (s *RepoTestSuite) TestGetRepository() {
repo, err := s.Store.GetRepository(context.Background(), s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) repo, err := s.Store.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.Repos[0].Owner, repo.Owner) s.Require().Equal(s.Fixtures.Repos[0].Owner, repo.Owner)
@ -249,7 +265,7 @@ func (s *RepoTestSuite) TestGetRepository() {
} }
func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() { func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() {
repo, err := s.Store.GetRepository(context.Background(), "TeSt-oWnEr-1", "TeSt-rEpO-1") repo, err := s.Store.GetRepository(s.adminCtx, "TeSt-oWnEr-1", "TeSt-rEpO-1")
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal("test-owner-1", repo.Owner) s.Require().Equal("test-owner-1", repo.Owner)
@ -257,7 +273,7 @@ func (s *RepoTestSuite) TestGetRepositoryCaseInsensitive() {
} }
func (s *RepoTestSuite) TestGetRepositoryNotFound() { func (s *RepoTestSuite) TestGetRepositoryNotFound() {
_, err := s.Store.GetRepository(context.Background(), "dummy-owner", "dummy-name") _, err := s.Store.GetRepository(s.adminCtx, "dummy-owner", "dummy-name")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repo: not found", err.Error()) s.Require().Equal("fetching repo: not found", err.Error())
@ -273,15 +289,15 @@ func (s *RepoTestSuite) TestGetRepositoryDBDecryptingErr() {
WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, 1). WithArgs(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner, 1).
WillReturnRows(sqlmock.NewRows([]string{"name", "owner"}).AddRow(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner)) WillReturnRows(sqlmock.NewRows([]string{"name", "owner"}).AddRow(s.Fixtures.Repos[0].Name, s.Fixtures.Repos[0].Owner))
_, err := s.StoreSQLMocked.GetRepository(context.Background(), s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name) _, err := s.StoreSQLMocked.GetRepository(s.adminCtx, s.Fixtures.Repos[0].Owner, s.Fixtures.Repos[0].Name)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repo: missing secret", err.Error()) s.Require().Equal("fetching repo: missing secret", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestListRepositories() { func (s *RepoTestSuite) TestListRepositories() {
repos, err := s.Store.ListRepositories((context.Background())) repos, err := s.Store.ListRepositories(s.adminCtx)
s.Require().Nil(err) s.Require().Nil(err)
s.equalReposByName(s.Fixtures.Repos, repos) s.equalReposByName(s.Fixtures.Repos, repos)
@ -292,11 +308,11 @@ func (s *RepoTestSuite) TestListRepositoriesDBFetchErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE `repositories`.`deleted_at` IS NULL")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE `repositories`.`deleted_at` IS NULL")).
WillReturnError(fmt.Errorf("fetching user from database mock error")) WillReturnError(fmt.Errorf("fetching user from database mock error"))
_, err := s.StoreSQLMocked.ListRepositories(context.Background()) _, err := s.StoreSQLMocked.ListRepositories(s.adminCtx)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching user from database: fetching user from database mock error", err.Error()) s.Require().Equal("fetching user from database: fetching user from database mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestListRepositoriesDBDecryptingErr() { func (s *RepoTestSuite) TestListRepositoriesDBDecryptingErr() {
@ -306,24 +322,24 @@ func (s *RepoTestSuite) TestListRepositoriesDBDecryptingErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE `repositories`.`deleted_at` IS NULL")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE `repositories`.`deleted_at` IS NULL")).
WillReturnRows(sqlmock.NewRows([]string{"id", "webhook_secret"}).AddRow(s.Fixtures.Repos[0].ID, s.Fixtures.Repos[0].WebhookSecret)) WillReturnRows(sqlmock.NewRows([]string{"id", "webhook_secret"}).AddRow(s.Fixtures.Repos[0].ID, s.Fixtures.Repos[0].WebhookSecret))
_, err := s.StoreSQLMocked.ListRepositories(context.Background()) _, err := s.StoreSQLMocked.ListRepositories(s.adminCtx)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repositories: decrypting secret: invalid passphrase length (expected length 32 characters)", err.Error()) s.Require().Equal("fetching repositories: decrypting secret: invalid passphrase length (expected length 32 characters)", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestDeleteRepository() { func (s *RepoTestSuite) TestDeleteRepository() {
err := s.Store.DeleteRepository(context.Background(), s.Fixtures.Repos[0].ID) err := s.Store.DeleteRepository(s.adminCtx, s.Fixtures.Repos[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) _, err = s.Store.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repo: not found", err.Error()) s.Require().Equal("fetching repo: not found", err.Error())
} }
func (s *RepoTestSuite) TestDeleteRepositoryInvalidRepoID() { func (s *RepoTestSuite) TestDeleteRepositoryInvalidRepoID() {
err := s.Store.DeleteRepository(context.Background(), "dummy-repo-id") err := s.Store.DeleteRepository(s.adminCtx, "dummy-repo-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) s.Require().Equal("fetching repo: parsing id: invalid request", err.Error())
@ -341,15 +357,15 @@ func (s *RepoTestSuite) TestDeleteRepositoryDBRemoveErr() {
WillReturnError(fmt.Errorf("mocked deleting repo error")) WillReturnError(fmt.Errorf("mocked deleting repo error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.DeleteRepository(context.Background(), s.Fixtures.Repos[0].ID) err := s.StoreSQLMocked.DeleteRepository(s.adminCtx, s.Fixtures.Repos[0].ID)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("deleting repo: mocked deleting repo error", err.Error()) s.Require().Equal("deleting repo: mocked deleting repo error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestUpdateRepository() { func (s *RepoTestSuite) TestUpdateRepository() {
repo, err := s.Store.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) repo, err := s.Store.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.Credentials.Name) s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, repo.Credentials.Name)
@ -357,69 +373,84 @@ func (s *RepoTestSuite) TestUpdateRepository() {
} }
func (s *RepoTestSuite) TestUpdateRepositoryInvalidRepoID() { func (s *RepoTestSuite) TestUpdateRepositoryInvalidRepoID() {
_, err := s.Store.UpdateRepository(context.Background(), "dummy-repo-id", s.Fixtures.UpdateRepoParams) _, err := s.Store.UpdateRepository(s.adminCtx, "dummy-repo-id", s.Fixtures.UpdateRepoParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) s.Require().Equal("saving repo: fetching repo: parsing id: invalid request", err.Error())
} }
func (s *RepoTestSuite) TestUpdateRepositoryDBEncryptErr() { func (s *RepoTestSuite) TestUpdateRepositoryDBEncryptErr() {
s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Repos[0].ID, 1). WithArgs(s.Fixtures.Repos[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID))
_, err := s.StoreSQLMocked.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) s.Require().Equal("saving repo: saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestUpdateRepositoryDBSaveErr() { func (s *RepoTestSuite) TestUpdateRepositoryDBSaveErr() {
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Repos[0].ID, 1). WithArgs(s.Fixtures.Repos[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID))
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec(("UPDATE `repositories` SET")). ExpectExec(("UPDATE `repositories` SET")).
WillReturnError(fmt.Errorf("saving repo mock error")) WillReturnError(fmt.Errorf("saving repo mock error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving repo: saving repo mock error", err.Error()) s.Require().Equal("saving repo: saving repo: saving repo mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestUpdateRepositoryDBDecryptingErr() { func (s *RepoTestSuite) TestUpdateRepositoryDBDecryptingErr() {
s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase s.StoreSQLMocked.cfg.Passphrase = wrongPassphrase
s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret s.Fixtures.UpdateRepoParams.WebhookSecret = webhookSecret
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Repos[0].ID, 1). WithArgs(s.Fixtures.Repos[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `github_credentials` WHERE name = ? AND `github_credentials`.`deleted_at` IS NULL ORDER BY `github_credentials`.`id` LIMIT 1")).
WithArgs(s.secondaryTestCreds.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.secondaryTestCreds.ID))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateRepository(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams) _, err := s.StoreSQLMocked.UpdateRepository(s.adminCtx, s.Fixtures.Repos[0].ID, s.Fixtures.UpdateRepoParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error()) s.Require().Equal("saving repo: saving repo: failed to encrypt string: invalid passphrase length (expected length 32 characters)", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestGetRepositoryByID() { func (s *RepoTestSuite) TestGetRepositoryByID() {
repo, err := s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) repo, err := s.Store.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.Repos[0].ID, repo.ID) s.Require().Equal(s.Fixtures.Repos[0].ID, repo.ID)
} }
func (s *RepoTestSuite) TestGetRepositoryByIDInvalidRepoID() { func (s *RepoTestSuite) TestGetRepositoryByIDInvalidRepoID() {
_, err := s.Store.GetRepositoryByID(context.Background(), "dummy-repo-id") _, err := s.Store.GetRepositoryByID(s.adminCtx, "dummy-repo-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) s.Require().Equal("fetching repo: parsing id: invalid request", err.Error())
@ -435,20 +466,20 @@ func (s *RepoTestSuite) TestGetRepositoryByIDDBDecryptingErr() {
WithArgs(s.Fixtures.Repos[0].ID). WithArgs(s.Fixtures.Repos[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"repo_id"}).AddRow(s.Fixtures.Repos[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"repo_id"}).AddRow(s.Fixtures.Repos[0].ID))
_, err := s.StoreSQLMocked.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) _, err := s.StoreSQLMocked.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching repo: missing secret", err.Error()) s.Require().Equal("fetching repo: missing secret", err.Error())
s.assertSQLMockExpectations()
} }
func (s *RepoTestSuite) TestCreateRepositoryPool() { func (s *RepoTestSuite) TestCreateRepositoryPool() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().Nil(err) s.Require().Nil(err)
repo, err := s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) repo, err := s.Store.GetRepositoryByID(s.adminCtx, s.Fixtures.Repos[0].ID)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot get repo by ID: %v", err)) s.FailNow(fmt.Sprintf("cannot get repo by ID: %v", err))
} }
@ -463,7 +494,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolMissingTags() {
s.Fixtures.CreatePoolParams.Tags = []string{} s.Fixtures.CreatePoolParams.Tags = []string{}
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("no tags specified", err.Error()) s.Require().Equal("no tags specified", err.Error())
@ -474,7 +505,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolInvalidRepoID() {
ID: "dummy-repo-id", ID: "dummy-repo-id",
EntityType: params.GithubEntityTypeRepository, EntityType: params.GithubEntityTypeRepository,
} }
_, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("parsing id: invalid request", err.Error()) s.Require().Equal("parsing id: invalid request", err.Error())
@ -492,7 +523,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBCreateErr() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) s.Require().Equal("checking pool existence: mocked creating pool error", err.Error())
@ -522,7 +553,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBPoolAlreadyExistErr() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("pool with the same image and flavor already exists on this provider", err.Error()) s.Require().Equal("pool with the same image and flavor already exists on this provider", err.Error())
@ -550,7 +581,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchTagErr() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error())
@ -587,7 +618,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBAddingPoolErr() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating pool: mocked adding pool error", err.Error()) s.Require().Equal("creating pool: mocked adding pool error", err.Error())
@ -627,7 +658,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBSaveTagErr() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("associating tags: mocked saving tag error", err.Error()) s.Require().Equal("associating tags: mocked saving tag error", err.Error())
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
@ -675,7 +706,7 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchPoolErr() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) _, err = s.StoreSQLMocked.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: not found", err.Error())
@ -688,14 +719,14 @@ func (s *RepoTestSuite) TestListRepoPools() {
repoPools := []params.Pool{} repoPools := []params.Pool{}
for i := 1; i <= 2; i++ { for i := 1; i <= 2; i++ {
s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%d", i) s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%d", i)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err))
} }
repoPools = append(repoPools, pool) repoPools = append(repoPools, pool)
} }
pools, err := s.Store.ListEntityPools(context.Background(), entity) pools, err := s.Store.ListEntityPools(s.adminCtx, entity)
s.Require().Nil(err) s.Require().Nil(err)
garmTesting.EqualDBEntityID(s.T(), repoPools, pools) garmTesting.EqualDBEntityID(s.T(), repoPools, pools)
@ -706,7 +737,7 @@ func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() {
ID: "dummy-repo-id", ID: "dummy-repo-id",
EntityType: params.GithubEntityTypeRepository, EntityType: params.GithubEntityTypeRepository,
} }
_, err := s.Store.ListEntityPools(context.Background(), entity) _, err := s.Store.ListEntityPools(s.adminCtx, entity)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error())
@ -715,12 +746,12 @@ func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() {
func (s *RepoTestSuite) TestGetRepositoryPool() { func (s *RepoTestSuite) TestGetRepositoryPool() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err))
} }
repoPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) repoPool, err := s.Store.GetEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(repoPool.ID, pool.ID) s.Require().Equal(repoPool.ID, pool.ID)
@ -731,7 +762,7 @@ func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() {
ID: "dummy-repo-id", ID: "dummy-repo-id",
EntityType: params.GithubEntityTypeRepository, EntityType: params.GithubEntityTypeRepository,
} }
_, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") _, err := s.Store.GetEntityPool(s.adminCtx, entity, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
@ -740,15 +771,15 @@ func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() {
func (s *RepoTestSuite) TestDeleteRepositoryPool() { func (s *RepoTestSuite) TestDeleteRepositoryPool() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err))
} }
err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) err = s.Store.DeleteEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) _, err = s.Store.GetEntityPool(s.adminCtx, entity, pool.ID)
s.Require().Equal("fetching pool: finding pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
@ -757,7 +788,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolInvalidRepoID() {
ID: "dummy-repo-id", ID: "dummy-repo-id",
EntityType: params.GithubEntityTypeRepository, EntityType: params.GithubEntityTypeRepository,
} }
err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") err := s.Store.DeleteEntityPool(s.adminCtx, entity, "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("parsing id: invalid request", err.Error()) s.Require().Equal("parsing id: invalid request", err.Error())
@ -767,7 +798,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err))
} }
@ -779,7 +810,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() {
WillReturnError(fmt.Errorf("mocked deleting pool error")) WillReturnError(fmt.Errorf("mocked deleting pool error"))
s.Fixtures.SQLMock.ExpectRollback() s.Fixtures.SQLMock.ExpectRollback()
err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) err = s.StoreSQLMocked.DeleteEntityPool(s.adminCtx, entity, pool.ID)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) s.Require().Equal("removing pool: mocked deleting pool error", err.Error())
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
@ -788,21 +819,21 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() {
func (s *RepoTestSuite) TestListRepoInstances() { func (s *RepoTestSuite) TestListRepoInstances() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) pool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err))
} }
poolInstances := []params.Instance{} poolInstances := []params.Instance{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-repo-%d", i) s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-repo-%d", i)
instance, err := s.Store.CreateInstance(context.Background(), pool.ID, s.Fixtures.CreateInstanceParams) instance, err := s.Store.CreateInstance(s.adminCtx, pool.ID, s.Fixtures.CreateInstanceParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) s.FailNow(fmt.Sprintf("cannot create instance: %s", err))
} }
poolInstances = append(poolInstances, instance) poolInstances = append(poolInstances, instance)
} }
instances, err := s.Store.ListEntityInstances(context.Background(), entity) instances, err := s.Store.ListEntityInstances(s.adminCtx, entity)
s.Require().Nil(err) s.Require().Nil(err)
s.equalInstancesByID(poolInstances, instances) s.equalInstancesByID(poolInstances, instances)
@ -813,7 +844,7 @@ func (s *RepoTestSuite) TestListRepoInstancesInvalidRepoID() {
ID: "dummy-repo-id", ID: "dummy-repo-id",
EntityType: params.GithubEntityTypeRepository, EntityType: params.GithubEntityTypeRepository,
} }
_, err := s.Store.ListEntityInstances(context.Background(), entity) _, err := s.Store.ListEntityInstances(s.adminCtx, entity)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) s.Require().Equal("fetching entity: parsing id: invalid request", err.Error())
@ -822,12 +853,12 @@ func (s *RepoTestSuite) TestListRepoInstancesInvalidRepoID() {
func (s *RepoTestSuite) TestUpdateRepositoryPool() { func (s *RepoTestSuite) TestUpdateRepositoryPool() {
entity, err := s.Fixtures.Repos[0].GetEntity() entity, err := s.Fixtures.Repos[0].GetEntity()
s.Require().Nil(err) s.Require().Nil(err)
repoPool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) repoPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, s.Fixtures.CreatePoolParams)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err))
} }
pool, err := s.Store.UpdateEntityPool(context.Background(), entity, repoPool.ID, s.Fixtures.UpdatePoolParams) pool, err := s.Store.UpdateEntityPool(s.adminCtx, entity, repoPool.ID, s.Fixtures.UpdatePoolParams)
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners)
@ -841,7 +872,7 @@ func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() {
ID: "dummy-repo-id", ID: "dummy-repo-id",
EntityType: params.GithubEntityTypeRepository, EntityType: params.GithubEntityTypeRepository,
} }
_, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-repo-id", s.Fixtures.UpdatePoolParams) _, err := s.Store.UpdateEntityPool(s.adminCtx, entity, "dummy-repo-id", s.Fixtures.UpdatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())

View file

@ -105,7 +105,7 @@ func (s *sqlDatabase) sqlAddressToParamsAddress(addr Address) commonParams.Addre
} }
} }
func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organization, error) { func (s *sqlDatabase) sqlToCommonOrganization(org Organization, detailed bool) (params.Organization, error) {
if len(org.WebhookSecret) == 0 { if len(org.WebhookSecret) == 0 {
return params.Organization{}, errors.New("missing secret") return params.Organization{}, errors.New("missing secret")
} }
@ -114,20 +114,21 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organiza
return params.Organization{}, errors.Wrap(err, "decrypting secret") return params.Organization{}, errors.Wrap(err, "decrypting secret")
} }
creds, err := s.sqlToCommonGithubCredentials(org.Credentials)
if err != nil {
return params.Organization{}, errors.Wrap(err, "converting credentials")
}
ret := params.Organization{ ret := params.Organization{
ID: org.ID.String(), ID: org.ID.String(),
Name: org.Name, Name: org.Name,
CredentialsName: creds.Name, CredentialsName: org.Credentials.Name,
Credentials: creds,
Pools: make([]params.Pool, len(org.Pools)), Pools: make([]params.Pool, len(org.Pools)),
WebhookSecret: string(secret), WebhookSecret: string(secret),
PoolBalancerType: org.PoolBalancerType, PoolBalancerType: org.PoolBalancerType,
} }
if detailed {
creds, err := s.sqlToCommonGithubCredentials(org.Credentials)
if err != nil {
return params.Organization{}, errors.Wrap(err, "converting credentials")
}
ret.Credentials = creds
}
if ret.PoolBalancerType == "" { if ret.PoolBalancerType == "" {
ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin
@ -143,7 +144,7 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization) (params.Organiza
return ret, nil return ret, nil
} }
func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise) (params.Enterprise, error) { func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise, detailed bool) (params.Enterprise, error) {
if len(enterprise.WebhookSecret) == 0 { if len(enterprise.WebhookSecret) == 0 {
return params.Enterprise{}, errors.New("missing secret") return params.Enterprise{}, errors.New("missing secret")
} }
@ -152,20 +153,23 @@ func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise) (params.Enter
return params.Enterprise{}, errors.Wrap(err, "decrypting secret") return params.Enterprise{}, errors.Wrap(err, "decrypting secret")
} }
creds, err := s.sqlToCommonGithubCredentials(enterprise.Credentials)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "converting credentials")
}
ret := params.Enterprise{ ret := params.Enterprise{
ID: enterprise.ID.String(), ID: enterprise.ID.String(),
Name: enterprise.Name, Name: enterprise.Name,
CredentialsName: creds.Name, CredentialsName: enterprise.Credentials.Name,
Credentials: creds,
Pools: make([]params.Pool, len(enterprise.Pools)), Pools: make([]params.Pool, len(enterprise.Pools)),
WebhookSecret: string(secret), WebhookSecret: string(secret),
PoolBalancerType: enterprise.PoolBalancerType, PoolBalancerType: enterprise.PoolBalancerType,
} }
if detailed {
creds, err := s.sqlToCommonGithubCredentials(enterprise.Credentials)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "converting credentials")
}
ret.Credentials = creds
}
if ret.PoolBalancerType == "" { if ret.PoolBalancerType == "" {
ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin
} }
@ -241,7 +245,7 @@ func (s *sqlDatabase) sqlToCommonTags(tag Tag) params.Tag {
} }
} }
func (s *sqlDatabase) sqlToCommonRepository(repo Repository) (params.Repository, error) { func (s *sqlDatabase) sqlToCommonRepository(repo Repository, detailed bool) (params.Repository, error) {
if len(repo.WebhookSecret) == 0 { if len(repo.WebhookSecret) == 0 {
return params.Repository{}, errors.New("missing secret") return params.Repository{}, errors.New("missing secret")
} }
@ -250,21 +254,24 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository) (params.Repository,
return params.Repository{}, errors.Wrap(err, "decrypting secret") return params.Repository{}, errors.Wrap(err, "decrypting secret")
} }
creds, err := s.sqlToCommonGithubCredentials(repo.Credentials)
if err != nil {
return params.Repository{}, errors.Wrap(err, "converting credentials")
}
ret := params.Repository{ ret := params.Repository{
ID: repo.ID.String(), ID: repo.ID.String(),
Name: repo.Name, Name: repo.Name,
Owner: repo.Owner, Owner: repo.Owner,
CredentialsName: creds.Name, CredentialsName: repo.Credentials.Name,
Credentials: creds,
Pools: make([]params.Pool, len(repo.Pools)), Pools: make([]params.Pool, len(repo.Pools)),
WebhookSecret: string(secret), WebhookSecret: string(secret),
PoolBalancerType: repo.PoolBalancerType, PoolBalancerType: repo.PoolBalancerType,
} }
if detailed {
creds, err := s.sqlToCommonGithubCredentials(repo.Credentials)
if err != nil {
return params.Repository{}, errors.Wrap(err, "converting credentials")
}
ret.Credentials = creds
}
if ret.PoolBalancerType == "" { if ret.PoolBalancerType == "" {
ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin ret.PoolBalancerType = params.PoolBalancerTypeRoundRobin
} }

View file

@ -18,19 +18,79 @@
package testing package testing
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
"testing" "testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/config" "github.com/cloudbase/garm/config"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/params"
"github.com/cloudbase/garm/util/appdefaults"
) )
//nolint:golangci-lint,gosec //nolint:golangci-lint,gosec
var encryptionPassphrase = "bocyasicgatEtenOubwonIbsudNutDom" var encryptionPassphrase = "bocyasicgatEtenOubwonIbsudNutDom"
func ImpersonateAdminContext(ctx context.Context, db common.Store, s *testing.T) context.Context {
adminUser, err := db.GetAdminUser(ctx)
if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
s.Fatalf("failed to get admin user: %v", err)
}
newUserParams := params.NewUserParams{
Email: "admin@localhost",
Username: "admin",
Password: "superSecretAdminPassword@123",
IsAdmin: true,
Enabled: true,
}
adminUser, err = db.CreateUser(ctx, newUserParams)
if err != nil {
s.Fatalf("failed to create admin user: %v", err)
}
}
ctx = auth.PopulateContext(ctx, adminUser)
return ctx
}
func CreateDefaultGithubEndpoint(ctx context.Context, db common.Store, s *testing.T) params.GithubEndpoint {
endpointParams := params.CreateGithubEndpointParams{
Name: "github.com",
Description: "github endpoint",
APIBaseURL: appdefaults.GithubDefaultBaseURL,
UploadBaseURL: appdefaults.GithubDefaultUploadBaseURL,
BaseURL: appdefaults.DefaultGithubURL,
}
endpoint, err := db.CreateGithubEndpoint(ctx, endpointParams)
if err != nil {
s.Fatalf("failed to create database object (github.com): %v", err)
}
return endpoint
}
func CreateTestGithubCredentials(ctx context.Context, credsName string, db common.Store, s *testing.T, endpoint params.GithubEndpoint) params.GithubCredentials {
newCredsParams := params.CreateGithubCredentialsParams{
Name: credsName,
Description: "Test creds",
AuthType: params.GithubAuthTypePAT,
PAT: params.GithubPAT{
OAuth2Token: "test-token",
},
}
newCreds, err := db.CreateGithubCredentials(ctx, endpoint.Name, newCredsParams)
if err != nil {
s.Fatalf("failed to create database object (new-creds): %v", err)
}
return newCreds
}
func GetTestSqliteDBConfig(t *testing.T) config.Database { func GetTestSqliteDBConfig(t *testing.T) config.Database {
dir, err := os.MkdirTemp("", "garm-config-test") dir, err := os.MkdirTemp("", "garm-config-test")
if err != nil { if err != nil {

View file

@ -25,8 +25,8 @@ func (r *Runner) CreateEnterprise(ctx context.Context, param params.CreateEnterp
return params.Enterprise{}, errors.Wrap(err, "validating params") return params.Enterprise{}, errors.Wrap(err, "validating params")
} }
creds, ok := r.credentials[param.CredentialsName] creds, err := r.store.GetGithubCredentialsByName(ctx, param.CredentialsName, true)
if !ok { if err != nil {
return params.Enterprise{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) return params.Enterprise{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName)
} }
@ -161,25 +161,13 @@ func (r *Runner) UpdateEnterprise(ctx context.Context, enterpriseID string, para
r.mux.Lock() r.mux.Lock()
defer r.mux.Unlock() defer r.mux.Unlock()
enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
}
if param.CredentialsName != "" {
// Check that credentials are set before saving to db
if _, ok := r.credentials[param.CredentialsName]; !ok {
return params.Enterprise{}, runnerErrors.NewBadRequestError("invalid credentials (%s) for enterprise %s", param.CredentialsName, enterprise.Name)
}
}
switch param.PoolBalancerType { switch param.PoolBalancerType {
case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone: case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone:
default: default:
return params.Enterprise{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType) return params.Enterprise{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType)
} }
enterprise, err = r.store.UpdateEnterprise(ctx, enterpriseID, param) enterprise, err := r.store.UpdateEnterprise(ctx, enterpriseID, param)
if err != nil { if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise") return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
} }

View file

@ -19,12 +19,11 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
runnerErrors "github.com/cloudbase/garm-provider-common/errors" runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/config"
"github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database"
dbCommon "github.com/cloudbase/garm/database/common" dbCommon "github.com/cloudbase/garm/database/common"
garmTesting "github.com/cloudbase/garm/internal/testing" //nolint:typecheck garmTesting "github.com/cloudbase/garm/internal/testing" //nolint:typecheck
@ -40,7 +39,7 @@ type EnterpriseTestFixtures struct {
Store dbCommon.Store Store dbCommon.Store
StoreEnterprises map[string]params.Enterprise StoreEnterprises map[string]params.Enterprise
Providers map[string]common.Provider Providers map[string]common.Provider
Credentials map[string]config.Github Credentials map[string]params.GithubCredentials
CreateEnterpriseParams params.CreateEnterpriseParams CreateEnterpriseParams params.CreateEnterpriseParams
CreatePoolParams params.CreatePoolParams CreatePoolParams params.CreatePoolParams
CreateInstanceParams params.CreateInstanceParams CreateInstanceParams params.CreateInstanceParams
@ -57,18 +56,25 @@ type EnterpriseTestSuite struct {
suite.Suite suite.Suite
Fixtures *EnterpriseTestFixtures Fixtures *EnterpriseTestFixtures
Runner *Runner Runner *Runner
testCreds params.GithubCredentials
secondaryTestCreds params.GithubCredentials
githubEndpoint params.GithubEndpoint
} }
func (s *EnterpriseTestSuite) SetupTest() { func (s *EnterpriseTestSuite) SetupTest() {
adminCtx := auth.GetAdminContext(context.Background())
// create testing sqlite database // create testing sqlite database
dbCfg := garmTesting.GetTestSqliteDBConfig(s.T()) dbCfg := garmTesting.GetTestSqliteDBConfig(s.T())
db, err := database.NewDatabase(adminCtx, dbCfg) db, err := database.NewDatabase(context.Background(), dbCfg)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) s.FailNow(fmt.Sprintf("failed to create db connection: %s", err))
} }
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint)
s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint)
// create some organization objects in the database, for testing purposes // create some organization objects in the database, for testing purposes
enterprises := map[string]params.Enterprise{} enterprises := map[string]params.Enterprise{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
@ -76,12 +82,12 @@ func (s *EnterpriseTestSuite) SetupTest() {
enterprise, err := db.CreateEnterprise( enterprise, err := db.CreateEnterprise(
adminCtx, adminCtx,
name, name,
fmt.Sprintf("test-creds-%v", i), s.testCreds.Name,
fmt.Sprintf("test-webhook-secret-%v", i), fmt.Sprintf("test-webhook-secret-%v", i),
params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypeRoundRobin,
) )
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%v)", i)) s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%v): %+v", i, err))
} }
enterprises[name] = enterprise enterprises[name] = enterprise
} }
@ -98,16 +104,13 @@ func (s *EnterpriseTestSuite) SetupTest() {
Providers: map[string]common.Provider{ Providers: map[string]common.Provider{
"test-provider": providerMock, "test-provider": providerMock,
}, },
Credentials: map[string]config.Github{ Credentials: map[string]params.GithubCredentials{
"test-creds": { s.testCreds.Name: s.testCreds,
Name: "test-creds-name", s.secondaryTestCreds.Name: s.secondaryTestCreds,
Description: "test-creds-description",
OAuth2Token: "test-creds-oauth2-token",
},
}, },
CreateEnterpriseParams: params.CreateEnterpriseParams{ CreateEnterpriseParams: params.CreateEnterpriseParams{
Name: "test-enterprise-create", Name: "test-enterprise-create",
CredentialsName: "test-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "test-create-enterprise-webhook-secret", WebhookSecret: "test-create-enterprise-webhook-secret",
}, },
CreatePoolParams: params.CreatePoolParams{ CreatePoolParams: params.CreatePoolParams{
@ -126,7 +129,7 @@ func (s *EnterpriseTestSuite) SetupTest() {
OSType: "linux", OSType: "linux",
}, },
UpdateRepoParams: params.UpdateEntityParams{ UpdateRepoParams: params.UpdateEntityParams{
CredentialsName: "test-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "test-update-repo-webhook-secret", WebhookSecret: "test-update-repo-webhook-secret",
}, },
UpdatePoolParams: params.UpdatePoolParams{ UpdatePoolParams: params.UpdatePoolParams{
@ -148,7 +151,6 @@ func (s *EnterpriseTestSuite) SetupTest() {
// setup test runner // setup test runner
runner := &Runner{ runner := &Runner{
providers: fixtures.Providers, providers: fixtures.Providers,
credentials: fixtures.Credentials,
ctx: fixtures.AdminContext, ctx: fixtures.AdminContext,
store: fixtures.Store, store: fixtures.Store,
poolManagerCtrl: fixtures.PoolMgrCtrlMock, poolManagerCtrl: fixtures.PoolMgrCtrlMock,
@ -164,13 +166,13 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() {
// call tested function // call tested function
enterprise, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams) enterprise, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams)
// assertions
s.Fixtures.PoolMgrMock.AssertExpectations(s.T())
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.CreateEnterpriseParams.Name, enterprise.Name) s.Require().Equal(s.Fixtures.CreateEnterpriseParams.Name, enterprise.Name)
s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.Credentials.Name) s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.Credentials.Name)
s.Require().Equal(params.PoolBalancerTypeRoundRobin, enterprise.PoolBalancerType) s.Require().Equal(params.PoolBalancerTypeRoundRobin, enterprise.PoolBalancerType)
// assertions
s.Fixtures.PoolMgrMock.AssertExpectations(s.T())
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
} }
func (s *EnterpriseTestSuite) TestCreateEnterpriseErrUnauthorized() { func (s *EnterpriseTestSuite) TestCreateEnterpriseErrUnauthorized() {
@ -322,7 +324,9 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidCreds() {
_, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) _, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams)
s.Require().Equal(runnerErrors.NewBadRequestError("invalid credentials (%s) for enterprise %s", s.Fixtures.UpdateRepoParams.CredentialsName, s.Fixtures.StoreEnterprises["test-enterprise-1"].Name), err) if !errors.Is(err, runnerErrors.ErrNotFound) {
s.FailNow(fmt.Sprintf("expected error: %v", runnerErrors.ErrNotFound))
}
} }
func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() { func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() {

View file

@ -38,8 +38,8 @@ func (r *Runner) CreateOrganization(ctx context.Context, param params.CreateOrgP
return params.Organization{}, errors.Wrap(err, "validating params") return params.Organization{}, errors.Wrap(err, "validating params")
} }
creds, ok := r.credentials[param.CredentialsName] creds, err := r.store.GetGithubCredentialsByName(ctx, param.CredentialsName, true)
if !ok { if err != nil {
return params.Organization{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) return params.Organization{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName)
} }
@ -190,25 +190,13 @@ func (r *Runner) UpdateOrganization(ctx context.Context, orgID string, param par
r.mux.Lock() r.mux.Lock()
defer r.mux.Unlock() defer r.mux.Unlock()
org, err := r.store.GetOrganizationByID(ctx, orgID)
if err != nil {
return params.Organization{}, errors.Wrap(err, "fetching org")
}
if param.CredentialsName != "" {
// Check that credentials are set before saving to db
if _, ok := r.credentials[param.CredentialsName]; !ok {
return params.Organization{}, runnerErrors.NewBadRequestError("invalid credentials (%s) for org %s", param.CredentialsName, org.Name)
}
}
switch param.PoolBalancerType { switch param.PoolBalancerType {
case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone: case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone:
default: default:
return params.Organization{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType) return params.Organization{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType)
} }
org, err = r.store.UpdateOrganization(ctx, orgID, param) org, err := r.store.UpdateOrganization(ctx, orgID, param)
if err != nil { if err != nil {
return params.Organization{}, errors.Wrap(err, "updating org") return params.Organization{}, errors.Wrap(err, "updating org")
} }

View file

@ -19,12 +19,11 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
runnerErrors "github.com/cloudbase/garm-provider-common/errors" runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/config"
"github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database"
dbCommon "github.com/cloudbase/garm/database/common" dbCommon "github.com/cloudbase/garm/database/common"
garmTesting "github.com/cloudbase/garm/internal/testing" garmTesting "github.com/cloudbase/garm/internal/testing"
@ -40,7 +39,7 @@ type OrgTestFixtures struct {
Store dbCommon.Store Store dbCommon.Store
StoreOrgs map[string]params.Organization StoreOrgs map[string]params.Organization
Providers map[string]common.Provider Providers map[string]common.Provider
Credentials map[string]config.Github Credentials map[string]params.GithubCredentials
CreateOrgParams params.CreateOrgParams CreateOrgParams params.CreateOrgParams
CreatePoolParams params.CreatePoolParams CreatePoolParams params.CreatePoolParams
CreateInstanceParams params.CreateInstanceParams CreateInstanceParams params.CreateInstanceParams
@ -57,18 +56,26 @@ type OrgTestSuite struct {
suite.Suite suite.Suite
Fixtures *OrgTestFixtures Fixtures *OrgTestFixtures
Runner *Runner Runner *Runner
testCreds params.GithubCredentials
secondaryTestCreds params.GithubCredentials
githubEndpoint params.GithubEndpoint
} }
func (s *OrgTestSuite) SetupTest() { func (s *OrgTestSuite) SetupTest() {
adminCtx := auth.GetAdminContext(context.Background())
// create testing sqlite database // create testing sqlite database
dbCfg := garmTesting.GetTestSqliteDBConfig(s.T()) dbCfg := garmTesting.GetTestSqliteDBConfig(s.T())
db, err := database.NewDatabase(adminCtx, dbCfg) db, err := database.NewDatabase(context.Background(), dbCfg)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) s.FailNow(fmt.Sprintf("failed to create db connection: %s", err))
} }
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint)
s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint)
// create some organization objects in the database, for testing purposes // create some organization objects in the database, for testing purposes
orgs := map[string]params.Organization{} orgs := map[string]params.Organization{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
@ -76,7 +83,7 @@ func (s *OrgTestSuite) SetupTest() {
org, err := db.CreateOrganization( org, err := db.CreateOrganization(
adminCtx, adminCtx,
name, name,
fmt.Sprintf("test-creds-%v", i), s.testCreds.Name,
fmt.Sprintf("test-webhook-secret-%v", i), fmt.Sprintf("test-webhook-secret-%v", i),
params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypeRoundRobin,
) )
@ -98,16 +105,13 @@ func (s *OrgTestSuite) SetupTest() {
Providers: map[string]common.Provider{ Providers: map[string]common.Provider{
"test-provider": providerMock, "test-provider": providerMock,
}, },
Credentials: map[string]config.Github{ Credentials: map[string]params.GithubCredentials{
"test-creds": { s.testCreds.Name: s.testCreds,
Name: "test-creds-name", s.secondaryTestCreds.Name: s.secondaryTestCreds,
Description: "test-creds-description",
OAuth2Token: "test-creds-oauth2-token",
},
}, },
CreateOrgParams: params.CreateOrgParams{ CreateOrgParams: params.CreateOrgParams{
Name: "test-org-create", Name: "test-org-create",
CredentialsName: "test-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "test-create-org-webhook-secret", WebhookSecret: "test-create-org-webhook-secret",
}, },
CreatePoolParams: params.CreatePoolParams{ CreatePoolParams: params.CreatePoolParams{
@ -126,7 +130,7 @@ func (s *OrgTestSuite) SetupTest() {
OSType: "linux", OSType: "linux",
}, },
UpdateRepoParams: params.UpdateEntityParams{ UpdateRepoParams: params.UpdateEntityParams{
CredentialsName: "test-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "test-update-repo-webhook-secret", WebhookSecret: "test-update-repo-webhook-secret",
}, },
UpdatePoolParams: params.UpdatePoolParams{ UpdatePoolParams: params.UpdatePoolParams{
@ -148,7 +152,6 @@ func (s *OrgTestSuite) SetupTest() {
// setup test runner // setup test runner
runner := &Runner{ runner := &Runner{
providers: fixtures.Providers, providers: fixtures.Providers,
credentials: fixtures.Credentials,
ctx: fixtures.AdminContext, ctx: fixtures.AdminContext,
store: fixtures.Store, store: fixtures.Store,
poolManagerCtrl: fixtures.PoolMgrCtrlMock, poolManagerCtrl: fixtures.PoolMgrCtrlMock,
@ -346,8 +349,9 @@ func (s *OrgTestSuite) TestUpdateOrganizationInvalidCreds() {
s.Fixtures.UpdateRepoParams.CredentialsName = invalidCredentialsName s.Fixtures.UpdateRepoParams.CredentialsName = invalidCredentialsName
_, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams) _, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams)
if !errors.Is(err, runnerErrors.ErrNotFound) {
s.Require().Equal(runnerErrors.NewBadRequestError("invalid credentials (%s) for org %s", s.Fixtures.UpdateRepoParams.CredentialsName, s.Fixtures.StoreOrgs["test-org-1"].Name), err) s.FailNow(fmt.Sprintf("expected error: %v", runnerErrors.ErrNotFound))
}
} }
func (s *OrgTestSuite) TestUpdateOrganizationPoolMgrFailed() { func (s *OrgTestSuite) TestUpdateOrganizationPoolMgrFailed() {

View file

@ -45,6 +45,11 @@ type PoolTestSuite struct {
suite.Suite suite.Suite
Fixtures *PoolTestFixtures Fixtures *PoolTestFixtures
Runner *Runner Runner *Runner
adminCtx context.Context
testCreds params.GithubCredentials
secondaryTestCreds params.GithubCredentials
githubEndpoint params.GithubEndpoint
} }
func (s *PoolTestSuite) SetupTest() { func (s *PoolTestSuite) SetupTest() {
@ -57,8 +62,14 @@ func (s *PoolTestSuite) SetupTest() {
s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) s.FailNow(fmt.Sprintf("failed to create db connection: %s", err))
} }
s.adminCtx = garmTesting.ImpersonateAdminContext(adminCtx, db, s.T())
s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(s.adminCtx, db, s.T())
s.testCreds = garmTesting.CreateTestGithubCredentials(s.adminCtx, "new-creds", db, s.T(), s.githubEndpoint)
s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(s.adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint)
// create an organization for testing purposes // create an organization for testing purposes
org, err := db.CreateOrganization(context.Background(), "test-org", "test-creds", "test-webhookSecret", params.PoolBalancerTypeRoundRobin) org, err := db.CreateOrganization(s.adminCtx, "test-org", s.testCreds.Name, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create org: %s", err)) s.FailNow(fmt.Sprintf("failed to create org: %s", err))
} }
@ -71,7 +82,7 @@ func (s *PoolTestSuite) SetupTest() {
orgPools := []params.Pool{} orgPools := []params.Pool{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
pool, err := db.CreateEntityPool( pool, err := db.CreateEntityPool(
context.Background(), adminCtx,
entity, entity,
params.CreatePoolParams{ params.CreatePoolParams{
ProviderName: "test-provider", ProviderName: "test-provider",
@ -112,10 +123,9 @@ func (s *PoolTestSuite) SetupTest() {
// setup test runner // setup test runner
runner := &Runner{ runner := &Runner{
providers: fixtures.Providers, providers: fixtures.Providers,
credentials: fixtures.Credentials, store: fixtures.Store,
store: fixtures.Store, ctx: fixtures.AdminContext,
ctx: fixtures.AdminContext,
} }
s.Runner = runner s.Runner = runner
} }

View file

@ -38,8 +38,8 @@ func (r *Runner) CreateRepository(ctx context.Context, param params.CreateRepoPa
return params.Repository{}, errors.Wrap(err, "validating params") return params.Repository{}, errors.Wrap(err, "validating params")
} }
creds, ok := r.credentials[param.CredentialsName] creds, err := r.store.GetGithubCredentialsByName(ctx, param.CredentialsName, true)
if !ok { if err != nil {
return params.Repository{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName) return params.Repository{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName)
} }
@ -189,25 +189,13 @@ func (r *Runner) UpdateRepository(ctx context.Context, repoID string, param para
r.mux.Lock() r.mux.Lock()
defer r.mux.Unlock() defer r.mux.Unlock()
repo, err := r.store.GetRepositoryByID(ctx, repoID)
if err != nil {
return params.Repository{}, errors.Wrap(err, "fetching repo")
}
if param.CredentialsName != "" {
// Check that credentials are set before saving to db
if _, ok := r.credentials[param.CredentialsName]; !ok {
return params.Repository{}, runnerErrors.NewBadRequestError("invalid credentials (%s) for repo %s/%s", param.CredentialsName, repo.Owner, repo.Name)
}
}
switch param.PoolBalancerType { switch param.PoolBalancerType {
case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone: case params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypePack, params.PoolBalancerTypeNone:
default: default:
return params.Repository{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType) return params.Repository{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType)
} }
repo, err = r.store.UpdateRepository(ctx, repoID, param) repo, err := r.store.UpdateRepository(ctx, repoID, param)
if err != nil { if err != nil {
return params.Repository{}, errors.Wrap(err, "updating repo") return params.Repository{}, errors.Wrap(err, "updating repo")
} }

View file

@ -19,12 +19,12 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
runnerErrors "github.com/cloudbase/garm-provider-common/errors" runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/auth" "github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/config"
"github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database"
dbCommon "github.com/cloudbase/garm/database/common" dbCommon "github.com/cloudbase/garm/database/common"
garmTesting "github.com/cloudbase/garm/internal/testing" garmTesting "github.com/cloudbase/garm/internal/testing"
@ -39,7 +39,7 @@ type RepoTestFixtures struct {
Store dbCommon.Store Store dbCommon.Store
StoreRepos map[string]params.Repository StoreRepos map[string]params.Repository
Providers map[string]common.Provider Providers map[string]common.Provider
Credentials map[string]config.Github Credentials map[string]params.GithubCredentials
CreateRepoParams params.CreateRepoParams CreateRepoParams params.CreateRepoParams
CreatePoolParams params.CreatePoolParams CreatePoolParams params.CreatePoolParams
CreateInstanceParams params.CreateInstanceParams CreateInstanceParams params.CreateInstanceParams
@ -56,18 +56,25 @@ type RepoTestSuite struct {
suite.Suite suite.Suite
Fixtures *RepoTestFixtures Fixtures *RepoTestFixtures
Runner *Runner Runner *Runner
testCreds params.GithubCredentials
secondaryTestCreds params.GithubCredentials
githubEndpoint params.GithubEndpoint
} }
func (s *RepoTestSuite) SetupTest() { func (s *RepoTestSuite) SetupTest() {
adminCtx := auth.GetAdminContext(context.Background())
// create testing sqlite database // create testing sqlite database
dbCfg := garmTesting.GetTestSqliteDBConfig(s.T()) dbCfg := garmTesting.GetTestSqliteDBConfig(s.T())
db, err := database.NewDatabase(adminCtx, dbCfg) db, err := database.NewDatabase(context.Background(), dbCfg)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) s.FailNow(fmt.Sprintf("failed to create db connection: %s", err))
} }
adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T())
s.githubEndpoint = garmTesting.CreateDefaultGithubEndpoint(adminCtx, db, s.T())
s.testCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), s.githubEndpoint)
s.secondaryTestCreds = garmTesting.CreateTestGithubCredentials(adminCtx, "secondary-creds", db, s.T(), s.githubEndpoint)
// create some repository objects in the database, for testing purposes // create some repository objects in the database, for testing purposes
repos := map[string]params.Repository{} repos := map[string]params.Repository{}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
@ -76,7 +83,7 @@ func (s *RepoTestSuite) SetupTest() {
adminCtx, adminCtx,
fmt.Sprintf("test-owner-%v", i), fmt.Sprintf("test-owner-%v", i),
name, name,
fmt.Sprintf("test-creds-%v", i), s.testCreds.Name,
fmt.Sprintf("test-webhook-secret-%v", i), fmt.Sprintf("test-webhook-secret-%v", i),
params.PoolBalancerTypeRoundRobin, params.PoolBalancerTypeRoundRobin,
) )
@ -97,17 +104,14 @@ func (s *RepoTestSuite) SetupTest() {
Providers: map[string]common.Provider{ Providers: map[string]common.Provider{
"test-provider": providerMock, "test-provider": providerMock,
}, },
Credentials: map[string]config.Github{ Credentials: map[string]params.GithubCredentials{
"test-creds": { s.testCreds.Name: s.testCreds,
Name: "test-creds-name", s.secondaryTestCreds.Name: s.secondaryTestCreds,
Description: "test-creds-description",
OAuth2Token: "test-creds-oauth2-token",
},
}, },
CreateRepoParams: params.CreateRepoParams{ CreateRepoParams: params.CreateRepoParams{
Owner: "test-owner-create", Owner: "test-owner-create",
Name: "test-repo-create", Name: "test-repo-create",
CredentialsName: "test-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "test-create-repo-webhook-secret", WebhookSecret: "test-create-repo-webhook-secret",
}, },
CreatePoolParams: params.CreatePoolParams{ CreatePoolParams: params.CreatePoolParams{
@ -126,7 +130,7 @@ func (s *RepoTestSuite) SetupTest() {
OSType: "linux", OSType: "linux",
}, },
UpdateRepoParams: params.UpdateEntityParams{ UpdateRepoParams: params.UpdateEntityParams{
CredentialsName: "test-creds", CredentialsName: s.testCreds.Name,
WebhookSecret: "test-update-repo-webhook-secret", WebhookSecret: "test-update-repo-webhook-secret",
}, },
UpdatePoolParams: params.UpdatePoolParams{ UpdatePoolParams: params.UpdatePoolParams{
@ -148,7 +152,6 @@ func (s *RepoTestSuite) SetupTest() {
// setup test runner // setup test runner
runner := &Runner{ runner := &Runner{
providers: fixtures.Providers, providers: fixtures.Providers,
credentials: fixtures.Credentials,
ctx: fixtures.AdminContext, ctx: fixtures.AdminContext,
store: fixtures.Store, store: fixtures.Store,
poolManagerCtrl: fixtures.PoolMgrCtrlMock, poolManagerCtrl: fixtures.PoolMgrCtrlMock,
@ -167,6 +170,7 @@ func (s *RepoTestSuite) TestCreateRepository() {
// assertions // assertions
s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrMock.AssertExpectations(s.T())
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Nil(err) s.Require().Nil(err)
s.Require().Equal(s.Fixtures.CreateRepoParams.Owner, repo.Owner) s.Require().Equal(s.Fixtures.CreateRepoParams.Owner, repo.Owner)
s.Require().Equal(s.Fixtures.CreateRepoParams.Name, repo.Name) s.Require().Equal(s.Fixtures.CreateRepoParams.Name, repo.Name)
@ -358,7 +362,9 @@ func (s *RepoTestSuite) TestUpdateRepositoryInvalidCreds() {
_, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams) _, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams)
s.Require().Equal(runnerErrors.NewBadRequestError("invalid credentials (%s) for repo %s/%s", s.Fixtures.UpdateRepoParams.CredentialsName, s.Fixtures.StoreRepos["test-repo-1"].Owner, s.Fixtures.StoreRepos["test-repo-1"].Name), err) if !errors.Is(err, runnerErrors.ErrNotFound) {
s.FailNow(fmt.Sprintf("expected error: %v", runnerErrors.ErrNotFound))
}
} }
func (s *RepoTestSuite) TestUpdateRepositoryPoolMgrFailed() { func (s *RepoTestSuite) TestUpdateRepositoryPoolMgrFailed() {

View file

@ -77,7 +77,6 @@ func NewRunner(ctx context.Context, cfg config.Config, db dbCommon.Store) (*Runn
store: db, store: db,
poolManagerCtrl: poolManagerCtrl, poolManagerCtrl: poolManagerCtrl,
providers: providers, providers: providers,
credentials: creds,
controllerID: ctrlID.ControllerID, controllerID: ctrlID.ControllerID,
} }
@ -355,8 +354,7 @@ type Runner struct {
poolManagerCtrl PoolManagerController poolManagerCtrl PoolManagerController
providers map[string]common.Provider providers map[string]common.Provider
credentials map[string]config.Github
controllerInfo params.ControllerInfo controllerInfo params.ControllerInfo
controllerID uuid.UUID controllerID uuid.UUID