From 8a79d9e8f95eb4f7e45671501edb1df98511589e Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Tue, 18 Jun 2024 17:45:48 +0000 Subject: [PATCH] Add more watcher tests Signed-off-by: Gabriel Adrian Samfira --- database/sql/enterprise.go | 38 +++- database/sql/organizations.go | 44 +++- database/sql/repositories.go | 5 +- database/watcher/watcher_store_test.go | 294 ++++++++++++++++++++++++- database/watcher/watcher_test.go | 8 +- 5 files changed, 366 insertions(+), 23 deletions(-) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index 7d20d2e8..c5af3bc4 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -16,6 +16,7 @@ package sql import ( "context" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -23,10 +24,11 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) { +func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (paramEnt params.Enterprise, err error) { if webhookSecret == "" { return params.Enterprise{}, errors.New("creating enterprise: missing secret") } @@ -34,6 +36,12 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam if err != nil { return params.Enterprise{}, errors.Wrap(err, "encoding secret") } + + defer func() { + if err == nil { + s.sendNotify(common.EnterpriseEntityType, common.CreateOperation, paramEnt) + } + }() newEnterprise := Enterprise{ Name: name, WebhookSecret: secret, @@ -66,12 +74,12 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam return params.Enterprise{}, errors.Wrap(err, "creating enterprise") } - param, err := s.sqlToCommonEnterprise(newEnterprise, true) + paramEnt, err = s.sqlToCommonEnterprise(newEnterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "creating enterprise") } - return param, nil + return paramEnt, nil } func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { @@ -124,11 +132,22 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e } func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error { - enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID) + enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials") if err != nil { return errors.Wrap(err, "fetching enterprise") } + defer func(ent Enterprise) { + if err == nil { + asParams, innerErr := s.sqlToCommonEnterprise(ent, true) + if innerErr == nil { + s.sendNotify(common.EnterpriseEntityType, common.DeleteOperation, asParams) + } else { + slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "enterprise", enterpriseID) + } + } + }(enterprise) + q := s.conn.Unscoped().Delete(&enterprise) if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { return errors.Wrap(q.Error, "deleting enterprise") @@ -137,10 +156,15 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) return nil } -func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) { +func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (newParams params.Enterprise, err error) { + defer func() { + if err == nil { + s.sendNotify(common.EnterpriseEntityType, common.UpdateOperation, newParams) + } + }() var enterprise Enterprise var creds GithubCredentials - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { var err error enterprise, err = s.getEnterpriseByID(ctx, tx, enterpriseID) if err != nil { @@ -196,7 +220,7 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } - newParams, err := s.sqlToCommonEnterprise(enterprise, true) + newParams, err = s.sqlToCommonEnterprise(enterprise, true) if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 1192c843..0f3d58a3 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -17,6 +17,7 @@ package sql import ( "context" "fmt" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -24,10 +25,11 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm-provider-common/util" + "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) { +func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (org params.Organization, err error) { if webhookSecret == "" { return params.Organization{}, errors.New("creating org: missing secret") } @@ -35,6 +37,12 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN if err != nil { return params.Organization{}, errors.Wrap(err, "encoding secret") } + + defer func() { + if err == nil { + s.sendNotify(common.OrganizationEntityType, common.CreateOperation, org) + } + }() newOrg := Organization{ Name: name, WebhookSecret: secret, @@ -68,13 +76,13 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN return params.Organization{}, errors.Wrap(err, "creating org") } - param, err := s.sqlToCommonOrganization(newOrg, true) + org, err = s.sqlToCommonOrganization(newOrg, true) if err != nil { return params.Organization{}, errors.Wrap(err, "creating org") } - param.WebhookSecret = webhookSecret + org.WebhookSecret = webhookSecret - return param, nil + return org, nil } func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.Organization, error) { @@ -114,12 +122,23 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio return ret, nil } -func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) error { - org, err := s.getOrgByID(ctx, s.conn, orgID) +func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) { + org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials") if err != nil { return errors.Wrap(err, "fetching org") } + defer func(org Organization) { + if err == nil { + asParam, innerErr := s.sqlToCommonOrganization(org, true) + if innerErr == nil { + s.sendNotify(common.OrganizationEntityType, common.DeleteOperation, asParam) + } else { + slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "org", orgID) + } + } + }(org) + q := s.conn.Unscoped().Delete(&org) if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { return errors.Wrap(q.Error, "deleting org") @@ -128,10 +147,15 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) erro return nil } -func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) { +func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (paramOrg params.Organization, err error) { + defer func() { + if err == nil { + s.sendNotify(common.OrganizationEntityType, common.UpdateOperation, paramOrg) + } + }() var org Organization var creds GithubCredentials - err := s.conn.Transaction(func(tx *gorm.DB) error { + err = s.conn.Transaction(func(tx *gorm.DB) error { var err error org, err = s.getOrgByID(ctx, tx, orgID) if err != nil { @@ -188,11 +212,11 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para if err != nil { return params.Organization{}, errors.Wrap(err, "updating enterprise") } - newParams, err := s.sqlToCommonOrganization(org, true) + paramOrg, err = s.sqlToCommonOrganization(org, true) if err != nil { return params.Organization{}, errors.Wrap(err, "saving org") } - return newParams, nil + return paramOrg, nil } func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) { diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 7ab1c522..5469950f 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -17,6 +17,7 @@ package sql import ( "context" "fmt" + "log/slog" "github.com/google/uuid" "github.com/pkg/errors" @@ -121,7 +122,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, } func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) { - repo, err := s.getRepoByID(ctx, s.conn, repoID) + repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials") if err != nil { return errors.Wrap(err, "fetching repo") } @@ -131,6 +132,8 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err asParam, innerErr := s.sqlToCommonRepository(repo, true) if innerErr == nil { s.sendNotify(common.RepositoryEntityType, common.DeleteOperation, asParam) + } else { + slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "repo", repoID) } } }(repo) diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index f7a2e4c3..895dab9d 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -8,6 +8,7 @@ import ( "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" + garmTesting "github.com/cloudbase/garm/internal/testing" "github.com/cloudbase/garm/params" ) @@ -18,16 +19,292 @@ type WatcherStoreTestSuite struct { ctx context.Context } +func (s *WatcherStoreTestSuite) TestEnterpriseWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "enterprise-test", + watcher.WithEntityTypeFilter(common.EnterpriseEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + ent, err := s.store.CreateEnterprise(s.ctx, "test-enterprise", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(ent.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.EnterpriseEntityType, + Operation: common.CreateOperation, + Payload: ent, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdateEntityParams{ + WebhookSecret: "updated", + } + + updatedEnt, err := s.store.UpdateEnterprise(s.ctx, ent.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal("updated", updatedEnt.WebhookSecret) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.EnterpriseEntityType, + Operation: common.UpdateOperation, + Payload: updatedEnt, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteEnterprise(s.ctx, ent.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.EnterpriseEntityType, + Operation: common.DeleteOperation, + Payload: updatedEnt, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestOrgWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "org-test", + watcher.WithEntityTypeFilter(common.OrganizationEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + org, err := s.store.CreateOrganization(s.ctx, "test-org", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(org.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.OrganizationEntityType, + Operation: common.CreateOperation, + Payload: org, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + updateParams := params.UpdateEntityParams{ + WebhookSecret: "updated", + } + + updatedOrg, err := s.store.UpdateOrganization(s.ctx, org.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal("updated", updatedOrg.WebhookSecret) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.OrganizationEntityType, + Operation: common.UpdateOperation, + Payload: updatedOrg, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteOrganization(s.ctx, org.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.OrganizationEntityType, + Operation: common.DeleteOperation, + Payload: updatedOrg, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestRepoWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "repo-test", + watcher.WithEntityTypeFilter(common.RepositoryEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T()) + creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep) + s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) }) + + repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin) + s.Require().NoError(err) + s.Require().NotEmpty(repo.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.RepositoryEntityType, + Operation: common.CreateOperation, + Payload: repo, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + newSecret := "updated" + updateParams := params.UpdateEntityParams{ + WebhookSecret: newSecret, + } + + updatedRepo, err := s.store.UpdateRepository(s.ctx, repo.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal(newSecret, updatedRepo.WebhookSecret) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.RepositoryEntityType, + Operation: common.UpdateOperation, + Payload: updatedRepo, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteRepository(s.ctx, repo.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.RepositoryEntityType, + Operation: common.DeleteOperation, + Payload: updatedRepo, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + +func (s *WatcherStoreTestSuite) TestGithubCredentialsWatcher() { + consumer, err := watcher.RegisterConsumer( + s.ctx, "gh-cred-test", + watcher.WithEntityTypeFilter(common.GithubCredentialsEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + + ghCredParams := params.CreateGithubCredentialsParams{ + Name: "test-creds", + Description: "test credentials", + Endpoint: "github.com", + AuthType: params.GithubAuthTypePAT, + PAT: params.GithubPAT{ + OAuth2Token: "bogus", + }, + } + + ghCred, err := s.store.CreateGithubCredentials(s.ctx, ghCredParams) + s.Require().NoError(err) + s.Require().NotEmpty(ghCred.ID) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubCredentialsEntityType, + Operation: common.CreateOperation, + Payload: ghCred, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + newDesc := "updated description" + updateParams := params.UpdateGithubCredentialsParams{ + Description: &newDesc, + } + + updatedGhCred, err := s.store.UpdateGithubCredentials(s.ctx, ghCred.ID, updateParams) + s.Require().NoError(err) + s.Require().Equal(newDesc, updatedGhCred.Description) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubCredentialsEntityType, + Operation: common.UpdateOperation, + Payload: updatedGhCred, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + err = s.store.DeleteGithubCredentials(s.ctx, ghCred.ID) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubCredentialsEntityType, + Operation: common.DeleteOperation, + // We only get the ID and Name of the deleted entity + Payload: params.GithubCredentials{ID: ghCred.ID, Name: ghCred.Name}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } +} + func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { consumer, err := watcher.RegisterConsumer( s.ctx, "gh-ep-test", watcher.WithEntityTypeFilter(common.GithubEndpointEntityType), watcher.WithAny( watcher.WithOperationTypeFilter(common.CreateOperation), - watcher.WithOperationTypeFilter(common.UpdateOperation)), + watcher.WithOperationTypeFilter(common.UpdateOperation), + watcher.WithOperationTypeFilter(common.DeleteOperation)), ) s.Require().NoError(err) s.Require().NotNil(consumer) + s.T().Cleanup(func() { consumer.Close() }) + ghEpParams := params.CreateGithubEndpointParams{ Name: "test", Description: "test endpoint", @@ -70,4 +347,19 @@ func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { case <-time.After(1 * time.Second): s.T().Fatal("expected payload not received") } + + err = s.store.DeleteGithubEndpoint(s.ctx, ghEp.Name) + s.Require().NoError(err) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubEndpointEntityType, + Operation: common.DeleteOperation, + // We only get the name of the deleted entity + Payload: params.GithubEndpoint{Name: ghEp.Name}, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } } diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go index b44c152e..21d15093 100644 --- a/database/watcher/watcher_test.go +++ b/database/watcher/watcher_test.go @@ -163,17 +163,17 @@ func TestWatcherTestSuite(t *testing.T) { } suite.Run(t, watcherSuite) - // These tests run store changes and make sure that the store properly - // triggers watcher notifications. - ctx := context.TODO() + ctx := context.Background() watcher.InitWatcher(ctx) store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) if err != nil { t.Fatalf("failed to create db connection: %s", err) } + + adminCtx := garmTesting.ImpersonateAdminContext(ctx, store, t) watcherStoreSuite := &WatcherStoreTestSuite{ - ctx: context.TODO(), + ctx: adminCtx, store: store, } suite.Run(t, watcherStoreSuite)