Merge pull request #263 from gabriel-samfira/add-database-watcher
Add database watcher
This commit is contained in:
commit
c188a6f2c1
27 changed files with 1906 additions and 116 deletions
|
|
@ -12,7 +12,7 @@ GARM supports creating pools on either GitHub itself or on your own deployment o
|
|||
|
||||
Through the use of providers, `GARM` can create runners in a variety of environments using the same `GARM` instance. Whether you want to create pools of runners in your OpenStack cloud, your Azure cloud and your Kubernetes cluster, that is easily achieved by just installing the appropriate providers, configuring them in `GARM` and creating pools that use them. You can create zero-runner pools for instances with high costs (large VMs, GPU enabled instances, etc) and have them spin up on demand, or you can create large pools of k8s backed runners that can be used for your CI/CD pipelines at a moment's notice. You can mix them up and create pools in any combination of providers or resource allocations you want.
|
||||
|
||||
:warning: **Important note**: The README and documentation in the `main` branch are relevant to the not yet released code that is present in `main`. Following the documentation from the `main` branch for a stable release of GARM, may lead to errors. To view the documentation for the latest stable release, please switch to the appropriate tag. For information about setting up `v0.1.4`, please refer to the [v0.1.4 tag](https://github.com/cloudbase/garm/tree/v0.1.4)
|
||||
:warning: **Important note**: The README and documentation in the `main` branch are relevant to the not yet released code that is present in `main`. Following the documentation from the `main` branch for a stable release of GARM, may lead to errors. To view the documentation for the latest stable release, please switch to the appropriate tag. For information about setting up `v0.1.4`, please refer to the [v0.1.4 tag](https://github.com/cloudbase/garm/tree/v0.1.4).
|
||||
|
||||
## Join us on slack
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ import (
|
|||
"github.com/cloudbase/garm/config"
|
||||
"github.com/cloudbase/garm/database"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/metrics"
|
||||
"github.com/cloudbase/garm/params"
|
||||
"github.com/cloudbase/garm/runner" //nolint:typecheck
|
||||
|
|
@ -183,6 +184,7 @@ func main() {
|
|||
}
|
||||
ctx, stop := signal.NotifyContext(context.Background(), signals...)
|
||||
defer stop()
|
||||
watcher.InitWatcher(ctx)
|
||||
|
||||
ctx = auth.GetAdminContext(ctx)
|
||||
|
||||
|
|
@ -313,6 +315,7 @@ func main() {
|
|||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer shutdownCancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
|
|
|
|||
12
database/common/errors.go
Normal file
12
database/common/errors.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package common
|
||||
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
ErrProducerClosed = fmt.Errorf("producer is closed")
|
||||
ErrProducerTimeoutErr = fmt.Errorf("producer timeout error")
|
||||
ErrProducerAlreadyRegistered = fmt.Errorf("producer already registered")
|
||||
ErrConsumerAlreadyRegistered = fmt.Errorf("consumer already registered")
|
||||
ErrWatcherAlreadyStarted = fmt.Errorf("watcher already started")
|
||||
ErrWatcherNotInitialized = fmt.Errorf("watcher not initialized")
|
||||
)
|
||||
|
|
@ -119,7 +119,7 @@ type JobsStore interface {
|
|||
DeleteCompletedJobs(ctx context.Context) error
|
||||
}
|
||||
|
||||
type EntityPools interface {
|
||||
type EntityPoolStore interface {
|
||||
CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error)
|
||||
GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error)
|
||||
DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error
|
||||
|
|
@ -144,8 +144,11 @@ type Store interface {
|
|||
UserStore
|
||||
InstanceStore
|
||||
JobsStore
|
||||
EntityPools
|
||||
GithubEndpointStore
|
||||
GithubCredentialsStore
|
||||
ControllerStore
|
||||
EntityPoolStore
|
||||
|
||||
ControllerInfo() (params.ControllerInfo, error)
|
||||
InitController() (params.ControllerInfo, error)
|
||||
}
|
||||
53
database/common/watcher.go
Normal file
53
database/common/watcher.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
package common
|
||||
|
||||
import "context"
|
||||
|
||||
type (
|
||||
DatabaseEntityType string
|
||||
OperationType string
|
||||
PayloadFilterFunc func(ChangePayload) bool
|
||||
)
|
||||
|
||||
const (
|
||||
RepositoryEntityType DatabaseEntityType = "repository"
|
||||
OrganizationEntityType DatabaseEntityType = "organization"
|
||||
EnterpriseEntityType DatabaseEntityType = "enterprise"
|
||||
PoolEntityType DatabaseEntityType = "pool"
|
||||
UserEntityType DatabaseEntityType = "user"
|
||||
InstanceEntityType DatabaseEntityType = "instance"
|
||||
JobEntityType DatabaseEntityType = "job"
|
||||
ControllerEntityType DatabaseEntityType = "controller"
|
||||
GithubCredentialsEntityType DatabaseEntityType = "github_credentials" // #nosec G101
|
||||
GithubEndpointEntityType DatabaseEntityType = "github_endpoint"
|
||||
)
|
||||
|
||||
const (
|
||||
CreateOperation OperationType = "create"
|
||||
UpdateOperation OperationType = "update"
|
||||
DeleteOperation OperationType = "delete"
|
||||
)
|
||||
|
||||
type ChangePayload struct {
|
||||
EntityType DatabaseEntityType
|
||||
Operation OperationType
|
||||
Payload interface{}
|
||||
}
|
||||
|
||||
type Consumer interface {
|
||||
Watch() <-chan ChangePayload
|
||||
IsClosed() bool
|
||||
Close()
|
||||
SetFilters(filters ...PayloadFilterFunc)
|
||||
}
|
||||
|
||||
type Producer interface {
|
||||
Notify(ChangePayload) error
|
||||
IsClosed() bool
|
||||
Close()
|
||||
}
|
||||
|
||||
type Watcher interface {
|
||||
RegisterProducer(ctx context.Context, ID string) (Producer, error)
|
||||
RegisterConsumer(ctx context.Context, ID string, filters ...PayloadFilterFunc) (Consumer, error)
|
||||
Close()
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"gorm.io/gorm"
|
||||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
|
|
@ -82,38 +83,49 @@ func (s *sqlDatabase) InitController() (params.ControllerInfo, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateController(info params.UpdateControllerParams) (params.ControllerInfo, error) {
|
||||
var dbInfo ControllerInfo
|
||||
q := s.conn.Model(&ControllerInfo{}).First(&dbInfo)
|
||||
if q.Error != nil {
|
||||
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return params.ControllerInfo{}, errors.Wrap(runnerErrors.ErrNotFound, "fetching controller info")
|
||||
func (s *sqlDatabase) UpdateController(info params.UpdateControllerParams) (paramInfo params.ControllerInfo, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.ControllerEntityType, common.UpdateOperation, paramInfo)
|
||||
}
|
||||
return params.ControllerInfo{}, errors.Wrap(q.Error, "fetching controller info")
|
||||
}()
|
||||
var dbInfo ControllerInfo
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
q := tx.Model(&ControllerInfo{}).First(&dbInfo)
|
||||
if q.Error != nil {
|
||||
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return errors.Wrap(runnerErrors.ErrNotFound, "fetching controller info")
|
||||
}
|
||||
return errors.Wrap(q.Error, "fetching controller info")
|
||||
}
|
||||
|
||||
if err := info.Validate(); err != nil {
|
||||
return errors.Wrap(err, "validating controller info")
|
||||
}
|
||||
|
||||
if info.MetadataURL != nil {
|
||||
dbInfo.MetadataURL = *info.MetadataURL
|
||||
}
|
||||
|
||||
if info.CallbackURL != nil {
|
||||
dbInfo.CallbackURL = *info.CallbackURL
|
||||
}
|
||||
|
||||
if info.WebhookURL != nil {
|
||||
dbInfo.WebhookBaseURL = *info.WebhookURL
|
||||
}
|
||||
|
||||
q = tx.Save(&dbInfo)
|
||||
if q.Error != nil {
|
||||
return errors.Wrap(q.Error, "saving controller info")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return params.ControllerInfo{}, errors.Wrap(err, "updating controller info")
|
||||
}
|
||||
|
||||
if err := info.Validate(); err != nil {
|
||||
return params.ControllerInfo{}, errors.Wrap(err, "validating controller info")
|
||||
}
|
||||
|
||||
if info.MetadataURL != nil {
|
||||
dbInfo.MetadataURL = *info.MetadataURL
|
||||
}
|
||||
|
||||
if info.CallbackURL != nil {
|
||||
dbInfo.CallbackURL = *info.CallbackURL
|
||||
}
|
||||
|
||||
if info.WebhookURL != nil {
|
||||
dbInfo.WebhookBaseURL = *info.WebhookURL
|
||||
}
|
||||
|
||||
q = s.conn.Save(&dbInfo)
|
||||
if q.Error != nil {
|
||||
return params.ControllerInfo{}, errors.Wrap(q.Error, "saving controller info")
|
||||
}
|
||||
|
||||
paramInfo, err := dbControllerToCommonController(dbInfo)
|
||||
paramInfo, err = dbControllerToCommonController(dbInfo)
|
||||
if err != nil {
|
||||
return params.ControllerInfo{}, errors.Wrap(err, "converting controller info")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package sql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -23,10 +24,11 @@ import (
|
|||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Enterprise, error) {
|
||||
func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (paramEnt params.Enterprise, err error) {
|
||||
if webhookSecret == "" {
|
||||
return params.Enterprise{}, errors.New("creating enterprise: missing secret")
|
||||
}
|
||||
|
|
@ -34,6 +36,12 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam
|
|||
if err != nil {
|
||||
return params.Enterprise{}, errors.Wrap(err, "encoding secret")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.EnterpriseEntityType, common.CreateOperation, paramEnt)
|
||||
}
|
||||
}()
|
||||
newEnterprise := Enterprise{
|
||||
Name: name,
|
||||
WebhookSecret: secret,
|
||||
|
|
@ -66,12 +74,12 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsNam
|
|||
return params.Enterprise{}, errors.Wrap(err, "creating enterprise")
|
||||
}
|
||||
|
||||
param, err := s.sqlToCommonEnterprise(newEnterprise, true)
|
||||
paramEnt, err = s.sqlToCommonEnterprise(newEnterprise, true)
|
||||
if err != nil {
|
||||
return params.Enterprise{}, errors.Wrap(err, "creating enterprise")
|
||||
}
|
||||
|
||||
return param, nil
|
||||
return paramEnt, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) {
|
||||
|
|
@ -124,11 +132,22 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e
|
|||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
|
||||
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID)
|
||||
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching enterprise")
|
||||
}
|
||||
|
||||
defer func(ent Enterprise) {
|
||||
if err == nil {
|
||||
asParams, innerErr := s.sqlToCommonEnterprise(ent, true)
|
||||
if innerErr == nil {
|
||||
s.sendNotify(common.EnterpriseEntityType, common.DeleteOperation, asParams)
|
||||
} else {
|
||||
slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "enterprise", enterpriseID)
|
||||
}
|
||||
}
|
||||
}(enterprise)
|
||||
|
||||
q := s.conn.Unscoped().Delete(&enterprise)
|
||||
if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return errors.Wrap(q.Error, "deleting enterprise")
|
||||
|
|
@ -137,10 +156,15 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) {
|
||||
func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (newParams params.Enterprise, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.EnterpriseEntityType, common.UpdateOperation, newParams)
|
||||
}
|
||||
}()
|
||||
var enterprise Enterprise
|
||||
var creds GithubCredentials
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
enterprise, err = s.getEnterpriseByID(ctx, tx, enterpriseID)
|
||||
if err != nil {
|
||||
|
|
@ -196,7 +220,7 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string,
|
|||
if err != nil {
|
||||
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
|
||||
}
|
||||
newParams, err := s.sqlToCommonEnterprise(enterprise, true)
|
||||
newParams, err = s.sqlToCommonEnterprise(enterprise, true)
|
||||
if err != nil {
|
||||
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
"github.com/cloudbase/garm/auth"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
|
|
@ -109,9 +110,14 @@ func getUIDFromContext(ctx context.Context) (uuid.UUID, error) {
|
|||
return asUUID, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.CreateGithubEndpointParams) (params.GithubEndpoint, error) {
|
||||
func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.CreateGithubEndpointParams) (ghEndpoint params.GithubEndpoint, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.GithubEndpointEntityType, common.CreateOperation, ghEndpoint)
|
||||
}
|
||||
}()
|
||||
var endpoint GithubEndpoint
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("name = ?", param.Name).First(&endpoint).Error; err == nil {
|
||||
return errors.Wrap(runnerErrors.ErrDuplicateEntity, "github endpoint already exists")
|
||||
}
|
||||
|
|
@ -132,7 +138,11 @@ func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.Creat
|
|||
if err != nil {
|
||||
return params.GithubEndpoint{}, errors.Wrap(err, "creating github endpoint")
|
||||
}
|
||||
return s.sqlToCommonGithubEndpoint(endpoint)
|
||||
ghEndpoint, err = s.sqlToCommonGithubEndpoint(endpoint)
|
||||
if err != nil {
|
||||
return params.GithubEndpoint{}, errors.Wrap(err, "converting github endpoint")
|
||||
}
|
||||
return ghEndpoint, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) ListGithubEndpoints(_ context.Context) ([]params.GithubEndpoint, error) {
|
||||
|
|
@ -153,12 +163,18 @@ func (s *sqlDatabase) ListGithubEndpoints(_ context.Context) ([]params.GithubEnd
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param params.UpdateGithubEndpointParams) (params.GithubEndpoint, error) {
|
||||
func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param params.UpdateGithubEndpointParams) (ghEndpoint params.GithubEndpoint, err error) {
|
||||
if name == defaultGithubEndpoint {
|
||||
return params.GithubEndpoint{}, errors.Wrap(runnerErrors.ErrBadRequest, "cannot update default github endpoint")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.GithubEndpointEntityType, common.UpdateOperation, ghEndpoint)
|
||||
}
|
||||
}()
|
||||
var endpoint GithubEndpoint
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.Wrap(runnerErrors.ErrNotFound, "github endpoint not found")
|
||||
|
|
@ -194,7 +210,11 @@ func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param
|
|||
if err != nil {
|
||||
return params.GithubEndpoint{}, errors.Wrap(err, "updating github endpoint")
|
||||
}
|
||||
return s.sqlToCommonGithubEndpoint(endpoint)
|
||||
ghEndpoint, err = s.sqlToCommonGithubEndpoint(endpoint)
|
||||
if err != nil {
|
||||
return params.GithubEndpoint{}, errors.Wrap(err, "converting github endpoint")
|
||||
}
|
||||
return ghEndpoint, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) GetGithubEndpoint(_ context.Context, name string) (params.GithubEndpoint, error) {
|
||||
|
|
@ -211,11 +231,17 @@ func (s *sqlDatabase) GetGithubEndpoint(_ context.Context, name string) (params.
|
|||
return s.sqlToCommonGithubEndpoint(endpoint)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) error {
|
||||
func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) (err error) {
|
||||
if name == defaultGithubEndpoint {
|
||||
return errors.Wrap(runnerErrors.ErrBadRequest, "cannot delete default github endpoint")
|
||||
}
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.GithubEndpointEntityType, common.DeleteOperation, params.GithubEndpoint{Name: name})
|
||||
}
|
||||
}()
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
var endpoint GithubEndpoint
|
||||
if err := tx.Where("name = ?", name).First(&endpoint).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
|
@ -267,7 +293,7 @@ func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params.CreateGithubCredentialsParams) (params.GithubCredentials, error) {
|
||||
func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params.CreateGithubCredentialsParams) (ghCreds params.GithubCredentials, err error) {
|
||||
userID, err := getUIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials")
|
||||
|
|
@ -275,6 +301,12 @@ func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params.
|
|||
if param.Endpoint == "" {
|
||||
return params.GithubCredentials{}, errors.Wrap(runnerErrors.ErrBadRequest, "endpoint name is required")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.GithubCredentialsEntityType, common.CreateOperation, ghCreds)
|
||||
}
|
||||
}()
|
||||
var creds GithubCredentials
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
var endpoint GithubEndpoint
|
||||
|
|
@ -323,7 +355,11 @@ func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params.
|
|||
if err != nil {
|
||||
return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials")
|
||||
}
|
||||
return s.sqlToCommonGithubCredentials(creds)
|
||||
ghCreds, err = s.sqlToCommonGithubCredentials(creds)
|
||||
if err != nil {
|
||||
return params.GithubCredentials{}, errors.Wrap(err, "converting github credentials")
|
||||
}
|
||||
return ghCreds, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) getGithubCredentialsByName(ctx context.Context, tx *gorm.DB, name string, detailed bool) (GithubCredentials, error) {
|
||||
|
|
@ -420,9 +456,14 @@ func (s *sqlDatabase) ListGithubCredentials(ctx context.Context) ([]params.Githu
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (params.GithubCredentials, error) {
|
||||
func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (ghCreds params.GithubCredentials, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.GithubCredentialsEntityType, common.UpdateOperation, ghCreds)
|
||||
}
|
||||
}()
|
||||
var creds GithubCredentials
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
q := tx.Preload("Endpoint")
|
||||
if !auth.IsAdmin(ctx) {
|
||||
userID, err := getUIDFromContext(ctx)
|
||||
|
|
@ -486,11 +527,22 @@ func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, para
|
|||
if err != nil {
|
||||
return params.GithubCredentials{}, errors.Wrap(err, "updating github credentials")
|
||||
}
|
||||
return s.sqlToCommonGithubCredentials(creds)
|
||||
|
||||
ghCreds, err = s.sqlToCommonGithubCredentials(creds)
|
||||
if err != nil {
|
||||
return params.GithubCredentials{}, errors.Wrap(err, "converting github credentials")
|
||||
}
|
||||
return ghCreds, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) error {
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) (err error) {
|
||||
var name string
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.GithubCredentialsEntityType, common.DeleteOperation, params.GithubCredentials{ID: id, Name: name})
|
||||
}
|
||||
}()
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
q := tx.Where("id = ?", id).
|
||||
Preload("Repositories").
|
||||
Preload("Organizations").
|
||||
|
|
@ -511,6 +563,8 @@ func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) erro
|
|||
}
|
||||
return errors.Wrap(err, "fetching github credentials")
|
||||
}
|
||||
name = creds.Name
|
||||
|
||||
if len(creds.Repositories) > 0 {
|
||||
return errors.Wrap(runnerErrors.ErrBadRequest, "cannot delete credentials with repositories")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package sql
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -25,35 +26,22 @@ import (
|
|||
"gorm.io/gorm/clause"
|
||||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
func (s *sqlDatabase) marshalAndSeal(data interface{}) ([]byte, error) {
|
||||
enc, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshalling data")
|
||||
}
|
||||
return util.Seal(enc, []byte(s.cfg.Passphrase))
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error {
|
||||
decrypted, err := util.Unseal(data, []byte(s.cfg.Passphrase))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decrypting data")
|
||||
}
|
||||
if err := json.Unmarshal(decrypted, target); err != nil {
|
||||
return errors.Wrap(err, "unmarshalling data")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) {
|
||||
func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) {
|
||||
pool, err := s.getPoolByID(s.conn, poolID)
|
||||
if err != nil {
|
||||
return params.Instance{}, errors.Wrap(err, "fetching pool")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance)
|
||||
}
|
||||
}()
|
||||
|
||||
var labels datatypes.JSON
|
||||
if len(param.AditionalLabels) > 0 {
|
||||
labels, err = json.Marshal(param.AditionalLabels)
|
||||
|
|
@ -154,11 +142,30 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string
|
|||
return s.sqlToParamsInstance(instance)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) error {
|
||||
func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) (err error) {
|
||||
instance, err := s.getPoolInstanceByName(poolID, instanceName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "deleting instance")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
var providerID string
|
||||
if instance.ProviderID != nil {
|
||||
providerID = *instance.ProviderID
|
||||
}
|
||||
if notifyErr := s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{
|
||||
ID: instance.ID.String(),
|
||||
Name: instance.Name,
|
||||
ProviderID: providerID,
|
||||
AgentID: instance.AgentID,
|
||||
PoolID: instance.PoolID.String(),
|
||||
}); notifyErr != nil {
|
||||
slog.With(slog.Any("error", notifyErr)).Error("failed to send notify")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if q := s.conn.Unscoped().Delete(&instance); q.Error != nil {
|
||||
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return nil
|
||||
|
|
@ -250,8 +257,12 @@ func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, p
|
|||
return params.Instance{}, errors.Wrap(err, "updating addresses")
|
||||
}
|
||||
}
|
||||
|
||||
return s.sqlToParamsInstance(instance)
|
||||
inst, err := s.sqlToParamsInstance(instance)
|
||||
if err != nil {
|
||||
return params.Instance{}, errors.Wrap(err, "converting instance")
|
||||
}
|
||||
s.sendNotify(common.InstanceEntityType, common.UpdateOperation, inst)
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) ListPoolInstances(_ context.Context, poolID string) ([]params.Instance, error) {
|
||||
|
|
|
|||
|
|
@ -93,7 +93,14 @@ func (s *sqlDatabase) paramsJobToWorkflowJob(ctx context.Context, job params.Job
|
|||
return workflofJob, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteJob(_ context.Context, jobID int64) error {
|
||||
func (s *sqlDatabase) DeleteJob(_ context.Context, jobID int64) (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if notifyErr := s.sendNotify(common.JobEntityType, common.DeleteOperation, params.Job{ID: jobID}); notifyErr != nil {
|
||||
slog.With(slog.Any("error", notifyErr)).Error("failed to send notify")
|
||||
}
|
||||
}
|
||||
}()
|
||||
q := s.conn.Delete(&WorkflowJob{}, jobID)
|
||||
if q.Error != nil {
|
||||
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
|
|
@ -134,10 +141,17 @@ func (s *sqlDatabase) LockJob(_ context.Context, jobID int64, entityID string) e
|
|||
return errors.Wrap(err, "saving job")
|
||||
}
|
||||
|
||||
asParams, err := sqlWorkflowJobToParamsJob(workflowJob)
|
||||
if err == nil {
|
||||
s.sendNotify(common.JobEntityType, common.UpdateOperation, asParams)
|
||||
} else {
|
||||
slog.With(slog.Any("error", err)).Error("failed to convert job to params")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) error {
|
||||
func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) (err error) {
|
||||
var workflowJob WorkflowJob
|
||||
q := s.conn.Clauses(clause.Locking{Strength: "UPDATE"}).Preload("Instance").Where("id = ? and status = ?", jobID, params.JobStatusQueued).First(&workflowJob)
|
||||
|
||||
|
|
@ -157,7 +171,12 @@ func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) error
|
|||
if err := s.conn.Save(&workflowJob).Error; err != nil {
|
||||
return errors.Wrap(err, "saving job")
|
||||
}
|
||||
|
||||
asParams, err := sqlWorkflowJobToParamsJob(workflowJob)
|
||||
if err == nil {
|
||||
s.sendNotify(common.JobEntityType, common.UpdateOperation, asParams)
|
||||
} else {
|
||||
slog.With(slog.Any("error", err)).Error("failed to convert job to params")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -186,6 +205,12 @@ func (s *sqlDatabase) UnlockJob(_ context.Context, jobID int64, entityID string)
|
|||
return errors.Wrap(err, "saving job")
|
||||
}
|
||||
|
||||
asParams, err := sqlWorkflowJobToParamsJob(workflowJob)
|
||||
if err == nil {
|
||||
s.sendNotify(common.JobEntityType, common.UpdateOperation, asParams)
|
||||
} else {
|
||||
slog.With(slog.Any("error", err)).Error("failed to convert job to params")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -198,9 +223,11 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa
|
|||
return params.Job{}, errors.Wrap(q.Error, "fetching job")
|
||||
}
|
||||
}
|
||||
|
||||
var operation common.OperationType
|
||||
if workflowJob.ID != 0 {
|
||||
// Update workflowJob with values from job.
|
||||
operation = common.UpdateOperation
|
||||
|
||||
workflowJob.Status = job.Status
|
||||
workflowJob.Action = job.Action
|
||||
workflowJob.Conclusion = job.Conclusion
|
||||
|
|
@ -238,6 +265,8 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa
|
|||
return params.Job{}, errors.Wrap(err, "saving job")
|
||||
}
|
||||
} else {
|
||||
operation = common.CreateOperation
|
||||
|
||||
workflowJob, err := s.paramsJobToWorkflowJob(ctx, job)
|
||||
if err != nil {
|
||||
return params.Job{}, errors.Wrap(err, "converting job")
|
||||
|
|
@ -247,7 +276,13 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa
|
|||
}
|
||||
}
|
||||
|
||||
return sqlWorkflowJobToParamsJob(workflowJob)
|
||||
asParams, err := sqlWorkflowJobToParamsJob(workflowJob)
|
||||
if err != nil {
|
||||
return params.Job{}, errors.Wrap(err, "converting job")
|
||||
}
|
||||
s.sendNotify(common.JobEntityType, operation, asParams)
|
||||
|
||||
return asParams, nil
|
||||
}
|
||||
|
||||
// ListJobsByStatus lists all jobs for a given status.
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package sql
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -24,10 +25,11 @@ import (
|
|||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Organization, error) {
|
||||
func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (org params.Organization, err error) {
|
||||
if webhookSecret == "" {
|
||||
return params.Organization{}, errors.New("creating org: missing secret")
|
||||
}
|
||||
|
|
@ -35,6 +37,12 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN
|
|||
if err != nil {
|
||||
return params.Organization{}, errors.Wrap(err, "encoding secret")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.OrganizationEntityType, common.CreateOperation, org)
|
||||
}
|
||||
}()
|
||||
newOrg := Organization{
|
||||
Name: name,
|
||||
WebhookSecret: secret,
|
||||
|
|
@ -68,13 +76,13 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsN
|
|||
return params.Organization{}, errors.Wrap(err, "creating org")
|
||||
}
|
||||
|
||||
param, err := s.sqlToCommonOrganization(newOrg, true)
|
||||
org, err = s.sqlToCommonOrganization(newOrg, true)
|
||||
if err != nil {
|
||||
return params.Organization{}, errors.Wrap(err, "creating org")
|
||||
}
|
||||
param.WebhookSecret = webhookSecret
|
||||
org.WebhookSecret = webhookSecret
|
||||
|
||||
return param, nil
|
||||
return org, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.Organization, error) {
|
||||
|
|
@ -114,12 +122,23 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) error {
|
||||
org, err := s.getOrgByID(ctx, s.conn, orgID)
|
||||
func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) {
|
||||
org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching org")
|
||||
}
|
||||
|
||||
defer func(org Organization) {
|
||||
if err == nil {
|
||||
asParam, innerErr := s.sqlToCommonOrganization(org, true)
|
||||
if innerErr == nil {
|
||||
s.sendNotify(common.OrganizationEntityType, common.DeleteOperation, asParam)
|
||||
} else {
|
||||
slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "org", orgID)
|
||||
}
|
||||
}
|
||||
}(org)
|
||||
|
||||
q := s.conn.Unscoped().Delete(&org)
|
||||
if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return errors.Wrap(q.Error, "deleting org")
|
||||
|
|
@ -128,10 +147,15 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) {
|
||||
func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (paramOrg params.Organization, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.OrganizationEntityType, common.UpdateOperation, paramOrg)
|
||||
}
|
||||
}()
|
||||
var org Organization
|
||||
var creds GithubCredentials
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
org, err = s.getOrgByID(ctx, tx, orgID)
|
||||
if err != nil {
|
||||
|
|
@ -188,11 +212,11 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
|
|||
if err != nil {
|
||||
return params.Organization{}, errors.Wrap(err, "updating enterprise")
|
||||
}
|
||||
newParams, err := s.sqlToCommonOrganization(org, true)
|
||||
paramOrg, err = s.sqlToCommonOrganization(org, true)
|
||||
if err != nil {
|
||||
return params.Organization{}, errors.Wrap(err, "saving org")
|
||||
}
|
||||
return newParams, nil
|
||||
return paramOrg, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error) {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"gorm.io/gorm"
|
||||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
|
|
@ -66,12 +67,18 @@ func (s *sqlDatabase) GetPoolByID(_ context.Context, poolID string) (params.Pool
|
|||
return s.sqlToCommonPool(pool)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) error {
|
||||
func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) (err error) {
|
||||
pool, err := s.getPoolByID(s.conn, poolID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching pool by ID")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.PoolEntityType, common.DeleteOperation, params.Pool{ID: poolID})
|
||||
}
|
||||
}()
|
||||
|
||||
if q := s.conn.Unscoped().Delete(&pool); q.Error != nil {
|
||||
return errors.Wrap(q.Error, "removing pool")
|
||||
}
|
||||
|
|
@ -247,11 +254,17 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par
|
|||
return pools, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) {
|
||||
func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (pool params.Pool, err error) {
|
||||
if len(param.Tags) == 0 {
|
||||
return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.PoolEntityType, common.CreateOperation, pool)
|
||||
}
|
||||
}()
|
||||
|
||||
newPool := Pool{
|
||||
ProviderName: param.ProviderName,
|
||||
MaxRunners: param.MaxRunners,
|
||||
|
|
@ -313,12 +326,12 @@ func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEn
|
|||
return params.Pool{}, err
|
||||
}
|
||||
|
||||
pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
|
||||
dbPool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
|
||||
if err != nil {
|
||||
return params.Pool{}, errors.Wrap(err, "fetching pool")
|
||||
}
|
||||
|
||||
return s.sqlToCommonPool(pool)
|
||||
return s.sqlToCommonPool(dbPool)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) {
|
||||
|
|
@ -329,12 +342,21 @@ func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntit
|
|||
return s.sqlToCommonPool(pool)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) error {
|
||||
func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (err error) {
|
||||
entityID, err := uuid.Parse(entity.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
pool := params.Pool{
|
||||
ID: poolID,
|
||||
}
|
||||
s.sendNotify(common.PoolEntityType, common.DeleteOperation, pool)
|
||||
}
|
||||
}()
|
||||
|
||||
poolUUID, err := uuid.Parse(poolID)
|
||||
if err != nil {
|
||||
return errors.Wrap(runnerErrors.ErrBadRequest, "parsing pool id")
|
||||
|
|
@ -357,9 +379,13 @@ func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEn
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
|
||||
var updatedPool params.Pool
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (updatedPool params.Pool, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.PoolEntityType, common.UpdateOperation, updatedPool)
|
||||
}
|
||||
}()
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
pool, err := s.getEntityPool(tx, entity.EntityType, entity.ID, poolID, "Tags", "Instances")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching pool")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package sql
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -24,10 +25,17 @@ import (
|
|||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) {
|
||||
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (param params.Repository, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.RepositoryEntityType, common.CreateOperation, param)
|
||||
}
|
||||
}()
|
||||
|
||||
if webhookSecret == "" {
|
||||
return params.Repository{}, errors.New("creating repo: missing secret")
|
||||
}
|
||||
|
|
@ -68,7 +76,7 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credent
|
|||
return params.Repository{}, errors.Wrap(err, "creating repository")
|
||||
}
|
||||
|
||||
param, err := s.sqlToCommonRepository(newRepo, true)
|
||||
param, err = s.sqlToCommonRepository(newRepo, true)
|
||||
if err != nil {
|
||||
return params.Repository{}, errors.Wrap(err, "creating repository")
|
||||
}
|
||||
|
|
@ -113,12 +121,23 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository,
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error {
|
||||
repo, err := s.getRepoByID(ctx, s.conn, repoID)
|
||||
func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) {
|
||||
repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching repo")
|
||||
}
|
||||
|
||||
defer func(repo Repository) {
|
||||
if err == nil {
|
||||
asParam, innerErr := s.sqlToCommonRepository(repo, true)
|
||||
if innerErr == nil {
|
||||
s.sendNotify(common.RepositoryEntityType, common.DeleteOperation, asParam)
|
||||
} else {
|
||||
slog.With(slog.Any("error", innerErr)).ErrorContext(ctx, "error sending delete notification", "repo", repoID)
|
||||
}
|
||||
}
|
||||
}(repo)
|
||||
|
||||
q := s.conn.Unscoped().Delete(&repo)
|
||||
if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return errors.Wrap(q.Error, "deleting repo")
|
||||
|
|
@ -127,10 +146,15 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) {
|
||||
func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (newParams params.Repository, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.sendNotify(common.RepositoryEntityType, common.UpdateOperation, newParams)
|
||||
}
|
||||
}()
|
||||
var repo Repository
|
||||
var creds GithubCredentials
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
repo, err = s.getRepoByID(ctx, tx, repoID)
|
||||
if err != nil {
|
||||
|
|
@ -186,7 +210,8 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
|
|||
if err != nil {
|
||||
return params.Repository{}, errors.Wrap(err, "updating enterprise")
|
||||
}
|
||||
newParams, err := s.sqlToCommonRepository(repo, true)
|
||||
|
||||
newParams, err = s.sqlToCommonRepository(repo, true)
|
||||
if err != nil {
|
||||
return params.Repository{}, errors.Wrap(err, "saving repo")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
|
||||
"github.com/cloudbase/garm/auth"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
|
@ -827,5 +828,11 @@ func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() {
|
|||
|
||||
func TestRepoTestSuite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
watcher.SetWatcher(&garmTesting.MockWatcher{})
|
||||
suite.Run(t, new(RepoTestSuite))
|
||||
}
|
||||
|
||||
func init() {
|
||||
watcher.SetWatcher(&garmTesting.MockWatcher{})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/cloudbase/garm/auth"
|
||||
"github.com/cloudbase/garm/config"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/params"
|
||||
"github.com/cloudbase/garm/util/appdefaults"
|
||||
)
|
||||
|
|
@ -68,10 +69,15 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "creating DB connection")
|
||||
}
|
||||
producer, err := watcher.RegisterProducer(ctx, "sql")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "registering producer")
|
||||
}
|
||||
db := &sqlDatabase{
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
producer: producer,
|
||||
}
|
||||
|
||||
if err := db.migrateDB(); err != nil {
|
||||
|
|
@ -81,9 +87,10 @@ func NewSQLDatabase(ctx context.Context, cfg config.Database) (common.Store, err
|
|||
}
|
||||
|
||||
type sqlDatabase struct {
|
||||
conn *gorm.DB
|
||||
ctx context.Context
|
||||
cfg config.Database
|
||||
conn *gorm.DB
|
||||
ctx context.Context
|
||||
cfg config.Database
|
||||
producer common.Producer
|
||||
}
|
||||
|
||||
var renameTemplate = `
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
commonParams "github.com/cloudbase/garm-provider-common/params"
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
|
|
@ -467,3 +468,38 @@ func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntit
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) marshalAndSeal(data interface{}) ([]byte, error) {
|
||||
enc, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshalling data")
|
||||
}
|
||||
return util.Seal(enc, []byte(s.cfg.Passphrase))
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error {
|
||||
decrypted, err := util.Unseal(data, []byte(s.cfg.Passphrase))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "decrypting data")
|
||||
}
|
||||
if err := json.Unmarshal(decrypted, target); err != nil {
|
||||
return errors.Wrap(err, "unmarshalling data")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) sendNotify(entityType dbCommon.DatabaseEntityType, op dbCommon.OperationType, payload interface{}) error {
|
||||
if s.producer == nil {
|
||||
// no producer was registered. Not sending notifications.
|
||||
return nil
|
||||
}
|
||||
if payload == nil {
|
||||
return errors.New("missing payload")
|
||||
}
|
||||
message := dbCommon.ChangePayload{
|
||||
Operation: op,
|
||||
Payload: payload,
|
||||
EntityType: entityType,
|
||||
}
|
||||
return s.producer.Notify(message)
|
||||
}
|
||||
|
|
|
|||
82
database/watcher/consumer.go
Normal file
82
database/watcher/consumer.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
)
|
||||
|
||||
type consumer struct {
|
||||
messages chan common.ChangePayload
|
||||
filters []common.PayloadFilterFunc
|
||||
id string
|
||||
|
||||
mux sync.Mutex
|
||||
closed bool
|
||||
quit chan struct{}
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *consumer) SetFilters(filters ...common.PayloadFilterFunc) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
w.filters = filters
|
||||
}
|
||||
|
||||
func (w *consumer) Watch() <-chan common.ChangePayload {
|
||||
return w.messages
|
||||
}
|
||||
|
||||
func (w *consumer) Close() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
close(w.messages)
|
||||
close(w.quit)
|
||||
w.closed = true
|
||||
}
|
||||
|
||||
func (w *consumer) IsClosed() bool {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
return w.closed
|
||||
}
|
||||
|
||||
func (w *consumer) Send(payload common.ChangePayload) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if len(w.filters) > 0 {
|
||||
shouldSend := true
|
||||
for _, filter := range w.filters {
|
||||
if !filter(payload) {
|
||||
shouldSend = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !shouldSend {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
slog.DebugContext(w.ctx, "sending payload")
|
||||
select {
|
||||
case <-w.quit:
|
||||
slog.DebugContext(w.ctx, "consumer is closed")
|
||||
case <-w.ctx.Done():
|
||||
slog.DebugContext(w.ctx, "consumer is closed")
|
||||
case <-time.After(1 * time.Second):
|
||||
slog.DebugContext(w.ctx, "timeout trying to send payload", "payload", payload)
|
||||
case w.messages <- payload:
|
||||
}
|
||||
}
|
||||
141
database/watcher/filters.go
Normal file
141
database/watcher/filters.go
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
package watcher
|
||||
|
||||
import (
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
type idGetter interface {
|
||||
GetID() string
|
||||
}
|
||||
|
||||
// WithAny returns a filter function that returns true if any of the provided filters return true.
|
||||
// This filter is useful if for example you want to watch for update operations on any of the supplied
|
||||
// entities.
|
||||
// Example:
|
||||
//
|
||||
// // Watch for any update operation on repositories or organizations
|
||||
// consumer.SetFilters(
|
||||
// watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
// watcher.WithAny(
|
||||
// watcher.WithEntityTypeFilter(common.RepositoryEntityType),
|
||||
// watcher.WithEntityTypeFilter(common.OrganizationEntityType),
|
||||
// ))
|
||||
func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
|
||||
return func(payload dbCommon.ChangePayload) bool {
|
||||
for _, filter := range filters {
|
||||
if filter(payload) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// WithEntityTypeFilter returns a filter function that filters payloads by entity type.
|
||||
// The filter function returns true if the payload's entity type matches the provided entity type.
|
||||
func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc {
|
||||
return func(payload dbCommon.ChangePayload) bool {
|
||||
return payload.EntityType == entityType
|
||||
}
|
||||
}
|
||||
|
||||
// WithOperationTypeFilter returns a filter function that filters payloads by operation type.
|
||||
func WithOperationTypeFilter(operationType dbCommon.OperationType) dbCommon.PayloadFilterFunc {
|
||||
return func(payload dbCommon.ChangePayload) bool {
|
||||
return payload.Operation == operationType
|
||||
}
|
||||
}
|
||||
|
||||
// WithEntityPoolFilter returns true if the change payload is a pool that belongs to the
|
||||
// supplied Github entity. This is useful when an entity worker wants to watch for changes
|
||||
// in pools that belong to it.
|
||||
func WithEntityPoolFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc {
|
||||
return func(payload dbCommon.ChangePayload) bool {
|
||||
switch payload.EntityType {
|
||||
case dbCommon.PoolEntityType:
|
||||
pool, ok := payload.Payload.(params.Pool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch ghEntity.EntityType {
|
||||
case params.GithubEntityTypeRepository:
|
||||
if pool.RepoID != ghEntity.ID {
|
||||
return false
|
||||
}
|
||||
case params.GithubEntityTypeOrganization:
|
||||
if pool.OrgID != ghEntity.ID {
|
||||
return false
|
||||
}
|
||||
case params.GithubEntityTypeEnterprise:
|
||||
if pool.EnterpriseID != ghEntity.ID {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithEntityFilter returns a filter function that filters payloads by entity.
|
||||
// Change payloads that match the entity type and ID will return true.
|
||||
func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc {
|
||||
return func(payload dbCommon.ChangePayload) bool {
|
||||
if params.GithubEntityType(payload.EntityType) != entity.EntityType {
|
||||
return false
|
||||
}
|
||||
var ent idGetter
|
||||
var ok bool
|
||||
switch payload.EntityType {
|
||||
case dbCommon.RepositoryEntityType:
|
||||
ent, ok = payload.Payload.(params.Repository)
|
||||
case dbCommon.OrganizationEntityType:
|
||||
ent, ok = payload.Payload.(params.Organization)
|
||||
case dbCommon.EnterpriseEntityType:
|
||||
ent, ok = payload.Payload.(params.Enterprise)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return ent.GetID() == entity.ID
|
||||
}
|
||||
}
|
||||
|
||||
func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFunc {
|
||||
return func(payload dbCommon.ChangePayload) bool {
|
||||
switch payload.EntityType {
|
||||
case dbCommon.JobEntityType:
|
||||
job, ok := payload.Payload.(params.Job)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch ghEntity.EntityType {
|
||||
case params.GithubEntityTypeRepository:
|
||||
if job.RepoID != nil && job.RepoID.String() != ghEntity.ID {
|
||||
return false
|
||||
}
|
||||
case params.GithubEntityTypeOrganization:
|
||||
if job.OrgID != nil && job.OrgID.String() != ghEntity.ID {
|
||||
return false
|
||||
}
|
||||
case params.GithubEntityTypeEnterprise:
|
||||
if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
56
database/watcher/producer.go
Normal file
56
database/watcher/producer.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
)
|
||||
|
||||
type producer struct {
|
||||
closed bool
|
||||
mux sync.Mutex
|
||||
id string
|
||||
|
||||
messages chan common.ChangePayload
|
||||
quit chan struct{}
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *producer) Notify(payload common.ChangePayload) error {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return common.ErrProducerClosed
|
||||
}
|
||||
|
||||
select {
|
||||
case <-w.quit:
|
||||
return common.ErrProducerClosed
|
||||
case <-w.ctx.Done():
|
||||
return common.ErrProducerClosed
|
||||
case <-time.After(1 * time.Second):
|
||||
return common.ErrProducerTimeoutErr
|
||||
case w.messages <- payload:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *producer) Close() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
w.closed = true
|
||||
close(w.messages)
|
||||
close(w.quit)
|
||||
}
|
||||
|
||||
func (w *producer) IsClosed() bool {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
return w.closed
|
||||
}
|
||||
17
database/watcher/test_export.go
Normal file
17
database/watcher/test_export.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package watcher
|
||||
|
||||
import "github.com/cloudbase/garm/database/common"
|
||||
|
||||
// SetWatcher sets the watcher to be used by the database package.
|
||||
// This function is intended for use in tests only.
|
||||
func SetWatcher(w common.Watcher) {
|
||||
databaseWatcher = w
|
||||
}
|
||||
|
||||
// GetWatcher returns the current watcher.
|
||||
func GetWatcher() common.Watcher {
|
||||
return databaseWatcher
|
||||
}
|
||||
181
database/watcher/watcher.go
Normal file
181
database/watcher/watcher.go
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
garmUtil "github.com/cloudbase/garm/util"
|
||||
)
|
||||
|
||||
var databaseWatcher common.Watcher
|
||||
|
||||
func InitWatcher(ctx context.Context) {
|
||||
if databaseWatcher != nil {
|
||||
return
|
||||
}
|
||||
ctx = garmUtil.WithContext(ctx, slog.Any("watcher", "database"))
|
||||
w := &watcher{
|
||||
producers: make(map[string]*producer),
|
||||
consumers: make(map[string]*consumer),
|
||||
quit: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
go w.loop()
|
||||
databaseWatcher = w
|
||||
}
|
||||
|
||||
func RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
|
||||
if databaseWatcher == nil {
|
||||
return nil, common.ErrWatcherNotInitialized
|
||||
}
|
||||
ctx = garmUtil.WithContext(ctx, slog.Any("producer_id", id))
|
||||
return databaseWatcher.RegisterProducer(ctx, id)
|
||||
}
|
||||
|
||||
func RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||
if databaseWatcher == nil {
|
||||
return nil, common.ErrWatcherNotInitialized
|
||||
}
|
||||
ctx = garmUtil.WithContext(ctx, slog.Any("consumer_id", id))
|
||||
return databaseWatcher.RegisterConsumer(ctx, id, filters...)
|
||||
}
|
||||
|
||||
type watcher struct {
|
||||
producers map[string]*producer
|
||||
consumers map[string]*consumer
|
||||
|
||||
mux sync.Mutex
|
||||
closed bool
|
||||
quit chan struct{}
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
|
||||
if _, ok := w.producers[id]; ok {
|
||||
return nil, errors.Wrapf(common.ErrProducerAlreadyRegistered, "producer_id: %s", id)
|
||||
}
|
||||
p := &producer{
|
||||
id: id,
|
||||
messages: make(chan common.ChangePayload, 1),
|
||||
quit: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
}
|
||||
w.producers[id] = p
|
||||
go w.serviceProducer(p)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (w *watcher) serviceProducer(prod *producer) {
|
||||
defer func() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
prod.Close()
|
||||
slog.InfoContext(w.ctx, "removing producer from watcher", "consumer_id", prod.id)
|
||||
delete(w.producers, prod.id)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-w.quit:
|
||||
slog.InfoContext(w.ctx, "shutting down watcher")
|
||||
return
|
||||
case <-w.ctx.Done():
|
||||
slog.InfoContext(w.ctx, "shutting down watcher")
|
||||
return
|
||||
case <-prod.quit:
|
||||
slog.InfoContext(w.ctx, "closing producer")
|
||||
return
|
||||
case <-prod.ctx.Done():
|
||||
slog.InfoContext(w.ctx, "closing producer")
|
||||
return
|
||||
case payload := <-prod.messages:
|
||||
w.mux.Lock()
|
||||
for _, c := range w.consumers {
|
||||
go c.Send(payload)
|
||||
}
|
||||
w.mux.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if _, ok := w.consumers[id]; ok {
|
||||
return nil, common.ErrConsumerAlreadyRegistered
|
||||
}
|
||||
c := &consumer{
|
||||
messages: make(chan common.ChangePayload, 1),
|
||||
filters: filters,
|
||||
quit: make(chan struct{}),
|
||||
id: id,
|
||||
ctx: ctx,
|
||||
}
|
||||
w.consumers[id] = c
|
||||
go w.serviceConsumer(c)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (w *watcher) serviceConsumer(consumer *consumer) {
|
||||
defer func() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
consumer.Close()
|
||||
slog.InfoContext(w.ctx, "removing consumer from watcher", "consumer_id", consumer.id)
|
||||
delete(w.consumers, consumer.id)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-consumer.quit:
|
||||
return
|
||||
case <-consumer.ctx.Done():
|
||||
return
|
||||
case <-w.quit:
|
||||
return
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *watcher) Close() {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
|
||||
close(w.quit)
|
||||
w.closed = true
|
||||
|
||||
for _, p := range w.producers {
|
||||
p.Close()
|
||||
}
|
||||
|
||||
for _, c := range w.consumers {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
databaseWatcher = nil
|
||||
}
|
||||
|
||||
func (w *watcher) loop() {
|
||||
defer func() {
|
||||
w.Close()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-w.quit:
|
||||
return
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
728
database/watcher/watcher_store_test.go
Normal file
728
database/watcher/watcher_store_test.go
Normal file
|
|
@ -0,0 +1,728 @@
|
|||
package watcher_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
commonParams "github.com/cloudbase/garm-provider-common/params"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
type WatcherStoreTestSuite struct {
|
||||
suite.Suite
|
||||
|
||||
store common.Store
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestJobWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "job-test",
|
||||
watcher.WithEntityTypeFilter(common.JobEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
jobParams := params.Job{
|
||||
ID: 1,
|
||||
RunID: 2,
|
||||
Action: "test-action",
|
||||
Conclusion: "started",
|
||||
Status: "in_progress",
|
||||
Name: "test-job",
|
||||
}
|
||||
|
||||
job, err := s.store.CreateOrUpdateJob(s.ctx, jobParams)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.JobEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: job,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
jobParams.Conclusion = "success"
|
||||
updatedJob, err := s.store.CreateOrUpdateJob(s.ctx, jobParams)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.JobEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedJob,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
entityID, err := uuid.NewUUID()
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.store.LockJob(s.ctx, updatedJob.ID, entityID.String())
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(event.Operation, common.UpdateOperation)
|
||||
s.Require().Equal(event.EntityType, common.JobEntityType)
|
||||
|
||||
job, ok := event.Payload.(params.Job)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal(job.ID, updatedJob.ID)
|
||||
s.Require().Equal(job.LockedBy, entityID)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.UnlockJob(s.ctx, updatedJob.ID, entityID.String())
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(event.Operation, common.UpdateOperation)
|
||||
s.Require().Equal(event.EntityType, common.JobEntityType)
|
||||
|
||||
job, ok := event.Payload.(params.Job)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal(job.ID, updatedJob.ID)
|
||||
s.Require().Equal(job.LockedBy, uuid.Nil)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
jobParams.Status = "queued"
|
||||
jobParams.LockedBy = entityID
|
||||
|
||||
updatedJob, err = s.store.CreateOrUpdateJob(s.ctx, jobParams)
|
||||
s.Require().NoError(err)
|
||||
select {
|
||||
case <-consumer.Watch():
|
||||
// throw away event.
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("unexpected payload received")
|
||||
}
|
||||
|
||||
err = s.store.BreakLockJobIsQueued(s.ctx, updatedJob.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(event.Operation, common.UpdateOperation)
|
||||
s.Require().Equal(event.EntityType, common.JobEntityType)
|
||||
|
||||
job, ok := event.Payload.(params.Job)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal(job.ID, updatedJob.ID)
|
||||
s.Require().Equal(job.LockedBy, uuid.Nil)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestInstanceWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "instance-test",
|
||||
watcher.WithEntityTypeFilter(common.InstanceEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T())
|
||||
creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep)
|
||||
s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) })
|
||||
|
||||
repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(repo.ID)
|
||||
s.T().Cleanup(func() { s.store.DeleteRepository(s.ctx, repo.ID) })
|
||||
|
||||
entity, err := repo.GetEntity()
|
||||
s.Require().NoError(err)
|
||||
|
||||
createPoolParams := params.CreatePoolParams{
|
||||
ProviderName: "test-provider",
|
||||
Image: "test-image",
|
||||
Flavor: "test-flavor",
|
||||
OSType: commonParams.Linux,
|
||||
OSArch: commonParams.Amd64,
|
||||
Tags: []string{"test-tag"},
|
||||
}
|
||||
|
||||
pool, err := s.store.CreateEntityPool(s.ctx, entity, createPoolParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(pool.ID)
|
||||
s.T().Cleanup(func() { s.store.DeleteEntityPool(s.ctx, entity, pool.ID) })
|
||||
|
||||
createInstanceParams := params.CreateInstanceParams{
|
||||
Name: "test-instance",
|
||||
OSType: commonParams.Linux,
|
||||
OSArch: commonParams.Amd64,
|
||||
Status: commonParams.InstanceCreating,
|
||||
}
|
||||
instance, err := s.store.CreateInstance(s.ctx, pool.ID, createInstanceParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(instance.ID)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.InstanceEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: instance,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
updateParams := params.UpdateInstanceParams{
|
||||
RunnerStatus: params.RunnerActive,
|
||||
}
|
||||
|
||||
updatedInstance, err := s.store.UpdateInstance(s.ctx, instance.Name, updateParams)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.InstanceEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedInstance,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeleteInstance(s.ctx, pool.ID, updatedInstance.Name)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.InstanceEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
Payload: params.Instance{
|
||||
ID: updatedInstance.ID,
|
||||
Name: updatedInstance.Name,
|
||||
ProviderID: updatedInstance.ProviderID,
|
||||
AgentID: updatedInstance.AgentID,
|
||||
PoolID: updatedInstance.PoolID,
|
||||
},
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestPoolWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "pool-test",
|
||||
watcher.WithEntityTypeFilter(common.PoolEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T())
|
||||
creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep)
|
||||
s.T().Cleanup(func() {
|
||||
if err := s.store.DeleteGithubCredentials(s.ctx, creds.ID); err != nil {
|
||||
s.T().Logf("failed to delete Github credentials: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(repo.ID)
|
||||
s.T().Cleanup(func() { s.store.DeleteRepository(s.ctx, repo.ID) })
|
||||
|
||||
entity, err := repo.GetEntity()
|
||||
s.Require().NoError(err)
|
||||
|
||||
createPoolParams := params.CreatePoolParams{
|
||||
ProviderName: "test-provider",
|
||||
Image: "test-image",
|
||||
Flavor: "test-flavor",
|
||||
OSType: commonParams.Linux,
|
||||
OSArch: commonParams.Amd64,
|
||||
Tags: []string{"test-tag"},
|
||||
}
|
||||
pool, err := s.store.CreateEntityPool(s.ctx, entity, createPoolParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(pool.ID)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.PoolEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: pool,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
updateParams := params.UpdatePoolParams{
|
||||
Tags: []string{"updated-tag"},
|
||||
}
|
||||
|
||||
updatedPool, err := s.store.UpdateEntityPool(s.ctx, entity, pool.ID, updateParams)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.PoolEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedPool,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeleteEntityPool(s.ctx, entity, pool.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.PoolEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
Payload: params.Pool{ID: pool.ID},
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
// Also test DeletePoolByID
|
||||
pool, err = s.store.CreateEntityPool(s.ctx, entity, createPoolParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(pool.ID)
|
||||
|
||||
// Consume the create event
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.PoolEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: pool,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeletePoolByID(s.ctx, pool.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.PoolEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
Payload: params.Pool{ID: pool.ID},
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestControllerWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "controller-test",
|
||||
watcher.WithEntityTypeFilter(common.ControllerEntityType),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
metadataURL := "http://metadata.example.com"
|
||||
updateParams := params.UpdateControllerParams{
|
||||
MetadataURL: &metadataURL,
|
||||
}
|
||||
|
||||
controller, err := s.store.UpdateController(updateParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(metadataURL, controller.MetadataURL)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.ControllerEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: controller,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestEnterpriseWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "enterprise-test",
|
||||
watcher.WithEntityTypeFilter(common.EnterpriseEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T())
|
||||
creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep)
|
||||
s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) })
|
||||
|
||||
ent, err := s.store.CreateEnterprise(s.ctx, "test-enterprise", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(ent.ID)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.EnterpriseEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: ent,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
updateParams := params.UpdateEntityParams{
|
||||
WebhookSecret: "updated",
|
||||
}
|
||||
|
||||
updatedEnt, err := s.store.UpdateEnterprise(s.ctx, ent.ID, updateParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("updated", updatedEnt.WebhookSecret)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.EnterpriseEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedEnt,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeleteEnterprise(s.ctx, ent.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.EnterpriseEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
Payload: updatedEnt,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestOrgWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "org-test",
|
||||
watcher.WithEntityTypeFilter(common.OrganizationEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T())
|
||||
creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep)
|
||||
s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) })
|
||||
|
||||
org, err := s.store.CreateOrganization(s.ctx, "test-org", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(org.ID)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.OrganizationEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: org,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
updateParams := params.UpdateEntityParams{
|
||||
WebhookSecret: "updated",
|
||||
}
|
||||
|
||||
updatedOrg, err := s.store.UpdateOrganization(s.ctx, org.ID, updateParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("updated", updatedOrg.WebhookSecret)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.OrganizationEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedOrg,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeleteOrganization(s.ctx, org.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.OrganizationEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
Payload: updatedOrg,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestRepoWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "repo-test",
|
||||
watcher.WithEntityTypeFilter(common.RepositoryEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.store, s.T())
|
||||
creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.store, s.T(), ep)
|
||||
s.T().Cleanup(func() { s.store.DeleteGithubCredentials(s.ctx, creds.ID) })
|
||||
|
||||
repo, err := s.store.CreateRepository(s.ctx, "test-owner", "test-repo", creds.Name, "test-secret", params.PoolBalancerTypeRoundRobin)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(repo.ID)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.RepositoryEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: repo,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
newSecret := "updated"
|
||||
updateParams := params.UpdateEntityParams{
|
||||
WebhookSecret: newSecret,
|
||||
}
|
||||
|
||||
updatedRepo, err := s.store.UpdateRepository(s.ctx, repo.ID, updateParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(newSecret, updatedRepo.WebhookSecret)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.RepositoryEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedRepo,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeleteRepository(s.ctx, repo.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.RepositoryEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
Payload: updatedRepo,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestGithubCredentialsWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "gh-cred-test",
|
||||
watcher.WithEntityTypeFilter(common.GithubCredentialsEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
ghCredParams := params.CreateGithubCredentialsParams{
|
||||
Name: "test-creds",
|
||||
Description: "test credentials",
|
||||
Endpoint: "github.com",
|
||||
AuthType: params.GithubAuthTypePAT,
|
||||
PAT: params.GithubPAT{
|
||||
OAuth2Token: "bogus",
|
||||
},
|
||||
}
|
||||
|
||||
ghCred, err := s.store.CreateGithubCredentials(s.ctx, ghCredParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(ghCred.ID)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.GithubCredentialsEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: ghCred,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
newDesc := "updated description"
|
||||
updateParams := params.UpdateGithubCredentialsParams{
|
||||
Description: &newDesc,
|
||||
}
|
||||
|
||||
updatedGhCred, err := s.store.UpdateGithubCredentials(s.ctx, ghCred.ID, updateParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(newDesc, updatedGhCred.Description)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.GithubCredentialsEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedGhCred,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeleteGithubCredentials(s.ctx, ghCred.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.GithubCredentialsEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
// We only get the ID and Name of the deleted entity
|
||||
Payload: params.GithubCredentials{ID: ghCred.ID, Name: ghCred.Name},
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherStoreTestSuite) TestGithubEndpointWatcher() {
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "gh-ep-test",
|
||||
watcher.WithEntityTypeFilter(common.GithubEndpointEntityType),
|
||||
watcher.WithAny(
|
||||
watcher.WithOperationTypeFilter(common.CreateOperation),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation),
|
||||
watcher.WithOperationTypeFilter(common.DeleteOperation)),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
s.T().Cleanup(func() { consumer.Close() })
|
||||
|
||||
ghEpParams := params.CreateGithubEndpointParams{
|
||||
Name: "test",
|
||||
Description: "test endpoint",
|
||||
APIBaseURL: "https://api.ghes.example.com",
|
||||
UploadBaseURL: "https://upload.ghes.example.com",
|
||||
BaseURL: "https://ghes.example.com",
|
||||
}
|
||||
|
||||
ghEp, err := s.store.CreateGithubEndpoint(s.ctx, ghEpParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(ghEp.Name)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.GithubEndpointEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: ghEp,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
newDesc := "updated description"
|
||||
updateParams := params.UpdateGithubEndpointParams{
|
||||
Description: &newDesc,
|
||||
}
|
||||
|
||||
updatedGhEp, err := s.store.UpdateGithubEndpoint(s.ctx, ghEp.Name, updateParams)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(newDesc, updatedGhEp.Description)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.GithubEndpointEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: updatedGhEp,
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
err = s.store.DeleteGithubEndpoint(s.ctx, ghEp.Name)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case event := <-consumer.Watch():
|
||||
s.Require().Equal(common.ChangePayload{
|
||||
EntityType: common.GithubEndpointEntityType,
|
||||
Operation: common.DeleteOperation,
|
||||
// We only get the name of the deleted entity
|
||||
Payload: params.GithubEndpoint{Name: ghEp.Name},
|
||||
}, event)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
}
|
||||
196
database/watcher/watcher_test.go
Normal file
196
database/watcher/watcher_test.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
//go:build testing
|
||||
|
||||
package watcher_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/cloudbase/garm/database"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||
)
|
||||
|
||||
type WatcherTestSuite struct {
|
||||
suite.Suite
|
||||
store common.Store
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) SetupTest() {
|
||||
ctx := context.TODO()
|
||||
watcher.InitWatcher(ctx)
|
||||
store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(s.T()))
|
||||
if err != nil {
|
||||
s.T().Fatalf("failed to create db connection: %s", err)
|
||||
}
|
||||
s.store = store
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TearDownTest() {
|
||||
s.store = nil
|
||||
currentWatcher := watcher.GetWatcher()
|
||||
if currentWatcher != nil {
|
||||
currentWatcher.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TestRegisterConsumerTwiceWillError() {
|
||||
consumer, err := watcher.RegisterConsumer(s.ctx, "test")
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
|
||||
consumer, err = watcher.RegisterConsumer(s.ctx, "test")
|
||||
s.Require().ErrorIs(err, common.ErrConsumerAlreadyRegistered)
|
||||
s.Require().Nil(consumer)
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TestRegisterProducerTwiceWillError() {
|
||||
producer, err := watcher.RegisterProducer(s.ctx, "test")
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(producer)
|
||||
|
||||
producer, err = watcher.RegisterProducer(s.ctx, "test")
|
||||
s.Require().ErrorIs(err, common.ErrProducerAlreadyRegistered)
|
||||
s.Require().Nil(producer)
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TestInitWatcherRanTwiceDoesNotReplaceWatcher() {
|
||||
ctx := context.TODO()
|
||||
currentWatcher := watcher.GetWatcher()
|
||||
s.Require().NotNil(currentWatcher)
|
||||
watcher.InitWatcher(ctx)
|
||||
newWatcher := watcher.GetWatcher()
|
||||
s.Require().Equal(currentWatcher, newWatcher)
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TestRegisterConsumerFailsIfWatcherIsNotInitialized() {
|
||||
s.store = nil
|
||||
currentWatcher := watcher.GetWatcher()
|
||||
currentWatcher.Close()
|
||||
|
||||
consumer, err := watcher.RegisterConsumer(s.ctx, "test")
|
||||
s.Require().Nil(consumer)
|
||||
s.Require().ErrorIs(err, common.ErrWatcherNotInitialized)
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TestRegisterProducerFailsIfWatcherIsNotInitialized() {
|
||||
s.store = nil
|
||||
currentWatcher := watcher.GetWatcher()
|
||||
currentWatcher.Close()
|
||||
|
||||
producer, err := watcher.RegisterProducer(s.ctx, "test")
|
||||
s.Require().Nil(producer)
|
||||
s.Require().ErrorIs(err, common.ErrWatcherNotInitialized)
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TestProducerAndConsumer() {
|
||||
producer, err := watcher.RegisterProducer(s.ctx, "test-producer")
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(producer)
|
||||
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "test-consumer",
|
||||
watcher.WithEntityTypeFilter(common.ControllerEntityType),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation))
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
|
||||
payload := common.ChangePayload{
|
||||
EntityType: common.ControllerEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: "test",
|
||||
}
|
||||
err = producer.Notify(payload)
|
||||
s.Require().NoError(err)
|
||||
|
||||
receivedPayload := <-consumer.Watch()
|
||||
s.Require().Equal(payload, receivedPayload)
|
||||
}
|
||||
|
||||
func (s *WatcherTestSuite) TestConsumetWithFilter() {
|
||||
producer, err := watcher.RegisterProducer(s.ctx, "test-producer")
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(producer)
|
||||
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
s.ctx, "test-consumer",
|
||||
watcher.WithEntityTypeFilter(common.ControllerEntityType),
|
||||
watcher.WithOperationTypeFilter(common.UpdateOperation))
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(consumer)
|
||||
|
||||
payload := common.ChangePayload{
|
||||
EntityType: common.ControllerEntityType,
|
||||
Operation: common.UpdateOperation,
|
||||
Payload: "test",
|
||||
}
|
||||
err = producer.Notify(payload)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case receivedPayload := <-consumer.Watch():
|
||||
s.Require().Equal(payload, receivedPayload)
|
||||
case <-time.After(1 * time.Second):
|
||||
s.T().Fatal("expected payload not received")
|
||||
}
|
||||
|
||||
payload = common.ChangePayload{
|
||||
EntityType: common.ControllerEntityType,
|
||||
Operation: common.CreateOperation,
|
||||
Payload: "test",
|
||||
}
|
||||
err = producer.Notify(payload)
|
||||
s.Require().NoError(err)
|
||||
|
||||
select {
|
||||
case <-consumer.Watch():
|
||||
s.T().Fatal("unexpected payload received")
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func maybeInitController(db common.Store) error {
|
||||
if _, err := db.ControllerInfo(); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := db.InitController(); err != nil {
|
||||
return errors.Wrap(err, "initializing controller")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestWatcherTestSuite(t *testing.T) {
|
||||
// Watcher tests
|
||||
watcherSuite := &WatcherTestSuite{
|
||||
ctx: context.TODO(),
|
||||
}
|
||||
suite.Run(t, watcherSuite)
|
||||
|
||||
ctx := context.Background()
|
||||
watcher.InitWatcher(ctx)
|
||||
|
||||
store, err := database.NewDatabase(ctx, garmTesting.GetTestSqliteDBConfig(t))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db connection: %s", err)
|
||||
}
|
||||
|
||||
err = maybeInitController(store)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init controller: %s", err)
|
||||
}
|
||||
|
||||
adminCtx := garmTesting.ImpersonateAdminContext(ctx, store, t)
|
||||
watcherStoreSuite := &WatcherStoreTestSuite{
|
||||
ctx: adminCtx,
|
||||
store: store,
|
||||
}
|
||||
suite.Run(t, watcherStoreSuite)
|
||||
}
|
||||
52
internal/testing/mock_watcher.go
Normal file
52
internal/testing/mock_watcher.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
)
|
||||
|
||||
type MockWatcher struct{}
|
||||
|
||||
func (w *MockWatcher) RegisterProducer(_ context.Context, _ string) (common.Producer, error) {
|
||||
return &MockProducer{}, nil
|
||||
}
|
||||
|
||||
func (w *MockWatcher) RegisterConsumer(_ context.Context, _ string, _ ...common.PayloadFilterFunc) (common.Consumer, error) {
|
||||
return &MockConsumer{}, nil
|
||||
}
|
||||
|
||||
func (w *MockWatcher) Close() {
|
||||
}
|
||||
|
||||
type MockProducer struct{}
|
||||
|
||||
func (p *MockProducer) Notify(_ common.ChangePayload) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *MockProducer) IsClosed() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *MockProducer) Close() {
|
||||
}
|
||||
|
||||
type MockConsumer struct{}
|
||||
|
||||
func (c *MockConsumer) Watch() <-chan common.ChangePayload {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockConsumer) SetFilters(_ ...common.PayloadFilterFunc) {
|
||||
}
|
||||
|
||||
func (c *MockConsumer) Close() {
|
||||
}
|
||||
|
||||
func (c *MockConsumer) IsClosed() bool {
|
||||
return false
|
||||
}
|
||||
|
|
@ -122,7 +122,7 @@ func CreateTestGithubCredentials(ctx context.Context, credsName string, db commo
|
|||
}
|
||||
newCreds, err := db.CreateGithubCredentials(ctx, newCredsParams)
|
||||
if err != nil {
|
||||
s.Fatalf("failed to create database object (new-creds): %v", err)
|
||||
s.Fatalf("failed to create database object (%s): %v", credsName, err)
|
||||
}
|
||||
return newCreds
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ type urls struct {
|
|||
}
|
||||
|
||||
func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
|
||||
ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", params.GithubEntityTypeRepository))
|
||||
ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", entity.EntityType))
|
||||
ghc, err := garmUtil.GithubClient(ctx, entity, cfgInternal.GithubCredentialsDetails)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting github client")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm/database"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
garmTesting "github.com/cloudbase/garm/internal/testing"
|
||||
"github.com/cloudbase/garm/params"
|
||||
"github.com/cloudbase/garm/runner/common"
|
||||
|
|
@ -51,6 +52,10 @@ type RepoTestFixtures struct {
|
|||
PoolMgrCtrlMock *runnerMocks.PoolManagerController
|
||||
}
|
||||
|
||||
func init() {
|
||||
watcher.SetWatcher(&garmTesting.MockWatcher{})
|
||||
}
|
||||
|
||||
type RepoTestSuite struct {
|
||||
suite.Suite
|
||||
Fixtures *RepoTestFixtures
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue