Add some tests
Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
parent
6051629810
commit
b51683f1ae
10 changed files with 409 additions and 19 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
type (
|
type (
|
||||||
DatabaseEntityType string
|
DatabaseEntityType string
|
||||||
OperationType string
|
OperationType string
|
||||||
|
|
@ -45,6 +47,7 @@ type Producer interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Watcher interface {
|
type Watcher interface {
|
||||||
RegisterProducer(ID string) (Producer, error)
|
RegisterProducer(ctx context.Context, ID string) (Producer, error)
|
||||||
RegisterConsumer(ID string, filters ...PayloadFilterFunc) (Consumer, error)
|
RegisterConsumer(ctx context.Context, ID string, filters ...PayloadFilterFunc) (Consumer, error)
|
||||||
|
Close()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "creating DB connection")
|
return nil, errors.Wrap(err, "creating DB connection")
|
||||||
}
|
}
|
||||||
producer, err := watcher.RegisterProducer("sql")
|
producer, err := watcher.RegisterProducer(ctx, "sql")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "registering producer")
|
return nil, errors.Wrap(err, "registering producer")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package watcher
|
package watcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -16,6 +17,7 @@ type consumer struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) {
|
func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) {
|
||||||
|
|
@ -54,10 +56,10 @@ func (w *consumer) Send(payload common.ChangePayload) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(w.filters) > 0 {
|
if len(w.filters) > 0 {
|
||||||
shouldSend := false
|
shouldSend := true
|
||||||
for _, filter := range w.filters {
|
for _, filter := range w.filters {
|
||||||
if filter(payload) {
|
if !filter(payload) {
|
||||||
shouldSend = true
|
shouldSend = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -67,9 +69,14 @@ func (w *consumer) Send(payload common.ChangePayload) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Sending payload to consumer", "consumer", w.id)
|
slog.DebugContext(w.ctx, "sending payload")
|
||||||
select {
|
select {
|
||||||
case w.messages <- payload:
|
case <-w.quit:
|
||||||
|
slog.DebugContext(w.ctx, "consumer is closed")
|
||||||
|
case <-w.ctx.Done():
|
||||||
|
slog.DebugContext(w.ctx, "consumer is closed")
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
|
slog.DebugContext(w.ctx, "timeout trying to send payload", "payload", payload)
|
||||||
|
case w.messages <- payload:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
141
database/watcher/filters.go
Normal file
141
database/watcher/filters.go
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
package watcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
dbCommon "github.com/cloudbase/garm/database/common"
|
||||||
|
"github.com/cloudbase/garm/params"
|
||||||
|
)
|
||||||
|
|
||||||
|
type idGetter interface {
|
||||||
|
GetID() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAny returns a filter function that returns true if any of the provided filters return true.
|
||||||
|
// This filter is useful if for example you want to watch for update operations on any of the supplied
|
||||||
|
// entities.
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// // Watch for any update operation on repositories or organizations
|
||||||
|
// consumer.SetFilters(
|
||||||
|
// watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||||
|
// watcher.WithAny(
|
||||||
|
// watcher.WithEntityTypeFilter(common.RepositoryEntityType),
|
||||||
|
// watcher.WithEntityTypeFilter(common.OrganizationEntityType),
|
||||||
|
// ))
|
||||||
|
func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
|
||||||
|
return func(payload dbCommon.ChangePayload) bool {
|
||||||
|
for _, filter := range filters {
|
||||||
|
if filter(payload) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEntityTypeFilter returns a filter function that filters payloads by entity type.
|
||||||
|
// The filter function returns true if the payload's entity type matches the provided entity type.
|
||||||
|
func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc {
|
||||||
|
return func(payload dbCommon.ChangePayload) bool {
|
||||||
|
return payload.EntityType == entityType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOperationTypeFilter returns a filter function that filters payloads by operation type.
|
||||||
|
func WithOperationTypeFilter(operationType dbCommon.OperationType) dbCommon.PayloadFilterFunc {
|
||||||
|
return func(payload dbCommon.ChangePayload) bool {
|
||||||
|
return payload.Operation == operationType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEntityPoolFilter returns true if the change payload is a pool that belongs to the
|
||||||
|
// supplied Github entity. This is useful when an entity worker wants to watch for changes
|
||||||
|
// in pools that belong to it.
|
||||||
|
func WithEntityPoolFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc {
|
||||||
|
return func(payload dbCommon.ChangePayload) bool {
|
||||||
|
switch payload.EntityType {
|
||||||
|
case dbCommon.PoolEntityType:
|
||||||
|
pool, ok := payload.Payload.(params.Pool)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch ghEntity.EntityType {
|
||||||
|
case params.GithubEntityTypeRepository:
|
||||||
|
if pool.RepoID != ghEntity.ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case params.GithubEntityTypeOrganization:
|
||||||
|
if pool.OrgID != ghEntity.ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case params.GithubEntityTypeEnterprise:
|
||||||
|
if pool.EnterpriseID != ghEntity.ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEntityFilter returns a filter function that filters payloads by entity.
|
||||||
|
// Change payloads that match the entity type and ID will return true.
|
||||||
|
func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc {
|
||||||
|
return func(payload dbCommon.ChangePayload) bool {
|
||||||
|
if params.GithubEntityType(payload.EntityType) != entity.EntityType {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var ent idGetter
|
||||||
|
var ok bool
|
||||||
|
switch payload.EntityType {
|
||||||
|
case dbCommon.RepositoryEntityType:
|
||||||
|
ent, ok = payload.Payload.(params.Repository)
|
||||||
|
case dbCommon.OrganizationEntityType:
|
||||||
|
ent, ok = payload.Payload.(params.Organization)
|
||||||
|
case dbCommon.EnterpriseEntityType:
|
||||||
|
ent, ok = payload.Payload.(params.Enterprise)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ent.GetID() == entity.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc {
|
||||||
|
return func(payload dbCommon.ChangePayload) bool {
|
||||||
|
switch payload.EntityType {
|
||||||
|
case dbCommon.JobEntityType:
|
||||||
|
job, ok := payload.Payload.(params.Job)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ghEntity.EntityType {
|
||||||
|
case params.GithubEntityTypeRepository:
|
||||||
|
if job.RepoID != nil && job.RepoID.String() != ghEntity.ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case params.GithubEntityTypeOrganization:
|
||||||
|
if job.OrgID != nil && job.OrgID.String() != ghEntity.ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case params.GithubEntityTypeEnterprise:
|
||||||
|
if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
package watcher
|
package watcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudbase/garm/database/common"
|
"github.com/cloudbase/garm/database/common"
|
||||||
)
|
)
|
||||||
|
|
@ -13,6 +15,7 @@ type producer struct {
|
||||||
|
|
||||||
messages chan common.ChangePayload
|
messages chan common.ChangePayload
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *producer) Notify(payload common.ChangePayload) error {
|
func (w *producer) Notify(payload common.ChangePayload) error {
|
||||||
|
|
@ -24,9 +27,13 @@ func (w *producer) Notify(payload common.ChangePayload) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case w.messages <- payload:
|
case <-w.quit:
|
||||||
default:
|
return common.ErrProducerClosed
|
||||||
|
case <-w.ctx.Done():
|
||||||
|
return common.ErrProducerClosed
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
return common.ErrProducerTimeoutErr
|
return common.ErrProducerTimeoutErr
|
||||||
|
case w.messages <- payload:
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,3 +10,8 @@ import "github.com/cloudbase/garm/database/common"
|
||||||
func SetWatcher(w common.Watcher) {
|
func SetWatcher(w common.Watcher) {
|
||||||
databaseWatcher = w
|
databaseWatcher = w
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWatcher returns the current watcher.
|
||||||
|
func GetWatcher() common.Watcher {
|
||||||
|
return databaseWatcher
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,11 @@ package watcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/cloudbase/garm/database/common"
|
"github.com/cloudbase/garm/database/common"
|
||||||
|
garmUtil "github.com/cloudbase/garm/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var databaseWatcher common.Watcher
|
var databaseWatcher common.Watcher
|
||||||
|
|
@ -13,6 +15,7 @@ func InitWatcher(ctx context.Context) {
|
||||||
if databaseWatcher != nil {
|
if databaseWatcher != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
ctx = garmUtil.WithContext(ctx, slog.Any("watcher", "database"))
|
||||||
w := &watcher{
|
w := &watcher{
|
||||||
producers: make(map[string]*producer),
|
producers: make(map[string]*producer),
|
||||||
consumers: make(map[string]*consumer),
|
consumers: make(map[string]*consumer),
|
||||||
|
|
@ -24,18 +27,20 @@ func InitWatcher(ctx context.Context) {
|
||||||
databaseWatcher = w
|
databaseWatcher = w
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterProducer(id string) (common.Producer, error) {
|
func RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
|
||||||
if databaseWatcher == nil {
|
if databaseWatcher == nil {
|
||||||
return nil, common.ErrWatcherNotInitialized
|
return nil, common.ErrWatcherNotInitialized
|
||||||
}
|
}
|
||||||
return databaseWatcher.RegisterProducer(id)
|
ctx = garmUtil.WithContext(ctx, slog.Any("producer_id", id))
|
||||||
|
return databaseWatcher.RegisterProducer(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
func RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||||
if databaseWatcher == nil {
|
if databaseWatcher == nil {
|
||||||
return nil, common.ErrWatcherNotInitialized
|
return nil, common.ErrWatcherNotInitialized
|
||||||
}
|
}
|
||||||
return databaseWatcher.RegisterConsumer(id, filters...)
|
ctx = garmUtil.WithContext(ctx, slog.Any("consumer_id", id))
|
||||||
|
return databaseWatcher.RegisterConsumer(ctx, id, filters...)
|
||||||
}
|
}
|
||||||
|
|
||||||
type watcher struct {
|
type watcher struct {
|
||||||
|
|
@ -48,7 +53,10 @@ type watcher struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *watcher) RegisterProducer(id string) (common.Producer, error) {
|
func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
|
||||||
|
w.mux.Lock()
|
||||||
|
defer w.mux.Unlock()
|
||||||
|
|
||||||
if _, ok := w.producers[id]; ok {
|
if _, ok := w.producers[id]; ok {
|
||||||
return nil, common.ErrProducerAlreadyRegistered
|
return nil, common.ErrProducerAlreadyRegistered
|
||||||
}
|
}
|
||||||
|
|
@ -56,6 +64,7 @@ func (w *watcher) RegisterProducer(id string) (common.Producer, error) {
|
||||||
id: id,
|
id: id,
|
||||||
messages: make(chan common.ChangePayload, 1),
|
messages: make(chan common.ChangePayload, 1),
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
w.producers[id] = p
|
w.producers[id] = p
|
||||||
go w.serviceProducer(p)
|
go w.serviceProducer(p)
|
||||||
|
|
@ -67,13 +76,16 @@ func (w *watcher) serviceProducer(prod *producer) {
|
||||||
w.mux.Lock()
|
w.mux.Lock()
|
||||||
defer w.mux.Unlock()
|
defer w.mux.Unlock()
|
||||||
prod.Close()
|
prod.Close()
|
||||||
|
slog.InfoContext(w.ctx, "removing producer from watcher", "consumer_id", prod.id)
|
||||||
delete(w.producers, prod.id)
|
delete(w.producers, prod.id)
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-w.quit:
|
case <-w.quit:
|
||||||
|
slog.InfoContext(w.ctx, "shutting down watcher")
|
||||||
return
|
return
|
||||||
case <-w.ctx.Done():
|
case <-w.ctx.Done():
|
||||||
|
slog.InfoContext(w.ctx, "shutting down watcher")
|
||||||
return
|
return
|
||||||
case payload := <-prod.messages:
|
case payload := <-prod.messages:
|
||||||
for _, c := range w.consumers {
|
for _, c := range w.consumers {
|
||||||
|
|
@ -83,7 +95,7 @@ func (w *watcher) serviceProducer(prod *producer) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||||
if _, ok := w.consumers[id]; ok {
|
if _, ok := w.consumers[id]; ok {
|
||||||
return nil, common.ErrConsumerAlreadyRegistered
|
return nil, common.ErrConsumerAlreadyRegistered
|
||||||
}
|
}
|
||||||
|
|
@ -92,6 +104,7 @@ func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFun
|
||||||
filters: filters,
|
filters: filters,
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
id: id,
|
id: id,
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
w.consumers[id] = c
|
w.consumers[id] = c
|
||||||
go w.serviceConsumer(c)
|
go w.serviceConsumer(c)
|
||||||
|
|
@ -103,6 +116,7 @@ func (w *watcher) serviceConsumer(consumer *consumer) {
|
||||||
w.mux.Lock()
|
w.mux.Lock()
|
||||||
defer w.mux.Unlock()
|
defer w.mux.Unlock()
|
||||||
consumer.Close()
|
consumer.Close()
|
||||||
|
slog.InfoContext(w.ctx, "removing consumer from watcher", "consumer_id", consumer.id)
|
||||||
delete(w.consumers, consumer.id)
|
delete(w.consumers, consumer.id)
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
|
|
@ -134,6 +148,8 @@ func (w *watcher) Close() {
|
||||||
for _, c := range w.consumers {
|
for _, c := range w.consumers {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
databaseWatcher = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *watcher) loop() {
|
func (w *watcher) loop() {
|
||||||
|
|
|
||||||
45
database/watcher/watcher_store_test.go
Normal file
45
database/watcher/watcher_store_test.go
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
package watcher_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudbase/garm/database"
|
||||||
|
"github.com/cloudbase/garm/database/common"
|
||||||
|
"github.com/cloudbase/garm/database/watcher"
|
||||||
|
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WatcherStoreTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
|
||||||
|
store common.Store
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() {
|
||||||
|
// ghEpParams := params.CreateGithubEndpointParams{
|
||||||
|
// Name: "test",
|
||||||
|
// Description: "test endpoint",
|
||||||
|
// APIBaseURL: "https://api.ghes.example.com",
|
||||||
|
// UploadBaseURL: "https://upload.ghes.example.com",
|
||||||
|
// BaseURL: "https://ghes.example.com",
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcherStoreTestSuite(t *testing.T) {
|
||||||
|
ctx := context.TODO()
|
||||||
|
watcher.InitWatcher(ctx)
|
||||||
|
|
||||||
|
store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create db connection: %s", err)
|
||||||
|
}
|
||||||
|
watcherSuite := &WatcherStoreTestSuite{
|
||||||
|
ctx: context.TODO(),
|
||||||
|
store: store,
|
||||||
|
}
|
||||||
|
suite.Run(t, watcherSuite)
|
||||||
|
}
|
||||||
159
database/watcher/watcher_test.go
Normal file
159
database/watcher/watcher_test.go
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
//go:build testing
|
||||||
|
|
||||||
|
package watcher_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudbase/garm/database"
|
||||||
|
"github.com/cloudbase/garm/database/common"
|
||||||
|
"github.com/cloudbase/garm/database/watcher"
|
||||||
|
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WatcherTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
store common.Store
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) SetupTest() {
|
||||||
|
ctx := context.TODO()
|
||||||
|
watcher.InitWatcher(ctx)
|
||||||
|
|
||||||
|
store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T()))
|
||||||
|
if err != nil {
|
||||||
|
s.T().Fatalf("failed to create db connection: %s", err)
|
||||||
|
}
|
||||||
|
s.store = store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TearDownTest() {
|
||||||
|
s.store = nil
|
||||||
|
currentWatcher := watcher.GetWatcher()
|
||||||
|
if currentWatcher != nil {
|
||||||
|
currentWatcher.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TestRegisterConsumer() {
|
||||||
|
consumer, err := watcher.RegisterConsumer(s.ctx, "test")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(consumer)
|
||||||
|
|
||||||
|
consumer, err = watcher.RegisterConsumer(s.ctx, "test")
|
||||||
|
s.Require().Error(err)
|
||||||
|
s.Require().Nil(consumer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TestRegisterProducer() {
|
||||||
|
producer, err := watcher.RegisterProducer(s.ctx, "test")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(producer)
|
||||||
|
|
||||||
|
producer, err = watcher.RegisterProducer(s.ctx, "test")
|
||||||
|
s.Require().Error(err)
|
||||||
|
s.Require().Nil(producer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TestInitWatcherRanTwiceDoesNotReplaceWatcher() {
|
||||||
|
ctx := context.TODO()
|
||||||
|
currentWatcher := watcher.GetWatcher()
|
||||||
|
s.Require().NotNil(currentWatcher)
|
||||||
|
watcher.InitWatcher(ctx)
|
||||||
|
newWatcher := watcher.GetWatcher()
|
||||||
|
s.Require().Equal(currentWatcher, newWatcher)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TestRegisterConsumerFailsIfWatcherIsNotInitialized() {
|
||||||
|
s.store = nil
|
||||||
|
currentWatcher := watcher.GetWatcher()
|
||||||
|
currentWatcher.Close()
|
||||||
|
|
||||||
|
consumer, err := watcher.RegisterConsumer(s.ctx, "test")
|
||||||
|
s.Require().Nil(consumer)
|
||||||
|
s.Require().ErrorIs(err, common.ErrWatcherNotInitialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TestRegisterProducerFailsIfWatcherIsNotInitialized() {
|
||||||
|
s.store = nil
|
||||||
|
currentWatcher := watcher.GetWatcher()
|
||||||
|
currentWatcher.Close()
|
||||||
|
|
||||||
|
producer, err := watcher.RegisterProducer(s.ctx, "test")
|
||||||
|
s.Require().Nil(producer)
|
||||||
|
s.Require().ErrorIs(err, common.ErrWatcherNotInitialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TestProducerAndConsumer() {
|
||||||
|
producer, err := watcher.RegisterProducer(s.ctx, "test-producer")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(producer)
|
||||||
|
|
||||||
|
consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(consumer)
|
||||||
|
|
||||||
|
payload := common.ChangePayload{
|
||||||
|
EntityType: common.ControllerEntityType,
|
||||||
|
Operation: common.UpdateOperation,
|
||||||
|
Payload: "test",
|
||||||
|
}
|
||||||
|
err = producer.Notify(payload)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
receivedPayload := <-consumer.Watch()
|
||||||
|
s.Require().Equal(payload, receivedPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatcherTestSuite) TestConsumetWithFilter() {
|
||||||
|
producer, err := watcher.RegisterProducer(s.ctx, "test-producer")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(producer)
|
||||||
|
|
||||||
|
consumer, err := watcher.RegisterConsumer(s.ctx, "test-consumer", func(payload common.ChangePayload) bool {
|
||||||
|
return payload.Operation == common.UpdateOperation
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(consumer)
|
||||||
|
|
||||||
|
payload := common.ChangePayload{
|
||||||
|
EntityType: common.ControllerEntityType,
|
||||||
|
Operation: common.UpdateOperation,
|
||||||
|
Payload: "test",
|
||||||
|
}
|
||||||
|
err = producer.Notify(payload)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case receivedPayload := <-consumer.Watch():
|
||||||
|
s.Require().Equal(payload, receivedPayload)
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
s.T().Fatal("expected payload not received")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = common.ChangePayload{
|
||||||
|
EntityType: common.ControllerEntityType,
|
||||||
|
Operation: common.CreateOperation,
|
||||||
|
Payload: "test",
|
||||||
|
}
|
||||||
|
err = producer.Notify(payload)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-consumer.Watch():
|
||||||
|
s.T().Fatal("unexpected payload received")
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcherTestSuite(t *testing.T) {
|
||||||
|
watcherSuite := &WatcherTestSuite{
|
||||||
|
ctx: context.TODO(),
|
||||||
|
}
|
||||||
|
suite.Run(t, watcherSuite)
|
||||||
|
}
|
||||||
|
|
@ -3,18 +3,25 @@
|
||||||
|
|
||||||
package testing
|
package testing
|
||||||
|
|
||||||
import "github.com/cloudbase/garm/database/common"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cloudbase/garm/database/common"
|
||||||
|
)
|
||||||
|
|
||||||
type MockWatcher struct{}
|
type MockWatcher struct{}
|
||||||
|
|
||||||
func (w *MockWatcher) RegisterProducer(_ string) (common.Producer, error) {
|
func (w *MockWatcher) RegisterProducer(_ context.Context, _ string) (common.Producer, error) {
|
||||||
return &MockProducer{}, nil
|
return &MockProducer{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *MockWatcher) RegisterConsumer(_ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) {
|
func (w *MockWatcher) RegisterConsumer(_ context.Context, _ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||||
return &MockConsumer{}, nil
|
return &MockConsumer{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *MockWatcher) Close() {
|
||||||
|
}
|
||||||
|
|
||||||
type MockProducer struct{}
|
type MockProducer struct{}
|
||||||
|
|
||||||
func (p *MockProducer) Notify(_ common.ChangePayload) error {
|
func (p *MockProducer) Notify(_ common.ChangePayload) error {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue