From c601f88cf7fb68477f15a9be9812f0e88de16d19 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Fri, 2 May 2025 12:22:04 +0000 Subject: [PATCH] Add more tests Signed-off-by: Gabriel Adrian Samfira --- database/sql/instances_test.go | 6 + database/sql/pools_test.go | 141 +++++++++++++++++++++- database/watcher/watcher.go | 9 ++ database/watcher/watcher_store_test.go | 5 +- util/github/scalesets/message_sessions.go | 2 +- 5 files changed, 160 insertions(+), 3 deletions(-) diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 9d000cef..90418be7 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -119,6 +119,12 @@ func (s *InstancesTestSuite) SetupTest() { CallbackURL: "https://garm.example.com/", Status: commonParams.InstanceRunning, RunnerStatus: params.RunnerIdle, + JitConfiguration: map[string]string{ + "secret": fmt.Sprintf("secret-%d", i), + }, + AditionalLabels: []string{ + fmt.Sprintf("label-%d", i), + }, }, ) if err != nil { diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index e6cf7f4a..990d6808 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -16,6 +16,7 @@ package sql import ( "context" + "encoding/json" "flag" "fmt" "regexp" @@ -27,7 +28,10 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" + commonParams "github.com/cloudbase/garm-provider-common/params" + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" garmTesting "github.com/cloudbase/garm/internal/testing" "github.com/cloudbase/garm/params" ) @@ -40,7 +44,9 @@ type PoolsTestFixtures struct { type PoolsTestSuite struct { suite.Suite - Store dbCommon.Store + Store dbCommon.Store + ctx context.Context + StoreSQLMocked *sqlDatabase Fixtures *PoolsTestFixtures adminCtx context.Context @@ -53,13 +59,21 @@ func (s *PoolsTestSuite) assertSQLMockExpectations() { } } +func (s *PoolsTestSuite) TearDownTest() { + watcher.CloseWatcher() +} + func (s *PoolsTestSuite) SetupTest() { // create testing sqlite database + ctx := context.Background() + watcher.InitWatcher(ctx) + db, err := NewSQLDatabase(context.Background(), garmTesting.GetTestSqliteDBConfig(s.T())) if err != nil { s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) } s.Store = db + s.ctx = garmTesting.ImpersonateAdminContext(ctx, s.Store, s.T()) adminCtx := garmTesting.ImpersonateAdminContext(context.Background(), db, s.T()) s.adminCtx = adminCtx @@ -194,6 +208,131 @@ func (s *PoolsTestSuite) TestDeletePoolByIDDBRemoveErr() { s.Require().Equal("removing pool: mocked removing pool error", err.Error()) } +func (s *PoolsTestSuite) TestEntityPoolOperations() { + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.Store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.Store, s.T(), ep) + s.T().Cleanup(func() { s.Store.DeleteGithubCredentials(s.ctx, creds.ID) }) + repo, err := s.Store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + s.T().Cleanup(func() { s.Store.DeleteRepository(s.ctx, repo.ID) }) + + entity, err := repo.GetEntity() + s.Require().NoError(err) + + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + Image: "test-image", + Flavor: "test-flavor", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Tags: []string{"test-tag"}, + } + + pool, err := s.Store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + s.T().Cleanup(func() { s.Store.DeleteEntityPool(s.ctx, entity, pool.ID) }) + + entityPool, err := s.Store.GetEntityPool(s.ctx, entity, pool.ID) + s.Require().NoError(err) + s.Require().Equal(pool.ID, entityPool.ID) + s.Require().Equal(pool.ProviderName, entityPool.ProviderName) + + updatePoolParams := params.UpdatePoolParams{ + Enabled: garmTesting.Ptr(true), + Flavor: "new-flavor", + Image: "new-image", + RunnerPrefix: params.RunnerPrefix{ + Prefix: "new-prefix", + }, + MaxRunners: garmTesting.Ptr(uint(100)), + MinIdleRunners: garmTesting.Ptr(uint(50)), + OSType: commonParams.Windows, + OSArch: commonParams.Amd64, + Tags: []string{"new-tag"}, + RunnerBootstrapTimeout: garmTesting.Ptr(uint(10)), + ExtraSpecs: json.RawMessage(`{"extra": "specs"}`), + GitHubRunnerGroup: garmTesting.Ptr("new-group"), + Priority: garmTesting.Ptr(uint(1)), + } + pool, err = s.Store.UpdateEntityPool(s.ctx, entity, pool.ID, updatePoolParams) + s.Require().NoError(err) + s.Require().Equal(*updatePoolParams.Enabled, pool.Enabled) + s.Require().Equal(updatePoolParams.Flavor, pool.Flavor) + s.Require().Equal(updatePoolParams.Image, pool.Image) + s.Require().Equal(updatePoolParams.RunnerPrefix.Prefix, pool.RunnerPrefix.Prefix) + s.Require().Equal(*updatePoolParams.MaxRunners, pool.MaxRunners) + s.Require().Equal(*updatePoolParams.MinIdleRunners, pool.MinIdleRunners) + s.Require().Equal(updatePoolParams.OSType, pool.OSType) + s.Require().Equal(updatePoolParams.OSArch, pool.OSArch) + s.Require().Equal(*updatePoolParams.RunnerBootstrapTimeout, pool.RunnerBootstrapTimeout) + s.Require().Equal(updatePoolParams.ExtraSpecs, pool.ExtraSpecs) + s.Require().Equal(*updatePoolParams.GitHubRunnerGroup, pool.GitHubRunnerGroup) + s.Require().Equal(*updatePoolParams.Priority, pool.Priority) + + entityPools, err := s.Store.ListEntityPools(s.ctx, entity) + s.Require().NoError(err) + s.Require().Len(entityPools, 1) + s.Require().Equal(pool.ID, entityPools[0].ID) + + tagsToMatch := []string{"new-tag"} + pools, err := s.Store.FindPoolsMatchingAllTags(s.ctx, entity.EntityType, entity.ID, tagsToMatch) + s.Require().NoError(err) + s.Require().Len(pools, 1) + s.Require().Equal(pool.ID, pools[0].ID) + + invalidTagsToMatch := []string{"invalid-tag"} + pools, err = s.Store.FindPoolsMatchingAllTags(s.ctx, entity.EntityType, entity.ID, invalidTagsToMatch) + s.Require().NoError(err) + s.Require().Len(pools, 0) +} + +func (s *PoolsTestSuite) TestListEntityInstances() { + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.Store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.Store, s.T(), ep) + s.T().Cleanup(func() { s.Store.DeleteGithubCredentials(s.ctx, creds.ID) }) + repo, err := s.Store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + s.T().Cleanup(func() { s.Store.DeleteRepository(s.ctx, repo.ID) }) + + entity, err := repo.GetEntity() + s.Require().NoError(err) + + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + Image: "test-image", + Flavor: "test-flavor", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Tags: []string{"test-tag"}, + } + + pool, err := s.Store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + s.T().Cleanup(func() { s.Store.DeleteEntityPool(s.ctx, entity, pool.ID) }) + + createInstanceParams := params.CreateInstanceParams{ + Name: "test-instance", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Status: commonParams.InstanceCreating, + } + instance, err := s.Store.CreateInstance(s.ctx, pool.ID, createInstanceParams) + s.Require().NoError(err) + s.Require().NotEmpty(instance.ID) + + s.T().Cleanup(func() { s.Store.DeleteInstance(s.ctx, pool.ID, instance.ID) }) + + instances, err := s.Store.ListEntityInstances(s.ctx, entity) + s.Require().NoError(err) + s.Require().Len(instances, 1) + s.Require().Equal(instance.ID, instances[0].ID) + s.Require().Equal(instance.Name, instances[0].Name) +} + func TestPoolsTestSuite(t *testing.T) { t.Parallel() suite.Run(t, new(PoolsTestSuite)) diff --git a/database/watcher/watcher.go b/database/watcher/watcher.go index 2ef1aeee..fda318c6 100644 --- a/database/watcher/watcher.go +++ b/database/watcher/watcher.go @@ -29,6 +29,15 @@ func InitWatcher(ctx context.Context) { databaseWatcher = w } +func CloseWatcher() error { + if databaseWatcher == nil { + return nil + } + databaseWatcher.Close() + databaseWatcher = nil + return nil +} + func RegisterProducer(ctx context.Context, id string) (common.Producer, error) { if databaseWatcher == nil { return nil, common.ErrWatcherNotInitialized diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index a0845b9c..af3185db 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -748,8 +748,11 @@ func consumeEvents(consumer common.Consumer) { consume: for { select { - case <-consumer.Watch(): + case _, ok := <-consumer.Watch(): // throw away event. + if !ok { + return + } case <-time.After(100 * time.Millisecond): break consume } diff --git a/util/github/scalesets/message_sessions.go b/util/github/scalesets/message_sessions.go index 79d5c26e..8fafc2c4 100644 --- a/util/github/scalesets/message_sessions.go +++ b/util/github/scalesets/message_sessions.go @@ -132,7 +132,7 @@ func (m *MessageSession) Refresh(ctx context.Context) error { if err := json.NewDecoder(resp.Body).Decode(&refreshedSession); err != nil { return fmt.Errorf("failed to decode response: %w", err) } - slog.DebugContext(ctx, "refreshed message session token", "session", refreshedSession) + slog.DebugContext(ctx, "refreshed message session token") m.session = &refreshedSession return nil }