Add some tests

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2024-06-17 19:42:50 +00:00
parent 6051629810
commit b51683f1ae
10 changed files with 409 additions and 19 deletions

View file

@ -1,5 +1,7 @@
package common
import "context"
type (
DatabaseEntityType string
OperationType string
@ -45,6 +47,7 @@ type Producer interface {
}
type Watcher interface {
RegisterProducer(ID string) (Producer, error)
RegisterConsumer(ID string, filters ...PayloadFilterFunc) (Consumer, error)
RegisterProducer(ctx context.Context, ID string) (Producer, error)
RegisterConsumer(ctx context.Context, ID string, filters ...PayloadFilterFunc) (Consumer, error)
Close()
}

View file

@ -69,7 +69,7 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err
if err != nil {
return nil, errors.Wrap(err, "creating DB connection")
}
producer, err := watcher.RegisterProducer("sql")
producer, err := watcher.RegisterProducer(ctx, "sql")
if err != nil {
return nil, errors.Wrap(err, "registering producer")
}

View file

@ -1,6 +1,7 @@
package watcher
import (
"context"
"log/slog"
"sync"
"time"
@ -16,6 +17,7 @@ type consumer struct {
mux sync.Mutex
closed bool
quit chan struct{}
ctx context.Context
}
func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) {
@ -54,10 +56,10 @@ func (w *consumer) Send(payload common.ChangePayload) {
}
if len(w.filters) > 0 {
shouldSend := false
shouldSend := true
for _, filter := range w.filters {
if filter(payload) {
shouldSend = true
if !filter(payload) {
shouldSend = false
break
}
}
@ -67,9 +69,14 @@ func (w *consumer) Send(payload common.ChangePayload) {
}
}
slog.Info("Sending payload to consumer", "consumer", w.id)
slog.DebugContext(w.ctx, "sending payload")
select {
case w.messages <- payload:
case <-w.quit:
slog.DebugContext(w.ctx, "consumer is closed")
case <-w.ctx.Done():
slog.DebugContext(w.ctx, "consumer is closed")
case <-time.After(1 * time.Second):
slog.DebugContext(w.ctx, "timeout trying to send payload", "payload", payload)
case w.messages <- payload:
}
}

141
database/watcher/filters.go Normal file
View file

@ -0,0 +1,141 @@
package watcher
import (
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/params"
)
type idGetter interface {
GetID() string
}
// WithAny returns a filter function that returns true if any of the provided filters return true.
// This filter is useful if for example you want to watch for update operations on any of the supplied
// entities.
// Example:
//
// // Watch for any update operation on repositories or organizations
// consumer.SetFilters(
// watcher.WithOperationTypeFilter(common.UpdateOperation),
// watcher.WithAny(
// watcher.WithEntityTypeFilter(common.RepositoryEntityType),
// watcher.WithEntityTypeFilter(common.OrganizationEntityType),
// ))
func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
for _, filter := range filters {
if filter(payload) {
return true
}
}
return false
}
}
// WithEntityTypeFilter returns a filter function that filters payloads by entity type.
// The filter function returns true if the payload's entity type matches the provided entity type.
func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
return payload.EntityType == entityType
}
}
// WithOperationTypeFilter returns a filter function that filters payloads by operation type.
func WithOperationTypeFilter(operationType dbCommon.OperationType) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
return payload.Operation == operationType
}
}
// WithEntityPoolFilter returns true if the change payload is a pool that belongs to the
// supplied Github entity. This is useful when an entity worker wants to watch for changes
// in pools that belong to it.
func WithEntityPoolFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
switch payload.EntityType {
case dbCommon.PoolEntityType:
pool, ok := payload.Payload.(params.Pool)
if !ok {
return false
}
switch ghEntity.EntityType {
case params.GithubEntityTypeRepository:
if pool.RepoID != ghEntity.ID {
return false
}
case params.GithubEntityTypeOrganization:
if pool.OrgID != ghEntity.ID {
return false
}
case params.GithubEntityTypeEnterprise:
if pool.EnterpriseID != ghEntity.ID {
return false
}
default:
return false
}
return true
default:
return false
}
}
}
// WithEntityFilter returns a filter function that filters payloads by entity.
// Change payloads that match the entity type and ID will return true.
func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
if params.GithubEntityType(payload.EntityType) != entity.EntityType {
return false
}
var ent idGetter
var ok bool
switch payload.EntityType {
case dbCommon.RepositoryEntityType:
ent, ok = payload.Payload.(params.Repository)
case dbCommon.OrganizationEntityType:
ent, ok = payload.Payload.(params.Organization)
case dbCommon.EnterpriseEntityType:
ent, ok = payload.Payload.(params.Enterprise)
default:
return false
}
if !ok {
return false
}
return ent.GetID() == entity.ID
}
}
func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
switch payload.EntityType {
case dbCommon.JobEntityType:
job, ok := payload.Payload.(params.Job)
if !ok {
return false
}
switch ghEntity.EntityType {
case params.GithubEntityTypeRepository:
if job.RepoID != nil && job.RepoID.String() != ghEntity.ID {
return false
}
case params.GithubEntityTypeOrganization:
if job.OrgID != nil && job.OrgID.String() != ghEntity.ID {
return false
}
case params.GithubEntityTypeEnterprise:
if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID {
return false
}
default:
return false
}
return true
default:
return false
}
}
}

