Add rudimentary database watcher
Adds a simple database watcher. At this point it's just one process, but the plan is to allow different implementations that inform the local running workers of changes that have occured on entities of interest in the database. Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
parent
214cb05072
commit
8d57fc8fa2
18 changed files with 514 additions and 41 deletions
|
|
@ -12,7 +12,7 @@ GARM supports creating pools on either GitHub itself or on your own deployment o
|
|||
|
||||
Through the use of providers, `GARM` can create runners in a variety of environments using the same `GARM` instance. Whether you want to create pools of runners in your OpenStack cloud, your Azure cloud and your Kubernetes cluster, that is easily achieved by just installing the appropriate providers, configuring them in `GARM` and creating pools that use them. You can create zero-runner pools for instances with high costs (large VMs, GPU enabled instances, etc) and have them spin up on demand, or you can create large pools of k8s backed runners that can be used for your CI/CD pipelines at a moment's notice. You can mix them up and create pools in any combination of providers or resource allocations you want.
|
||||
|
||||
:warning: **Important note**: The README and documentation in the `main` branch are relevant to the not yet released code that is present in `main`. Following the documentation from the `main` branch for a stable release of GARM, may lead to errors. To view the documentation for the latest stable release, please switch to the appropriate tag. For information about setting up `v0.1.4`, please refer to the [v0.1.4 tag](https://github.com/cloudbase/garm/tree/v0.1.4)
|
||||
:warning: **Important note**: The README and documentation in the `main` branch are relevant to the not yet released code that is present in `main`. Following the documentation from the `main` branch for a stable release of GARM, may lead to errors. To view the documentation for the latest stable release, please switch to the appropriate tag. For information about setting up `v0.1.4`, please refer to the [v0.1.4 tag](https://github.com/cloudbase/garm/tree/v0.1.4).
|
||||
|
||||
## Join us on slack
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ import (
|
|||
"github.com/cloudbase/garm/config"
|
||||
"github.com/cloudbase/garm/database"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/metrics"
|
||||
"github.com/cloudbase/garm/params"
|
||||
"github.com/cloudbase/garm/runner" //nolint:typecheck
|
||||
|
|
@ -183,6 +184,7 @@ func main() {
|
|||
}
|
||||
ctx, stop := signal.NotifyContext(context.Background(), signals...)
|
||||
defer stop()
|
||||
watcher.InitWatcher(ctx)
|
||||
|
||||
ctx = auth.GetAdminContext(ctx)
|
||||
|
||||
|
|
@ -313,6 +315,7 @@ func main() {
|
|||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer shutdownCancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
|
|
|
|||
12
database/common/errors.go
Normal file
12
database/common/errors.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package common
|
||||
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
ErrProducerClosed = fmt.Errorf("producer is closed")
|
||||
ErrProducerTimeoutErr = fmt.Errorf("producer timeout error")
|
||||
ErrProducerAlreadyRegistered = fmt.Errorf("producer already registered")
|
||||
ErrConsumerAlreadyRegistered = fmt.Errorf("consumer already registered")
|
||||
ErrWatcherAlreadyStarted = fmt.Errorf("watcher already started")
|
||||
ErrWatcherNotInitialized = fmt.Errorf("watcher not initialized")
|
||||
)
|
||||
|
|
@ -119,7 +119,7 @@ type JobsStore interface {
|
|||
DeleteCompletedJobs(ctx context.Context) error
|
||||
}
|
||||
|
||||
type EntityPools interface {
|
||||
type EntityPoolStore interface {
|
||||
CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error)
|
||||
GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error)
|
||||
DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error
|
||||
|
|
@ -144,8 +144,11 @@ type Store interface {
|
|||
UserStore
|
||||
InstanceStore
|
||||
JobsStore
|
||||
EntityPools
|
||||
GithubEndpointStore
|
||||
GithubCredentialsStore
|
||||
ControllerStore
|
||||
EntityPoolStore
|
||||
|
||||
ControllerInfo() (params.ControllerInfo, error)
|
||||
InitController() (params.ControllerInfo, error)
|
||||
}
|
||||
50
database/common/watcher.go
Normal file
50
database/common/watcher.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package common
|
||||
|
||||
type (
|
||||
DatabaseEntityType string
|
||||
OperationType string
|
||||
PayloadFilterFunc func(ChangePayload) bool
|
||||
)
|
||||
|
||||
const (
|
||||
RepositoryEntityType DatabaseEntityType = "repository"
|
||||
OrganizationEntityType DatabaseEntityType = "organization"
|
||||
EnterpriseEntityType DatabaseEntityType = "enterprise"
|
||||
PoolEntityType DatabaseEntityType = "pool"
|
||||
UserEntityType DatabaseEntityType = "user"
|
||||
InstanceEntityType DatabaseEntityType = "instance"
|
||||
JobEntityType DatabaseEntityType = "job"
|
||||
ControllerEntityType DatabaseEntityType = "controller"
|
||||
GithubCredentialsEntityType DatabaseEntityType = "github_credentials"
|
||||
GithubEndpointEntityType DatabaseEntityType = "github_endpoint"
|
||||
)
|
||||
|
||||
const (
|
||||
CreateOperation OperationType = "create"
|
||||
UpdateOperation OperationType = "update"
|
||||
DeleteOperation OperationType = "delete"
|
||||
)
|
||||
|
||||
type ChangePayload struct {
|
||||
EntityType DatabaseEntityType
|
||||
Operation OperationType
|
||||
Payload interface{}
|
||||
}
|
||||
|
||||
type Consumer interface {
|
||||
Watch() <-chan ChangePayload
|
||||
IsClosed() bool
|
||||
Close()
|
||||
SetFilters(filters ...PayloadFilterFunc)
|
||||
}
|
||||
|
||||
type Producer interface {
|
||||
Notify(ChangePayload) error
|
||||
IsClosed() bool
|
||||
Close()
|
||||
}
|
||||
|
||||
type Watcher interface {
|
||||
RegisterProducer(ID string) (Producer, error)
|
||||
RegisterConsumer(ID string, filters ...PayloadFilterFunc) (Consumer, error)
|
||||
}
|
||||
|
|
@ -25,29 +25,9 @@ import (
|
|||
"gorm.io/gorm/clause"
|
||||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
func (s *sqlDatabase) marshalAndSeal(data interface{}) ([]byte, error) {
|
||||
enc, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshalling data")
|
||||
}
|
||||
return util.Seal(enc, []byte(s.cfg.Passphrase))
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error {
|
||||
decrypted, err := util.Unseal(data, []byte(s.cfg.Passphrase))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decrypting data")
|
||||
}
|
||||
if err := json.Unmarshal(decrypted, target); err != nil {
|
||||
return errors.Wrap(err, "unmarshalling data")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) {
|
||||
pool, err := s.getPoolByID(s.conn, poolID)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"gorm.io/gorm"
|
||||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
|
|
@ -66,12 +67,18 @@ func (s *sqlDatabase) GetPoolByID(_ context.Context, poolID string) (params.Pool
|
|||
return s.sqlToCommonPool(pool)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) error {
|
||||
func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) (err error) {
|
||||
pool, err := s.getPoolByID(s.conn, poolID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching pool by ID")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.PoolEntityType, common.DeleteOperation, pool)
|
||||
}
|
||||
}()
|
||||
|
||||
if q := s.conn.Unscoped().Delete(&pool); q.Error != nil {
|
||||
return errors.Wrap(q.Error, "removing pool")
|
||||
}
|
||||
|
|
@ -247,11 +254,17 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par
|
|||
return pools, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) {
|
||||
func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (pool params.Pool, err error) {
|
||||
if len(param.Tags) == 0 {
|
||||
return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.PoolEntityType, common.CreateOperation, pool)
|
||||
}
|
||||
}()
|
||||
|
||||
newPool := Pool{
|
||||
ProviderName: param.ProviderName,
|
||||
MaxRunners: param.MaxRunners,
|
||||
|
|
@ -313,12 +326,12 @@ func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEn
|
|||
return params.Pool{}, err
|
||||
}
|
||||
|
||||
pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
|
||||
dbPool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
|
||||
if err != nil {
|
||||
return params.Pool{}, errors.Wrap(err, "fetching pool")
|
||||
}
|
||||
|
||||
return s.sqlToCommonPool(pool)
|
||||
return s.sqlToCommonPool(dbPool)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) {
|
||||
|
|
@ -329,12 +342,21 @@ func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntit
|
|||
return s.sqlToCommonPool(pool)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) error {
|
||||
func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (err error) {
|
||||
entityID, err := uuid.Parse(entity.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
pool := params.Pool{
|
||||
ID: poolID,
|
||||
}
|
||||
s.sendNotify(common.PoolEntityType, common.DeleteOperation, pool)
|
||||
}
|
||||
}()
|
||||
|
||||
poolUUID, err := uuid.Parse(poolID)
|
||||
if err != nil {
|
||||
return errors.Wrap(runnerErrors.ErrBadRequest, "parsing pool id")
|
||||
|
|
@ -374,6 +396,7 @@ func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEn
|
|||
if err != nil {
|
||||
return params.Pool{}, err
|
||||
}
|
||||
s.sendNotify(common.PoolEntityType, common.UpdateOperation, updatedPool)
|
||||
return updatedPool, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,10 +24,17 @@ import (
|
|||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) {
|
||||
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (param params.Repository, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.RepositoryEntityType, common.CreateOperation, param)
|
||||
}
|
||||
}()
|
||||
|
||||
if webhookSecret == "" {
|
||||
return params.Repository{}, errors.New("creating repo: missing secret")
|
||||
}
|
||||
|
|
@ -68,7 +75,7 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent
|
|||
return params.Repository{}, errors.Wrap(err, "creating repository")
|
||||
}
|
||||
|
||||
param, err := s.sqlToCommonRepository(newRepo, true)
|
||||
param, err = s.sqlToCommonRepository(newRepo, true)
|
||||
if err != nil {
|
||||
return params.Repository{}, errors.Wrap(err, "creating repository")
|
||||
}
|
||||
|
|
@ -113,12 +120,21 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository,
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error {
|
||||
func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) {
|
||||
repo, err := s.getRepoByID(ctx, s.conn, repoID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching repo")
|
||||
}
|
||||
|
||||
defer func(repo Repository) {
|
||||
if err == nil {
|
||||
asParam, innerErr := s.sqlToCommonRepository(repo, true)
|
||||
if innerErr == nil {
|
||||
s.sendNotify(common.RepositoryEntityType, common.DeleteOperation, asParam)
|
||||
}
|
||||
}
|
||||
}(repo)
|
||||
|
||||
q := s.conn.Unscoped().Delete(&repo)
|
||||
if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return errors.Wrap(q.Error, "deleting repo")
|
||||
|
|
@ -127,10 +143,15 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) {
|
||||
func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (newParams params.Repository, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.RepositoryEntityType, common.UpdateOperation, newParams)
|
||||
}
|
||||
}()
|
||||
var repo Repository
|
||||
var creds GithubCredentials
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
repo, err = s.getRepoByID(ctx, tx, repoID)
|
||||
if err != nil {
|
||||
|
|
@ -186,7 +207,8 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
|
|||
if err != nil {
|
||||
return params.Repository{}, errors.Wrap(err, "updating enterprise")
|
||||
}
|
||||
newParams, err := s.sqlToCommonRepository(repo, true)
|
||||
|
||||
newParams, err = s.sqlToCommonRepository(repo, true)
|
||||
if err != nil {
|
||||
return params.Repository{}, errors.Wrap(err, "saving repo")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
|
||||
"github.com/cloudbase/garm/auth"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
|
@ -827,5 +828,11 @@ func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() {
|
|||
|
||||
func TestRepoTestSuite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
watcher.SetWatcher(&garmTesting.MockWatcher{})
|
||||
suite.Run(t, new(RepoTestSuite))
|
||||
}
|
||||
|
||||
func init() {
|
||||
watcher.SetWatcher(&garmTesting.MockWatcher{})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/cloudbase/garm/auth"
|
||||
"github.com/cloudbase/garm/config"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/params"
|
||||
"github.com/cloudbase/garm/util/appdefaults"
|
||||
)
|
||||
|
|
@ -68,10 +69,15 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "creating DB connection")
|
||||
}
|
||||
producer, err := watcher.RegisterProducer("sql")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "registering producer")
|
||||
}
|
||||
db := &sqlDatabase{
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
producer: producer,
|
||||
}
|
||||
|
||||
if err := db.migrateDB(); err != nil {
|
||||
|
|
@ -81,9 +87,10 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err
|
|||
}
|
||||
|
||||
type sqlDatabase struct {
|
||||
conn *gorm.DB
|
||||
ctx context.Context
|
||||
cfg config.Database
|
||||
conn *gorm.DB
|
||||
ctx context.Context
|
||||
cfg config.Database
|
||||
producer common.Producer
|
||||
}
|
||||
|
||||
var renameTemplate = `
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
commonParams "github.com/cloudbase/garm-provider-common/params"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
|
|
@ -467,3 +468,31 @@ func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntit
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) marshalAndSeal(data interface{}) ([]byte, error) {
|
||||
enc, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshalling data")
|
||||
}
|
||||
return util.Seal(enc, []byte(s.cfg.Passphrase))
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error {
|
||||
decrypted, err := util.Unseal(data, []byte(s.cfg.Passphrase))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decrypting data")
|
||||
}
|
||||
if err := json.Unmarshal(decrypted, target); err != nil {
|
||||
return errors.Wrap(err, "unmarshalling data")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) sendNotify(entityType dbCommon.DatabaseEntityType, op dbCommon.OperationType, payload interface{}) {
|
||||
message := dbCommon.ChangePayload{
|
||||
Operation: op,
|
||||
Payload: payload,
|
||||
EntityType: entityType,
|
||||
}
|
||||
s.producer.Notify(message)
|
||||
}
|
||||
|
|
|
|||
75
database/watcher/consumer.go
Normal file
75
database/watcher/consumer.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package watcher
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
)
|
||||
|
||||
type consumer struct {
|
||||
messages chan common.ChangePayload
|
||||
filters []common.PayloadFilterFunc
|
||||
id string
|
||||
|
||||
mux sync.Mutex
|
||||
closed bool
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
w.filters = filters
|
||||
}
|
||||
|
||||
func (w *consumer) Watch() <-chan common.ChangePayload {
|
||||
return w.messages
|
||||
}
|
||||
|
||||
func (w *consumer) Close() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
close(w.messages)
|
||||
close(w.quit)
|
||||
w.closed = true
|
||||
}
|
||||
|
||||
func (w *consumer) IsClosed() bool {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
return w.closed
|
||||
}
|
||||
|
||||
func (w *consumer) Send(payload common.ChangePayload) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if len(w.filters) > 0 {
|
||||
shouldSend := false
|
||||
for _, filter := range w.filters {
|
||||
if filter(payload) {
|
||||
shouldSend = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !shouldSend {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("Sending payload to consumer", "consumer", w.id)
|
||||
select {
|
||||
case w.messages <- payload:
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
49
database/watcher/producer.go
Normal file
49
database/watcher/producer.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package watcher
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
)
|
||||
|
||||
type producer struct {
|
||||
closed bool
|
||||
mux sync.Mutex
|
||||
id string
|
||||
|
||||
messages chan common.ChangePayload
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
func (w *producer) Notify(payload common.ChangePayload) error {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return common.ErrProducerClosed
|
||||
}
|
||||
|
||||
select {
|
||||
case w.messages <- payload:
|
||||
default:
|
||||
return common.ErrProducerTimeoutErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *producer) Close() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
w.closed = true
|
||||
close(w.messages)
|
||||
close(w.quit)
|
||||
}
|
||||
|
||||
func (w *producer) IsClosed() bool {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
return w.closed
|
||||
}
|
||||
12
database/watcher/test_export.go
Normal file
12
database/watcher/test_export.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package watcher
|
||||
|
||||
import "github.com/cloudbase/garm/database/common"
|
||||
|
||||
// SetWatcher sets the watcher to be used by the database package.
|
||||
// This function is intended for use in tests only.
|
||||
func SetWatcher(w common.Watcher) {
|
||||
databaseWatcher = w
|
||||
}
|
||||
151
database/watcher/watcher.go
Normal file
151
database/watcher/watcher.go
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
)
|
||||
|
||||
var databaseWatcher common.Watcher
|
||||
|
||||
func InitWatcher(ctx context.Context) {
|
||||
if databaseWatcher != nil {
|
||||
return
|
||||
}
|
||||
w := &watcher{
|
||||
producers: make(map[string]*producer),
|
||||
consumers: make(map[string]*consumer),
|
||||
quit: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
go w.loop()
|
||||
databaseWatcher = w
|
||||
}
|
||||
|
||||
func RegisterProducer(id string) (common.Producer, error) {
|
||||
if databaseWatcher == nil {
|
||||
return nil, common.ErrWatcherNotInitialized
|
||||
}
|
||||
return databaseWatcher.RegisterProducer(id)
|
||||
}
|
||||
|
||||
func RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||
if databaseWatcher == nil {
|
||||
return nil, common.ErrWatcherNotInitialized
|
||||
}
|
||||
return databaseWatcher.RegisterConsumer(id, filters...)
|
||||
}
|
||||
|
||||
type watcher struct {
|
||||
producers map[string]*producer
|
||||
consumers map[string]*consumer
|
||||
|
||||
mux sync.Mutex
|
||||
closed bool
|
||||
quit chan struct{}
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *watcher) RegisterProducer(id string) (common.Producer, error) {
|
||||
if _, ok := w.producers[id]; ok {
|
||||
return nil, common.ErrProducerAlreadyRegistered
|
||||
}
|
||||
p := &producer{
|
||||
id: id,
|
||||
messages: make(chan common.ChangePayload, 1),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
w.producers[id] = p
|
||||
go w.serviceProducer(p)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (w *watcher) serviceProducer(prod *producer) {
|
||||
defer func() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
prod.Close()
|
||||
delete(w.producers, prod.id)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-w.quit:
|
||||
return
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case payload := <-prod.messages:
|
||||
for _, c := range w.consumers {
|
||||
go c.Send(payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *watcher) RegisterConsumer(id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||
if _, ok := w.consumers[id]; ok {
|
||||
return nil, common.ErrConsumerAlreadyRegistered
|
||||
}
|
||||
c := &consumer{
|
||||
messages: make(chan common.ChangePayload, 1),
|
||||
filters: filters,
|
||||
quit: make(chan struct{}),
|
||||
id: id,
|
||||
}
|
||||
w.consumers[id] = c
|
||||
go w.serviceConsumer(c)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (w *watcher) serviceConsumer(consumer *consumer) {
|
||||
defer func() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
consumer.Close()
|
||||
delete(w.consumers, consumer.id)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-consumer.quit:
|
||||
return
|
||||
case <-w.quit:
|
||||
return
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *watcher) Close() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
|
||||
close(w.quit)
|
||||
w.closed = true
|
||||
|
||||
for _, p := range w.producers {
|
||||
p.Close()
|
||||
}
|
||||
|
||||
for _, c := range w.consumers {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *watcher) loop() {
|
||||
defer func() {
|
||||
w.Close()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-w.quit:
|
||||
return
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
45
internal/testing/mock_watcher.go
Normal file
45
internal/testing/mock_watcher.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package testing
|
||||
|
||||
import "github.com/cloudbase/garm/database/common"
|
||||
|
||||
type MockWatcher struct{}
|
||||
|
||||
func (w *MockWatcher) RegisterProducer(_ string) (common.Producer, error) {
|
||||
return &MockProducer{}, nil
|
||||
}
|
||||
|
||||
func (w *MockWatcher) RegisterConsumer(_ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||
return &MockConsumer{}, nil
|
||||
}
|
||||
|
||||
type MockProducer struct{}
|
||||
|
||||
func (p *MockProducer) Notify(_ common.ChangePayload) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *MockProducer) IsClosed() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *MockProducer) Close() {
|
||||
}
|
||||
|
||||
type MockConsumer struct{}
|
||||
|
||||
func (c *MockConsumer) Watch() <-chan common.ChangePayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockConsumer) SetFilters(_ ...common.PayloadFilterFunc) {
|
||||
}
|
||||
|
||||
func (c *MockConsumer) Close() {
|
||||
}
|
||||
|
||||
func (c *MockConsumer) IsClosed() bool {
|
||||
return false
|
||||
}
|
||||
|
|
@ -69,7 +69,7 @@ type urls struct {
|
|||
}
|
||||
|
||||
func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
|
||||
ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", params.GithubEntityTypeRepository))
|
||||
ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", entity.EntityType))
|
||||
ghc, err := garmUtil.GithubClient(ctx, entity, cfgInternal.GithubCredentialsDetails)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting github client")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm/database"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||
"github.com/cloudbase/garm/params"
|
||||
"github.com/cloudbase/garm/runner/common"
|
||||
|
|
@ -51,6 +52,10 @@ type RepoTestFixtures struct {
|
|||
PoolMgrCtrlMock *runnerMocks.PoolManagerController
|
||||
}
|
||||
|
||||
func init() {
|
||||
watcher.SetWatcher(&garmTesting.MockWatcher{})
|
||||
}
|
||||
|
||||
type RepoTestSuite struct {
|
||||
suite.Suite
|
||||
Fixtures *RepoTestFixtures
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue