From 0c8c6f5668b058361bddaa477b38da1c1c2c7b48 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Wed, 19 Jun 2024 12:19:58 +0000 Subject: [PATCH] Add more tests Signed-off-by: Gabriel Adrian Samfira --- database/sql/instances.go | 36 +++- database/sql/pools.go | 12 +- database/watcher/watcher_store_test.go | 217 +++++++++++++++++++++++++ database/watcher/watcher_test.go | 2 - internal/testing/testing.go | 2 +- 5 files changed, 253 insertions(+), 16 deletions(-) diff --git a/database/sql/instances.go b/database/sql/instances.go index c09b60f3..3f113669 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -25,15 +25,22 @@ import ( "gorm.io/gorm/clause" runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { +func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) { pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") } + defer func() { + if err == nil { + s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance) + } + }() + var labels datatypes.JSON if len(param.AditionalLabels) > 0 { labels, err = json.Marshal(param.AditionalLabels) @@ -134,11 +141,28 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string return s.sqlToParamsInstance(instance) } -func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) error { +func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) (err error) { instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return errors.Wrap(err, "deleting instance") } + + defer func() { + if err == nil { + var providerID string + if instance.ProviderID != nil { + providerID = *instance.ProviderID + } + s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{ + ID: instance.ID.String(), + Name: instance.Name, + ProviderID: providerID, + AgentID: instance.AgentID, + PoolID: instance.PoolID.String(), + }) + } + }() + if q := s.conn.Unscoped().Delete(&instance); q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return nil @@ -230,8 +254,12 @@ func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, p return params.Instance{}, errors.Wrap(err, "updating addresses") } } - - return s.sqlToParamsInstance(instance) + inst, err := s.sqlToParamsInstance(instance) + if err != nil { + return params.Instance{}, errors.Wrap(err, "converting instance") + } + s.sendNotify(common.InstanceEntityType, common.UpdateOperation, inst) + return inst, nil } func (s *sqlDatabase) ListPoolInstances(_ context.Context, poolID string) ([]params.Instance, error) { diff --git a/database/sql/pools.go b/database/sql/pools.go index 89500ed9..0cb4a094 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -17,7 +17,6 @@ package sql import ( "context" "fmt" - "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -74,16 +73,11 @@ func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) (err erro return errors.Wrap(err, "fetching pool by ID") } - defer func(pool Pool) { + defer func() { if err == nil { - asParams, innerErr := s.sqlToCommonPool(pool) - if innerErr == nil { - s.sendNotify(common.PoolEntityType, common.DeleteOperation, asParams) - } else { - slog.With(slog.Any("error", innerErr)).ErrorContext(s.ctx, "error sending delete notification", "pool", poolID) - } + s.sendNotify(common.PoolEntityType, common.DeleteOperation, params.Pool{ID: poolID}) } - }(pool) + }() if q := s.conn.Unscoped().Delete(&pool); q.Error != nil { return errors.Wrap(q.Error, "removing pool") diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index fa82a339..80f71325 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/suite" + commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" garmTesting "github.com/cloudbase/garm/internal/testing" @@ -19,6 +20,221 @@ type WatcherStoreTestSuite struct { ctx context.Context } +func (s *WatcherStoreTestSuite) TestInstanceWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "instance-test", + watcher.WithEntityTypeFilter(common.InstanceEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + s.T().Cleanup(func() { s.store.DeleteRepository(s.ctx, repo.ID) }) + + entity, err := repo.GetEntity() + s.Require().NoError(err) + + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + Image: "test-image", + Flavor: "test-flavor", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Tags: []string{"test-tag"}, + } + + pool, err := s.store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + s.T().Cleanup(func() { s.store.DeleteEntityPool(s.ctx, entity, pool.ID) }) + + createInstanceParams := params.CreateInstanceParams{ + Name: "test-instance", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Status: commonParams.InstanceCreating, + } + instance, err := s.store.CreateInstance(s.ctx, pool.ID, createInstanceParams) + s.Require().NoError(err) + s.Require().NotEmpty(instance.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.InstanceEntityType, + Operation: common.CreateOperation, + Payload: instance, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdateInstanceParams{ + RunnerStatus: params.RunnerActive, + } + + updatedInstance, err := s.store.UpdateInstance(s.ctx, instance.Name, updateParams) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.InstanceEntityType, + Operation: common.UpdateOperation, + Payload: updatedInstance, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteInstance(s.ctx, pool.ID, updatedInstance.Name) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.InstanceEntityType, + Operation: common.DeleteOperation, + Payload: params.Instance{ + ID: updatedInstance.ID, + Name: updatedInstance.Name, + ProviderID: updatedInstance.ProviderID, + AgentID: updatedInstance.AgentID, + PoolID: updatedInstance.PoolID, + }, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestPoolWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "pool-test", + watcher.WithEntityTypeFilter(common.PoolEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { + if err := s.store.DeleteGithubCredentials(s.ctx, creds.ID); err != nil { + s.T().Logf("failed to delete Github credentials: %v", err) + } + }) + + repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + s.T().Cleanup(func() { s.store.DeleteRepository(s.ctx, repo.ID) }) + + entity, err := repo.GetEntity() + s.Require().NoError(err) + + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + Image: "test-image", + Flavor: "test-flavor", + OSType: commonParams.Linux, + OSArch: commonParams.Amd64, + Tags: []string{"test-tag"}, + } + pool, err := s.store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.CreateOperation, + Payload: pool, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdatePoolParams{ + Tags: []string{"updated-tag"}, + } + + updatedPool, err := s.store.UpdateEntityPool(s.ctx, entity, pool.ID, updateParams) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.UpdateOperation, + Payload: updatedPool, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteEntityPool(s.ctx, entity, pool.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.DeleteOperation, + Payload: params.Pool{ID: pool.ID}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + // Also test DeletePoolByID + pool, err = s.store.CreateEntityPool(s.ctx, entity, createPoolParams) + s.Require().NoError(err) + s.Require().NotEmpty(pool.ID) + + // Consume the create event + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.CreateOperation, + Payload: pool, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeletePoolByID(s.ctx, pool.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.PoolEntityType, + Operation: common.DeleteOperation, + Payload: params.Pool{ID: pool.ID}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + func (s *WatcherStoreTestSuite) TestControllerWatcher() { consumer, err := watcher.RegisterConsumer( s.ctx, "controller-test", @@ -275,6 +491,7 @@ func (s *WatcherStoreTestSuite) TestGithubCredentialsWatcher() { ghCred, err := s.store.CreateGithubCredentials(s.ctx, ghCredParams) s.Require().NoError(err) s.Require().NotEmpty(ghCred.ID) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, ghCred.ID) }) select { case event := <-consumer.Watch(): diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go index 6d1091ed..c5b56fe2 100644 --- a/database/watcher/watcher_test.go +++ b/database/watcher/watcher_test.go @@ -4,7 +4,6 @@ package watcher_test import ( "context" - "fmt" "testing" "time" @@ -26,7 +25,6 @@ type WatcherTestSuite struct { func (s *WatcherTestSuite) SetupTest() { ctx := context.TODO() watcher.InitWatcher(ctx) - fmt.Printf("creating store: %v\n", s.store) store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T())) if err != nil { s.T().Fatalf("failed to create db connection: %s", err) diff --git a/internal/testing/testing.go b/internal/testing/testing.go index 8949f7cf..6e76956f 100644 --- a/internal/testing/testing.go +++ b/internal/testing/testing.go @@ -122,7 +122,7 @@ func CreateTestGithubCredentials(ctx context.Context, credsName string, db commo } newCreds, err := db.CreateGithubCredentials(ctx, newCredsParams) if err != nil { - s.Fatalf("failed to create database object (new-creds): %v", err) + s.Fatalf("failed to create database object (%s): %v", credsName, err) } return newCreds }