View file

@ -1,7 +1,9 @@
package watcher
import (
"context"
"sync"
"time"
"github.com/cloudbase/garm/database/common"
)
@ -13,6 +15,7 @@ type producer struct {
messages chan common.ChangePayload
quit chan struct{}
ctx context.Context
}
func (w *producer) Notify(payload common.ChangePayload) error {
@ -24,9 +27,13 @@ func (w *producer) Notify(payload common.ChangePayload) error {
}
select {
case w.messages <- payload:
default:
case <-w.quit:
return common.ErrProducerClosed
case <-w.ctx.Done():
return common.ErrProducerClosed
case <-time.After(1 * time.Second):
return common.ErrProducerTimeoutErr
case w.messages <- payload:
}
return nil
}

View file

@ -10,3 +10,8 @@ import "github.com/cloudbase/garm/database/common"
func SetWatcher(w common.Watcher) {
databaseWatcher = w
}
// GetWatcher returns the current watcher.
func GetWatcher() common.Watcher {
return databaseWatcher
}

View file

@ -2,9 +2,11 @@ package watcher
import (
"context"
"log/slog"
"sync"
"github.com/cloudbase/garm/database/common"
garmUtil "github.com/cloudbase/garm/util"
)
var databaseWatcher common.Watcher
@ -13,6 +15,7 @@ func InitWatcher(ctx context.Context) {
if databaseWatcher != nil {
return
}
ctx = garmUtil.WithContext(ctx, slog.Any("watcher", "database"))
w := &watcher{
producers: make(map[string]*producer),
consumers: make(map[string]*consumer),
@ -24,18 +27,20 @@ func InitWatcher(ctx context.Context) {
databaseWatcher = w
}
func RegisterProducer(id string) (common.Producer, error) {
func RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
if databaseWatcher == nil {
return nil, common.ErrWatcherNotInitialized
}
return databaseWatcher.RegisterProducer(id)
ctx = garmUtil.WithContext(ctx, slog.Any("producer_id", id))
return databaseWatcher.RegisterProducer(ctx, id)
}
func RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
func RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
if databaseWatcher == nil {
return nil, common.ErrWatcherNotInitialized
}
return databaseWatcher.RegisterConsumer(id, filters...)
ctx = garmUtil.WithContext(ctx, slog.Any("consumer_id", id))
return databaseWatcher.RegisterConsumer(ctx, id, filters...)
}
type watcher struct {
@ -48,7 +53,10 @@ type watcher struct {
ctx context.Context
}
func (w *watcher) RegisterProducer(id string) (common.Producer, error) {
func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
w.mux.Lock()
defer w.mux.Unlock()
if _, ok := w.producers[id]; ok {
return nil, common.ErrProducerAlreadyRegistered
}
@ -56,6 +64,7 @@ func (w *watcher) RegisterProducer(id string) (common.Producer, error) {
id: id,
messages: make(chan common.ChangePayload, 1),
quit: make(chan struct{}),
ctx: ctx,
}
w.producers[id] = p
go w.serviceProducer(p)
@ -67,13 +76,16 @@ func (w *watcher) serviceProducer(prod *producer) {
w.mux.Lock()
defer w.mux.Unlock()
prod.Close()
slog.InfoContext(w.ctx, "removing producer from watcher", "consumer_id", prod.id)
delete(w.producers, prod.id)
}()
for {
select {
case <-w.quit:
slog.InfoContext(w.ctx, "shutting down watcher")
return
case <-w.ctx.Done():
slog.InfoContext(w.ctx, "shutting down watcher")
return
case payload := <-prod.messages:
for _, c := range w.consumers {
@ -83,7 +95,7 @@ func (w *watcher) serviceProducer(prod *producer) {
}
}
func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
if _, ok := w.consumers[id]; ok {
return nil, common.ErrConsumerAlreadyRegistered
}
@ -92,6 +104,7 @@ func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFun
filters: filters,
quit: make(chan struct{}),
id: id,
ctx: ctx,
}
w.consumers[id] = c
go w.serviceConsumer(c)
@ -103,6 +116,7 @@ func (w *watcher) serviceConsumer(consumer *consumer) {
w.mux.Lock()
defer w.mux.Unlock()
consumer.Close()
slog.InfoContext(w.ctx, "removing consumer from watcher", "consumer_id", consumer.id)
delete(w.consumers, consumer.id)
}()
for {
@ -134,6 +148,8 @@ func (w *watcher) Close() {
for _, c := range w.consumers {
c.Close()
}
databaseWatcher = nil
}
func (w *watcher) loop() {

View file

@ -0,0 +1,45 @@
package watcher_test
import (
"context"
"testing"
"github.com/cloudbase/garm/database"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
garmTesting "github.com/cloudbase/garm/internal/testing"
"github.com/stretchr/testify/suite"
)
type WatcherStoreTestSuite struct {
suite.Suite
store common.Store
ctx context.Context
}
func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() {
// ghEpParams := params.CreateGithubEndpointParams{
// Name: "test",
// Description: "test endpoint",
// APIBaseURL: "https://api.ghes.example.com",
// UploadBaseURL: "https://upload.ghes.example.com",
// BaseURL: "https://ghes.example.com",
// }
}
func TestWatcherStoreTestSuite(t *testing.T) {
ctx := context.TODO()
watcher.InitWatcher(ctx)
store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t))
if err != nil {
t.Fatalf("failed to create db connection: %s", err)
}
watcherSuite := &WatcherStoreTestSuite{
ctx: context.TODO(),
store: store,
}
suite.Run(t, watcherSuite)
}

View file

@ -0,0 +1,159 @@
//go:build testing
package watcher_test
import (
"context"
"testing"
"time"
"github.com/cloudbase/garm/database"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
garmTesting "github.com/cloudbase/garm/internal/testing"
"github.com/stretchr/testify/suite"
)
type WatcherTestSuite struct {
suite.Suite
store common.Store
ctx context.Context
}
func (s *WatcherTestSuite) SetupTest() {
ctx := context.TODO()
watcher.InitWatcher(ctx)
store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T()))
if err != nil {
s.T().Fatalf("failed to create db connection: %s", err)
}
s.store = store
}
func (s *WatcherTestSuite) TearDownTest() {
s.store = nil
currentWatcher := watcher.GetWatcher()
if currentWatcher != nil {
currentWatcher.Close()
}
}
func (s *WatcherTestSuite) TestRegisterConsumer() {
consumer, err := watcher.RegisterConsumer(s.ctx, "test")
s.Require().NoError(err)
s.Require().NotNil(consumer)
consumer, err = watcher.RegisterConsumer(s.ctx, "test")
s.Require().Error(err)
s.Require().Nil(consumer)
}
func (s *WatcherTestSuite) TestRegisterProducer() {
producer, err := watcher.RegisterProducer(s.ctx, "test")
s.Require().NoError(err)
s.Require().NotNil(producer)
producer, err = watcher.RegisterProducer(s.ctx, "test")
s.Require().Error(err)
s.Require().Nil(producer)
}
func (s *WatcherTestSuite) TestInitWatcherRanTwiceDoesNotReplaceWatcher() {
ctx := context.TODO()
currentWatcher := watcher.GetWatcher()
s.Require().NotNil(currentWatcher)
watcher.InitWatcher(ctx)
newWatcher := watcher.GetWatcher()
s.Require().Equal(currentWatcher, newWatcher)
}
func (s *WatcherTestSuite) TestRegisterConsumerFailsIfWatcherIsNotInitialized() {
s.store = nil
currentWatcher := watcher.GetWatcher()
currentWatcher.Close()
consumer, err := watcher.RegisterConsumer(s.ctx, "test")
s.Require().Nil(consumer)
s.Require().ErrorIs(err, common.ErrWatcherNotInitialized)
}
func (s *WatcherTestSuite) TestRegisterProducerFailsIfWatcherIsNotInitialized() {
s.store = nil
currentWatcher := watcher.GetWatcher()
currentWatcher.Close()
producer, err := watcher.RegisterProducer(s.ctx, "test")
s.Require().Nil(producer)
s.Require().ErrorIs(err, common.ErrWatcherNotInitialized)
}
func (s *WatcherTestSuite) TestProducerAndConsumer() {
producer, err := watcher.RegisterProducer(s.ctx, "test-producer")
s.Require().NoError(err)
s.Require().NotNil(producer)
consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer")
s.Require().NoError(err)
s.Require().NotNil(consumer)
payload := common.ChangePayload{
EntityType: common.ControllerEntityType,
Operation: common.UpdateOperation,
Payload: "test",
}
err = producer.Notify(payload)
s.Require().NoError(err)
receivedPayload := <-consumer.Watch()
s.Require().Equal(payload, receivedPayload)
}
func (s *WatcherTestSuite) TestConsumetWithFilter() {
producer, err := watcher.RegisterProducer(s.ctx, "test-producer")
s.Require().NoError(err)
s.Require().NotNil(producer)
consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer", func(payload common.ChangePayload) bool {
return payload.Operation == common.UpdateOperation
})
s.Require().NoError(err)
s.Require().NotNil(consumer)
payload := common.ChangePayload{
EntityType: common.ControllerEntityType,
Operation: common.UpdateOperation,
Payload: "test",
}
err = producer.Notify(payload)
s.Require().NoError(err)
select {
case receivedPayload := <-consumer.Watch():
s.Require().Equal(payload, receivedPayload)
case <-time.After(1 * time.Second):
s.T().Fatal("expected payload not received")
}
payload = common.ChangePayload{
EntityType: common.ControllerEntityType,
Operation: common.CreateOperation,
Payload: "test",
}
err = producer.Notify(payload)
s.Require().NoError(err)
select {
case <-consumer.Watch():
s.T().Fatal("unexpected payload received")
case <-time.After(1 * time.Second):
}
}
func TestWatcherTestSuite(t *testing.T) {
watcherSuite := &WatcherTestSuite{
ctx: context.TODO(),
}
suite.Run(t, watcherSuite)
}

View file

@ -3,18 +3,25 @@
package testing
import "github.com/cloudbase/garm/database/common"
import (
"context"
"github.com/cloudbase/garm/database/common"
)
type MockWatcher struct{}
func (w *MockWatcher) RegisterProducer(_ string) (common.Producer, error) {
func (w *MockWatcher) RegisterProducer(_ context.Context, _ string) (common.Producer, error) {
return &MockProducer{}, nil
}
func (w *MockWatcher) RegisterConsumer(_ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) {
func (w *MockWatcher) RegisterConsumer(_ context.Context, _ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) {
return &MockConsumer{}, nil
}
func (w *MockWatcher) Close() {
}
type MockProducer struct{}
func (p *MockProducer) Notify(_ common.ChangePayload) error {