diff --git a/database/watcher/watcher.go b/database/watcher/watcher.go index 86ba594e..ef5a5525 100644 --- a/database/watcher/watcher.go +++ b/database/watcher/watcher.go @@ -5,6 +5,8 @@ import ( "log/slog" "sync" + "github.com/pkg/errors" + "github.com/cloudbase/garm/database/common" garmUtil "github.com/cloudbase/garm/util" ) @@ -58,7 +60,7 @@ func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Produ defer w.mux.Unlock() if _, ok := w.producers[id]; ok { - return nil, common.ErrProducerAlreadyRegistered + return nil, errors.Wrapf(common.ErrProducerAlreadyRegistered, "producer_id: %s", id) } p := &producer{ id: id, @@ -87,15 +89,25 @@ func (w *watcher) serviceProducer(prod *producer) { case <-w.ctx.Done(): slog.InfoContext(w.ctx, "shutting down watcher") return + case <-prod.quit: + slog.InfoContext(w.ctx, "closing producer") + return + case <-prod.ctx.Done(): + slog.InfoContext(w.ctx, "closing producer") + return case payload := <-prod.messages: + w.mux.Lock() for _, c := range w.consumers { go c.Send(payload) } + w.mux.Unlock() } } } func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { + w.mux.Lock() + defer w.mux.Unlock() if _, ok := w.consumers[id]; ok { return nil, common.ErrConsumerAlreadyRegistered } @@ -123,6 +135,8 @@ func (w *watcher) serviceConsumer(consumer *consumer) { select { case <-consumer.quit: return + case <-consumer.ctx.Done(): + return case <-w.quit: return case <-w.ctx.Done(): diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index b5353c03..f7a2e4c3 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -2,13 +2,13 @@ package watcher_test import ( "context" - "testing" + "time" + + "github.com/stretchr/testify/suite" - "github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" - garmTesting "github.com/cloudbase/garm/internal/testing" - "github.com/stretchr/testify/suite" + "github.com/cloudbase/garm/params" ) type WatcherStoreTestSuite struct { @@ -19,27 +19,55 @@ type WatcherStoreTestSuite struct { } func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { - // ghEpParams := params.CreateGithubEndpointParams{ - // Name: "test", - // Description: "test endpoint", - // APIBaseURL: "https://api.ghes.example.com", - // UploadBaseURL: "https://upload.ghes.example.com", - // BaseURL: "https://ghes.example.com", - // } - -} - -func TestWatcherStoreTestSuite(t *testing.T) { - ctx := context.TODO() - watcher.InitWatcher(ctx) - - store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) - if err != nil { - t.Fatalf("failed to create db connection: %s", err) + consumer, err := watcher.RegisterConsumer( + s.ctx, "gh-ep-test", + watcher.WithEntityTypeFilter(common.GithubEndpointEntityType), + watcher.WithAny( + watcher.WithOperationTypeFilter(common.CreateOperation), + watcher.WithOperationTypeFilter(common.UpdateOperation)), + ) + s.Require().NoError(err) + s.Require().NotNil(consumer) + ghEpParams := params.CreateGithubEndpointParams{ + Name: "test", + Description: "test endpoint", + APIBaseURL: "https://api.ghes.example.com", + UploadBaseURL: "https://upload.ghes.example.com", + BaseURL: "https://ghes.example.com", } - watcherSuite := &WatcherStoreTestSuite{ - ctx: context.TODO(), - store: store, + + ghEp, err := s.store.CreateGithubEndpoint(s.ctx, ghEpParams) + s.Require().NoError(err) + s.Require().NotEmpty(ghEp.Name) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubEndpointEntityType, + Operation: common.CreateOperation, + Payload: ghEp, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + newDesc := "updated description" + updateParams := params.UpdateGithubEndpointParams{ + Description: &newDesc, + } + + updatedGhEp, err := s.store.UpdateGithubEndpoint(s.ctx, ghEp.Name, updateParams) + s.Require().NoError(err) + s.Require().Equal(newDesc, updatedGhEp.Description) + + select { + case event := <-consumer.Watch(): + s.Require().Equal(common.ChangePayload{ + EntityType: common.GithubEndpointEntityType, + Operation: common.UpdateOperation, + Payload: updatedGhEp, + }, event) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") } - suite.Run(t, watcherSuite) } diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go index 838cdeb0..b44c152e 100644 --- a/database/watcher/watcher_test.go +++ b/database/watcher/watcher_test.go @@ -4,14 +4,16 @@ package watcher_test import ( "context" + "fmt" "testing" "time" + "github.com/stretchr/testify/suite" + "github.com/cloudbase/garm/database" "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/database/watcher" garmTesting "github.com/cloudbase/garm/internal/testing" - "github.com/stretchr/testify/suite" ) type WatcherTestSuite struct { @@ -23,7 +25,7 @@ type WatcherTestSuite struct { func (s *WatcherTestSuite) SetupTest() { ctx := context.TODO() watcher.InitWatcher(ctx) - + fmt.Printf("creating store: %v\n", s.store) store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T())) if err != nil { s.T().Fatalf("failed to create db connection: %s", err) @@ -39,23 +41,23 @@ func (s *WatcherTestSuite) TearDownTest() { } } -func (s *WatcherTestSuite) TestRegisterConsumer() { +func (s *WatcherTestSuite) TestRegisterConsumerTwiceWillError() { consumer, err := watcher.RegisterConsumer(s.ctx, "test") s.Require().NoError(err) s.Require().NotNil(consumer) consumer, err = watcher.RegisterConsumer(s.ctx, "test") - s.Require().Error(err) + s.Require().ErrorIs(err, common.ErrConsumerAlreadyRegistered) s.Require().Nil(consumer) } -func (s *WatcherTestSuite) TestRegisterProducer() { +func (s *WatcherTestSuite) TestRegisterProducerTwiceWillError() { producer, err := watcher.RegisterProducer(s.ctx, "test") s.Require().NoError(err) s.Require().NotNil(producer) producer, err = watcher.RegisterProducer(s.ctx, "test") - s.Require().Error(err) + s.Require().ErrorIs(err, common.ErrProducerAlreadyRegistered) s.Require().Nil(producer) } @@ -93,7 +95,10 @@ func (s *WatcherTestSuite) TestProducerAndConsumer() { s.Require().NoError(err) s.Require().NotNil(producer) - consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer") + consumer, err := watcher.RegisterConsumer( + s.ctx, "test-consumer", + watcher.WithEntityTypeFilter(common.ControllerEntityType), + watcher.WithOperationTypeFilter(common.UpdateOperation)) s.Require().NoError(err) s.Require().NotNil(consumer) @@ -114,9 +119,10 @@ func (s *WatcherTestSuite) TestConsumetWithFilter() { s.Require().NoError(err) s.Require().NotNil(producer) - consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer", func(payload common.ChangePayload) bool { - return payload.Operation == common.UpdateOperation - }) + consumer, err := watcher.RegisterConsumer( + s.ctx, "test-consumer", + watcher.WithEntityTypeFilter(common.ControllerEntityType), + watcher.WithOperationTypeFilter(common.UpdateOperation)) s.Require().NoError(err) s.Require().NotNil(consumer) @@ -148,12 +154,27 @@ func (s *WatcherTestSuite) TestConsumetWithFilter() { s.T().Fatal("unexpected payload received") case <-time.After(1 * time.Second): } - } func TestWatcherTestSuite(t *testing.T) { + // Watcher tests watcherSuite := &WatcherTestSuite{ ctx: context.TODO(), } suite.Run(t, watcherSuite) + + // These tests run store changes and make sure that the store properly + // triggers watcher notifications. + ctx := context.TODO() + watcher.InitWatcher(ctx) + + store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) + if err != nil { + t.Fatalf("failed to create db connection: %s", err) + } + watcherStoreSuite := &WatcherStoreTestSuite{ + ctx: context.TODO(), + store: store, + } + suite.Run(t, watcherStoreSuite) }