diff --git a/cmd/garm/main.go b/cmd/garm/main.go index 15ba7069..958ea001 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -181,127 +181,7 @@ func maybeUpdateURLsFromConfig(cfg config.Config, store common.Store) error { return nil } -func configureRouter(ctx context.Context, cfg config.Config, db common.Store, hub *websocket.Hub, runner *runner.Runner) (http.Handler, error) { - authenticator := auth.NewAuthenticator(cfg.JWTAuth, db) - controller, err := controllers.NewAPIController(runner, authenticator, hub) - if err != nil { - return nil, fmt.Errorf("creating controller: %w", err) - } - - instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth) - if err != nil { - return nil, fmt.Errorf("creating instance middleware: %w", err) - } - - jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth) - if err != nil { - return nil, fmt.Errorf("creating jwt middleware: %w", err) - } - - initMiddleware, err := auth.NewInitRequiredMiddleware(db) - if err != nil { - return nil, fmt.Errorf("creating init required middleware: %w", err) - } - - urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db) - if err != nil { - return nil, fmt.Errorf("creating urls required middleware: %w", err) - } - - metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth) - if err != nil { - return nil, fmt.Errorf("creating metrics middleware: %w", err) - } - - router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement) - - // start the metrics collector - if cfg.Metrics.Enable { - slog.InfoContext(ctx, "setting up metric routes") - router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware) - - slog.InfoContext(ctx, "register metrics") - if err := metrics.RegisterMetrics(); err != nil { - return nil, fmt.Errorf("registering metrics: %w", err) - } - - slog.InfoContext(ctx, "start metrics collection") - runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration()) - } - - if cfg.Default.DebugServer { - runtime.SetBlockProfileRate(1) - runtime.SetMutexProfileFraction(1) - slog.InfoContext(ctx, "setting up debug routes") - router = routers.WithDebugServer(router) - } - - corsMw := mux.CORSMethodMiddleware(router) - router.Use(corsMw) - - allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins) - methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"}) - headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"}) - - handler := handlers.CORS(methodsOk, headersOk, allowedOrigins)(router) - return handler, nil -} - -func startWorkers(ctx context.Context, cfg config.Config, db common.Store, controllerID string) (func() error, error) { - credsWorker, err := credentials.NewWorker(ctx, db) - if err != nil { - return nil, fmt.Errorf("failed to create credentials worker: %+v", err) - } - - if err := credsWorker.Start(); err != nil { - return nil, fmt.Errorf("failed to start credentials worker: %+v", err) - } - - providers, err := providers.LoadProvidersFromConfig(ctx, cfg, controllerID) - if err != nil { - return nil, fmt.Errorf("loading providers: %+v", err) - } - - entityController, err := entity.NewController(ctx, db, providers) - if err != nil { - return nil, fmt.Errorf("failed to create entity controller: %+v", err) - } - if err := entityController.Start(); err != nil { - return nil, fmt.Errorf("failed to start entity controller: %+v", err) - } - - instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret) - if err != nil { - return nil, fmt.Errorf("failed to create instance token getter: %+v", err) - } - - providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter) - if err != nil { - return nil, fmt.Errorf("failed to create provider worker: %+v", err) - } - if err := providerWorker.Start(); err != nil { - return nil, fmt.Errorf("failed to start provider worker: %+v", err) - } - - return func() error { - slog.InfoContext(ctx, "shutting down credentials worker") - if err := credsWorker.Stop(); err != nil { - slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop credentials worker") - } - - slog.InfoContext(ctx, "shutting down entity controller") - if err := entityController.Stop(); err != nil { - slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller") - } - - slog.InfoContext(ctx, "shutting down provider worker") - if err := providerWorker.Stop(); err != nil { - slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker") - } - return nil - }, nil -} - +//gocyclo:ignore func main() { flag.Parse() if *version { @@ -313,6 +193,7 @@ func main() { watcher.InitWatcher(ctx) ctx = auth.GetAdminContext(ctx) + cfg, err := config.NewConfig(*conf) if err != nil { log.Fatalf("Fetching config: %+v", err) //nolint:gocritic @@ -357,9 +238,38 @@ func main() { log.Fatal(err) } - stopWorkersFn, err := startWorkers(ctx, *cfg, db, controllerInfo.ControllerID.String()) + credsWorker, err := credentials.NewWorker(ctx, db) if err != nil { - log.Fatalf("failed to start workers: %+v", err) + log.Fatalf("failed to create credentials worker: %+v", err) + } + if err := credsWorker.Start(); err != nil { + log.Fatalf("failed to start credentials worker: %+v", err) + } + + providers, err := providers.LoadProvidersFromConfig(ctx, *cfg, controllerInfo.ControllerID.String()) + if err != nil { + log.Fatalf("loading providers: %+v", err) + } + + entityController, err := entity.NewController(ctx, db, providers) + if err != nil { + log.Fatalf("failed to create entity controller: %+v", err) + } + if err := entityController.Start(); err != nil { + log.Fatalf("failed to start entity controller: %+v", err) + } + + instanceTokenGetter, err := auth.NewInstanceTokenGetter(cfg.JWTAuth.Secret) + if err != nil { + log.Fatalf("failed to create instance token getter: %+v", err) + } + + providerWorker, err := provider.NewWorker(ctx, db, providers, instanceTokenGetter) + if err != nil { + log.Fatalf("failed to create provider worker: %+v", err) + } + if err := providerWorker.Start(); err != nil { + log.Fatalf("failed to start provider worker: %+v", err) } runner, err := runner.NewRunner(ctx, *cfg, db) @@ -372,17 +282,73 @@ func main() { log.Fatal(err) } - handler, err := configureRouter(ctx, *cfg, db, hub, runner) + authenticator := auth.NewAuthenticator(cfg.JWTAuth, db) + controller, err := controllers.NewAPIController(runner, authenticator, hub) if err != nil { - log.Fatalf("failed to configure router: %+v", err) + log.Fatalf("failed to create controller: %+v", err) } + instanceMiddleware, err := auth.NewInstanceMiddleware(db, cfg.JWTAuth) + if err != nil { + log.Fatal(err) + } + + jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth) + if err != nil { + log.Fatal(err) + } + + initMiddleware, err := auth.NewInitRequiredMiddleware(db) + if err != nil { + log.Fatal(err) + } + + urlsRequiredMiddleware, err := auth.NewUrlsRequiredMiddleware(db) + if err != nil { + log.Fatal(err) + } + + metricsMiddleware, err := auth.NewMetricsMiddleware(cfg.JWTAuth) + if err != nil { + log.Fatal(err) + } + + router := routers.NewAPIRouter(controller, jwtMiddleware, initMiddleware, urlsRequiredMiddleware, instanceMiddleware, cfg.Default.EnableWebhookManagement) + + // start the metrics collector + if cfg.Metrics.Enable { + slog.InfoContext(ctx, "setting up metric routes") + router = routers.WithMetricsRouter(router, cfg.Metrics.DisableAuth, metricsMiddleware) + + slog.InfoContext(ctx, "register metrics") + if err := metrics.RegisterMetrics(); err != nil { + log.Fatal(err) + } + + slog.InfoContext(ctx, "start metrics collection") + runnerMetrics.CollectObjectMetric(ctx, runner, cfg.Metrics.Duration()) + } + + if cfg.Default.DebugServer { + runtime.SetBlockProfileRate(1) + runtime.SetMutexProfileFraction(1) + slog.InfoContext(ctx, "setting up debug routes") + router = routers.WithDebugServer(router) + } + + corsMw := mux.CORSMethodMiddleware(router) + router.Use(corsMw) + + allowedOrigins := handlers.AllowedOrigins(cfg.APIServer.CORSOrigins) + methodsOk := handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "OPTIONS", "DELETE"}) + headersOk := handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"}) + // nolint:golangci-lint,gosec // G112: Potential Slowloris Attack because ReadHeaderTimeout is not configured in the http.Server srv := &http.Server{ Addr: cfg.APIServer.BindAddress(), // Pass our instance of gorilla/mux in. - Handler: handler, + Handler: handlers.CORS(methodsOk, headersOk, allowedOrigins)(router), } listener, err := net.Listen("tcp", srv.Addr) @@ -404,16 +370,26 @@ func main() { <-ctx.Done() + if err := credsWorker.Stop(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop credentials worker") + } + + slog.InfoContext(ctx, "shutting down entity controller") + if err := entityController.Stop(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop entity controller") + } + + slog.InfoContext(ctx, "shutting down provider worker") + if err := providerWorker.Stop(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop provider worker") + } + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 60*time.Second) defer shutdownCancel() if err := srv.Shutdown(shutdownCtx); err != nil { slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed") } - if err := stopWorkersFn(); err != nil { - slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to stop workers") - } - slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop") if err := runner.Wait(); err != nil { slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers")