Slight refactor; add creds cache worker

* Split the main function into a couple of more functions
* Add credentials, entity, pool and scaleset cache
* add credentials worker that updates the cache

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2025-05-05 18:21:57 +00:00
parent 3b3095c546
commit 1d093cc336
8 changed files with 554 additions and 101 deletions

View file

@ -51,6 +51,7 @@ import (
garmUtil "github.com/cloudbase/garm/util"
"github.com/cloudbase/garm/util/appdefaults"
"github.com/cloudbase/garm/websocket"
"github.com/cloudbase/garm/workers/credentials"
"github.com/cloudbase/garm/workers/entity"
"github.com/cloudbase/garm/workers/provider"
)
@ -180,7 +181,127 @@ func maybeUpdateURLsFromConfig(cfg config.Config, store common.Store) error {
return nil
}
//gocyclo:ignore
func configureRouter(ctx context.Context, cfg config.Config, db common.Store, hub *websocket.Hub, runner *runner.Runner) (http.Handler, error) {
authenticator := auth.NewAuthenticator(cfg.JWTAuth, db)
controller, err := controllers.NewAPIController(runner, authenticator, hub)
if err != nil {
return nil, fmt.Errorf("creating controller: %w", err)
}
instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth)
if err != nil {
return nil, fmt.Errorf("creating instance middleware: %w", err)
}
jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth)
if err != nil {
return nil, fmt.Errorf("creating jwt middleware: %w", err)
}
initMiddleware, err := auth.NewInitRequiredMiddleware(db)
if err != nil {
return nil, fmt.Errorf("creating init required middleware: %w", err)
}
urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db)
if err != nil {
return nil, fmt.Errorf("creating urls required middleware: %w", err)
}
metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth)
if err != nil {
return nil, fmt.Errorf("creating metrics middleware: %w", err)
}
router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement)
// start the metrics collector
if cfg.Metrics.Enable {
slog.InfoContext(ctx, "setting up metric routes")
router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware)
slog.InfoContext(ctx, "register metrics")
if err := metrics.RegisterMetrics(); err != nil {
return nil, fmt.Errorf("registering metrics: %w", err)
}
slog.InfoContext(ctx, "start metrics collection")
runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration())
}
if cfg.Default.DebugServer {
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
slog.InfoContext(ctx, "setting up debug routes")
router = routers.WithDebugServer(router)
}
corsMw := mux.CORSMethodMiddleware(router)
router.Use(corsMw)
allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins)
methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"})
headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"})
handler := handlers.CORS(methodsOk, headersOk, allowedOrigins)(router)
return handler, nil
}
func startWorkers(ctx context.Context, cfg config.Config, db common.Store, controllerID string) (func() error, error) {
credsWorker, err := credentials.NewWorker(ctx, db)
if err != nil {
return nil, fmt.Errorf("failed to create credentials worker: %+v", err)
}
if err := credsWorker.Start(); err != nil {
return nil, fmt.Errorf("failed to start credentials worker: %+v", err)
}
providers, err := providers.LoadProvidersFromConfig(ctx, cfg, controllerID)
if err != nil {
return nil, fmt.Errorf("loading providers: %+v", err)
}
entityController, err := entity.NewController(ctx, db, providers)
if err != nil {
return nil, fmt.Errorf("failed to create entity controller: %+v", err)
}
if err := entityController.Start(); err != nil {
return nil, fmt.Errorf("failed to start entity controller: %+v", err)
}
instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret)
if err != nil {
return nil, fmt.Errorf("failed to create instance token getter: %+v", err)
}
providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter)
if err != nil {
return nil, fmt.Errorf("failed to create provider worker: %+v", err)
}
if err := providerWorker.Start(); err != nil {
return nil, fmt.Errorf("failed to start provider worker: %+v", err)
}
return func() error {
slog.InfoContext(ctx, "shutting down credentials worker")
if err := credsWorker.Stop(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop credentials worker")
}
slog.InfoContext(ctx, "shutting down entity controller")
if err := entityController.Stop(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller")
}
slog.InfoContext(ctx, "shutting down provider worker")
if err := providerWorker.Stop(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker")
}
return nil
}, nil
}
func main() {
flag.Parse()
if *version {
@ -192,7 +313,6 @@ func main() {
watcher.InitWatcher(ctx)
ctx = auth.GetAdminContext(ctx)
cfg, err := config.NewConfig(*conf)
if err != nil {
log.Fatalf("Fetching config: %+v", err) //nolint:gocritic
@ -237,30 +357,9 @@ func main() {
log.Fatal(err)
}
providers, err := providers.LoadProvidersFromConfig(ctx, *cfg, controllerInfo.ControllerID.String())
stopWorkersFn, err := startWorkers(ctx, *cfg, db, controllerInfo.ControllerID.String())
if err != nil {
log.Fatalf("loading providers: %+v", err)
}
entityController, err := entity.NewController(ctx, db, providers)
if err != nil {
log.Fatalf("failed to create entity controller: %+v", err)
}
if err := entityController.Start(); err != nil {
log.Fatalf("failed to start entity controller: %+v", err)
}
instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret)
if err != nil {
log.Fatalf("failed to create instance token getter: %+v", err)
}
providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter)
if err != nil {
log.Fatalf("failed to create provider worker: %+v", err)
}
if err := providerWorker.Start(); err != nil {
log.Fatalf("failed to start provider worker: %+v", err)
log.Fatalf("failed to start workers: %+v", err)
}
runner, err := runner.NewRunner(ctx, *cfg, db)
@ -273,73 +372,17 @@ func main() {
log.Fatal(err)
}
authenticator := auth.NewAuthenticator(cfg.JWTAuth, db)
controller, err := controllers.NewAPIController(runner, authenticator, hub)
handler, err := configureRouter(ctx, *cfg, db, hub, runner)
if err != nil {
log.Fatalf("failed to create controller: %+v", err)
log.Fatalf("failed to configure router: %+v", err)
}
instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth)
if err != nil {
log.Fatal(err)
}
jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth)
if err != nil {
log.Fatal(err)
}
initMiddleware, err := auth.NewInitRequiredMiddleware(db)
if err != nil {
log.Fatal(err)
}
urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db)
if err != nil {
log.Fatal(err)
}
metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth)
if err != nil {
log.Fatal(err)
}
router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement)
// start the metrics collector
if cfg.Metrics.Enable {
slog.InfoContext(ctx, "setting up metric routes")
router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware)
slog.InfoContext(ctx, "register metrics")
if err := metrics.RegisterMetrics(); err != nil {
log.Fatal(err)
}
slog.InfoContext(ctx, "start metrics collection")
runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration())
}
if cfg.Default.DebugServer {
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
slog.InfoContext(ctx, "setting up debug routes")
router = routers.WithDebugServer(router)
}
corsMw := mux.CORSMethodMiddleware(router)
router.Use(corsMw)
allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins)
methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"})
headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"})
// nolint:golangci-lint,gosec
// G112: Potential Slowloris Attack because ReadHeaderTimeout is not configured in the http.Server
srv := &http.Server{
Addr: cfg.APIServer.BindAddress(),
// Pass our instance of gorilla/mux in.
Handler: handlers.CORS(methodsOk, headersOk, allowedOrigins)(router),
Handler: handler,
}
listener, err := net.Listen("tcp", srv.Addr)
@ -361,22 +404,16 @@ func main() {
<-ctx.Done()
slog.InfoContext(ctx, "shutting down entity controller")
if err := entityController.Stop(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller")
}
slog.InfoContext(ctx, "shutting down provider worker")
if err := providerWorker.Stop(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker")
}
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second)
defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed")
}
if err := stopWorkersFn(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop workers")
}
slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop")
if err := runner.Wait(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers")