diff --git a/database/common/watcher.go b/database/common/watcher.go index 69bf9788..73af32bd 100644 --- a/database/common/watcher.go +++ b/database/common/watcher.go @@ -1,5 +1,7 @@ package common +import "context" + type ( DatabaseEntityType string OperationType string @@ -45,6 +47,7 @@ type Producer interface { } type Watcher interface { - RegisterProducer(ID string) (Producer, error) - RegisterConsumer(ID string, filters ...PayloadFilterFunc) (Consumer, error) + RegisterProducer(ctx context.Context, ID string) (Producer, error) + RegisterConsumer(ctx context.Context, ID string, filters ...PayloadFilterFunc) (Consumer, error) + Close() } diff --git a/database/sql/sql.go b/database/sql/sql.go index 6ee8a2d9..937ef676 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -69,7 +69,7 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err if err != nil { return nil, errors.Wrap(err, "creating DB connection") } - producer, err := watcher.RegisterProducer("sql") + producer, err := watcher.RegisterProducer(ctx, "sql") if err != nil { return nil, errors.Wrap(err, "registering producer") } diff --git a/database/watcher/consumer.go b/database/watcher/consumer.go index 369344ba..fb36c694 100644 --- a/database/watcher/consumer.go +++ b/database/watcher/consumer.go @@ -1,6 +1,7 @@ package watcher import ( + "context" "log/slog" "sync" "time" @@ -16,6 +17,7 @@ type consumer struct { mux sync.Mutex closed bool quit chan struct{} + ctx context.Context } func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) { @@ -54,10 +56,10 @@ func (w *consumer) Send(payload common.ChangePayload) { } if len(w.filters) > 0 { - shouldSend := false + shouldSend := true for _, filter := range w.filters { - if filter(payload) { - shouldSend = true + if !filter(payload) { + shouldSend = false break } } @@ -67,9 +69,14 @@ func (w *consumer) Send(payload common.ChangePayload) { } } - slog.Info("Sending payload to consumer", "consumer", w.id) + slog.DebugContext(w.ctx, "sending payload") select { - case w.messages <- payload: + case <-w.quit: + slog.DebugContext(w.ctx, "consumer is closed") + case <-w.ctx.Done(): + slog.DebugContext(w.ctx, "consumer is closed") case <-time.After(1 * time.Second): + slog.DebugContext(w.ctx, "timeout trying to send payload", "payload", payload) + case w.messages <- payload: } } diff --git a/database/watcher/filters.go b/database/watcher/filters.go new file mode 100644 index 00000000..9b175d7a --- /dev/null +++ b/database/watcher/filters.go @@ -0,0 +1,141 @@ +package watcher + +import ( + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" +) + +type idGetter interface { + GetID() string +} + +// WithAny returns a filter function that returns true if any of the provided filters return true. +// This filter is useful if for example you want to watch for update operations on any of the supplied +// entities. +// Example: +// +// // Watch for any update operation on repositories or organizations +// consumer.SetFilters( +// watcher.WithOperationTypeFilter(common.UpdateOperation), +// watcher.WithAny( +// watcher.WithEntityTypeFilter(common.RepositoryEntityType), +// watcher.WithEntityTypeFilter(common.OrganizationEntityType), +// )) +func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + for _, filter := range filters { + if filter(payload) { + return true + } + } + return false + } +} + +// WithEntityTypeFilter returns a filter function that filters payloads by entity type. +// The filter function returns true if the payload's entity type matches the provided entity type. +func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + return payload.EntityType == entityType + } +} + +// WithOperationTypeFilter returns a filter function that filters payloads by operation type. +func WithOperationTypeFilter(operationType dbCommon.OperationType) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + return payload.Operation == operationType + } +} + +// WithEntityPoolFilter returns true if the change payload is a pool that belongs to the +// supplied Github entity. This is useful when an entity worker wants to watch for changes +// in pools that belong to it. +func WithEntityPoolFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + switch payload.EntityType { + case dbCommon.PoolEntityType: + pool, ok := payload.Payload.(params.Pool) + if !ok { + return false + } + switch ghEntity.EntityType { + case params.GithubEntityTypeRepository: + if pool.RepoID != ghEntity.ID { + return false + } + case params.GithubEntityTypeOrganization: + if pool.OrgID != ghEntity.ID { + return false + } + case params.GithubEntityTypeEnterprise: + if pool.EnterpriseID != ghEntity.ID { + return false + } + default: + return false + } + return true + default: + return false + } + } +} + +// WithEntityFilter returns a filter function that filters payloads by entity. +// Change payloads that match the entity type and ID will return true. +func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + if params.GithubEntityType(payload.EntityType) != entity.EntityType { + return false + } + var ent idGetter + var ok bool + switch payload.EntityType { + case dbCommon.RepositoryEntityType: + ent, ok = payload.Payload.(params.Repository) + case dbCommon.OrganizationEntityType: + ent, ok = payload.Payload.(params.Organization) + case dbCommon.EnterpriseEntityType: + ent, ok = payload.Payload.(params.Enterprise) + default: + return false + } + if !ok { + return false + } + return ent.GetID() == entity.ID + } +} + +func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + switch payload.EntityType { + case dbCommon.JobEntityType: + job, ok := payload.Payload.(params.Job) + if !ok { + return false + } + + switch ghEntity.EntityType { + case params.GithubEntityTypeRepository: + if job.RepoID != nil && job.RepoID.String() != ghEntity.ID { + return false + } + case params.GithubEntityTypeOrganization: + if job.OrgID != nil && job.OrgID.String() != ghEntity.ID { + return false + } + case params.GithubEntityTypeEnterprise: + if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID { + return false + } + default: + return false + } + + return true + default: + return false + } + } +} diff --git a/database/watcher/producer.go b/database/watcher/producer.go index 70578004..fd61aa16 100644 --- a/database/watcher/producer.go +++ b/database/watcher/producer.go @@ -1,7 +1,9 @@ package watcher import ( + "context" "sync" + "time" "github.com/cloudbase/garm/database/common" ) @@ -13,6 +15,7 @@ type producer struct { messages chan common.ChangePayload quit chan struct{} + ctx context.Context } func (w *producer) Notify(payload common.ChangePayload) error { @@ -24,9 +27,13 @@ func (w *producer) Notify(payload common.ChangePayload) error { } select { - case w.messages <- payload: - default: + case <-w.quit: + return common.ErrProducerClosed + case <-w.ctx.Done(): + return common.ErrProducerClosed + case <-time.After(1 * time.Second): return common.ErrProducerTimeoutErr + case w.messages <- payload: } return nil } diff --git a/database/watcher/test_export.go b/database/watcher/test_export.go index 4c75233e..f9b4ecf1 100644 --- a/database/watcher/test_export.go +++ b/database/watcher/test_export.go @@ -10,3 +10,8 @@ import "github.com/cloudbase/garm/database/common" func SetWatcher(w common.Watcher) { databaseWatcher = w } + +// GetWatcher returns the current watcher. +func GetWatcher() common.Watcher { + return databaseWatcher +} diff --git a/database/watcher/watcher.go b/database/watcher/watcher.go index 23400e21..86ba594e 100644 --- a/database/watcher/watcher.go +++ b/database/watcher/watcher.go @@ -2,9 +2,11 @@ package watcher import ( "context" + "log/slog" "sync" "github.com/cloudbase/garm/database/common" + garmUtil "github.com/cloudbase/garm/util" ) var databaseWatcher common.Watcher @@ -13,6 +15,7 @@ func InitWatcher(ctx context.Context) { if databaseWatcher != nil { return } + ctx = garmUtil.WithContext(ctx, slog.Any("watcher", "database")) w := &watcher{ producers: make(map[string]*producer), consumers: make(map[string]*consumer), @@ -24,18 +27,20 @@ func InitWatcher(ctx context.Context) { databaseWatcher = w } -func RegisterProducer(id string) (common.Producer, error) { +func RegisterProducer(ctx context.Context, id string) (common.Producer, error) { if databaseWatcher == nil { return nil, common.ErrWatcherNotInitialized } - return databaseWatcher.RegisterProducer(id) + ctx = garmUtil.WithContext(ctx, slog.Any("producer_id", id)) + return databaseWatcher.RegisterProducer(ctx, id) } -func RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { +func RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { if databaseWatcher == nil { return nil, common.ErrWatcherNotInitialized } - return databaseWatcher.RegisterConsumer(id, filters...) + ctx = garmUtil.WithContext(ctx, slog.Any("consumer_id", id)) + return databaseWatcher.RegisterConsumer(ctx, id, filters...) } type watcher struct { @@ -48,7 +53,10 @@ type watcher struct { ctx context.Context } -func (w *watcher) RegisterProducer(id string) (common.Producer, error) { +func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Producer, error) { + w.mux.Lock() + defer w.mux.Unlock() + if _, ok := w.producers[id]; ok { return nil, common.ErrProducerAlreadyRegistered } @@ -56,6 +64,7 @@ func (w *watcher) RegisterProducer(id string) (common.Producer, error) { id: id, messages: make(chan common.ChangePayload, 1), quit: make(chan struct{}), + ctx: ctx, } w.producers[id] = p go w.serviceProducer(p) @@ -67,13 +76,16 @@ func (w *watcher) serviceProducer(prod *producer) { w.mux.Lock() defer w.mux.Unlock() prod.Close() + slog.InfoContext(w.ctx, "removing producer from watcher", "consumer_id", prod.id) delete(w.producers, prod.id) }() for { select { case <-w.quit: + slog.InfoContext(w.ctx, "shutting down watcher") return case <-w.ctx.Done(): + slog.InfoContext(w.ctx, "shutting down watcher") return case payload := <-prod.messages: for _, c := range w.consumers { @@ -83,7 +95,7 @@ func (w *watcher) serviceProducer(prod *producer) { } } -func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { +func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) { if _, ok := w.consumers[id]; ok { return nil, common.ErrConsumerAlreadyRegistered } @@ -92,6 +104,7 @@ func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFun filters: filters, quit: make(chan struct{}), id: id, + ctx: ctx, } w.consumers[id] = c go w.serviceConsumer(c) @@ -103,6 +116,7 @@ func (w *watcher) serviceConsumer(consumer *consumer) { w.mux.Lock() defer w.mux.Unlock() consumer.Close() + slog.InfoContext(w.ctx, "removing consumer from watcher", "consumer_id", consumer.id) delete(w.consumers, consumer.id) }() for { @@ -134,6 +148,8 @@ func (w *watcher) Close() { for _, c := range w.consumers { c.Close() } + + databaseWatcher = nil } func (w *watcher) loop() { diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go new file mode 100644 index 00000000..b5353c03 --- /dev/null +++ b/database/watcher/watcher_store_test.go @@ -0,0 +1,45 @@ +package watcher_test + +import ( + "context" + "testing" + + "github.com/cloudbase/garm/database" + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" + garmTesting "github.com/cloudbase/garm/internal/testing" + "github.com/stretchr/testify/suite" +) + +type WatcherStoreTestSuite struct { + suite.Suite + + store common.Store + ctx context.Context +} + +func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() { + // ghEpParams := params.CreateGithubEndpointParams{ + // Name: "test", + // Description: "test endpoint", + // APIBaseURL: "https://api.ghes.example.com", + // UploadBaseURL: "https://upload.ghes.example.com", + // BaseURL: "https://ghes.example.com", + // } + +} + +func TestWatcherStoreTestSuite(t *testing.T) { + ctx := context.TODO() + watcher.InitWatcher(ctx) + + store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t)) + if err != nil { + t.Fatalf("failed to create db connection: %s", err) + } + watcherSuite := &WatcherStoreTestSuite{ + ctx: context.TODO(), + store: store, + } + suite.Run(t, watcherSuite) +} diff --git a/database/watcher/watcher_test.go b/database/watcher/watcher_test.go new file mode 100644 index 00000000..838cdeb0 --- /dev/null +++ b/database/watcher/watcher_test.go @@ -0,0 +1,159 @@ +//go:build testing + +package watcher_test + +import ( + "context" + "testing" + "time" + + "github.com/cloudbase/garm/database" + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" + garmTesting "github.com/cloudbase/garm/internal/testing" + "github.com/stretchr/testify/suite" +) + +type WatcherTestSuite struct { + suite.Suite + store common.Store + ctx context.Context +} + +func (s *WatcherTestSuite) SetupTest() { + ctx := context.TODO() + watcher.InitWatcher(ctx) + + store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T())) + if err != nil { + s.T().Fatalf("failed to create db connection: %s", err) + } + s.store = store +} + +func (s *WatcherTestSuite) TearDownTest() { + s.store = nil + currentWatcher := watcher.GetWatcher() + if currentWatcher != nil { + currentWatcher.Close() + } +} + +func (s *WatcherTestSuite) TestRegisterConsumer() { + consumer, err := watcher.RegisterConsumer(s.ctx, "test") + s.Require().NoError(err) + s.Require().NotNil(consumer) + + consumer, err = watcher.RegisterConsumer(s.ctx, "test") + s.Require().Error(err) + s.Require().Nil(consumer) +} + +func (s *WatcherTestSuite) TestRegisterProducer() { + producer, err := watcher.RegisterProducer(s.ctx, "test") + s.Require().NoError(err) + s.Require().NotNil(producer) + + producer, err = watcher.RegisterProducer(s.ctx, "test") + s.Require().Error(err) + s.Require().Nil(producer) +} + +func (s *WatcherTestSuite) TestInitWatcherRanTwiceDoesNotReplaceWatcher() { + ctx := context.TODO() + currentWatcher := watcher.GetWatcher() + s.Require().NotNil(currentWatcher) + watcher.InitWatcher(ctx) + newWatcher := watcher.GetWatcher() + s.Require().Equal(currentWatcher, newWatcher) +} + +func (s *WatcherTestSuite) TestRegisterConsumerFailsIfWatcherIsNotInitialized() { + s.store = nil + currentWatcher := watcher.GetWatcher() + currentWatcher.Close() + + consumer, err := watcher.RegisterConsumer(s.ctx, "test") + s.Require().Nil(consumer) + s.Require().ErrorIs(err, common.ErrWatcherNotInitialized) +} + +func (s *WatcherTestSuite) TestRegisterProducerFailsIfWatcherIsNotInitialized() { + s.store = nil + currentWatcher := watcher.GetWatcher() + currentWatcher.Close() + + producer, err := watcher.RegisterProducer(s.ctx, "test") + s.Require().Nil(producer) + s.Require().ErrorIs(err, common.ErrWatcherNotInitialized) +} + +func (s *WatcherTestSuite) TestProducerAndConsumer() { + producer, err := watcher.RegisterProducer(s.ctx, "test-producer") + s.Require().NoError(err) + s.Require().NotNil(producer) + + consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer") + s.Require().NoError(err) + s.Require().NotNil(consumer) + + payload := common.ChangePayload{ + EntityType: common.ControllerEntityType, + Operation: common.UpdateOperation, + Payload: "test", + } + err = producer.Notify(payload) + s.Require().NoError(err) + + receivedPayload := <-consumer.Watch() + s.Require().Equal(payload, receivedPayload) +} + +func (s *WatcherTestSuite) TestConsumetWithFilter() { + producer, err := watcher.RegisterProducer(s.ctx, "test-producer") + s.Require().NoError(err) + s.Require().NotNil(producer) + + consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer", func(payload common.ChangePayload) bool { + return payload.Operation == common.UpdateOperation + }) + s.Require().NoError(err) + s.Require().NotNil(consumer) + + payload := common.ChangePayload{ + EntityType: common.ControllerEntityType, + Operation: common.UpdateOperation, + Payload: "test", + } + err = producer.Notify(payload) + s.Require().NoError(err) + + select { + case receivedPayload := <-consumer.Watch(): + s.Require().Equal(payload, receivedPayload) + case <-time.After(1 * time.Second): + s.T().Fatal("expected payload not received") + } + + payload = common.ChangePayload{ + EntityType: common.ControllerEntityType, + Operation: common.CreateOperation, + Payload: "test", + } + err = producer.Notify(payload) + s.Require().NoError(err) + + select { + case <-consumer.Watch(): + s.T().Fatal("unexpected payload received") + case <-time.After(1 * time.Second): + } + +} + +func TestWatcherTestSuite(t *testing.T) { + watcherSuite := &WatcherTestSuite{ + ctx: context.TODO(), + } + suite.Run(t, watcherSuite) +} diff --git a/internal/testing/mock_watcher.go b/internal/testing/mock_watcher.go index 394091bd..67ae5da4 100644 --- a/internal/testing/mock_watcher.go +++ b/internal/testing/mock_watcher.go @@ -3,18 +3,25 @@ package testing -import "github.com/cloudbase/garm/database/common" +import ( + "context" + + "github.com/cloudbase/garm/database/common" +) type MockWatcher struct{} -func (w *MockWatcher) RegisterProducer(_ string) (common.Producer, error) { +func (w *MockWatcher) RegisterProducer(_ context.Context, _ string) (common.Producer, error) { return &MockProducer{}, nil } -func (w *MockWatcher) RegisterConsumer(_ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) { +func (w *MockWatcher) RegisterConsumer(_ context.Context, _ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) { return &MockConsumer{}, nil } +func (w *MockWatcher) Close() { +} + type MockProducer struct{} func (p *MockProducer) Notify(_ common.ChangePayload) error {