Slight refactor; add creds cache worker
* Split the main function into a couple of more functions * Add credentials, entity, pool and scaleset cache * add credentials worker that updates the cache Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
parent
3b3095c546
commit
1d093cc336
8 changed files with 554 additions and 101 deletions
73
cache/credentials_cache.go
vendored
Normal file
73
cache/credentials_cache.go
vendored
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
var credentialsCache *GithubCredentials
|
||||
|
||||
func init() {
|
||||
ghCredentialsCache := &GithubCredentials{
|
||||
cache: make(map[uint]params.GithubCredentials),
|
||||
}
|
||||
credentialsCache = ghCredentialsCache
|
||||
}
|
||||
|
||||
type GithubCredentials struct {
|
||||
mux sync.Mutex
|
||||
|
||||
cache map[uint]params.GithubCredentials
|
||||
}
|
||||
|
||||
func (g *GithubCredentials) SetCredentials(credentials params.GithubCredentials) {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
|
||||
g.cache[credentials.ID] = credentials
|
||||
}
|
||||
|
||||
func (g *GithubCredentials) GetCredentials(id uint) (params.GithubCredentials, bool) {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
|
||||
if creds, ok := g.cache[id]; ok {
|
||||
return creds, true
|
||||
}
|
||||
return params.GithubCredentials{}, false
|
||||
}
|
||||
|
||||
func (g *GithubCredentials) DeleteCredentials(id uint) {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
|
||||
delete(g.cache, id)
|
||||
}
|
||||
|
||||
func (g *GithubCredentials) GetAllCredentials() []params.GithubCredentials {
|
||||
g.mux.Lock()
|
||||
defer g.mux.Unlock()
|
||||
|
||||
creds := make([]params.GithubCredentials, 0, len(g.cache))
|
||||
for _, cred := range g.cache {
|
||||
creds = append(creds, cred)
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
func SetGithubCredentials(credentials params.GithubCredentials) {
|
||||
credentialsCache.SetCredentials(credentials)
|
||||
}
|
||||
|
||||
func GetGithubCredentials(id uint) (params.GithubCredentials, bool) {
|
||||
return credentialsCache.GetCredentials(id)
|
||||
}
|
||||
|
||||
func DeleteGithubCredentials(id uint) {
|
||||
credentialsCache.DeleteCredentials(id)
|
||||
}
|
||||
|
||||
func GetAllGithubCredentials() []params.GithubCredentials {
|
||||
return credentialsCache.GetAllCredentials()
|
||||
}
|
||||
189
cache/entity_cache.go
vendored
Normal file
189
cache/entity_cache.go
vendored
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
var entityCache *EntityCache
|
||||
|
||||
func init() {
|
||||
ghEntityCache := &EntityCache{
|
||||
entities: make(map[string]EntityItem),
|
||||
}
|
||||
entityCache = ghEntityCache
|
||||
}
|
||||
|
||||
type EntityItem struct {
|
||||
Entity params.GithubEntity
|
||||
Pools map[string]params.Pool
|
||||
ScaleSets map[uint]params.ScaleSet
|
||||
}
|
||||
|
||||
type EntityCache struct {
|
||||
mux sync.Mutex
|
||||
// entity IDs are UUID4s. It is highly unlikely they will collide (🤞).
|
||||
entities map[string]EntityItem
|
||||
}
|
||||
|
||||
func (e *EntityCache) GetEntity(entity params.GithubEntity) (EntityItem, bool) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entity.ID]; ok {
|
||||
// Updating specific credential details will not update entity cache which
|
||||
// uses those credentials.
|
||||
// Entity credentials in the cache are only updated if you swap the creds
|
||||
// on the entity. We get the updated credentials from the credentials cache.
|
||||
creds, ok := GetGithubCredentials(cache.Entity.Credentials.ID)
|
||||
if ok {
|
||||
cache.Entity.Credentials = creds
|
||||
}
|
||||
return cache, true
|
||||
}
|
||||
return EntityItem{}, false
|
||||
}
|
||||
|
||||
func (e *EntityCache) SetEntity(entity params.GithubEntity) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
e.entities[entity.ID] = EntityItem{
|
||||
Entity: entity,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EntityCache) ReplaceEntityPools(entityID string, pools map[string]params.Pool) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
cache.Pools = pools
|
||||
e.entities[entityID] = cache
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EntityCache) ReplaceEntityScaleSets(entityID string, scaleSets map[uint]params.ScaleSet) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
cache.ScaleSets = scaleSets
|
||||
e.entities[entityID] = cache
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EntityCache) DeleteEntity(entityID string) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
delete(e.entities, entityID)
|
||||
}
|
||||
|
||||
func (e *EntityCache) SetEntityPool(entityID string, pool params.Pool) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
cache.Pools[pool.ID] = pool
|
||||
e.entities[entityID] = cache
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EntityCache) SetEntityScaleSet(entityID string, scaleSet params.ScaleSet) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
cache.ScaleSets[scaleSet.ID] = scaleSet
|
||||
e.entities[entityID] = cache
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EntityCache) DeleteEntityPool(entityID string, poolID string) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
delete(cache.Pools, poolID)
|
||||
e.entities[entityID] = cache
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EntityCache) DeleteEntityScaleSet(entityID string, scaleSetID uint) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
delete(cache.ScaleSets, scaleSetID)
|
||||
e.entities[entityID] = cache
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EntityCache) GetEntityPool(entityID string, poolID string) (params.Pool, bool) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
if pool, ok := cache.Pools[poolID]; ok {
|
||||
return pool, true
|
||||
}
|
||||
}
|
||||
return params.Pool{}, false
|
||||
}
|
||||
|
||||
func (e *EntityCache) GetEntityScaleSet(entityID string, scaleSetID uint) (params.ScaleSet, bool) {
|
||||
e.mux.Lock()
|
||||
defer e.mux.Unlock()
|
||||
|
||||
if cache, ok := e.entities[entityID]; ok {
|
||||
if scaleSet, ok := cache.ScaleSets[scaleSetID]; ok {
|
||||
return scaleSet, true
|
||||
}
|
||||
}
|
||||
return params.ScaleSet{}, false
|
||||
}
|
||||
|
||||
func GetEntity(entity params.GithubEntity) (EntityItem, bool) {
|
||||
return entityCache.GetEntity(entity)
|
||||
}
|
||||
|
||||
func SetEntity(entity params.GithubEntity) {
|
||||
entityCache.SetEntity(entity)
|
||||
}
|
||||
|
||||
func ReplaceEntityPools(entityID string, pools map[string]params.Pool) {
|
||||
entityCache.ReplaceEntityPools(entityID, pools)
|
||||
}
|
||||
|
||||
func ReplaceEntityScaleSets(entityID string, scaleSets map[uint]params.ScaleSet) {
|
||||
entityCache.ReplaceEntityScaleSets(entityID, scaleSets)
|
||||
}
|
||||
|
||||
func DeleteEntity(entityID string) {
|
||||
entityCache.DeleteEntity(entityID)
|
||||
}
|
||||
|
||||
func SetEntityPool(entityID string, pool params.Pool) {
|
||||
entityCache.SetEntityPool(entityID, pool)
|
||||
}
|
||||
|
||||
func SetEntityScaleSet(entityID string, scaleSet params.ScaleSet) {
|
||||
entityCache.SetEntityScaleSet(entityID, scaleSet)
|
||||
}
|
||||
|
||||
func DeleteEntityPool(entityID string, poolID string) {
|
||||
entityCache.DeleteEntityPool(entityID, poolID)
|
||||
}
|
||||
|
||||
func DeleteEntityScaleSet(entityID string, scaleSetID uint) {
|
||||
entityCache.DeleteEntityScaleSet(entityID, scaleSetID)
|
||||
}
|
||||
|
||||
func GetEntityPool(entityID string, poolID string) (params.Pool, bool) {
|
||||
return entityCache.GetEntityPool(entityID, poolID)
|
||||
}
|
||||
|
||||
func GetEntityScaleSet(entityID string, scaleSetID uint) (params.ScaleSet, bool) {
|
||||
return entityCache.GetEntityScaleSet(entityID, scaleSetID)
|
||||
}
|
||||
0
cache/cache.go → cache/tools_cache.go
vendored
0
cache/cache.go → cache/tools_cache.go
vendored
225
cmd/garm/main.go
225
cmd/garm/main.go
|
|
@ -51,6 +51,7 @@ import (
|
|||
garmUtil "github.com/cloudbase/garm/util"
|
||||
"github.com/cloudbase/garm/util/appdefaults"
|
||||
"github.com/cloudbase/garm/websocket"
|
||||
"github.com/cloudbase/garm/workers/credentials"
|
||||
"github.com/cloudbase/garm/workers/entity"
|
||||
"github.com/cloudbase/garm/workers/provider"
|
||||
)
|
||||
|
|
@ -180,7 +181,127 @@ func maybeUpdateURLsFromConfig(cfg config.Config, store common.Store) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
//gocyclo:ignore
|
||||
func configureRouter(ctx context.Context, cfg config.Config, db common.Store, hub *websocket.Hub, runner *runner.Runner) (http.Handler, error) {
|
||||
authenticator := auth.NewAuthenticator(cfg.JWTAuth, db)
|
||||
controller, err := controllers.NewAPIController(runner, authenticator, hub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating controller: %w", err)
|
||||
}
|
||||
|
||||
instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating instance middleware: %w", err)
|
||||
}
|
||||
|
||||
jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating jwt middleware: %w", err)
|
||||
}
|
||||
|
||||
initMiddleware, err := auth.NewInitRequiredMiddleware(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating init required middleware: %w", err)
|
||||
}
|
||||
|
||||
urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating urls required middleware: %w", err)
|
||||
}
|
||||
|
||||
metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating metrics middleware: %w", err)
|
||||
}
|
||||
|
||||
router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement)
|
||||
|
||||
// start the metrics collector
|
||||
if cfg.Metrics.Enable {
|
||||
slog.InfoContext(ctx, "setting up metric routes")
|
||||
router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware)
|
||||
|
||||
slog.InfoContext(ctx, "register metrics")
|
||||
if err := metrics.RegisterMetrics(); err != nil {
|
||||
return nil, fmt.Errorf("registering metrics: %w", err)
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "start metrics collection")
|
||||
runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration())
|
||||
}
|
||||
|
||||
if cfg.Default.DebugServer {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
slog.InfoContext(ctx, "setting up debug routes")
|
||||
router = routers.WithDebugServer(router)
|
||||
}
|
||||
|
||||
corsMw := mux.CORSMethodMiddleware(router)
|
||||
router.Use(corsMw)
|
||||
|
||||
allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins)
|
||||
methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"})
|
||||
headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"})
|
||||
|
||||
handler := handlers.CORS(methodsOk, headersOk, allowedOrigins)(router)
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
func startWorkers(ctx context.Context, cfg config.Config, db common.Store, controllerID string) (func() error, error) {
|
||||
credsWorker, err := credentials.NewWorker(ctx, db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create credentials worker: %+v", err)
|
||||
}
|
||||
|
||||
if err := credsWorker.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start credentials worker: %+v", err)
|
||||
}
|
||||
|
||||
providers, err := providers.LoadProvidersFromConfig(ctx, cfg, controllerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading providers: %+v", err)
|
||||
}
|
||||
|
||||
entityController, err := entity.NewController(ctx, db, providers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create entity controller: %+v", err)
|
||||
}
|
||||
if err := entityController.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start entity controller: %+v", err)
|
||||
}
|
||||
|
||||
instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance token getter: %+v", err)
|
||||
}
|
||||
|
||||
providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create provider worker: %+v", err)
|
||||
}
|
||||
if err := providerWorker.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start provider worker: %+v", err)
|
||||
}
|
||||
|
||||
return func() error {
|
||||
slog.InfoContext(ctx, "shutting down credentials worker")
|
||||
if err := credsWorker.Stop(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop credentials worker")
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "shutting down entity controller")
|
||||
if err := entityController.Stop(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller")
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "shutting down provider worker")
|
||||
if err := providerWorker.Stop(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker")
|
||||
}
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
if *version {
|
||||
|
|
@ -192,7 +313,6 @@ func main() {
|
|||
watcher.InitWatcher(ctx)
|
||||
|
||||
ctx = auth.GetAdminContext(ctx)
|
||||
|
||||
cfg, err := config.NewConfig(*conf)
|
||||
if err != nil {
|
||||
log.Fatalf("Fetching config: %+v", err) //nolint:gocritic
|
||||
|
|
@ -237,30 +357,9 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
providers, err := providers.LoadProvidersFromConfig(ctx, *cfg, controllerInfo.ControllerID.String())
|
||||
stopWorkersFn, err := startWorkers(ctx, *cfg, db, controllerInfo.ControllerID.String())
|
||||
if err != nil {
|
||||
log.Fatalf("loading providers: %+v", err)
|
||||
}
|
||||
|
||||
entityController, err := entity.NewController(ctx, db, providers)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create entity controller: %+v", err)
|
||||
}
|
||||
if err := entityController.Start(); err != nil {
|
||||
log.Fatalf("failed to start entity controller: %+v", err)
|
||||
}
|
||||
|
||||
instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create instance token getter: %+v", err)
|
||||
}
|
||||
|
||||
providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create provider worker: %+v", err)
|
||||
}
|
||||
if err := providerWorker.Start(); err != nil {
|
||||
log.Fatalf("failed to start provider worker: %+v", err)
|
||||
log.Fatalf("failed to start workers: %+v", err)
|
||||
}
|
||||
|
||||
runner, err := runner.NewRunner(ctx, *cfg, db)
|
||||
|
|
@ -273,73 +372,17 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
authenticator := auth.NewAuthenticator(cfg.JWTAuth, db)
|
||||
controller, err := controllers.NewAPIController(runner, authenticator, hub)
|
||||
handler, err := configureRouter(ctx, *cfg, db, hub, runner)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create controller: %+v", err)
|
||||
log.Fatalf("failed to configure router: %+v", err)
|
||||
}
|
||||
|
||||
instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
initMiddleware, err := auth.NewInitRequiredMiddleware(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement)
|
||||
|
||||
// start the metrics collector
|
||||
if cfg.Metrics.Enable {
|
||||
slog.InfoContext(ctx, "setting up metric routes")
|
||||
router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware)
|
||||
|
||||
slog.InfoContext(ctx, "register metrics")
|
||||
if err := metrics.RegisterMetrics(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "start metrics collection")
|
||||
runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration())
|
||||
}
|
||||
|
||||
if cfg.Default.DebugServer {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
slog.InfoContext(ctx, "setting up debug routes")
|
||||
router = routers.WithDebugServer(router)
|
||||
}
|
||||
|
||||
corsMw := mux.CORSMethodMiddleware(router)
|
||||
router.Use(corsMw)
|
||||
|
||||
allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins)
|
||||
methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"})
|
||||
headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"})
|
||||
|
||||
// nolint:golangci-lint,gosec
|
||||
// G112: Potential Slowloris Attack because ReadHeaderTimeout is not configured in the http.Server
|
||||
srv := &http.Server{
|
||||
Addr: cfg.APIServer.BindAddress(),
|
||||
// Pass our instance of gorilla/mux in.
|
||||
Handler: handlers.CORS(methodsOk, headersOk, allowedOrigins)(router),
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", srv.Addr)
|
||||
|
|
@ -361,22 +404,16 @@ func main() {
|
|||
|
||||
<-ctx.Done()
|
||||
|
||||
slog.InfoContext(ctx, "shutting down entity controller")
|
||||
if err := entityController.Stop(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller")
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "shutting down provider worker")
|
||||
if err := providerWorker.Stop(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker")
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer shutdownCancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed")
|
||||
}
|
||||
|
||||
if err := stopWorkersFn(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop workers")
|
||||
}
|
||||
|
||||
slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop")
|
||||
if err := runner.Wait(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers")
|
||||
|
|
|
|||
133
workers/credentials/credentials.go
Normal file
133
workers/credentials/credentials.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
package credentials
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudbase/garm/cache"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/params"
|
||||
garmUtil "github.com/cloudbase/garm/util"
|
||||
)
|
||||
|
||||
func NewWorker(ctx context.Context, store dbCommon.Store) (*Worker, error) {
|
||||
consumerID := "credentials-worker"
|
||||
|
||||
ctx = garmUtil.WithSlogContext(
|
||||
ctx,
|
||||
slog.Any("worker", consumerID))
|
||||
|
||||
return &Worker{
|
||||
ctx: ctx,
|
||||
consumerID: consumerID,
|
||||
store: store,
|
||||
running: false,
|
||||
quit: make(chan struct{}),
|
||||
credentials: make(map[uint]params.GithubCredentials),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Worker is responsible for maintaining the credentials cache.
|
||||
type Worker struct {
|
||||
consumerID string
|
||||
ctx context.Context
|
||||
|
||||
consumer dbCommon.Consumer
|
||||
store dbCommon.Store
|
||||
|
||||
credentials map[uint]params.GithubCredentials
|
||||
|
||||
running bool
|
||||
quit chan struct{}
|
||||
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (w *Worker) loadAllCredentials() error {
|
||||
creds, err := w.store.ListGithubCredentials(w.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, cred := range creds {
|
||||
w.credentials[cred.ID] = cred
|
||||
cache.SetGithubCredentials(cred)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) Start() error {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
|
||||
if w.running {
|
||||
return nil
|
||||
}
|
||||
slog.DebugContext(w.ctx, "starting credentials worker")
|
||||
if err := w.loadAllCredentials(); err != nil {
|
||||
return fmt.Errorf("loading credentials: %w", err)
|
||||
}
|
||||
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
w.ctx, w.consumerID,
|
||||
watcher.WithEntityTypeFilter(dbCommon.GithubCredentialsEntityType),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create consumer for entity controller: %w", err)
|
||||
}
|
||||
w.consumer = consumer
|
||||
|
||||
w.running = true
|
||||
go w.loop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) Stop() error {
|
||||
w.mux.Lock()
|
||||
defer w.mux.Unlock()
|
||||
|
||||
if !w.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(w.quit)
|
||||
w.running = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) loop() {
|
||||
defer w.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.quit:
|
||||
return
|
||||
case event, ok := <-w.consumer.Watch():
|
||||
if !ok {
|
||||
slog.ErrorContext(w.ctx, "consumer channel closed")
|
||||
return
|
||||
}
|
||||
creds, ok := event.Payload.(params.GithubCredentials)
|
||||
if !ok {
|
||||
slog.ErrorContext(w.ctx, "invalid payload for entity type", "entity_type", event.EntityType, "payload", event.Payload)
|
||||
continue
|
||||
}
|
||||
w.mux.Lock()
|
||||
switch event.Operation {
|
||||
case dbCommon.DeleteOperation:
|
||||
slog.DebugContext(w.ctx, "got delete operation")
|
||||
delete(w.credentials, creds.ID)
|
||||
cache.DeleteGithubCredentials(creds.ID)
|
||||
default:
|
||||
w.credentials[creds.ID] = creds
|
||||
cache.SetGithubCredentials(creds)
|
||||
}
|
||||
w.mux.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/cloudbase/garm/auth"
|
||||
"github.com/cloudbase/garm/cache"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/runner/common"
|
||||
|
|
@ -66,6 +67,9 @@ func (c *Controller) loadAllRepositories() error {
|
|||
return fmt.Errorf("starting worker: %w", err)
|
||||
}
|
||||
c.Entities[entity.ID] = worker
|
||||
// take advantage of the fact that we're loading all entities
|
||||
// and set the cache.
|
||||
cache.SetEntity(entity)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -90,6 +94,9 @@ func (c *Controller) loadAllOrganizations() error {
|
|||
return fmt.Errorf("starting worker: %w", err)
|
||||
}
|
||||
c.Entities[entity.ID] = worker
|
||||
// take advantage of the fact that we're loading all entities
|
||||
// and set the cache.
|
||||
cache.SetEntity(entity)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -114,6 +121,9 @@ func (c *Controller) loadAllEnterprises() error {
|
|||
return fmt.Errorf("starting worker: %w", err)
|
||||
}
|
||||
c.Entities[entity.ID] = worker
|
||||
// take advantage of the fact that we're loading all entities
|
||||
// and set the cache.
|
||||
cache.SetEntity(entity)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -126,14 +136,14 @@ func (c *Controller) Start() error {
|
|||
}
|
||||
c.mux.Unlock()
|
||||
|
||||
if err := c.loadAllRepositories(); err != nil {
|
||||
return fmt.Errorf("loading repositories: %w", err)
|
||||
if err := c.loadAllEnterprises(); err != nil {
|
||||
return fmt.Errorf("loading enterprises: %w", err)
|
||||
}
|
||||
if err := c.loadAllOrganizations(); err != nil {
|
||||
return fmt.Errorf("loading organizations: %w", err)
|
||||
}
|
||||
if err := c.loadAllEnterprises(); err != nil {
|
||||
return fmt.Errorf("loading enterprises: %w", err)
|
||||
if err := c.loadAllRepositories(); err != nil {
|
||||
return fmt.Errorf("loading repositories: %w", err)
|
||||
}
|
||||
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package entity
|
|||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/cloudbase/garm/cache"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
|
@ -95,4 +96,5 @@ func (c *Controller) handleWatcherDeleteOperation(entityGetter params.EntityGett
|
|||
return
|
||||
}
|
||||
delete(c.Entities, entity.ID)
|
||||
cache.DeleteEntity(entity.ID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/cloudbase/garm/cache"
|
||||
dbCommon "github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/params"
|
||||
"github.com/cloudbase/garm/runner/common"
|
||||
|
|
@ -63,6 +64,7 @@ func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet, ghCli c
|
|||
|
||||
if _, ok := c.ScaleSets[sSet.ID]; ok {
|
||||
slog.DebugContext(c.ctx, "scale set already exists in worker list", "scale_set_id", sSet.ID)
|
||||
cache.SetEntityScaleSet(c.Entity.ID, sSet)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -88,9 +90,9 @@ func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet, ghCli c
|
|||
}
|
||||
c.ScaleSets[sSet.ID] = &scaleSet{
|
||||
scaleSet: sSet,
|
||||
// status: scaleSetStatus{},
|
||||
worker: worker,
|
||||
worker: worker,
|
||||
}
|
||||
cache.SetEntityScaleSet(c.Entity.ID, sSet)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -109,6 +111,7 @@ func (c *Controller) handleScaleSetDeleteOperation(sSet params.ScaleSet) error {
|
|||
return fmt.Errorf("stopping scale set worker: %w", err)
|
||||
}
|
||||
delete(c.ScaleSets, sSet.ID)
|
||||
cache.DeleteEntityScaleSet(c.Entity.ID, sSet.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -116,12 +119,16 @@ func (c *Controller) handleScaleSetUpdateOperation(sSet params.ScaleSet) error {
|
|||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if _, ok := c.ScaleSets[sSet.ID]; !ok {
|
||||
set, ok := c.ScaleSets[sSet.ID]
|
||||
if !ok {
|
||||
// Some error may have occurred when the scale set was first created, so we
|
||||
// attempt to create it after the user updated the scale set, hopefully
|
||||
// fixing the reason for the failure.
|
||||
return c.handleScaleSetCreateOperation(sSet, c.ghCli)
|
||||
}
|
||||
set.scaleSet = sSet
|
||||
c.ScaleSets[sSet.ID] = set
|
||||
cache.SetEntityScaleSet(c.Entity.ID, sSet)
|
||||
// We let the watcher in the scale set worker handle the update operation.
|
||||
return nil
|
||||
}
|
||||
|
|
@ -139,6 +146,7 @@ func (c *Controller) handleCredentialsEvent(event dbCommon.ChangePayload) {
|
|||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
cache.SetGithubCredentials(credentials)
|
||||
if c.Entity.Credentials.ID != credentials.ID {
|
||||
// stale update event.
|
||||
return
|
||||
|
|
@ -177,6 +185,7 @@ func (c *Controller) handleEntityEvent(event dbCommon.ChangePayload) {
|
|||
}
|
||||
}
|
||||
c.Entity = entity
|
||||
cache.SetEntity(c.Entity)
|
||||
default:
|
||||
slog.ErrorContext(c.ctx, "invalid operation type", "operation_type", event.Operation)
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue