diff --git a/README.md b/README.md index 18c3fc8e..35d525f6 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,13 @@ Garm enables you to create and automatically maintain pools of [self-hosted GitH The goal of ```garm``` is to be simple to set up, simple to configure and simple to use. It is a single binary that can run on any GNU/Linux machine without any other requirements other than the providers it creates the runners in. It is intended to be easy to deploy in any environment and can create runners in any system you can write a provider for. There is no complicated setup process and no extremely complex concepts to understant. Once set up, it's meant to stay out of your way. +Garm supports creating pools on either GitHub itself or on your own deployment of [GitHub Enterprise Server](https://docs.github.com/en/enterprise-server@3.5/admin/overview/about-github-enterprise-server). For instructions on how to use ```garm``` with GHE, see the [credentials](/doc/github_credentials.md) section of the documentation. + ## Installing ## Build from source -You need to have Go install, then run: +You need to have Go installed, then run: ```bash git clone https://github.com/cloudbase/garm diff --git a/apiserver/controllers/instances.go b/apiserver/controllers/instances.go index a8d6fe2f..7ace7fc6 100644 --- a/apiserver/controllers/instances.go +++ b/apiserver/controllers/instances.go @@ -202,3 +202,17 @@ func (a *APIController) InstanceStatusMessageHandler(w http.ResponseWriter, r *h w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) } + +func (a *APIController) InstanceGithubRegistrationTokenHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + token, err := a.r.GetInstanceGithubRegistrationToken(ctx) + if err != nil { + handleError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(token)) +} diff --git a/apiserver/routers/routers.go b/apiserver/routers/routers.go index 87d7d471..c7db9aaa 100644 --- a/apiserver/routers/routers.go +++ b/apiserver/routers/routers.go @@ -18,37 +18,43 @@ import ( "io" "net/http" - gorillaHandlers "github.com/gorilla/handlers" "github.com/gorilla/mux" "garm/apiserver/controllers" "garm/auth" + "garm/util" ) func NewAPIRouter(han *controllers.APIController, logWriter io.Writer, authMiddleware, initMiddleware, instanceMiddleware auth.Middleware) *mux.Router { router := mux.NewRouter() - log := gorillaHandlers.CombinedLoggingHandler + logMiddleware := util.NewLoggingMiddleware(logWriter) + router.Use(logMiddleware) // Handles github webhooks webhookRouter := router.PathPrefix("/webhooks").Subrouter() - webhookRouter.PathPrefix("/").Handler(log(logWriter, http.HandlerFunc(han.CatchAll))) - webhookRouter.PathPrefix("").Handler(log(logWriter, http.HandlerFunc(han.CatchAll))) + webhookRouter.PathPrefix("/").Handler(http.HandlerFunc(han.CatchAll)) + webhookRouter.PathPrefix("").Handler(http.HandlerFunc(han.CatchAll)) // Handles API calls apiSubRouter := router.PathPrefix("/api/v1").Subrouter() // FirstRunHandler firstRunRouter := apiSubRouter.PathPrefix("/first-run").Subrouter() - firstRunRouter.Handle("/", log(logWriter, http.HandlerFunc(han.FirstRunHandler))).Methods("POST", "OPTIONS") + firstRunRouter.Handle("/", http.HandlerFunc(han.FirstRunHandler)).Methods("POST", "OPTIONS") - // Instance callback + // Instance URLs callbackRouter := apiSubRouter.PathPrefix("/callbacks").Subrouter() - callbackRouter.Handle("/status/", log(logWriter, http.HandlerFunc(han.InstanceStatusMessageHandler))).Methods("POST", "OPTIONS") - callbackRouter.Handle("/status", log(logWriter, http.HandlerFunc(han.InstanceStatusMessageHandler))).Methods("POST", "OPTIONS") + callbackRouter.Handle("/status/", http.HandlerFunc(han.InstanceStatusMessageHandler)).Methods("POST", "OPTIONS") + callbackRouter.Handle("/status", http.HandlerFunc(han.InstanceStatusMessageHandler)).Methods("POST", "OPTIONS") callbackRouter.Use(instanceMiddleware.Middleware) + + metadataRouter := apiSubRouter.PathPrefix("/metadata").Subrouter() + metadataRouter.Handle("/runner-registration-token/", http.HandlerFunc(han.InstanceGithubRegistrationTokenHandler)).Methods("GET", "OPTIONS") + metadataRouter.Handle("/runner-registration-token", http.HandlerFunc(han.InstanceGithubRegistrationTokenHandler)).Methods("GET", "OPTIONS") + metadataRouter.Use(instanceMiddleware.Middleware) // Login authRouter := apiSubRouter.PathPrefix("/auth").Subrouter() - authRouter.Handle("/{login:login\\/?}", log(logWriter, http.HandlerFunc(han.LoginHandler))).Methods("POST", "OPTIONS") + authRouter.Handle("/{login:login\\/?}", http.HandlerFunc(han.LoginHandler)).Methods("POST", "OPTIONS") authRouter.Use(initMiddleware.Middleware) apiRouter := apiSubRouter.PathPrefix("").Subrouter() @@ -59,158 +65,158 @@ func NewAPIRouter(han *controllers.APIController, logWriter io.Writer, authMiddl // Pools // /////////// // List all pools - apiRouter.Handle("/pools/", log(logWriter, http.HandlerFunc(han.ListAllPoolsHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/pools", log(logWriter, http.HandlerFunc(han.ListAllPoolsHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/pools/", http.HandlerFunc(han.ListAllPoolsHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/pools", http.HandlerFunc(han.ListAllPoolsHandler)).Methods("GET", "OPTIONS") // Get one pool - apiRouter.Handle("/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.GetPoolByIDHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/pools/{poolID}", log(logWriter, http.HandlerFunc(han.GetPoolByIDHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/pools/{poolID}/", http.HandlerFunc(han.GetPoolByIDHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/pools/{poolID}", http.HandlerFunc(han.GetPoolByIDHandler)).Methods("GET", "OPTIONS") // Delete one pool - apiRouter.Handle("/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.DeletePoolByIDHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/pools/{poolID}", log(logWriter, http.HandlerFunc(han.DeletePoolByIDHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/pools/{poolID}/", http.HandlerFunc(han.DeletePoolByIDHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/pools/{poolID}", http.HandlerFunc(han.DeletePoolByIDHandler)).Methods("DELETE", "OPTIONS") // Update one pool - apiRouter.Handle("/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.UpdatePoolByIDHandler))).Methods("PUT", "OPTIONS") - apiRouter.Handle("/pools/{poolID}", log(logWriter, http.HandlerFunc(han.UpdatePoolByIDHandler))).Methods("PUT", "OPTIONS") + apiRouter.Handle("/pools/{poolID}/", http.HandlerFunc(han.UpdatePoolByIDHandler)).Methods("PUT", "OPTIONS") + apiRouter.Handle("/pools/{poolID}", http.HandlerFunc(han.UpdatePoolByIDHandler)).Methods("PUT", "OPTIONS") // List pool instances - apiRouter.Handle("/pools/{poolID}/instances/", log(logWriter, http.HandlerFunc(han.ListPoolInstancesHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/pools/{poolID}/instances", log(logWriter, http.HandlerFunc(han.ListPoolInstancesHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/pools/{poolID}/instances/", http.HandlerFunc(han.ListPoolInstancesHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/pools/{poolID}/instances", http.HandlerFunc(han.ListPoolInstancesHandler)).Methods("GET", "OPTIONS") ///////////// // Runners // ///////////// // Get instance - apiRouter.Handle("/instances/{instanceName}/", log(logWriter, http.HandlerFunc(han.GetInstanceHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/instances/{instanceName}", log(logWriter, http.HandlerFunc(han.GetInstanceHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/instances/{instanceName}/", http.HandlerFunc(han.GetInstanceHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/instances/{instanceName}", http.HandlerFunc(han.GetInstanceHandler)).Methods("GET", "OPTIONS") // Delete runner - apiRouter.Handle("/instances/{instanceName}/", log(logWriter, http.HandlerFunc(han.DeleteInstanceHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/instances/{instanceName}", log(logWriter, http.HandlerFunc(han.DeleteInstanceHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/instances/{instanceName}/", http.HandlerFunc(han.DeleteInstanceHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/instances/{instanceName}", http.HandlerFunc(han.DeleteInstanceHandler)).Methods("DELETE", "OPTIONS") // List runners - apiRouter.Handle("/instances/", log(logWriter, http.HandlerFunc(han.ListAllInstancesHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/instances", log(logWriter, http.HandlerFunc(han.ListAllInstancesHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/instances/", http.HandlerFunc(han.ListAllInstancesHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/instances", http.HandlerFunc(han.ListAllInstancesHandler)).Methods("GET", "OPTIONS") ///////////////////// // Repos and pools // ///////////////////// // Get pool - apiRouter.Handle("/repositories/{repoID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.GetRepoPoolHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.GetRepoPoolHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/{poolID}/", http.HandlerFunc(han.GetRepoPoolHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/{poolID}", http.HandlerFunc(han.GetRepoPoolHandler)).Methods("GET", "OPTIONS") // Delete pool - apiRouter.Handle("/repositories/{repoID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.DeleteRepoPoolHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.DeleteRepoPoolHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/{poolID}/", http.HandlerFunc(han.DeleteRepoPoolHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/{poolID}", http.HandlerFunc(han.DeleteRepoPoolHandler)).Methods("DELETE", "OPTIONS") // Update pool - apiRouter.Handle("/repositories/{repoID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.UpdateRepoPoolHandler))).Methods("PUT", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.UpdateRepoPoolHandler))).Methods("PUT", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/{poolID}/", http.HandlerFunc(han.UpdateRepoPoolHandler)).Methods("PUT", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/{poolID}", http.HandlerFunc(han.UpdateRepoPoolHandler)).Methods("PUT", "OPTIONS") // List pools - apiRouter.Handle("/repositories/{repoID}/pools/", log(logWriter, http.HandlerFunc(han.ListRepoPoolsHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}/pools", log(logWriter, http.HandlerFunc(han.ListRepoPoolsHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/", http.HandlerFunc(han.ListRepoPoolsHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools", http.HandlerFunc(han.ListRepoPoolsHandler)).Methods("GET", "OPTIONS") // Create pool - apiRouter.Handle("/repositories/{repoID}/pools/", log(logWriter, http.HandlerFunc(han.CreateRepoPoolHandler))).Methods("POST", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}/pools", log(logWriter, http.HandlerFunc(han.CreateRepoPoolHandler))).Methods("POST", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools/", http.HandlerFunc(han.CreateRepoPoolHandler)).Methods("POST", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools", http.HandlerFunc(han.CreateRepoPoolHandler)).Methods("POST", "OPTIONS") // Repo instances list - apiRouter.Handle("/repositories/{repoID}/instances/", log(logWriter, http.HandlerFunc(han.ListRepoInstancesHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}/instances", log(logWriter, http.HandlerFunc(han.ListRepoInstancesHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/instances/", http.HandlerFunc(han.ListRepoInstancesHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/instances", http.HandlerFunc(han.ListRepoInstancesHandler)).Methods("GET", "OPTIONS") // Get repo - apiRouter.Handle("/repositories/{repoID}/", log(logWriter, http.HandlerFunc(han.GetRepoByIDHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}", log(logWriter, http.HandlerFunc(han.GetRepoByIDHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/", http.HandlerFunc(han.GetRepoByIDHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}", http.HandlerFunc(han.GetRepoByIDHandler)).Methods("GET", "OPTIONS") // Update repo - apiRouter.Handle("/repositories/{repoID}/", log(logWriter, http.HandlerFunc(han.UpdateRepoHandler))).Methods("PUT", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}", log(logWriter, http.HandlerFunc(han.UpdateRepoHandler))).Methods("PUT", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/", http.HandlerFunc(han.UpdateRepoHandler)).Methods("PUT", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}", http.HandlerFunc(han.UpdateRepoHandler)).Methods("PUT", "OPTIONS") // Delete repo - apiRouter.Handle("/repositories/{repoID}/", log(logWriter, http.HandlerFunc(han.DeleteRepoHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/repositories/{repoID}", log(logWriter, http.HandlerFunc(han.DeleteRepoHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/", http.HandlerFunc(han.DeleteRepoHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}", http.HandlerFunc(han.DeleteRepoHandler)).Methods("DELETE", "OPTIONS") // List repos - apiRouter.Handle("/repositories/", log(logWriter, http.HandlerFunc(han.ListReposHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/repositories", log(logWriter, http.HandlerFunc(han.ListReposHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/", http.HandlerFunc(han.ListReposHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories", http.HandlerFunc(han.ListReposHandler)).Methods("GET", "OPTIONS") // Create repo - apiRouter.Handle("/repositories/", log(logWriter, http.HandlerFunc(han.CreateRepoHandler))).Methods("POST", "OPTIONS") - apiRouter.Handle("/repositories", log(logWriter, http.HandlerFunc(han.CreateRepoHandler))).Methods("POST", "OPTIONS") + apiRouter.Handle("/repositories/", http.HandlerFunc(han.CreateRepoHandler)).Methods("POST", "OPTIONS") + apiRouter.Handle("/repositories", http.HandlerFunc(han.CreateRepoHandler)).Methods("POST", "OPTIONS") ///////////////////////////// // Organizations and pools // ///////////////////////////// // Get pool - apiRouter.Handle("/organizations/{orgID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.GetOrgPoolHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.GetOrgPoolHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/{poolID}/", http.HandlerFunc(han.GetOrgPoolHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/{poolID}", http.HandlerFunc(han.GetOrgPoolHandler)).Methods("GET", "OPTIONS") // Delete pool - apiRouter.Handle("/organizations/{orgID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.DeleteOrgPoolHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.DeleteOrgPoolHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/{poolID}/", http.HandlerFunc(han.DeleteOrgPoolHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/{poolID}", http.HandlerFunc(han.DeleteOrgPoolHandler)).Methods("DELETE", "OPTIONS") // Update pool - apiRouter.Handle("/organizations/{orgID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.UpdateOrgPoolHandler))).Methods("PUT", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.UpdateOrgPoolHandler))).Methods("PUT", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/{poolID}/", http.HandlerFunc(han.UpdateOrgPoolHandler)).Methods("PUT", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/{poolID}", http.HandlerFunc(han.UpdateOrgPoolHandler)).Methods("PUT", "OPTIONS") // List pools - apiRouter.Handle("/organizations/{orgID}/pools/", log(logWriter, http.HandlerFunc(han.ListOrgPoolsHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}/pools", log(logWriter, http.HandlerFunc(han.ListOrgPoolsHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/", http.HandlerFunc(han.ListOrgPoolsHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools", http.HandlerFunc(han.ListOrgPoolsHandler)).Methods("GET", "OPTIONS") // Create pool - apiRouter.Handle("/organizations/{orgID}/pools/", log(logWriter, http.HandlerFunc(han.CreateOrgPoolHandler))).Methods("POST", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}/pools", log(logWriter, http.HandlerFunc(han.CreateOrgPoolHandler))).Methods("POST", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools/", http.HandlerFunc(han.CreateOrgPoolHandler)).Methods("POST", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/pools", http.HandlerFunc(han.CreateOrgPoolHandler)).Methods("POST", "OPTIONS") // Repo instances list - apiRouter.Handle("/organizations/{orgID}/instances/", log(logWriter, http.HandlerFunc(han.ListOrgInstancesHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}/instances", log(logWriter, http.HandlerFunc(han.ListOrgInstancesHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/instances/", http.HandlerFunc(han.ListOrgInstancesHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/instances", http.HandlerFunc(han.ListOrgInstancesHandler)).Methods("GET", "OPTIONS") // Get org - apiRouter.Handle("/organizations/{orgID}/", log(logWriter, http.HandlerFunc(han.GetOrgByIDHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}", log(logWriter, http.HandlerFunc(han.GetOrgByIDHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/", http.HandlerFunc(han.GetOrgByIDHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}", http.HandlerFunc(han.GetOrgByIDHandler)).Methods("GET", "OPTIONS") // Update org - apiRouter.Handle("/organizations/{orgID}/", log(logWriter, http.HandlerFunc(han.UpdateOrgHandler))).Methods("PUT", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}", log(logWriter, http.HandlerFunc(han.UpdateOrgHandler))).Methods("PUT", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/", http.HandlerFunc(han.UpdateOrgHandler)).Methods("PUT", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}", http.HandlerFunc(han.UpdateOrgHandler)).Methods("PUT", "OPTIONS") // Delete org - apiRouter.Handle("/organizations/{orgID}/", log(logWriter, http.HandlerFunc(han.DeleteOrgHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/organizations/{orgID}", log(logWriter, http.HandlerFunc(han.DeleteOrgHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}/", http.HandlerFunc(han.DeleteOrgHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/organizations/{orgID}", http.HandlerFunc(han.DeleteOrgHandler)).Methods("DELETE", "OPTIONS") // List orgs - apiRouter.Handle("/organizations/", log(logWriter, http.HandlerFunc(han.ListOrgsHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/organizations", log(logWriter, http.HandlerFunc(han.ListOrgsHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/", http.HandlerFunc(han.ListOrgsHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations", http.HandlerFunc(han.ListOrgsHandler)).Methods("GET", "OPTIONS") // Create org - apiRouter.Handle("/organizations/", log(logWriter, http.HandlerFunc(han.CreateOrgHandler))).Methods("POST", "OPTIONS") - apiRouter.Handle("/organizations", log(logWriter, http.HandlerFunc(han.CreateOrgHandler))).Methods("POST", "OPTIONS") + apiRouter.Handle("/organizations/", http.HandlerFunc(han.CreateOrgHandler)).Methods("POST", "OPTIONS") + apiRouter.Handle("/organizations", http.HandlerFunc(han.CreateOrgHandler)).Methods("POST", "OPTIONS") ///////////////////////////// // Enterprises and pools // ///////////////////////////// // Get pool - apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.GetEnterprisePoolHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.GetEnterprisePoolHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", http.HandlerFunc(han.GetEnterprisePoolHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", http.HandlerFunc(han.GetEnterprisePoolHandler)).Methods("GET", "OPTIONS") // Delete pool - apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.DeleteEnterprisePoolHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.DeleteEnterprisePoolHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", http.HandlerFunc(han.DeleteEnterprisePoolHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", http.HandlerFunc(han.DeleteEnterprisePoolHandler)).Methods("DELETE", "OPTIONS") // Update pool - apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", log(logWriter, http.HandlerFunc(han.UpdateEnterprisePoolHandler))).Methods("PUT", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", log(logWriter, http.HandlerFunc(han.UpdateEnterprisePoolHandler))).Methods("PUT", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", http.HandlerFunc(han.UpdateEnterprisePoolHandler)).Methods("PUT", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", http.HandlerFunc(han.UpdateEnterprisePoolHandler)).Methods("PUT", "OPTIONS") // List pools - apiRouter.Handle("/enterprises/{enterpriseID}/pools/", log(logWriter, http.HandlerFunc(han.ListEnterprisePoolsHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}/pools", log(logWriter, http.HandlerFunc(han.ListEnterprisePoolsHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/", http.HandlerFunc(han.ListEnterprisePoolsHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools", http.HandlerFunc(han.ListEnterprisePoolsHandler)).Methods("GET", "OPTIONS") // Create pool - apiRouter.Handle("/enterprises/{enterpriseID}/pools/", log(logWriter, http.HandlerFunc(han.CreateEnterprisePoolHandler))).Methods("POST", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}/pools", log(logWriter, http.HandlerFunc(han.CreateEnterprisePoolHandler))).Methods("POST", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools/", http.HandlerFunc(han.CreateEnterprisePoolHandler)).Methods("POST", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/pools", http.HandlerFunc(han.CreateEnterprisePoolHandler)).Methods("POST", "OPTIONS") // Repo instances list - apiRouter.Handle("/enterprises/{enterpriseID}/instances/", log(logWriter, http.HandlerFunc(han.ListEnterpriseInstancesHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}/instances", log(logWriter, http.HandlerFunc(han.ListEnterpriseInstancesHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/instances/", http.HandlerFunc(han.ListEnterpriseInstancesHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/instances", http.HandlerFunc(han.ListEnterpriseInstancesHandler)).Methods("GET", "OPTIONS") // Get org - apiRouter.Handle("/enterprises/{enterpriseID}/", log(logWriter, http.HandlerFunc(han.GetEnterpriseByIDHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}", log(logWriter, http.HandlerFunc(han.GetEnterpriseByIDHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/", http.HandlerFunc(han.GetEnterpriseByIDHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}", http.HandlerFunc(han.GetEnterpriseByIDHandler)).Methods("GET", "OPTIONS") // Update org - apiRouter.Handle("/enterprises/{enterpriseID}/", log(logWriter, http.HandlerFunc(han.UpdateEnterpriseHandler))).Methods("PUT", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}", log(logWriter, http.HandlerFunc(han.UpdateEnterpriseHandler))).Methods("PUT", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/", http.HandlerFunc(han.UpdateEnterpriseHandler)).Methods("PUT", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}", http.HandlerFunc(han.UpdateEnterpriseHandler)).Methods("PUT", "OPTIONS") // Delete org - apiRouter.Handle("/enterprises/{enterpriseID}/", log(logWriter, http.HandlerFunc(han.DeleteEnterpriseHandler))).Methods("DELETE", "OPTIONS") - apiRouter.Handle("/enterprises/{enterpriseID}", log(logWriter, http.HandlerFunc(han.DeleteEnterpriseHandler))).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}/", http.HandlerFunc(han.DeleteEnterpriseHandler)).Methods("DELETE", "OPTIONS") + apiRouter.Handle("/enterprises/{enterpriseID}", http.HandlerFunc(han.DeleteEnterpriseHandler)).Methods("DELETE", "OPTIONS") // List orgs - apiRouter.Handle("/enterprises/", log(logWriter, http.HandlerFunc(han.ListEnterprisesHandler))).Methods("GET", "OPTIONS") - apiRouter.Handle("/enterprises", log(logWriter, http.HandlerFunc(han.ListEnterprisesHandler))).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises/", http.HandlerFunc(han.ListEnterprisesHandler)).Methods("GET", "OPTIONS") + apiRouter.Handle("/enterprises", http.HandlerFunc(han.ListEnterprisesHandler)).Methods("GET", "OPTIONS") // Create org - apiRouter.Handle("/enterprises/", log(logWriter, http.HandlerFunc(han.CreateEnterpriseHandler))).Methods("POST", "OPTIONS") - apiRouter.Handle("/enterprises", log(logWriter, http.HandlerFunc(han.CreateEnterpriseHandler))).Methods("POST", "OPTIONS") + apiRouter.Handle("/enterprises/", http.HandlerFunc(han.CreateEnterpriseHandler)).Methods("POST", "OPTIONS") + apiRouter.Handle("/enterprises", http.HandlerFunc(han.CreateEnterpriseHandler)).Methods("POST", "OPTIONS") // Credentials and providers - apiRouter.Handle("/credentials/", log(logWriter, http.HandlerFunc(han.ListCredentials))).Methods("GET", "OPTIONS") - apiRouter.Handle("/credentials", log(logWriter, http.HandlerFunc(han.ListCredentials))).Methods("GET", "OPTIONS") - apiRouter.Handle("/providers/", log(logWriter, http.HandlerFunc(han.ListProviders))).Methods("GET", "OPTIONS") - apiRouter.Handle("/providers", log(logWriter, http.HandlerFunc(han.ListProviders))).Methods("GET", "OPTIONS") + apiRouter.Handle("/credentials/", http.HandlerFunc(han.ListCredentials)).Methods("GET", "OPTIONS") + apiRouter.Handle("/credentials", http.HandlerFunc(han.ListCredentials)).Methods("GET", "OPTIONS") + apiRouter.Handle("/providers/", http.HandlerFunc(han.ListProviders)).Methods("GET", "OPTIONS") + apiRouter.Handle("/providers", http.HandlerFunc(han.ListProviders)).Methods("GET", "OPTIONS") // Websocket log writer - apiRouter.Handle("/{ws:ws\\/?}", log(logWriter, http.HandlerFunc(han.WSHandler))).Methods("GET") + apiRouter.Handle("/{ws:ws\\/?}", http.HandlerFunc(han.WSHandler)).Methods("GET") return router } diff --git a/auth/context.go b/auth/context.go index 6d86b168..ba1bf9cb 100644 --- a/auth/context.go +++ b/auth/context.go @@ -18,25 +18,11 @@ import ( "context" "garm/params" + "garm/runner/providers/common" ) type contextFlags string -/* -// InstanceJWTClaims holds JWT claims -type InstanceJWTClaims struct { - ID string `json:"id"` - Name string `json:"name"` - PoolID string `json:"provider_id"` - // Scope is either repository or organization - Scope common.PoolType `json:"scope"` - // Entity is the repo or org name - Entity string `json:"entity"` - jwt.StandardClaims -} - -*/ - const ( isAdminKey contextFlags = "is_admin" fullNameKey contextFlags = "full_name" @@ -45,11 +31,13 @@ const ( isEnabledFlag contextFlags = "is_enabled" jwtTokenFlag contextFlags = "jwt_token" - instanceIDKey contextFlags = "id" - instanceNameKey contextFlags = "name" - instancePoolIDKey contextFlags = "pool_id" - instancePoolTypeKey contextFlags = "scope" - instanceEntityKey contextFlags = "entity" + instanceIDKey contextFlags = "id" + instanceNameKey contextFlags = "name" + instancePoolIDKey contextFlags = "pool_id" + instancePoolTypeKey contextFlags = "scope" + instanceEntityKey contextFlags = "entity" + instanceRunnerStatus contextFlags = "status" + instanceTokenFetched contextFlags = "tokenFetched" ) func SetInstanceID(ctx context.Context, id string) context.Context { @@ -64,6 +52,30 @@ func InstanceID(ctx context.Context) string { return elem.(string) } +func SetInstanceTokenFetched(ctx context.Context, fetched bool) context.Context { + return context.WithValue(ctx, instanceTokenFetched, fetched) +} + +func InstanceTokenFetched(ctx context.Context) bool { + elem := ctx.Value(instanceTokenFetched) + if elem == nil { + return false + } + return elem.(bool) +} + +func SetInstanceRunnerStatus(ctx context.Context, val common.RunnerStatus) context.Context { + return context.WithValue(ctx, instanceRunnerStatus, val) +} + +func InstanceRunnerStatus(ctx context.Context) common.RunnerStatus { + elem := ctx.Value(instanceRunnerStatus) + if elem == nil { + return common.RunnerPending + } + return elem.(common.RunnerStatus) +} + func SetInstanceName(ctx context.Context, val string) context.Context { return context.WithValue(ctx, instanceNameKey, val) } @@ -116,6 +128,8 @@ func PopulateInstanceContext(ctx context.Context, instance params.Instance) cont ctx = SetInstanceID(ctx, instance.ID) ctx = SetInstanceName(ctx, instance.Name) ctx = SetInstancePoolID(ctx, instance.PoolID) + ctx = SetInstanceRunnerStatus(ctx, instance.RunnerStatus) + ctx = SetInstanceTokenFetched(ctx, instance.TokenFetched) return ctx } diff --git a/auth/instance_middleware.go b/auth/instance_middleware.go index 3229927b..c8cdc63e 100644 --- a/auth/instance_middleware.go +++ b/auth/instance_middleware.go @@ -26,6 +26,7 @@ import ( runnerErrors "garm/errors" "garm/params" "garm/runner/common" + providerCommon "garm/runner/providers/common" "github.com/golang-jwt/jwt" "github.com/pkg/errors" @@ -145,6 +146,14 @@ func (amw *instanceMiddleware) Middleware(next http.Handler) http.Handler { if InstanceID(ctx) == "" { invalidAuthResponse(w) + return + } + + runnerStatus := InstanceRunnerStatus(ctx) + if runnerStatus != providerCommon.RunnerInstalling && runnerStatus != providerCommon.RunnerPending { + // Instances that have finished installing can no longer authenticate to the API + invalidAuthResponse(w) + return } // ctx = SetJWTClaim(ctx, *claims) diff --git a/auth/jwt.go b/auth/jwt.go index e43571cd..835b5cc7 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -77,8 +77,7 @@ func invalidAuthResponse(w http.ResponseWriter) { w.Header().Add("Content-Type", "application/json") json.NewEncoder(w).Encode( apiParams.APIErrorResponse{ - Error: "Authentication failed", - Details: "Invalid authentication token", + Error: "Authentication failed", }) } diff --git a/cloudconfig/templates.go b/cloudconfig/templates.go index c68c88ef..475081d9 100644 --- a/cloudconfig/templates.go +++ b/cloudconfig/templates.go @@ -23,15 +23,22 @@ import ( var CloudConfigTemplate = `#!/bin/bash -set -ex +set -e set -o pipefail CALLBACK_URL="{{ .CallbackURL }}" +METADATA_URL="{{ .MetadataURL }}" BEARER_TOKEN="{{ .CallbackToken }}" +if [ -z "$METADATA_URL" ];then + echo "no token is available and METADATA_URL is not set" + exit 1 +fi +GITHUB_TOKEN=$(curl --fail -s -X GET -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${METADATA_URL}/runner-registration-token/") + function call() { PAYLOAD="$1" - curl -s -X POST -d "${PAYLOAD}" -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${CALLBACK_URL}" || echo "failed to call home: exit code ($?)" + curl --fail -s -X POST -d "${PAYLOAD}" -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${CALLBACK_URL}" || echo "failed to call home: exit code ($?)" } function sendStatus() { @@ -55,13 +62,11 @@ sendStatus "downloading tools from {{ .DownloadURL }}" TEMP_TOKEN="" - - if [ ! -z "{{ .TempDownloadToken }}" ]; then TEMP_TOKEN="Authorization: Bearer {{ .TempDownloadToken }}" fi -curl -L -H "${TEMP_TOKEN}" -o "/home/runner/{{ .FileName }}" "{{ .DownloadURL }}" || fail "failed to download tools" +curl -L -H "${TEMP_TOKEN}" -o "/home/{{ .RunnerUsername }}/{{ .FileName }}" "{{ .DownloadURL }}" || fail "failed to download tools" mkdir -p /home/runner/actions-runner || fail "failed to create actions-runner folder" @@ -74,7 +79,7 @@ cd /home/{{ .RunnerUsername }}/actions-runner sudo ./bin/installdependencies.sh || fail "failed to install dependencies" sendStatus "configuring runner" -sudo -u {{ .RunnerUsername }} -- ./config.sh --unattended --url "{{ .RepoURL }}" --token "{{ .GithubToken }}" --name "{{ .RunnerName }}" --labels "{{ .RunnerLabels }}" --ephemeral || fail "failed to configure runner" +sudo -u {{ .RunnerUsername }} -- ./config.sh --unattended --url "{{ .RepoURL }}" --token "$GITHUB_TOKEN" --name "{{ .RunnerName }}" --labels "{{ .RunnerLabels }}" --ephemeral || fail "failed to configure runner" sendStatus "installing runner service" ./svc.sh install {{ .RunnerUsername }} || fail "failed to install service" @@ -98,7 +103,7 @@ type InstallRunnerParams struct { RunnerUsername string RunnerGroup string RepoURL string - GithubToken string + MetadataURL string RunnerName string RunnerLabels string CallbackURL string diff --git a/config/config.go b/config/config.go index d02daea5..f6b8329d 100644 --- a/config/config.go +++ b/config/config.go @@ -20,6 +20,7 @@ import ( "fmt" "log" "net" + "net/url" "os" "path/filepath" "time" @@ -169,8 +170,12 @@ type Default struct { // ConfigDir is the folder where the runner may save any aditional files // or configurations it may need. Things like auto-generated SSH keys that // may be used to access the runner instances. - ConfigDir string `toml:"config_dir,omitempty" json:"config-dir,omitempty"` + ConfigDir string `toml:"config_dir,omitempty" json:"config-dir,omitempty"` + // CallbackURL is the URL where the instances can send back status reports. CallbackURL string `toml:"callback_url" json:"callback-url"` + // MetadataURL is the URL where instances can fetch information they may need + // to set themselves up. + MetadataURL string `toml:"metadata_url" json:"metadata-url"` // LogFile is the location of the log file. LogFile string `toml:"log_file,omitempty" json:"log-file"` EnableLogStreamer bool `toml:"enable_log_streamer"` @@ -180,6 +185,17 @@ func (d *Default) Validate() error { if d.CallbackURL == "" { return fmt.Errorf("missing callback_url") } + _, err := url.Parse(d.CallbackURL) + if err != nil { + return errors.Wrap(err, "validating callback_url") + } + + if d.MetadataURL == "" { + return fmt.Errorf("missing metadata-url") + } + if _, err := url.Parse(d.MetadataURL); err != nil { + return errors.Wrap(err, "validating metadata_url") + } if d.ConfigDir == "" { return fmt.Errorf("config_dir cannot be empty") diff --git a/config/config_test.go b/config/config_test.go index d8a473a6..d8910fc9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -15,7 +15,6 @@ package config import ( - "io/ioutil" "os" "path/filepath" "testing" @@ -33,6 +32,7 @@ func getDefaultSectionConfig(configDir string) Default { return Default{ ConfigDir: configDir, CallbackURL: "https://garm.example.com/", + MetadataURL: "https://garm.example.com/api/v1/metadata", LogFile: filepath.Join(configDir, "garm.log"), } } @@ -105,7 +105,7 @@ func getDefaultJWTCofig() JWTAuth { } func getDefaultConfig(t *testing.T) Config { - dir, err := ioutil.TempDir("", "garm-config-test") + dir, err := os.MkdirTemp("", "garm-config-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } @@ -129,7 +129,7 @@ func TestConfig(t *testing.T) { } func TestDefaultSectionConfig(t *testing.T) { - dir, err := ioutil.TempDir("", "garm-config-test") + dir, err := os.MkdirTemp("", "garm-config-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } @@ -150,14 +150,25 @@ func TestDefaultSectionConfig(t *testing.T) { name: "CallbackURL cannot be empty", cfg: Default{ CallbackURL: "", + MetadataURL: cfg.MetadataURL, ConfigDir: cfg.ConfigDir, }, errString: "missing callback_url", }, + { + name: "MetadataURL cannot be empty", + cfg: Default{ + CallbackURL: cfg.CallbackURL, + MetadataURL: "", + ConfigDir: cfg.ConfigDir, + }, + errString: "missing metadata-url", + }, { name: "ConfigDir cannot be empty", cfg: Default{ CallbackURL: cfg.CallbackURL, + MetadataURL: cfg.MetadataURL, ConfigDir: "", }, errString: "config_dir cannot be empty", @@ -166,6 +177,7 @@ func TestDefaultSectionConfig(t *testing.T) { name: "config_dir must exist and be accessible", cfg: Default{ CallbackURL: cfg.CallbackURL, + MetadataURL: cfg.MetadataURL, ConfigDir: "/i/do/not/exist", }, errString: "accessing config dir: stat /i/do/not/exist:.*", @@ -306,14 +318,14 @@ func TestAPITLSconfig(t *testing.T) { } func TestTLSConfig(t *testing.T) { - dir, err := ioutil.TempDir("", "garm-config-test") + dir, err := os.MkdirTemp("", "garm-config-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } t.Cleanup(func() { os.RemoveAll(dir) }) invalidCert := filepath.Join(dir, "invalid_cert.pem") - err = ioutil.WriteFile(invalidCert, []byte("bogus content"), 0755) + err = os.WriteFile(invalidCert, []byte("bogus content"), 0755) if err != nil { t.Fatalf("failed to write file: %s", err) } @@ -396,7 +408,7 @@ func TestTLSConfig(t *testing.T) { } func TestDatabaseConfig(t *testing.T) { - dir, err := ioutil.TempDir("", "garm-config-test") + dir, err := os.MkdirTemp("", "garm-config-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } @@ -503,7 +515,7 @@ func TestDatabaseConfig(t *testing.T) { } func TestGormParams(t *testing.T) { - dir, err := ioutil.TempDir("", "garm-config-test") + dir, err := os.MkdirTemp("", "garm-config-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } @@ -527,7 +539,7 @@ func TestGormParams(t *testing.T) { } func TestSQLiteConfig(t *testing.T) { - dir, err := ioutil.TempDir("", "garm-config-test") + dir, err := os.MkdirTemp("", "garm-config-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } @@ -676,7 +688,7 @@ func TestNewConfig(t *testing.T) { } func TestNewConfigEmptyConfigDir(t *testing.T) { - dirPath, err := ioutil.TempDir("", "garm-config-test") + dirPath, err := os.MkdirTemp("", "garm-config-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } diff --git a/config/external_test.go b/config/external_test.go index a537b570..1da36d33 100644 --- a/config/external_test.go +++ b/config/external_test.go @@ -16,7 +16,6 @@ package config import ( "fmt" - "io/ioutil" "os" "path/filepath" "testing" @@ -25,13 +24,13 @@ import ( ) func getDefaultExternalConfig(t *testing.T) External { - dir, err := ioutil.TempDir("", "garm-test") + dir, err := os.MkdirTemp("", "garm-test") if err != nil { t.Fatalf("failed to create temporary directory: %s", err) } t.Cleanup(func() { os.RemoveAll(dir) }) - err = ioutil.WriteFile(filepath.Join(dir, "garm-external-provider"), []byte{}, 0755) + err = os.WriteFile(filepath.Join(dir, "garm-external-provider"), []byte{}, 0755) if err != nil { t.Fatalf("failed to write file: %s", err) } diff --git a/config/testdata/test-empty-config-dir.toml b/config/testdata/test-empty-config-dir.toml index 30f36e7a..f86750a6 100644 --- a/config/testdata/test-empty-config-dir.toml +++ b/config/testdata/test-empty-config-dir.toml @@ -1,5 +1,6 @@ [default] callback_url = "https://garm.example.com/" + metadata_url = "https://garm.example.com/" config_dir = "" [apiserver] diff --git a/config/testdata/test-valid-config.toml b/config/testdata/test-valid-config.toml index 681ce22b..4fee2597 100644 --- a/config/testdata/test-valid-config.toml +++ b/config/testdata/test-valid-config.toml @@ -1,5 +1,6 @@ [default] callback_url = "https://garm.example.com/" + metadata_url = "https://garm.example.com/" config_dir = "./testdata" [apiserver] diff --git a/contrib/providers.d/azure/cloudconfig/install_runner.tpl b/contrib/providers.d/azure/cloudconfig/install_runner.tpl index a13f78c3..910d8eac 100644 --- a/contrib/providers.d/azure/cloudconfig/install_runner.tpl +++ b/contrib/providers.d/azure/cloudconfig/install_runner.tpl @@ -1,20 +1,28 @@ #!/bin/bash -set -ex +set -e set -o pipefail +METADATA_URL="GARM_METADATA_URL" CALLBACK_URL="GARM_CALLBACK_URL" BEARER_TOKEN="GARM_CALLBACK_TOKEN" DOWNLOAD_URL="GH_DOWNLOAD_URL" +DOWNLOAD_TOKEN="GH_TEMP_DOWNLOAD_TOKEN" FILENAME="GH_FILENAME" TARGET_URL="GH_TARGET_URL" -RUNNER_TOKEN="GH_RUNNER_TOKEN" RUNNER_NAME="GH_RUNNER_NAME" RUNNER_LABELS="GH_RUNNER_LABELS" +TEMP_TOKEN="" + + +if [ -z "$METADATA_URL" ];then + echo "no token is available and METADATA_URL is not set" + exit 1 +fi function call() { PAYLOAD="$1" - curl -s -X POST -d "${PAYLOAD}" -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${CALLBACK_URL}" || echo "failed to call home: exit code ($?)" + curl --fail -s -X POST -d "${PAYLOAD}" -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${CALLBACK_URL}" || echo "failed to call home: exit code ($?)" } function sendStatus() { @@ -34,10 +42,12 @@ function fail() { exit 1 } - +if [ ! -z "$DOWNLOAD_TOKEN" ]; then + TEMP_TOKEN="Authorization: Bearer $DOWNLOAD_TOKEN" +fi sendStatus "downloading tools from ${DOWNLOAD_URL}" -curl -L -o "/home/runner/${FILENAME}" "${DOWNLOAD_URL}" || fail "failed to download tools" +curl --fail -L -H "${TEMP_TOKEN}" -o "/home/runner/${FILENAME}" "${DOWNLOAD_URL}" || fail "failed to download tools" mkdir -p /home/runner/actions-runner || fail "failed to create actions-runner folder" @@ -49,8 +59,11 @@ sendStatus "installing dependencies" cd /home/runner/actions-runner sudo ./bin/installdependencies.sh || fail "failed to install dependencies" +sendStatus "fetching runner registration token" +GITHUB_TOKEN=$(curl --fail -s -X GET -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${METADATA_URL}" || fail "failed to get runner registration token") + sendStatus "configuring runner" -sudo -u runner -- ./config.sh --unattended --url "${TARGET_URL}" --token "${RUNNER_TOKEN}" --name "${RUNNER_NAME}" --labels "${RUNNER_LABELS}" --ephemeral || fail "failed to configure runner" +sudo -u runner -- ./config.sh --unattended --url "${TARGET_URL}" --token "${GITHUB_TOKEN}" --name "${RUNNER_NAME}" --labels "${RUNNER_LABELS}" --ephemeral || fail "failed to configure runner" sendStatus "installing runner service" ./svc.sh install runner || fail "failed to install service" @@ -65,4 +78,4 @@ if [ $? -ne 0 ];then fi set -e -success "runner successfully installed" $AGENT_ID +success "runner successfully installed" $AGENT_ID \ No newline at end of file diff --git a/contrib/providers.d/azure/garm-external-provider b/contrib/providers.d/azure/garm-external-provider index 98d6435a..1e3ebff9 100755 --- a/contrib/providers.d/azure/garm-external-provider +++ b/contrib/providers.d/azure/garm-external-provider @@ -74,6 +74,19 @@ function downloadURL() { echo "${URL}" } +function tempDownloadToken() { + [ -z "$1" -o -z "$2" ] && return 1 + GH_ARCH="${GARM_TO_GH_ARCH_MAP[$2]}" + TOKEN=$(echo "$INPUT" | jq -c -r --arg OS "$1" --arg ARCH "$GH_ARCH" '(.tools[] | select( .os == $OS and .architecture == $ARCH)).temp_download_token') + echo "${TOKEN}" +} + +function runnerTokenURL() { + METADATA_URL=$(echo "$INPUT" | jq -c -r '."metadata-url"') + checkValNotNull "${METADATA_URL}" "metadata-url" || return $? + echo "${METADATA_URL}/runner-registration-token/" +} + function downloadFilename() { [ -z "$1" -o -z "$2" ] && return 1 GH_OS="${AZURE_OS_TO_GH_OS_MAP[$1]}" @@ -113,12 +126,6 @@ function repoURL() { echo "${REPO}" } -function ghAccessToken() { - TOKEN=$(echo "$INPUT" | jq -c -r '.github_runner_access_token') - checkValNotNull "${TOKEN}" "github_runner_access_token" || return $? - echo "${TOKEN}" -} - function callbackURL() { CB_URL=$(echo "$INPUT" | jq -c -r '."callback-url"') checkValNotNull "${CB_URL}" "callback-url" || return $? @@ -177,6 +184,7 @@ function getCloudConfig() { ARCH=$(requestedArch) DW_URL=$(downloadURL "${OS_TYPE}" "${ARCH}") + DW_TOKEN=$(tempDownloadToken "${OS_TYPE}" "${ARCH}") DW_FILENAME=$(downloadFilename "${OS_TYPE}" "${ARCH}") LABELS=$(labels) @@ -190,8 +198,9 @@ function getCloudConfig() { -e "s|GH_DOWNLOAD_URL|${DW_URL}|g" \ -e "s|GH_FILENAME|${DW_FILENAME}|g" \ -e "s|GH_TARGET_URL|$(repoURL)|g" \ - -e "s|GH_RUNNER_TOKEN|$(ghAccessToken)|g" \ + -e "s|GARM_METADATA_URL|$(runnerTokenURL)|g" \ -e "s|GH_RUNNER_NAME|$(instanceName)|g" \ + -e "s|GH_TEMP_DOWNLOAD_TOKEN|${DW_TOKEN}|g" \ -e "s|GH_RUNNER_LABELS|${LABELS}|g" > ${TMP_SCRIPT} AS_B64=$(base64 -w0 ${TMP_SCRIPT}) diff --git a/contrib/providers.d/openstack/cloudconfig/install_runner.tpl b/contrib/providers.d/openstack/cloudconfig/install_runner.tpl index a13f78c3..910d8eac 100644 --- a/contrib/providers.d/openstack/cloudconfig/install_runner.tpl +++ b/contrib/providers.d/openstack/cloudconfig/install_runner.tpl @@ -1,20 +1,28 @@ #!/bin/bash -set -ex +set -e set -o pipefail +METADATA_URL="GARM_METADATA_URL" CALLBACK_URL="GARM_CALLBACK_URL" BEARER_TOKEN="GARM_CALLBACK_TOKEN" DOWNLOAD_URL="GH_DOWNLOAD_URL" +DOWNLOAD_TOKEN="GH_TEMP_DOWNLOAD_TOKEN" FILENAME="GH_FILENAME" TARGET_URL="GH_TARGET_URL" -RUNNER_TOKEN="GH_RUNNER_TOKEN" RUNNER_NAME="GH_RUNNER_NAME" RUNNER_LABELS="GH_RUNNER_LABELS" +TEMP_TOKEN="" + + +if [ -z "$METADATA_URL" ];then + echo "no token is available and METADATA_URL is not set" + exit 1 +fi function call() { PAYLOAD="$1" - curl -s -X POST -d "${PAYLOAD}" -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${CALLBACK_URL}" || echo "failed to call home: exit code ($?)" + curl --fail -s -X POST -d "${PAYLOAD}" -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${CALLBACK_URL}" || echo "failed to call home: exit code ($?)" } function sendStatus() { @@ -34,10 +42,12 @@ function fail() { exit 1 } - +if [ ! -z "$DOWNLOAD_TOKEN" ]; then + TEMP_TOKEN="Authorization: Bearer $DOWNLOAD_TOKEN" +fi sendStatus "downloading tools from ${DOWNLOAD_URL}" -curl -L -o "/home/runner/${FILENAME}" "${DOWNLOAD_URL}" || fail "failed to download tools" +curl --fail -L -H "${TEMP_TOKEN}" -o "/home/runner/${FILENAME}" "${DOWNLOAD_URL}" || fail "failed to download tools" mkdir -p /home/runner/actions-runner || fail "failed to create actions-runner folder" @@ -49,8 +59,11 @@ sendStatus "installing dependencies" cd /home/runner/actions-runner sudo ./bin/installdependencies.sh || fail "failed to install dependencies" +sendStatus "fetching runner registration token" +GITHUB_TOKEN=$(curl --fail -s -X GET -H 'Accept: application/json' -H "Authorization: Bearer ${BEARER_TOKEN}" "${METADATA_URL}" || fail "failed to get runner registration token") + sendStatus "configuring runner" -sudo -u runner -- ./config.sh --unattended --url "${TARGET_URL}" --token "${RUNNER_TOKEN}" --name "${RUNNER_NAME}" --labels "${RUNNER_LABELS}" --ephemeral || fail "failed to configure runner" +sudo -u runner -- ./config.sh --unattended --url "${TARGET_URL}" --token "${GITHUB_TOKEN}" --name "${RUNNER_NAME}" --labels "${RUNNER_LABELS}" --ephemeral || fail "failed to configure runner" sendStatus "installing runner service" ./svc.sh install runner || fail "failed to install service" @@ -65,4 +78,4 @@ if [ $? -ne 0 ];then fi set -e -success "runner successfully installed" $AGENT_ID +success "runner successfully installed" $AGENT_ID \ No newline at end of file diff --git a/contrib/providers.d/openstack/garm-external-provider b/contrib/providers.d/openstack/garm-external-provider index da0876d1..910fec9d 100755 --- a/contrib/providers.d/openstack/garm-external-provider +++ b/contrib/providers.d/openstack/garm-external-provider @@ -145,6 +145,19 @@ function downloadURL() { echo "${URL}" } +function tempDownloadToken() { + [ -z "$1" -o -z "$2" ] && return 1 + GH_ARCH="${GARM_TO_GH_ARCH_MAP[$2]}" + TOKEN=$(echo "$INPUT" | jq -c -r --arg OS "$1" --arg ARCH "$GH_ARCH" '(.tools[] | select( .os == $OS and .architecture == $ARCH)).temp_download_token') + echo "${TOKEN}" +} + +function runnerTokenURL() { + METADATA_URL=$(echo "$INPUT" | jq -c -r '."metadata-url"') + checkValNotNull "${METADATA_URL}" "metadata-url" || return $? + echo "${METADATA_URL}/runner-registration-token/" +} + function downloadFilename() { [ -z "$1" -o -z "$2" ] && return 1 GH_ARCH="${GARM_TO_GH_ARCH_MAP[$2]}" @@ -177,12 +190,6 @@ function repoURL() { echo "${REPO}" } -function ghAccessToken() { - TOKEN=$(echo "$INPUT" | jq -c -r '.github_runner_access_token') - checkValNotNull "${TOKEN}" "github_runner_access_token" || return $? - echo "${TOKEN}" -} - function callbackURL() { CB_URL=$(echo "$INPUT" | jq -c -r '."callback-url"') checkValNotNull "${CB_URL}" "callback-url" || return $? @@ -215,6 +222,7 @@ function getCloudConfig() { ARCH=$(requestedArch) DW_URL=$(downloadURL "${OS_TYPE}" "${ARCH}") + DW_TOKEN=$(tempDownloadToken "${OS_TYPE}" "${ARCH}") DW_FILENAME=$(downloadFilename "${OS_TYPE}" "${ARCH}") LABELS=$(labels) @@ -228,8 +236,9 @@ function getCloudConfig() { -e "s|GH_DOWNLOAD_URL|${DW_URL}|g" \ -e "s|GH_FILENAME|${DW_FILENAME}|g" \ -e "s|GH_TARGET_URL|$(repoURL)|g" \ - -e "s|GH_RUNNER_TOKEN|$(ghAccessToken)|g" \ + -e "s|GARM_METADATA_URL|$(runnerTokenURL)|g" \ -e "s|GH_RUNNER_NAME|$(instanceName)|g" \ + -e "s|GH_TEMP_DOWNLOAD_TOKEN|${DW_TOKEN}|g" \ -e "s|GH_RUNNER_LABELS|${LABELS}|g" > ${TMP_SCRIPT} AS_B64=$(base64 -w0 ${TMP_SCRIPT}) @@ -306,7 +315,7 @@ function CreateInstance() { if [ $? -ne 0 ];then CODE=$? # cleanup - rm -f "${CC_FILE}" || true + rm -f "${CC_FILE}" || true openstack server delete "${INSTANCE_NAME}" || true openstack volume delete "${INSTANCE_NAME}" || true set -e diff --git a/database/common/common.go b/database/common/common.go index d5955f54..1b4153d4 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -106,7 +106,8 @@ type InstanceStore interface { ListAllInstances(ctx context.Context) ([]params.Instance, error) GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error) - AddInstanceStatusMessage(ctx context.Context, instanceID string, statusMessage string) error + AddInstanceEvent(ctx context.Context, instanceID string, event params.EventType, eventLevel params.EventLevel, eventMessage string) error + ListInstanceEvents(ctx context.Context, instanceID string, eventType params.EventType, eventLevel params.EventLevel) ([]params.StatusMessage, error) } //go:generate mockery --name=Store diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index 27d7f86e..31ed2617 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.15.0. DO NOT EDIT. package mocks @@ -14,13 +14,13 @@ type Store struct { mock.Mock } -// AddInstanceStatusMessage provides a mock function with given fields: ctx, instanceID, statusMessage -func (_m *Store) AddInstanceStatusMessage(ctx context.Context, instanceID string, statusMessage string) error { - ret := _m.Called(ctx, instanceID, statusMessage) +// AddInstanceEvent provides a mock function with given fields: ctx, instanceID, event, statusMessage +func (_m *Store) AddInstanceEvent(ctx context.Context, instanceID string, event params.EventType, statusMessage string) error { + ret := _m.Called(ctx, instanceID, event, statusMessage) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, instanceID, statusMessage) + if rf, ok := ret.Get(0).(func(context.Context, string, params.EventType, string) error); ok { + r0 = rf(ctx, instanceID, event, statusMessage) } else { r0 = ret.Error(0) } diff --git a/database/sql/instances.go b/database/sql/instances.go index 34f4c720..e813b11a 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -30,6 +30,7 @@ func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param p if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") } + newInstance := Instance{ Pool: pool, Name: param.Name, @@ -38,6 +39,7 @@ func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param p OSType: param.OSType, OSArch: param.OSArch, CallbackURL: param.CallbackURL, + MetadataURL: param.MetadataURL, } q := s.conn.Create(&newInstance) if q.Error != nil { @@ -112,6 +114,7 @@ func (s *sqlDatabase) GetPoolInstanceByName(ctx context.Context, poolID string, if err != nil { return params.Instance{}, errors.Wrap(err, "fetching instance") } + return s.sqlToParamsInstance(instance), nil } @@ -120,6 +123,7 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string if err != nil { return params.Instance{}, errors.Wrap(err, "fetching instance") } + return s.sqlToParamsInstance(instance), nil } @@ -137,14 +141,42 @@ func (s *sqlDatabase) DeleteInstance(ctx context.Context, poolID string, instanc return nil } -func (s *sqlDatabase) AddInstanceStatusMessage(ctx context.Context, instanceID string, statusMessage string) error { +func (s *sqlDatabase) ListInstanceEvents(ctx context.Context, instanceID string, eventType params.EventType, eventLevel params.EventLevel) ([]params.StatusMessage, error) { + var events []InstanceStatusUpdate + query := s.conn.Model(&InstanceStatusUpdate{}).Where("instance_id = ?", instanceID) + if eventLevel != "" { + query = query.Where("event_level = ?", eventLevel) + } + + if eventType != "" { + query = query.Where("event_type = ?", eventType) + } + + if result := query.Find(&events); result.Error != nil { + return nil, errors.Wrap(result.Error, "fetching events") + } + + eventParams := make([]params.StatusMessage, len(events)) + for idx, val := range events { + eventParams[idx] = params.StatusMessage{ + Message: val.Message, + EventType: val.EventType, + EventLevel: val.EventLevel, + } + } + return eventParams, nil +} + +func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceID string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error { instance, err := s.getInstanceByID(ctx, instanceID) if err != nil { return errors.Wrap(err, "updating instance") } msg := InstanceStatusUpdate{ - Message: statusMessage, + Message: statusMessage, + EventType: event, + EventLevel: eventLevel, } if err := s.conn.Model(&instance).Association("StatusMessages").Append(&msg); err != nil { @@ -186,6 +218,10 @@ func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceID string, par instance.CreateAttempt = param.CreateAttempt } + if param.TokenFetched != nil { + instance.TokenFetched = *param.TokenFetched + } + instance.ProviderFault = param.ProviderFault q := s.conn.Save(&instance) @@ -205,17 +241,25 @@ func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceID string, par return params.Instance{}, errors.Wrap(err, "updating addresses") } } + return s.sqlToParamsInstance(instance), nil } func (s *sqlDatabase) ListPoolInstances(ctx context.Context, poolID string) ([]params.Instance, error) { - pool, err := s.getPoolByID(ctx, poolID, "Tags", "Instances") + u, err := uuid.FromString(poolID) if err != nil { - return nil, errors.Wrap(err, "fetching pool") + return nil, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } - ret := make([]params.Instance, len(pool.Instances)) - for idx, inst := range pool.Instances { + var instances []Instance + query := s.conn.Model(&Instance{}).Where("pool_id = ?", u) + + if err := query.Find(&instances); err.Error != nil { + return nil, errors.Wrap(err.Error, "fetching instances") + } + + ret := make([]params.Instance, len(instances)) + for idx, inst := range instances { ret[idx] = s.sqlToParamsInstance(inst) } return ret, nil diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 88278774..d1831b31 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -343,11 +343,11 @@ func (s *InstancesTestSuite) TestDeleteInstanceDBDeleteErr() { s.Require().Equal("deleting instance: mocked delete instance error", err.Error()) } -func (s *InstancesTestSuite) TestAddInstanceStatusMessage() { +func (s *InstancesTestSuite) TestAddInstanceEvent() { storeInstance := s.Fixtures.Instances[0] statusMsg := "test-status-message" - err := s.Store.AddInstanceStatusMessage(context.Background(), storeInstance.ID, statusMsg) + err := s.Store.AddInstanceEvent(context.Background(), storeInstance.ID, params.StatusEvent, params.EventInfo, statusMsg) s.Require().Nil(err) instance, err := s.Store.GetInstanceByName(context.Background(), storeInstance.Name) @@ -358,13 +358,13 @@ func (s *InstancesTestSuite) TestAddInstanceStatusMessage() { s.Require().Equal(statusMsg, instance.StatusMessages[0].Message) } -func (s *InstancesTestSuite) TestAddInstanceStatusMessageInvalidPoolID() { - err := s.Store.AddInstanceStatusMessage(context.Background(), "dummy-id", "dummy-message") +func (s *InstancesTestSuite) TestAddInstanceEventInvalidPoolID() { + err := s.Store.AddInstanceEvent(context.Background(), "dummy-id", params.StatusEvent, params.EventInfo, "dummy-message") s.Require().Equal("updating instance: parsing id: invalid request", err.Error()) } -func (s *InstancesTestSuite) TestAddInstanceStatusMessageDBUpdateErr() { +func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() { instance := s.Fixtures.Instances[0] statusMsg := "test-status-message" @@ -390,7 +390,7 @@ func (s *InstancesTestSuite) TestAddInstanceStatusMessageDBUpdateErr() { WillReturnError(fmt.Errorf("mocked add status message error")) s.Fixtures.SQLMock.ExpectRollback() - err := s.StoreSQLMocked.AddInstanceStatusMessage(context.Background(), instance.ID, statusMsg) + err := s.StoreSQLMocked.AddInstanceEvent(context.Background(), instance.ID, params.StatusEvent, params.EventInfo, statusMsg) s.assertSQLMockExpectations() s.Require().NotNil(err) @@ -496,7 +496,7 @@ func (s *InstancesTestSuite) TestListPoolInstances() { func (s *InstancesTestSuite) TestListPoolInstancesInvalidPoolID() { _, err := s.Store.ListPoolInstances(context.Background(), "dummy-pool-id") - s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *InstancesTestSuite) TestListAllInstances() { diff --git a/database/sql/models.go b/database/sql/models.go index 523c258b..f47045a1 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -16,6 +16,7 @@ package sql import ( "garm/config" + "garm/params" "garm/runner/providers/common" "time" @@ -118,7 +119,9 @@ type Address struct { type InstanceStatusUpdate struct { Base - Message string `gorm:"type:text"` + EventType params.EventType `gorm:"index:eventType"` + EventLevel params.EventLevel + Message string `gorm:"type:text"` InstanceID uuid.UUID Instance Instance `gorm:"foreignKey:InstanceID"` @@ -138,8 +141,10 @@ type Instance struct { Status common.InstanceStatus RunnerStatus common.RunnerStatus CallbackURL string + MetadataURL string ProviderFault []byte `gorm:"type:longblob"` CreateAttempt int + TokenFetched bool PoolID uuid.UUID Pool Pool `gorm:"foreignKey:PoolID"` diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index 2f431875..8b2af3b5 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -51,28 +51,6 @@ type OrgTestSuite struct { Fixtures *OrgTestFixtures } -func (s *OrgTestSuite) equalOrgsByName(expected, actual []params.Organization) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].Name > expected[j].Name }) - sort.Slice(actual, func(i, j int) bool { return actual[i].Name > actual[j].Name }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].Name, actual[i].Name) - } -} - -func (s *OrgTestSuite) equalPoolsByID(expected, actual []params.Pool) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].ID > expected[j].ID }) - sort.Slice(actual, func(i, j int) bool { return actual[i].ID > actual[j].ID }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].ID, actual[i].ID) - } -} - func (s *OrgTestSuite) equalInstancesByName(expected, actual []params.Instance) { s.Require().Equal(len(expected), len(actual)) @@ -276,7 +254,7 @@ func (s *OrgTestSuite) TestListOrganizations() { orgs, err := s.Store.ListOrganizations(context.Background()) s.Require().Nil(err) - s.equalOrgsByName(s.Fixtures.Orgs, orgs) + garmTesting.EqualDBEntityByName(s.T(), s.Fixtures.Orgs, orgs) } func (s *OrgTestSuite) TestListOrganizationsDBFetchErr() { @@ -689,7 +667,7 @@ func (s *OrgTestSuite) TestListOrgPools() { pools, err := s.Store.ListOrgPools(context.Background(), s.Fixtures.Orgs[0].ID) s.Require().Nil(err) - s.equalPoolsByID(orgPools, pools) + garmTesting.EqualDBEntityID(s.T(), orgPools, pools) } func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() { diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index bc8f755c..17058a5a 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -22,7 +22,6 @@ import ( garmTesting "garm/internal/testing" "garm/params" "regexp" - "sort" "testing" "github.com/stretchr/testify/suite" @@ -45,17 +44,6 @@ type PoolsTestSuite struct { Fixtures *PoolsTestFixtures } -func (s *PoolsTestSuite) equalPoolsByID(expected, actual []params.Pool) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].ID > expected[j].ID }) - sort.Slice(actual, func(i, j int) bool { return actual[i].ID > actual[j].ID }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].ID, actual[i].ID) - } -} - func (s *PoolsTestSuite) assertSQLMockExpectations() { err := s.Fixtures.SQLMock.ExpectationsWereMet() if err != nil { @@ -134,7 +122,7 @@ func (s *PoolsTestSuite) TestListAllPools() { pools, err := s.Store.ListAllPools(context.Background()) s.Require().Nil(err) - s.equalPoolsByID(s.Fixtures.Pools, pools) + garmTesting.EqualDBEntityID(s.T(), s.Fixtures.Pools, pools) } func (s *PoolsTestSuite) TestListAllPoolsDBFetchErr() { diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 9223c990..95ca90a4 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -60,17 +60,6 @@ func (s *RepoTestSuite) equalReposByName(expected, actual []params.Repository) { } } -func (s *RepoTestSuite) equalPoolsByID(expected, actual []params.Pool) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].ID > expected[j].ID }) - sort.Slice(actual, func(i, j int) bool { return actual[i].ID > actual[j].ID }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].ID, actual[i].ID) - } -} - func (s *RepoTestSuite) equalInstancesByID(expected, actual []params.Instance) { s.Require().Equal(len(expected), len(actual)) @@ -714,7 +703,7 @@ func (s *RepoTestSuite) TestListRepoPools() { pools, err := s.Store.ListRepoPools(context.Background(), s.Fixtures.Repos[0].ID) s.Require().Nil(err) - s.equalPoolsByID(repoPools, pools) + garmTesting.EqualDBEntityID(s.T(), repoPools, pools) } func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() { diff --git a/database/sql/util.go b/database/sql/util.go index e8ff2f54..5a39d5bb 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -41,9 +41,11 @@ func (s *sqlDatabase) sqlToParamsInstance(instance Instance) params.Instance { RunnerStatus: instance.RunnerStatus, PoolID: instance.PoolID.String(), CallbackURL: instance.CallbackURL, + MetadataURL: instance.MetadataURL, StatusMessages: []params.StatusMessage{}, CreateAttempt: instance.CreateAttempt, UpdatedAt: instance.UpdatedAt, + TokenFetched: instance.TokenFetched, } if len(instance.ProviderFault) > 0 { @@ -56,8 +58,10 @@ func (s *sqlDatabase) sqlToParamsInstance(instance Instance) params.Instance { for _, msg := range instance.StatusMessages { ret.StatusMessages = append(ret.StatusMessages, params.StatusMessage{ - CreatedAt: msg.CreatedAt, - Message: msg.Message, + CreatedAt: msg.CreatedAt, + Message: msg.Message, + EventType: msg.EventType, + EventLevel: msg.EventLevel, }) } return ret diff --git a/doc/running_garm.md b/doc/running_garm.md index dbfe60a8..3fa8f281 100644 --- a/doc/running_garm.md +++ b/doc/running_garm.md @@ -81,7 +81,7 @@ Global Flags: Use "garm-cli completion [command] --help" for more information about a command. ``` -## Adding a repository/organization +## Adding a repository/organization/enterprise To add a repository, we need credentials. Let's list the available credentials currently configured. These credentials are added to ```garm``` using the config file (see above), but we need to reference them by name when creating a repo. @@ -128,6 +128,22 @@ ubuntu@experiments:~$ garm-cli organization create \ +-------------+--------------------------------------+ ``` +To add an enterprise, use the following command: + +```bash +ubuntu@experiments:~$ garm-cli enterprise create \ + --credentials=gabriel \ + --name=gsamfira \ + --webhook-secret="$SECRET" ++-------------+--------------------------------------+ +| FIELD | VALUE | ++-------------+--------------------------------------+ +| ID | 0925033b-049f-4334-a460-c26f979d2356 | +| Name | gsamfira | +| Credentials | gabriel | ++-------------+--------------------------------------+ +``` + ## Creating a pool Pools are objects that define one type of worker and rules by which that pool of workers will be maintained. You can have multiple pools of different types of instances. Each pool can have different images, be on different providers and have different tags. @@ -356,18 +372,22 @@ Usage: Available Commands: completion Generate the autocompletion script for the specified shell credentials List configured credentials + debug-log Stream garm log + enterprise Manage enterprise help Help about any command init Initialize a newly installed garm - login Log into a manager organization Manage organizations pool List pools + profile Add, delete or update profiles provider Interacts with the providers API resource. repository Manage repositories runner List runners in a pool + version Print version and exit Flags: --debug Enable debug on all API calls -h, --help help for garm-cli Use "garm-cli [command] --help" for more information about a command. + ``` diff --git a/doc/webhooks_and_callbacks.md b/doc/webhooks_and_callbacks.md index a8502525..4768d553 100644 --- a/doc/webhooks_and_callbacks.md +++ b/doc/webhooks_and_callbacks.md @@ -68,7 +68,7 @@ garm-cli runner show garm-f5227755-129d-4e2d-b306-377a8f3a5dfe +-----------------+--------------------------------------------------------------------------------------------------------------------------------------------------+ ``` -This URL if set, must be accessible by the instance. If you wish to restrict access to it, a reverse proxy can be configured to accept requests only from networks in which the runners ```garm``` manages will be spun up. This URL doesn't need to be globally accessible, it just needs to be accessible by the instances. +This URL must be set and must be accessible by the instance. If you wish to restrict access to it, a reverse proxy can be configured to accept requests only from networks in which the runners ```garm``` manages will be spun up. This URL doesn't need to be globally accessible, it just needs to be accessible by the instances. For example, in a scenario where you expose the API endpoint directly, this setting could look like the following: @@ -76,6 +76,20 @@ For example, in a scenario where you expose the API endpoint directly, this sett callback_url = "https://garm.example.com/api/v1/callbacks/status" ``` -Authentication is done using a short-lived JWT token, that gets generated for a particular instance that we are spinning up. That JWT token only has access to update it's own status. No other API endpoints will work with that JWT token. The validity of the token is equal to the pool bootstrap timeout value (default 20 minutes) plus the garm polling interval (5 minutes). +Authentication is done using a short-lived JWT token, that gets generated for a particular instance that we are spinning up. That JWT token grants access to the instance to only update it's own status and to fetch metadata for itself. No other API endpoints will work with that JWT token. The validity of the token is equal to the pool bootstrap timeout value (default 20 minutes) plus the garm polling interval (5 minutes). -There is a sample ```nginx``` config [in the testdata folder](/testdata/nginx-server.conf). Feel free to customize it whichever way you see fit. \ No newline at end of file +There is a sample ```nginx``` config [in the testdata folder](/testdata/nginx-server.conf). Feel free to customize it whichever way you see fit. + +## The metadata_url option + +The metadata URL is the base URL for any information an instance may need to fetch in order to finish setting itself up. As this URL may be placed behind a reverse proxy, you'll need to configure it in the ```garm``` config file. Ultimately this URL will need to point to the following ```garm``` API endpoint: + +```bash +GET /api/v1/metadata +``` + +This URL needs to be accessible only by the instances ```garm``` sets up. This URL will not be used by anyone else. To configure it in ```garm``` add the following line in the ```[default]``` section of your ```garm``` config: + +```toml +metadata_url = "https://garm.example.com/api/v1/metadata" +``` diff --git a/internal/testing/testing.go b/internal/testing/testing.go index 7ae23b2d..3fd0e458 100644 --- a/internal/testing/testing.go +++ b/internal/testing/testing.go @@ -21,7 +21,10 @@ import ( "garm/config" "os" "path/filepath" + "sort" "testing" + + "github.com/stretchr/testify/require" ) var ( @@ -44,3 +47,42 @@ func GetTestSqliteDBConfig(t *testing.T) config.Database { }, } } + +type IDDBEntity interface { + GetID() string +} + +type NameAndIDDBEntity interface { + IDDBEntity + GetName() string +} + +func EqualDBEntityByName[T NameAndIDDBEntity](t *testing.T, expected, actual []T) { + require.Equal(t, len(expected), len(actual)) + + sort.Slice(expected, func(i, j int) bool { return expected[i].GetName() > expected[j].GetName() }) + sort.Slice(actual, func(i, j int) bool { return actual[i].GetName() > actual[j].GetName() }) + + for i := 0; i < len(expected); i++ { + require.Equal(t, expected[i].GetName(), actual[i].GetName()) + } +} + +func EqualDBEntityID[T IDDBEntity](t *testing.T, expected, actual []T) { + require.Equal(t, len(expected), len(actual)) + + sort.Slice(expected, func(i, j int) bool { return expected[i].GetID() > expected[j].GetID() }) + sort.Slice(actual, func(i, j int) bool { return actual[i].GetID() > actual[j].GetID() }) + + for i := 0; i < len(expected); i++ { + require.Equal(t, expected[i].GetID(), actual[i].GetID()) + } +} + +func DBEntityMapToSlice[T NameAndIDDBEntity](orgs map[string]T) []T { + orgsSlice := []T{} + for _, value := range orgs { + orgsSlice = append(orgsSlice, value) + } + return orgsSlice +} diff --git a/params/params.go b/params/params.go index b0a07450..decd041f 100644 --- a/params/params.go +++ b/params/params.go @@ -24,20 +24,35 @@ import ( ) type AddressType string +type EventType string +type EventLevel string const ( PublicAddress AddressType = "public" PrivateAddress AddressType = "private" ) +const ( + StatusEvent EventType = "status" + FetchTokenEvent EventType = "fetchToken" +) + +const ( + EventInfo EventLevel = "info" + EventWarning EventLevel = "warning" + EventError EventLevel = "error" +) + type Address struct { Address string `json:"address"` Type AddressType `json:"type"` } type StatusMessage struct { - CreatedAt time.Time `json:"created_at"` - Message string `json:"message"` + CreatedAt time.Time `json:"created_at"` + Message string `json:"message"` + EventType EventType `json:"event_type"` + EventLevel EventLevel `json:"event_level"` } type Instance struct { @@ -73,11 +88,21 @@ type Instance struct { ProviderFault []byte `json:"provider_fault,omitempty"` StatusMessages []StatusMessage `json:"status_messages,omitempty"` + UpdatedAt time.Time `json:"updated_at"` // Do not serialize sensitive info. - CallbackURL string `json:"-"` - CreateAttempt int `json:"-"` - UpdatedAt time.Time `json:"updated_at"` + CallbackURL string `json:"-"` + MetadataURL string `json:"-"` + CreateAttempt int `json:"-"` + TokenFetched bool `json:"-"` +} + +func (i Instance) GetName() string { + return i.Name +} + +func (i Instance) GetID() string { + return i.ID } type BootstrapInstance struct { @@ -85,12 +110,11 @@ type BootstrapInstance struct { Tools []*github.RunnerApplicationDownload `json:"tools"` // RepoURL is the URL the github runner agent needs to configure itself. RepoURL string `json:"repo_url"` - // GithubRunnerAccessToken is the token we fetch from github to allow the runner to - // register itself. - GithubRunnerAccessToken string `json:"github_runner_access_token"` // CallbackUrl is the URL where the instance can send a post, signaling // progress or status. CallbackURL string `json:"callback-url"` + // MetadataURL is the URL where instances can fetch information needed to set themselves up. + MetadataURL string `json:"metadata-url"` // InstanceToken is the token that needs to be set by the instance in the headers // in order to send updated back to the garm via CallbackURL. InstanceToken string `json:"instance-token"` @@ -133,6 +157,10 @@ type Pool struct { RunnerBootstrapTimeout uint `json:"runner_bootstrap_timeout"` } +func (p Pool) GetID() string { + return p.ID +} + func (p *Pool) RunnerTimeout() uint { if p.RunnerBootstrapTimeout == 0 { return config.DefaultRunnerBootstrapTimeout @@ -144,6 +172,7 @@ type Internal struct { OAuth2Token string `json:"oauth2"` ControllerID string `json:"controller_id"` InstanceCallbackURL string `json:"instance_callback_url"` + InstanceMetadataURL string `json:"instance_metadata_url"` JWTSecret string `json:"jwt_secret"` // GithubCredentialsDetails contains all info about the credentials, except the // token, which is added above. @@ -161,6 +190,14 @@ type Repository struct { WebhookSecret string `json:"-"` } +func (r Repository) GetName() string { + return r.Name +} + +func (r Repository) GetID() string { + return r.ID +} + type Organization struct { ID string `json:"id"` Name string `json:"name"` @@ -171,6 +208,14 @@ type Organization struct { WebhookSecret string `json:"-"` } +func (o Organization) GetName() string { + return o.Name +} + +func (o Organization) GetID() string { + return o.ID +} + type Enterprise struct { ID string `json:"id"` Name string `json:"name"` @@ -181,6 +226,14 @@ type Enterprise struct { WebhookSecret string `json:"-"` } +func (e Enterprise) GetName() string { + return e.Name +} + +func (e Enterprise) GetID() string { + return e.ID +} + // Users holds information about a particular user type User struct { ID string `json:"id"` @@ -227,3 +280,8 @@ type PoolManagerStatus struct { IsRunning bool `json:"running"` FailureReason string `json:"failure_reason,omitempty"` } + +type RunnerInfo struct { + Name string + Labels []string +} diff --git a/params/requests.go b/params/requests.go index e2ded969..a383badd 100644 --- a/params/requests.go +++ b/params/requests.go @@ -74,7 +74,7 @@ type CreateEnterpriseParams struct { func (c *CreateEnterpriseParams) Validate() error { if c.Name == "" { - return errors.NewBadRequestError("missing org name") + return errors.NewBadRequestError("missing enterprise name") } if c.CredentialsName == "" { @@ -113,6 +113,7 @@ type CreateInstanceParams struct { Status common.InstanceStatus RunnerStatus common.RunnerStatus CallbackURL string + MetadataURL string CreateAttempt int `json:"-"` } @@ -172,6 +173,7 @@ type UpdateInstanceParams struct { ProviderFault []byte `json:"provider_fault,omitempty"` AgentID int64 `json:"-"` CreateAttempt int `json:"-"` + TokenFetched *bool `json:"-"` } type UpdateUserParams struct { diff --git a/runner/common/mocks/GithubClient.go b/runner/common/mocks/GithubClient.go index b8db134a..a5b08140 100644 --- a/runner/common/mocks/GithubClient.go +++ b/runner/common/mocks/GithubClient.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.15.0. DO NOT EDIT. package mocks diff --git a/runner/common/mocks/GithubEnterpriseClient.go b/runner/common/mocks/GithubEnterpriseClient.go index 741d139c..e9c3acc9 100644 --- a/runner/common/mocks/GithubEnterpriseClient.go +++ b/runner/common/mocks/GithubEnterpriseClient.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.15.0. DO NOT EDIT. package mocks diff --git a/runner/common/mocks/PoolManager.go b/runner/common/mocks/PoolManager.go index 4d55c339..9f327279 100644 --- a/runner/common/mocks/PoolManager.go +++ b/runner/common/mocks/PoolManager.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.15.0. DO NOT EDIT. package mocks @@ -27,6 +27,27 @@ func (_m *PoolManager) ForceDeleteRunner(runner params.Instance) error { return r0 } +// GithubRunnerRegistrationToken provides a mock function with given fields: +func (_m *PoolManager) GithubRunnerRegistrationToken() (string, error) { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // HandleWorkflowJob provides a mock function with given fields: job func (_m *PoolManager) HandleWorkflowJob(job params.WorkflowJob) error { ret := _m.Called(job) diff --git a/runner/common/mocks/Provider.go b/runner/common/mocks/Provider.go index fa6b5f2b..35206df6 100644 --- a/runner/common/mocks/Provider.go +++ b/runner/common/mocks/Provider.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.15.0. DO NOT EDIT. package mocks diff --git a/runner/common/pool.go b/runner/common/pool.go index 113125e9..a075eb61 100644 --- a/runner/common/pool.go +++ b/runner/common/pool.go @@ -43,6 +43,7 @@ const ( type PoolManager interface { ID() string WebhookSecret() string + GithubRunnerRegistrationToken() (string, error) HandleWorkflowJob(job params.WorkflowJob) error RefreshState(param params.UpdatePoolStateParams) error ForceDeleteRunner(runner params.Instance) error diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go new file mode 100644 index 00000000..86591e77 --- /dev/null +++ b/runner/enterprises_test.go @@ -0,0 +1,548 @@ +// Copyright 2022 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package runner + +import ( + "context" + "fmt" + "garm/auth" + "garm/config" + "garm/database" + dbCommon "garm/database/common" + runnerErrors "garm/errors" + garmTesting "garm/internal/testing" + "garm/params" + "garm/runner/common" + runnerCommonMocks "garm/runner/common/mocks" + runnerMocks "garm/runner/mocks" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +type EnterpriseTestFixtures struct { + AdminContext context.Context + DBFile string + Store dbCommon.Store + StoreEnterprises map[string]params.Enterprise + Providers map[string]common.Provider + Credentials map[string]config.Github + CreateEnterpriseParams params.CreateEnterpriseParams + CreatePoolParams params.CreatePoolParams + CreateInstanceParams params.CreateInstanceParams + UpdateRepoParams params.UpdateRepositoryParams + UpdatePoolParams params.UpdatePoolParams + UpdatePoolStateParams params.UpdatePoolStateParams + ErrMock error + ProviderMock *runnerCommonMocks.Provider + PoolMgrMock *runnerCommonMocks.PoolManager + PoolMgrCtrlMock *runnerMocks.PoolManagerController +} + +type EnterpriseTestSuite struct { + suite.Suite + Fixtures *EnterpriseTestFixtures + Runner *Runner +} + +func (s *EnterpriseTestSuite) SetupTest() { + adminCtx := auth.GetAdminContext() + + // create testing sqlite database + dbCfg := garmTesting.GetTestSqliteDBConfig(s.T()) + db, err := database.NewDatabase(adminCtx, dbCfg) + if err != nil { + s.FailNow(fmt.Sprintf("failed to create db connection: %s", err)) + } + + // create some organization objects in the database, for testing purposes + enterprises := map[string]params.Enterprise{} + for i := 1; i <= 3; i++ { + name := fmt.Sprintf("test-enterprise-%v", i) + enterprise, err := db.CreateEnterprise( + adminCtx, + name, + fmt.Sprintf("test-creds-%v", i), + fmt.Sprintf("test-webhook-secret-%v", i), + ) + if err != nil { + s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%v)", i)) + } + enterprises[name] = enterprise + } + + // setup test fixtures + var maxRunners uint = 40 + var minIdleRunners uint = 20 + providerMock := runnerCommonMocks.NewProvider(s.T()) + fixtures := &EnterpriseTestFixtures{ + AdminContext: adminCtx, + DBFile: dbCfg.SQLite.DBFile, + Store: db, + StoreEnterprises: enterprises, + Providers: map[string]common.Provider{ + "test-provider": providerMock, + }, + Credentials: map[string]config.Github{ + "test-creds": { + Name: "test-creds-name", + Description: "test-creds-description", + OAuth2Token: "test-creds-oauth2-token", + }, + }, + CreateEnterpriseParams: params.CreateEnterpriseParams{ + Name: "test-enterprise-create", + CredentialsName: "test-creds", + }, + CreatePoolParams: params.CreatePoolParams{ + ProviderName: "test-provider", + MaxRunners: 4, + MinIdleRunners: 2, + Image: "test", + Flavor: "test", + OSType: "linux", + OSArch: "arm64", + Tags: []string{"self-hosted", "arm64", "linux"}, + RunnerBootstrapTimeout: 0, + }, + CreateInstanceParams: params.CreateInstanceParams{ + Name: "test-instance-name", + OSType: "linux", + }, + UpdateRepoParams: params.UpdateRepositoryParams{ + CredentialsName: "test-creds", + WebhookSecret: "test-update-repo-webhook-secret", + }, + UpdatePoolParams: params.UpdatePoolParams{ + MaxRunners: &maxRunners, + MinIdleRunners: &minIdleRunners, + Image: "test-images-updated", + Flavor: "test-flavor-updated", + }, + UpdatePoolStateParams: params.UpdatePoolStateParams{ + WebhookSecret: "test-update-repo-webhook-secret", + }, + ErrMock: fmt.Errorf("mock error"), + ProviderMock: providerMock, + PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()), + PoolMgrCtrlMock: runnerMocks.NewPoolManagerController(s.T()), + } + s.Fixtures = fixtures + + // setup test runner + runner := &Runner{ + providers: fixtures.Providers, + credentials: fixtures.Credentials, + ctx: fixtures.AdminContext, + store: fixtures.Store, + poolManagerCtrl: fixtures.PoolMgrCtrlMock, + } + s.Runner = runner +} + +func (s *EnterpriseTestSuite) TestCreateEnterprise() { + // setup mocks expectations + s.Fixtures.PoolMgrMock.On("Start").Return(nil) + s.Fixtures.PoolMgrCtrlMock.On("CreateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise"), s.Fixtures.Providers, s.Fixtures.Store).Return(s.Fixtures.PoolMgrMock, nil) + + // call tested function + enterprise, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams) + + // assertions + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Nil(err) + s.Require().Equal(s.Fixtures.CreateEnterpriseParams.Name, enterprise.Name) + s.Require().Equal(s.Fixtures.Credentials[s.Fixtures.CreateEnterpriseParams.CredentialsName].Name, enterprise.CredentialsName) +} + +func (s *EnterpriseTestSuite) TestCreateEnterpriseErrUnauthorized() { + _, err := s.Runner.CreateEnterprise(context.Background(), s.Fixtures.CreateEnterpriseParams) + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestCreateEnterpriseEmptyParams() { + _, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, params.CreateEnterpriseParams{}) + + s.Require().Regexp("validating params: missing enterprise name", err.Error()) +} + +func (s *EnterpriseTestSuite) TestCreateEnterpriseMissingCredentials() { + s.Fixtures.CreateEnterpriseParams.CredentialsName = "not-existent-creds-name" + + _, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams) + + s.Require().Equal(runnerErrors.NewBadRequestError("credentials %s not defined", s.Fixtures.CreateEnterpriseParams.CredentialsName), err) +} + +func (s *EnterpriseTestSuite) TestCreateEnterpriseAlreadyExists() { + s.Fixtures.CreateEnterpriseParams.Name = "test-enterprise-1" // this is already created in `SetupTest()` + + _, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams) + + s.Require().Equal(runnerErrors.NewConflictError("enterprise %s already exists", s.Fixtures.CreateEnterpriseParams.Name), err) +} + +func (s *EnterpriseTestSuite) TestCreateEnterprisePoolMgrFailed() { + s.Fixtures.PoolMgrCtrlMock.On("CreateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise"), s.Fixtures.Providers, s.Fixtures.Store).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + + _, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Equal(fmt.Sprintf("creating enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) +} + +func (s *EnterpriseTestSuite) TestCreateEnterpriseStartPoolMgrFailed() { + s.Fixtures.PoolMgrMock.On("Start").Return(s.Fixtures.ErrMock) + s.Fixtures.PoolMgrCtrlMock.On("CreateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise"), s.Fixtures.Providers, s.Fixtures.Store).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("DeleteEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.ErrMock) + + _, err := s.Runner.CreateEnterprise(s.Fixtures.AdminContext, s.Fixtures.CreateEnterpriseParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Equal(fmt.Sprintf("starting enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) +} + +func (s *EnterpriseTestSuite) TestListEnterprises() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil) + orgs, err := s.Runner.ListEnterprises(s.Fixtures.AdminContext) + + s.Require().Nil(err) + garmTesting.EqualDBEntityByName(s.T(), garmTesting.DBEntityMapToSlice(s.Fixtures.StoreEnterprises), orgs) +} + +func (s *EnterpriseTestSuite) TestListEnterprisesErrUnauthorized() { + _, err := s.Runner.ListEnterprises(context.Background()) + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestGetEnterpriseByID() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil) + org, err := s.Runner.GetEnterpriseByID(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID) + + s.Require().Nil(err) + s.Require().Equal(s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, org.ID) +} + +func (s *EnterpriseTestSuite) TestGetEnterpriseByIDErrUnauthorized() { + _, err := s.Runner.GetEnterpriseByID(context.Background(), "dummy-org-id") + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestDeleteEnterprise() { + s.Fixtures.PoolMgrCtrlMock.On("DeleteEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(nil) + + err := s.Runner.DeleteEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-3"].ID) + + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Nil(err) + + _, err = s.Fixtures.Store.GetEnterpriseByID(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-3"].ID) + s.Require().Equal("fetching enterprise: not found", err.Error()) +} + +func (s *EnterpriseTestSuite) TestDeleteEnterpriseErrUnauthorized() { + err := s.Runner.DeleteEnterprise(context.Background(), "dummy-enterprise-id") + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDefinedFailed() { + pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create store enterprises pool: %v", err)) + } + + err = s.Runner.DeleteEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID) + + s.Require().Equal(runnerErrors.NewBadRequestError("enterprise has pools defined (%s)", pool.ID), err) +} + +func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolMgrFailed() { + s.Fixtures.PoolMgrCtrlMock.On("DeleteEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.ErrMock) + + err := s.Runner.DeleteEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID) + + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Equal(fmt.Sprintf("deleting enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterprise() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("CreateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise"), s.Fixtures.Providers, s.Fixtures.Store).Return(s.Fixtures.PoolMgrMock, nil) + + org, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Nil(err) + s.Require().Equal(s.Fixtures.UpdateRepoParams.CredentialsName, org.CredentialsName) + s.Require().Equal(s.Fixtures.UpdateRepoParams.WebhookSecret, org.WebhookSecret) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterpriseErrUnauthorized() { + _, err := s.Runner.UpdateEnterprise(context.Background(), "dummy-enterprise-id", s.Fixtures.UpdateRepoParams) + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidCreds() { + s.Fixtures.UpdateRepoParams.CredentialsName = "invalid-creds-name" + + _, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) + + s.Require().Equal(runnerErrors.NewBadRequestError("invalid credentials (%s) for enterprise %s", s.Fixtures.UpdateRepoParams.CredentialsName, s.Fixtures.StoreEnterprises["test-enterprise-1"].Name), err) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + s.Fixtures.PoolMgrMock.On("RefreshState", s.Fixtures.UpdatePoolStateParams).Return(s.Fixtures.ErrMock) + + _, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Equal(fmt.Sprintf("updating enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterpriseCreateEnterprisePoolMgrFailed() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("CreateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise"), s.Fixtures.Providers, s.Fixtures.Store).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + + _, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Equal(fmt.Sprintf("creating enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) +} + +func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + + pool, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Nil(err) + + enterprise, err := s.Fixtures.Store.GetEnterpriseByID(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID) + if err != nil { + s.FailNow(fmt.Sprintf("cannot get enterprise by ID: %v", err)) + } + s.Require().Equal(1, len(enterprise.Pools)) + s.Require().Equal(pool.ID, enterprise.Pools[0].ID) + s.Require().Equal(s.Fixtures.CreatePoolParams.ProviderName, enterprise.Pools[0].ProviderName) + s.Require().Equal(s.Fixtures.CreatePoolParams.MaxRunners, enterprise.Pools[0].MaxRunners) + s.Require().Equal(s.Fixtures.CreatePoolParams.MinIdleRunners, enterprise.Pools[0].MinIdleRunners) +} + +func (s *EnterpriseTestSuite) TestCreateEnterprisePoolErrUnauthorized() { + _, err := s.Runner.CreateEnterprisePool(context.Background(), "dummy-enterprise-id", s.Fixtures.CreatePoolParams) + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestCreateEnterprisePoolErrNotFound() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, runnerErrors.ErrNotFound) + + _, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Equal(runnerErrors.ErrNotFound, err) +} + +func (s *EnterpriseTestSuite) TestCreateEnterprisePoolFetchPoolParamsFailed() { + s.Fixtures.CreatePoolParams.ProviderName = "not-existent-provider-name" + + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + + _, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Regexp("fetching pool params: no such provider", err.Error()) +} + +func (s *EnterpriseTestSuite) TestGetEnterprisePoolByID() { + enterprisePool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) + } + + pool, err := s.Runner.GetEnterprisePoolByID(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, enterprisePool.ID) + + s.Require().Nil(err) + s.Require().Equal(enterprisePool.ID, pool.ID) +} + +func (s *EnterpriseTestSuite) TestGetEnterprisePoolByIDErrUnauthorized() { + _, err := s.Runner.GetEnterprisePoolByID(context.Background(), "dummy-enterprise-id", "dummy-pool-id") + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { + pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) + } + + err = s.Runner.DeleteEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID) + + s.Require().Nil(err) + + _, err = s.Fixtures.Store.GetEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID) + s.Require().Equal("fetching pool: not found", err.Error()) +} + +func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolErrUnauthorized() { + err := s.Runner.DeleteEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id") + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolRunnersFailed() { + pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) + } + instance, err := s.Fixtures.Store.CreateInstance(s.Fixtures.AdminContext, pool.ID, s.Fixtures.CreateInstanceParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) + } + + err = s.Runner.DeleteEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID) + + s.Require().Equal(runnerErrors.NewBadRequestError("pool has runners: %s", instance.ID), err) +} + +func (s *EnterpriseTestSuite) TestListEnterprisePools() { + enterprisePools := []params.Pool{} + for i := 1; i <= 2; i++ { + s.Fixtures.CreatePoolParams.Image = fmt.Sprintf("test-enterprise-%v", i) + pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) + } + enterprisePools = append(enterprisePools, pool) + } + + pools, err := s.Runner.ListEnterprisePools(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID) + + s.Require().Nil(err) + garmTesting.EqualDBEntityID(s.T(), enterprisePools, pools) +} + +func (s *EnterpriseTestSuite) TestListOrgPoolsErrUnauthorized() { + _, err := s.Runner.ListOrgPools(context.Background(), "dummy-org-id") + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { + enterprisePool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) + } + + pool, err := s.Runner.UpdateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, enterprisePool.ID, s.Fixtures.UpdatePoolParams) + + s.Require().Nil(err) + s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) + s.Require().Equal(*s.Fixtures.UpdatePoolParams.MinIdleRunners, pool.MinIdleRunners) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolErrUnauthorized() { + _, err := s.Runner.UpdateEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams) + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMinIdleGreaterThanMax() { + pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) + } + var maxRunners uint = 10 + var minIdleRunners uint = 11 + s.Fixtures.UpdatePoolParams.MaxRunners = &maxRunners + s.Fixtures.UpdatePoolParams.MinIdleRunners = &minIdleRunners + + _, err = s.Runner.UpdateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID, s.Fixtures.UpdatePoolParams) + + s.Require().Equal(runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners"), err) +} + +func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { + pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) + } + poolInstances := []params.Instance{} + for i := 1; i <= 3; i++ { + s.Fixtures.CreateInstanceParams.Name = fmt.Sprintf("test-enterprise-%v", i) + instance, err := s.Fixtures.Store.CreateInstance(s.Fixtures.AdminContext, pool.ID, s.Fixtures.CreateInstanceParams) + if err != nil { + s.FailNow(fmt.Sprintf("cannot create instance: %s", err)) + } + poolInstances = append(poolInstances, instance) + } + + instances, err := s.Runner.ListEnterpriseInstances(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID) + + s.Require().Nil(err) + garmTesting.EqualDBEntityID(s.T(), poolInstances, instances) +} + +func (s *EnterpriseTestSuite) TestListEnterpriseInstancesErrUnauthorized() { + _, err := s.Runner.ListEnterpriseInstances(context.Background(), "dummy-enterprise-id") + + s.Require().Equal(runnerErrors.ErrUnauthorized, err) +} + +func (s *EnterpriseTestSuite) TestFindEnterprisePoolManager() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + + poolManager, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Nil(err) + s.Require().Equal(s.Fixtures.PoolMgrMock, poolManager) +} + +func (s *EnterpriseTestSuite) TestFindEnterprisePoolManagerFetchPoolMgrFailed() { + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + + _, err := s.Runner.findEnterprisePoolManager(s.Fixtures.StoreEnterprises["test-enterprise-1"].Name) + + s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) + s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) + s.Require().Regexp("fetching pool manager for enterprise", err.Error()) +} + +func TestEnterpriseTestSuite(t *testing.T) { + t.Parallel() + suite.Run(t, new(EnterpriseTestSuite)) +} diff --git a/runner/mocks/PoolManagerController.go b/runner/mocks/PoolManagerController.go index a310f483..ce708fcd 100644 --- a/runner/mocks/PoolManagerController.go +++ b/runner/mocks/PoolManagerController.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.15.0. DO NOT EDIT. package mocks diff --git a/runner/organizations_test.go b/runner/organizations_test.go index e87cc255..b8a7ef48 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -27,7 +27,6 @@ import ( "garm/runner/common" runnerCommonMocks "garm/runner/common/mocks" runnerMocks "garm/runner/mocks" - "sort" "testing" "github.com/stretchr/testify/mock" @@ -59,47 +58,6 @@ type OrgTestSuite struct { Runner *Runner } -func (s *OrgTestSuite) orgsMapValues(orgs map[string]params.Organization) []params.Organization { - orgsSlice := []params.Organization{} - for _, value := range orgs { - orgsSlice = append(orgsSlice, value) - } - return orgsSlice -} - -func (s *OrgTestSuite) equalOrgsByName(expected, actual []params.Organization) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].Name > expected[j].Name }) - sort.Slice(actual, func(i, j int) bool { return actual[i].Name > actual[j].Name }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].Name, actual[i].Name) - } -} - -func (s *OrgTestSuite) equalPoolsByID(expected, actual []params.Pool) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].ID > expected[j].ID }) - sort.Slice(actual, func(i, j int) bool { return actual[i].ID > actual[j].ID }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].ID, actual[i].ID) - } -} - -func (s *OrgTestSuite) equalInstancesByName(expected, actual []params.Instance) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].Name > expected[j].Name }) - sort.Slice(actual, func(i, j int) bool { return actual[i].Name > actual[j].Name }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].Name, actual[i].Name) - } -} - func (s *OrgTestSuite) SetupTest() { adminCtx := auth.GetAdminContext() @@ -267,7 +225,7 @@ func (s *OrgTestSuite) TestListOrganizations() { orgs, err := s.Runner.ListOrganizations(s.Fixtures.AdminContext) s.Require().Nil(err) - s.equalOrgsByName(s.orgsMapValues(s.Fixtures.StoreOrgs), orgs) + garmTesting.EqualDBEntityByName(s.T(), garmTesting.DBEntityMapToSlice(s.Fixtures.StoreOrgs), orgs) } func (s *OrgTestSuite) TestListOrganizationsErrUnauthorized() { @@ -493,7 +451,7 @@ func (s *OrgTestSuite) TestListOrgPools() { pools, err := s.Runner.ListOrgPools(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID) s.Require().Nil(err) - s.equalPoolsByID(orgPools, pools) + garmTesting.EqualDBEntityID(s.T(), orgPools, pools) } func (s *OrgTestSuite) TestListOrgPoolsErrUnauthorized() { @@ -554,7 +512,7 @@ func (s *OrgTestSuite) TestListOrgInstances() { instances, err := s.Runner.ListOrgInstances(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID) s.Require().Nil(err) - s.equalInstancesByName(poolInstances, instances) + garmTesting.EqualDBEntityID(s.T(), poolInstances, instances) } func (s *OrgTestSuite) TestListOrgInstancesErrUnauthorized() { @@ -585,5 +543,6 @@ func (s *OrgTestSuite) TestFindOrgPoolManagerFetchPoolMgrFailed() { } func TestOrgTestSuite(t *testing.T) { + t.Parallel() suite.Run(t, new(OrgTestSuite)) } diff --git a/runner/pool/enterprise.go b/runner/pool/enterprise.go index 0cc99d97..45931842 100644 --- a/runner/pool/enterprise.go +++ b/runner/pool/enterprise.go @@ -18,7 +18,7 @@ import ( ) // test that we implement PoolManager -var _ poolHelper = &organization{} +var _ poolHelper = &enterprise{} func NewEnterprisePoolManager(ctx context.Context, cfg params.Enterprise, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) { ghc, ghEnterpriseClient, err := util.GithubClient(ctx, cfgInternal.OAuth2Token, cfgInternal.GithubCredentialsDetails) @@ -61,18 +61,25 @@ type enterprise struct { mux sync.Mutex } -func (r *enterprise) GetRunnerNameFromWorkflow(job params.WorkflowJob) (string, error) { +func (r *enterprise) GetRunnerInfoFromWorkflow(job params.WorkflowJob) (params.RunnerInfo, error) { + if err := r.ValidateOwner(job); err != nil { + return params.RunnerInfo{}, errors.Wrap(err, "validating owner") + } workflow, ghResp, err := r.ghcli.GetWorkflowJobByID(r.ctx, job.Repository.Owner.Login, job.Repository.Name, job.WorkflowJob.ID) if err != nil { if ghResp.StatusCode == http.StatusUnauthorized { - return "", errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runners") + return params.RunnerInfo{}, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching workflow info") } - return "", errors.Wrap(err, "fetching workflow info") + return params.RunnerInfo{}, errors.Wrap(err, "fetching workflow info") } + if workflow.RunnerName != nil { - return *workflow.RunnerName, nil + return params.RunnerInfo{ + Name: *workflow.RunnerName, + Labels: workflow.Labels, + }, nil } - return "", fmt.Errorf("failed to find runner name from workflow") + return params.RunnerInfo{}, fmt.Errorf("failed to find runner name from workflow") } func (r *enterprise) UpdateState(param params.UpdatePoolStateParams) error { @@ -179,6 +186,10 @@ func (r *enterprise) GetCallbackURL() string { return r.cfgInternal.InstanceCallbackURL } +func (r *enterprise) GetMetadataURL() string { + return r.cfgInternal.InstanceMetadataURL +} + func (r *enterprise) FindPoolByTags(labels []string) (params.Pool, error) { pool, err := r.store.FindEnterprisePoolByTags(r.ctx, r.id, labels) if err != nil { diff --git a/runner/pool/interfaces.go b/runner/pool/interfaces.go index 9f1f7b85..9e69dec3 100644 --- a/runner/pool/interfaces.go +++ b/runner/pool/interfaces.go @@ -24,7 +24,7 @@ type poolHelper interface { GetGithubToken() string GetGithubRunners() ([]*github.Runner, error) GetGithubRegistrationToken() (string, error) - GetRunnerNameFromWorkflow(job params.WorkflowJob) (string, error) + GetRunnerInfoFromWorkflow(job params.WorkflowJob) (params.RunnerInfo, error) RemoveGithubRunner(runnerID int64) (*github.Response, error) FetchTools() ([]*github.RunnerApplicationDownload, error) @@ -34,6 +34,7 @@ type poolHelper interface { JwtToken() string String() string GetCallbackURL() string + GetMetadataURL() string FindPoolByTags(labels []string) (params.Pool, error) GetPoolByID(poolID string) (params.Pool, error) ValidateOwner(job params.WorkflowJob) error diff --git a/runner/pool/organization.go b/runner/pool/organization.go index b8efba72..9371b4d5 100644 --- a/runner/pool/organization.go +++ b/runner/pool/organization.go @@ -73,18 +73,25 @@ type organization struct { mux sync.Mutex } -func (r *organization) GetRunnerNameFromWorkflow(job params.WorkflowJob) (string, error) { +func (r *organization) GetRunnerInfoFromWorkflow(job params.WorkflowJob) (params.RunnerInfo, error) { + if err := r.ValidateOwner(job); err != nil { + return params.RunnerInfo{}, errors.Wrap(err, "validating owner") + } workflow, ghResp, err := r.ghcli.GetWorkflowJobByID(r.ctx, job.Organization.Login, job.Repository.Name, job.WorkflowJob.ID) if err != nil { if ghResp.StatusCode == http.StatusUnauthorized { - return "", errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runner name") + return params.RunnerInfo{}, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching workflow info") } - return "", errors.Wrap(err, "fetching workflow info") + return params.RunnerInfo{}, errors.Wrap(err, "fetching workflow info") } + if workflow.RunnerName != nil { - return *workflow.RunnerName, nil + return params.RunnerInfo{ + Name: *workflow.RunnerName, + Labels: workflow.Labels, + }, nil } - return "", fmt.Errorf("failed to find runner name from workflow") + return params.RunnerInfo{}, fmt.Errorf("failed to find runner name from workflow") } func (r *organization) UpdateState(param params.UpdatePoolStateParams) error { @@ -192,6 +199,10 @@ func (r *organization) GetCallbackURL() string { return r.cfgInternal.InstanceCallbackURL } +func (r *organization) GetMetadataURL() string { + return r.cfgInternal.InstanceMetadataURL +} + func (r *organization) FindPoolByTags(labels []string) (params.Pool, error) { pool, err := r.store.FindOrganizationPoolByTags(r.ctx, r.id, labels) if err != nil { diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 1eac96c2..a6d4f7c4 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -67,16 +67,37 @@ type basePoolManager struct { mux sync.Mutex } -func controllerIDFromLabels(labels []*github.RunnerLabels) string { +func controllerIDFromLabels(labels []string) string { for _, lbl := range labels { - if lbl.Name != nil && strings.HasPrefix(*lbl.Name, controllerLabelPrefix) { - labelName := *lbl.Name - return labelName[len(controllerLabelPrefix):] + if strings.HasPrefix(lbl, controllerLabelPrefix) { + return lbl[len(controllerLabelPrefix):] } } return "" } +func labelsFromRunner(runner *github.Runner) []string { + if runner == nil || runner.Labels == nil { + return []string{} + } + + var labels []string + for _, val := range runner.Labels { + if val == nil { + continue + } + labels = append(labels, val.GetName()) + } + return labels +} + +// isManagedRunner returns true if labels indicate the runner belongs to a pool +// this manager is responsible for. +func (r *basePoolManager) isManagedRunner(labels []string) bool { + runnerControllerID := controllerIDFromLabels(labels) + return runnerControllerID == r.controllerID +} + // cleanupOrphanedProviderRunners compares runners in github with local runners and removes // any local runners that are not present in Github. Runners that are "idle" in our // provider, but do not exist in github, will be removed. This can happen if the @@ -92,6 +113,10 @@ func (r *basePoolManager) cleanupOrphanedProviderRunners(runners []*github.Runne runnerNames := map[string]bool{} for _, run := range runners { + if !r.isManagedRunner(labelsFromRunner(run)) { + log.Printf("runner %s is not managed by a pool belonging to %s", *run.Name, r.helper.String()) + continue + } runnerNames[*run.Name] = true } @@ -127,24 +152,26 @@ func (r *basePoolManager) reapTimedOutRunners(runners []*github.Runner) error { runnerNames := map[string]bool{} for _, run := range runners { + if !r.isManagedRunner(labelsFromRunner(run)) { + log.Printf("runner %s is not managed by a pool belonging to %s", *run.Name, r.helper.String()) + continue + } runnerNames[*run.Name] = true } for _, instance := range dbInstances { if ok := runnerNames[instance.Name]; !ok { - if instance.Status == providerCommon.InstanceRunning { - pool, err := r.store.GetPoolByID(r.ctx, instance.PoolID) - if err != nil { - return errors.Wrap(err, "fetching instance pool info") - } - if time.Since(instance.UpdatedAt).Minutes() < float64(pool.RunnerTimeout()) { - continue - } - log.Printf("reaping instance %s due to timeout", instance.Name) - if err := r.setInstanceStatus(instance.Name, providerCommon.InstancePendingDelete, nil); err != nil { - log.Printf("failed to update runner %s status", instance.Name) - return errors.Wrap(err, "updating runner") - } + pool, err := r.store.GetPoolByID(r.ctx, instance.PoolID) + if err != nil { + return errors.Wrap(err, "fetching instance pool info") + } + if time.Since(instance.UpdatedAt).Minutes() < float64(pool.RunnerTimeout()) { + continue + } + log.Printf("reaping instance %s due to timeout", instance.Name) + if err := r.setInstanceStatus(instance.Name, providerCommon.InstancePendingDelete, nil); err != nil { + log.Printf("failed to update runner %s status", instance.Name) + return errors.Wrap(err, "updating runner") } } } @@ -157,9 +184,8 @@ func (r *basePoolManager) reapTimedOutRunners(runners []*github.Runner) error { // first remove the instance from github, and then from our database. func (r *basePoolManager) cleanupOrphanedGithubRunners(runners []*github.Runner) error { for _, runner := range runners { - runnerControllerID := controllerIDFromLabels(runner.Labels) - if runnerControllerID != r.controllerID { - // Not a runner we manage. Do not remove foreign runner. + if !r.isManagedRunner(labelsFromRunner(runner)) { + log.Printf("runner %s is not managed by a pool belonging to %s", *runner.Name, r.helper.String()) continue } @@ -265,6 +291,11 @@ func (r *basePoolManager) fetchInstance(runnerName string) (params.Instance, err return params.Instance{}, errors.Wrap(err, "fetching instance") } + _, err = r.helper.GetPoolByID(runner.PoolID) + if err != nil { + return params.Instance{}, errors.Wrap(err, "fetching pool") + } + return runner, nil } @@ -312,6 +343,9 @@ func (r *basePoolManager) acquireNewInstance(job params.WorkflowJob) error { pool, err := r.helper.FindPoolByTags(requestedLabels) if err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + return nil + } return errors.Wrap(err, "fetching suitable pool") } log.Printf("adding new runner with requested tags %s in pool %s", strings.Join(job.WorkflowJob.Labels, ", "), pool.ID) @@ -353,6 +387,7 @@ func (r *basePoolManager) AddRunner(ctx context.Context, poolID string) error { OSArch: pool.OSArch, OSType: pool.OSType, CallbackURL: r.helper.GetCallbackURL(), + MetadataURL: r.helper.GetMetadataURL(), CreateAttempt: 1, } @@ -512,16 +547,6 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error labels = append(labels, r.controllerLabel()) labels = append(labels, r.poolLabel(pool.ID)) - tk, err := r.helper.GetGithubRegistrationToken() - if err != nil { - if errors.Is(err, runnerErrors.ErrUnauthorized) { - failureReason := fmt.Sprintf("failed to fetch registration token: %q", err) - r.setPoolRunningState(false, failureReason) - log.Print(failureReason) - } - return errors.Wrap(err, "fetching registration token") - } - jwtValidity := pool.RunnerTimeout() var poolType common.PoolType = common.RepositoryPool if pool.OrgID != "" { @@ -535,18 +560,18 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error } bootstrapArgs := params.BootstrapInstance{ - Name: instance.Name, - Tools: r.tools, - RepoURL: r.helper.GithubURL(), - GithubRunnerAccessToken: tk, - CallbackURL: instance.CallbackURL, - InstanceToken: jwtToken, - OSArch: pool.OSArch, - Flavor: pool.Flavor, - Image: pool.Image, - Labels: labels, - PoolID: instance.PoolID, - CACertBundle: r.credsDetails.CABundle, + Name: instance.Name, + Tools: r.tools, + RepoURL: r.helper.GithubURL(), + MetadataURL: instance.MetadataURL, + CallbackURL: instance.CallbackURL, + InstanceToken: jwtToken, + OSArch: pool.OSArch, + Flavor: pool.Flavor, + Image: pool.Image, + Labels: labels, + PoolID: instance.PoolID, + CACertBundle: r.credsDetails.CABundle, } var instanceIDToDelete string @@ -581,25 +606,39 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error return nil } -func (r *basePoolManager) getRunnerNameFromJob(job params.WorkflowJob) (string, error) { - if job.WorkflowJob.RunnerName != "" { - return job.WorkflowJob.RunnerName, nil +func (r *basePoolManager) getRunnerDetailsFromJob(job params.WorkflowJob) (params.RunnerInfo, error) { + runnerInfo := params.RunnerInfo{ + Name: job.WorkflowJob.RunnerName, + Labels: job.WorkflowJob.Labels, } - // Runner name was not set in WorkflowJob by github. We can still attempt to - // fetch the info we need, using the workflow run ID, from the API. - log.Printf("runner name not found in workflow job, attempting to fetch from API") - runnerName, err := r.helper.GetRunnerNameFromWorkflow(job) - if err != nil { - if errors.Is(err, runnerErrors.ErrUnauthorized) { - failureReason := fmt.Sprintf("failed to fetch runner name from API: %q", err) - r.setPoolRunningState(false, failureReason) - log.Print(failureReason) + var err error + if job.WorkflowJob.RunnerName == "" { + // Runner name was not set in WorkflowJob by github. We can still attempt to + // fetch the info we need, using the workflow run ID, from the API. + log.Printf("runner name not found in workflow job, attempting to fetch from API") + runnerInfo, err = r.helper.GetRunnerInfoFromWorkflow(job) + if err != nil { + if errors.Is(err, runnerErrors.ErrUnauthorized) { + failureReason := fmt.Sprintf("failed to fetch runner name from API: %q", err) + r.setPoolRunningState(false, failureReason) + log.Print(failureReason) + } + return params.RunnerInfo{}, errors.Wrap(err, "fetching runner name from API") } - return "", errors.Wrap(err, "fetching runner name from API") } - return runnerName, nil + runnerDetails, err := r.store.GetInstanceByName(context.Background(), runnerInfo.Name) + if err != nil { + log.Printf("could not find runner details for %s", runnerInfo.Name) + return params.RunnerInfo{}, errors.Wrap(err, "fetching runner details") + } + + if _, err := r.helper.GetPoolByID(runnerDetails.PoolID); err != nil { + log.Printf("runner %s (pool ID: %s) does not belong to any pool we manage: %s", runnerDetails.Name, runnerDetails.PoolID, err) + return params.RunnerInfo{}, errors.Wrap(err, "fetching pool for instance") + } + return runnerInfo, nil } func (r *basePoolManager) HandleWorkflowJob(job params.WorkflowJob) error { @@ -621,34 +660,52 @@ func (r *basePoolManager) HandleWorkflowJob(job params.WorkflowJob) error { case "completed": // ignore the error here. A completed job may not have a runner name set // if it was never assigned to a runner, and was canceled. - runnerName, _ := r.getRunnerNameFromJob(job) - // Set instance in database to pending delete. - if runnerName == "" { + runnerInfo, err := r.getRunnerDetailsFromJob(job) + if err != nil { // Unassigned jobs will have an empty runner_name. - // There is nothing to to in this case. - log.Printf("no runner was assigned. Skipping.") + // We also need to ignore not found errors, as we may get a webhook regarding + // a workflow that is handled by a runner at a different hierarchy level. return nil } + // update instance workload state. - if err := r.setInstanceRunnerStatus(runnerName, providerCommon.RunnerTerminated); err != nil { - log.Printf("failed to update runner %s status", runnerName) + if err := r.setInstanceRunnerStatus(runnerInfo.Name, providerCommon.RunnerTerminated); err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + return nil + } + log.Printf("failed to update runner %s status", runnerInfo.Name) return errors.Wrap(err, "updating runner") } - log.Printf("marking instance %s as pending_delete", runnerName) - if err := r.setInstanceStatus(runnerName, providerCommon.InstancePendingDelete, nil); err != nil { - log.Printf("failed to update runner %s status", runnerName) + log.Printf("marking instance %s as pending_delete", runnerInfo.Name) + if err := r.setInstanceStatus(runnerInfo.Name, providerCommon.InstancePendingDelete, nil); err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + return nil + } + log.Printf("failed to update runner %s status", runnerInfo.Name) return errors.Wrap(err, "updating runner") } case "in_progress": // in_progress jobs must have a runner name/ID assigned. Sometimes github will send a hook without // a runner set. In such cases, we attemt to fetch it from the API. - runnerName, err := r.getRunnerNameFromJob(job) + runnerInfo, err := r.getRunnerDetailsFromJob(job) if err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + // This is most likely a runner we're not managing. If we define a repo from within an org + // and also define that same org, we will get a hook from github from both the repo and the org + // regarding the same workflow. We look for the runner in the database, and make sure it exists and is + // part of a pool that this manager is responsible for. A not found error here will most likely mean + // that we are not responsible for that runner, and we should ignore it. + return nil + } return errors.Wrap(err, "determining runner name") } + // update instance workload state. - if err := r.setInstanceRunnerStatus(runnerName, providerCommon.RunnerActive); err != nil { - log.Printf("failed to update runner %s status", job.WorkflowJob.RunnerName) + if err := r.setInstanceRunnerStatus(runnerInfo.Name, providerCommon.RunnerActive); err != nil { + if errors.Is(err, runnerErrors.ErrNotFound) { + return nil + } + log.Printf("failed to update runner %s status", runnerInfo.Name) return errors.Wrap(err, "updating runner") } } @@ -725,10 +782,11 @@ func (r *basePoolManager) retryFailedInstancesForOnePool(pool params.Pool) { existingInstances, err := r.store.ListPoolInstances(r.ctx, pool.ID) if err != nil { - log.Printf("failed to ensure minimum idle workers for pool %s: %s", pool.ID, err) + log.Printf("retrying failed instances: failed to list instances for pool %s: %s", pool.ID, err) return } + wg := sync.WaitGroup{} for _, instance := range existingInstances { if instance.Status != providerCommon.InstanceError { continue @@ -736,28 +794,33 @@ func (r *basePoolManager) retryFailedInstancesForOnePool(pool params.Pool) { if instance.CreateAttempt >= maxCreateAttempts { continue } + wg.Add(1) + go func(inst params.Instance) { + defer wg.Done() + // NOTE(gabriel-samfira): this is done in parallel. If there are many failed instances + // this has the potential to create many API requests to the target provider. + // TODO(gabriel-samfira): implement request throttling. + if err := r.deleteInstanceFromProvider(inst); err != nil { + log.Printf("failed to delete instance %s from provider: %s", inst.Name, err) + } - // NOTE(gabriel-samfira): this is done in parallel. If there are many failed instances - // this has the potential to create many API requests to the target provider. - // TODO(gabriel-samfira): implement request throttling. - if instance.ProviderID == "" && instance.Name == "" { - // This really should not happen, but no harm in being extra paranoid. The name is set - // when creating a db entity for the runner, so we should at least have a name. - return - } - // TODO(gabriel-samfira): Incrementing CreateAttempt should be done within a transaction. - // It's fairly safe to do here (for now), as there should be no other code path that updates - // an instance in this state. - updateParams := params.UpdateInstanceParams{ - CreateAttempt: instance.CreateAttempt + 1, - Status: providerCommon.InstancePendingCreate, - } - log.Printf("queueing previously failed instance %s for retry", instance.Name) - // Set instance to pending create and wait for retry. - if err := r.updateInstance(instance.Name, updateParams); err != nil { - log.Printf("failed to update runner %s status", instance.Name) - } + // TODO(gabriel-samfira): Incrementing CreateAttempt should be done within a transaction. + // It's fairly safe to do here (for now), as there should be no other code path that updates + // an instance in this state. + var tokenFetched bool = false + updateParams := params.UpdateInstanceParams{ + CreateAttempt: inst.CreateAttempt + 1, + TokenFetched: &tokenFetched, + Status: providerCommon.InstancePendingCreate, + } + log.Printf("queueing previously failed instance %s for retry", inst.Name) + // Set instance to pending create and wait for retry. + if err := r.updateInstance(inst.Name, updateParams); err != nil { + log.Printf("failed to update runner %s status", inst.Name) + } + }(instance) } + wg.Wait() } func (r *basePoolManager) retryFailedInstances() { @@ -816,9 +879,6 @@ func (r *basePoolManager) deleteInstanceFromProvider(instance params.Instance) e return errors.Wrap(err, "removing instance") } - if err := r.store.DeleteInstance(r.ctx, pool.ID, instance.Name); err != nil { - return errors.Wrap(err, "deleting instance from database") - } return nil } @@ -855,6 +915,10 @@ func (r *basePoolManager) deletePendingInstances() { if err != nil { log.Printf("failed to delete instance from provider: %+v", err) } + + if err := r.store.DeleteInstance(r.ctx, instance.PoolID, instance.Name); err != nil { + return errors.Wrap(err, "deleting instance from database") + } return }(instance) } @@ -969,6 +1033,10 @@ func (r *basePoolManager) WebhookSecret() string { return r.helper.WebhookSecret() } +func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) { + return r.helper.GetGithubRegistrationToken() +} + func (r *basePoolManager) ID() string { return r.helper.ID() } diff --git a/runner/pool/repository.go b/runner/pool/repository.go index 598ac939..9c69915a 100644 --- a/runner/pool/repository.go +++ b/runner/pool/repository.go @@ -75,18 +75,25 @@ type repository struct { mux sync.Mutex } -func (r *repository) GetRunnerNameFromWorkflow(job params.WorkflowJob) (string, error) { +func (r *repository) GetRunnerInfoFromWorkflow(job params.WorkflowJob) (params.RunnerInfo, error) { + if err := r.ValidateOwner(job); err != nil { + return params.RunnerInfo{}, errors.Wrap(err, "validating owner") + } workflow, ghResp, err := r.ghcli.GetWorkflowJobByID(r.ctx, job.Repository.Owner.Login, job.Repository.Name, job.WorkflowJob.ID) if err != nil { if ghResp.StatusCode == http.StatusUnauthorized { - return "", errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runner name") + return params.RunnerInfo{}, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching workflow info") } - return "", errors.Wrap(err, "fetching workflow info") + return params.RunnerInfo{}, errors.Wrap(err, "fetching workflow info") } + if workflow.RunnerName != nil { - return *workflow.RunnerName, nil + return params.RunnerInfo{ + Name: *workflow.RunnerName, + Labels: workflow.Labels, + }, nil } - return "", fmt.Errorf("failed to find runner name from workflow") + return params.RunnerInfo{}, fmt.Errorf("failed to find runner name from workflow") } func (r *repository) UpdateState(param params.UpdatePoolStateParams) error { @@ -193,6 +200,10 @@ func (r *repository) GetCallbackURL() string { return r.cfgInternal.InstanceCallbackURL } +func (r *repository) GetMetadataURL() string { + return r.cfgInternal.InstanceMetadataURL +} + func (r *repository) FindPoolByTags(labels []string) (params.Pool, error) { pool, err := r.store.FindRepositoryPoolByTags(r.ctx, r.id, labels) if err != nil { diff --git a/runner/providers/external/util.go b/runner/providers/external/util.go index 4a39b2e3..33375122 100644 --- a/runner/providers/external/util.go +++ b/runner/providers/external/util.go @@ -29,7 +29,6 @@ func bootstrapParamsToEnv(param params.BootstrapInstance) []string { fmt.Sprintf("%s_BOOTSTRAP_CALLBACK_URL='%s'", envPrefix, param.CallbackURL), fmt.Sprintf("%s_BOOTSTRAP_REPO_URL='%s'", envPrefix, param.RepoURL), fmt.Sprintf("%s_BOOTSTRAP_LABELS='%s'", envPrefix, strings.Join(param.Labels, ",")), - fmt.Sprintf("%s_BOOTSTRAP_GITHUB_ACCESS_TOKEN='%s'", envPrefix, param.GithubRunnerAccessToken), } for idx, tool := range param.Tools { diff --git a/runner/repositories_test.go b/runner/repositories_test.go index 2466d006..e372f6dd 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -27,7 +27,6 @@ import ( "garm/runner/common" runnerCommonMocks "garm/runner/common/mocks" runnerMocks "garm/runner/mocks" - "sort" "testing" "github.com/stretchr/testify/mock" @@ -58,47 +57,6 @@ type RepoTestSuite struct { Runner *Runner } -func (s *RepoTestSuite) reposMapValues(repos map[string]params.Repository) []params.Repository { - reposSlice := []params.Repository{} - for _, value := range repos { - reposSlice = append(reposSlice, value) - } - return reposSlice -} - -func (s *RepoTestSuite) equalReposByName(expected, actual []params.Repository) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].Name > expected[j].Name }) - sort.Slice(actual, func(i, j int) bool { return actual[i].Name > actual[j].Name }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].Name, actual[i].Name) - } -} - -func (s *RepoTestSuite) equalPoolsByID(expected, actual []params.Pool) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].ID > expected[j].ID }) - sort.Slice(actual, func(i, j int) bool { return actual[i].ID > actual[j].ID }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].ID, actual[i].ID) - } -} - -func (s *RepoTestSuite) equalInstancesByID(expected, actual []params.Instance) { - s.Require().Equal(len(expected), len(actual)) - - sort.Slice(expected, func(i, j int) bool { return expected[i].ID > expected[j].ID }) - sort.Slice(actual, func(i, j int) bool { return actual[i].ID > actual[j].ID }) - - for i := 0; i < len(expected); i++ { - s.Require().Equal(expected[i].ID, actual[i].ID) - } -} - func (s *RepoTestSuite) SetupTest() { adminCtx := auth.GetAdminContext() @@ -270,7 +228,7 @@ func (s *RepoTestSuite) TestListRepositories() { repos, err := s.Runner.ListRepositories(s.Fixtures.AdminContext) s.Require().Nil(err) - s.equalReposByName(s.reposMapValues(s.Fixtures.StoreRepos), repos) + garmTesting.EqualDBEntityByName(s.T(), garmTesting.DBEntityMapToSlice(s.Fixtures.StoreRepos), repos) } func (s *RepoTestSuite) TestListRepositoriesErrUnauthorized() { @@ -495,7 +453,7 @@ func (s *RepoTestSuite) TestListRepoPools() { pools, err := s.Runner.ListRepoPools(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID) s.Require().Nil(err) - s.equalPoolsByID(repoPools, pools) + garmTesting.EqualDBEntityID(s.T(), repoPools, pools) } func (s *RepoTestSuite) TestListRepoPoolsErrUnauthorized() { @@ -522,7 +480,7 @@ func (s *RepoTestSuite) TestListPoolInstances() { instances, err := s.Runner.ListPoolInstances(s.Fixtures.AdminContext, pool.ID) s.Require().Nil(err) - s.equalInstancesByID(poolInstances, instances) + garmTesting.EqualDBEntityID(s.T(), poolInstances, instances) } func (s *RepoTestSuite) TestListPoolInstancesErrUnauthorized() { @@ -583,7 +541,7 @@ func (s *RepoTestSuite) TestListRepoInstances() { instances, err := s.Runner.ListRepoInstances(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID) s.Require().Nil(err) - s.equalInstancesByID(poolInstances, instances) + garmTesting.EqualDBEntityID(s.T(), poolInstances, instances) } func (s *RepoTestSuite) TestListRepoInstancesErrUnauthorized() { @@ -614,5 +572,6 @@ func (s *RepoTestSuite) TestFindRepoPoolManagerFetchPoolMgrFailed() { } func TestRepoTestSuite(t *testing.T) { + t.Parallel() suite.Run(t, new(RepoTestSuite)) } diff --git a/runner/runner.go b/runner/runner.go index 3c7b00ba..b7db8060 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -239,6 +239,7 @@ func (p *poolManagerCtrl) getInternalConfig(credsName string) (params.Internal, OAuth2Token: creds.OAuth2Token, ControllerID: p.controllerID, InstanceCallbackURL: p.config.Default.CallbackURL, + InstanceMetadataURL: p.config.Default.MetadataURL, JWTSecret: p.config.JWTAuth.Secret, GithubCredentialsDetails: params.GithubCredentials{ Name: creds.Name, @@ -583,8 +584,10 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ switch HookTargetType(hookTargetType) { case RepoHook: + log.Printf("got hook for repo %s/%s", job.Repository.Owner.Login, job.Repository.Name) poolManager, err = r.findRepoPoolManager(job.Repository.Owner.Login, job.Repository.Name) case OrganizationHook: + log.Printf("got hook for org %s", job.Organization.Login) poolManager, err = r.findOrgPoolManager(job.Organization.Login) case EnterpriseHook: poolManager, err = r.findEnterprisePoolManager(job.Enterprise.Slug) @@ -703,7 +706,7 @@ func (r *Runner) AddInstanceStatusMessage(ctx context.Context, param params.Inst return runnerErrors.ErrUnauthorized } - if err := r.store.AddInstanceStatusMessage(ctx, instanceID, param.Message); err != nil { + if err := r.store.AddInstanceEvent(ctx, instanceID, params.StatusEvent, params.EventInfo, param.Message); err != nil { return errors.Wrap(err, "adding status update") } @@ -722,6 +725,95 @@ func (r *Runner) AddInstanceStatusMessage(ctx context.Context, param params.Inst return nil } +func (r *Runner) GetInstanceGithubRegistrationToken(ctx context.Context) (string, error) { + instanceName := auth.InstanceName(ctx) + if instanceName == "" { + return "", runnerErrors.ErrUnauthorized + } + + // Check if this instance already fetched a registration token. We only allow an instance to + // fetch one token. If the instance fails to bootstrap after a token is fetched, we reset the + // token fetched field when re-queueing the instance. + if auth.InstanceTokenFetched(ctx) { + return "", runnerErrors.ErrUnauthorized + } + + status := auth.InstanceRunnerStatus(ctx) + if status != providerCommon.RunnerPending && status != providerCommon.RunnerInstalling { + return "", runnerErrors.ErrUnauthorized + } + + instance, err := r.store.GetInstanceByName(ctx, instanceName) + if err != nil { + return "", errors.Wrap(err, "fetching instance") + } + + poolMgr, err := r.getPoolManagerFromInstance(ctx, instance) + if err != nil { + return "", errors.Wrap(err, "fetching pool manager for instance") + } + + token, err := poolMgr.GithubRunnerRegistrationToken() + if err != nil { + return "", errors.Wrap(err, "fetching runner token") + } + + tokenFetched := true + updateParams := params.UpdateInstanceParams{ + TokenFetched: &tokenFetched, + } + + if _, err := r.store.UpdateInstance(r.ctx, instance.ID, updateParams); err != nil { + return "", errors.Wrap(err, "setting token_fetched for instance") + } + + if err := r.store.AddInstanceEvent(ctx, instance.ID, params.FetchTokenEvent, params.EventInfo, "runner registration token was retrieved"); err != nil { + return "", errors.Wrap(err, "recording event") + } + + return token, nil +} + +func (r *Runner) getPoolManagerFromInstance(ctx context.Context, instance params.Instance) (common.PoolManager, error) { + pool, err := r.store.GetPoolByID(ctx, instance.PoolID) + if err != nil { + return nil, errors.Wrap(err, "fetching pool") + } + + var poolMgr common.PoolManager + + if pool.RepoID != "" { + repo, err := r.store.GetRepositoryByID(ctx, pool.RepoID) + if err != nil { + return nil, errors.Wrap(err, "fetching repo") + } + poolMgr, err = r.findRepoPoolManager(repo.Owner, repo.Name) + if err != nil { + return nil, errors.Wrapf(err, "fetching pool manager for repo %s", pool.RepoName) + } + } else if pool.OrgID != "" { + org, err := r.store.GetOrganizationByID(ctx, pool.OrgID) + if err != nil { + return nil, errors.Wrap(err, "fetching org") + } + poolMgr, err = r.findOrgPoolManager(org.Name) + if err != nil { + return nil, errors.Wrapf(err, "fetching pool manager for org %s", pool.OrgName) + } + } else if pool.EnterpriseID != "" { + enterprise, err := r.store.GetEnterpriseByID(ctx, pool.EnterpriseID) + if err != nil { + return nil, errors.Wrap(err, "fetching enterprise") + } + poolMgr, err = r.findEnterprisePoolManager(enterprise.Name) + if err != nil { + return nil, errors.Wrapf(err, "fetching pool manager for enterprise %s", pool.EnterpriseName) + } + } + + return poolMgr, nil +} + func (r *Runner) ForceDeleteRunner(ctx context.Context, instanceName string) error { if !auth.IsAdmin(ctx) { return runnerErrors.ErrUnauthorized @@ -732,46 +824,9 @@ func (r *Runner) ForceDeleteRunner(ctx context.Context, instanceName string) err return errors.Wrap(err, "fetching instance") } - switch instance.Status { - case providerCommon.InstanceRunning, providerCommon.InstanceError: - default: - return runnerErrors.NewBadRequestError("runner must be in %q or %q state", providerCommon.InstanceRunning, providerCommon.InstanceError) - } - - pool, err := r.store.GetPoolByID(ctx, instance.PoolID) + poolMgr, err := r.getPoolManagerFromInstance(ctx, instance) if err != nil { - return errors.Wrap(err, "fetching pool") - } - - var poolMgr common.PoolManager - - if pool.RepoID != "" { - repo, err := r.store.GetRepositoryByID(ctx, pool.RepoID) - if err != nil { - return errors.Wrap(err, "fetching repo") - } - poolMgr, err = r.findRepoPoolManager(repo.Owner, repo.Name) - if err != nil { - return errors.Wrapf(err, "fetching pool manager for repo %s", pool.RepoName) - } - } else if pool.OrgID != "" { - org, err := r.store.GetOrganizationByID(ctx, pool.OrgID) - if err != nil { - return errors.Wrap(err, "fetching org") - } - poolMgr, err = r.findOrgPoolManager(org.Name) - if err != nil { - return errors.Wrapf(err, "fetching pool manager for org %s", pool.OrgName) - } - } else if pool.EnterpriseID != "" { - enterprise, err := r.store.GetEnterpriseByID(ctx, pool.EnterpriseID) - if err != nil { - return errors.Wrap(err, "fetching enterprise") - } - poolMgr, err = r.findEnterprisePoolManager(enterprise.Name) - if err != nil { - return errors.Wrapf(err, "fetching pool manager for enterprise %s", pool.EnterpriseName) - } + return errors.Wrap(err, "fetching pool manager for instance") } if err := poolMgr.ForceDeleteRunner(instance); err != nil { diff --git a/testdata/config.toml b/testdata/config.toml index 7f1c0d42..7aff4207 100644 --- a/testdata/config.toml +++ b/testdata/config.toml @@ -4,9 +4,17 @@ # the github actions runner. Status messages can be seen by querying the # runner status in garm. callback_url = "https://garm.example.com/api/v1/callbacks/status" + +# This URL is used by instances to retrieve information they need to set themselves +# up. Access to this URL is granted using the same JWT token used to send back +# status updates. Once the instance transitions to "installed" or "failed" state, +# access to both the status and metadata endpoints is disabled. +metadata_url = "https://garm.example.com/api/v1/metadata" + # This folder is defined here for future use. Right now, we create a SSH # public/private key-pair. config_dir = "/etc/garm" + # Uncomment this line if you'd like to log to a file instead of standard output. # log_file = "/tmp/runner-manager.log" diff --git a/util/util.go b/util/util.go index 8658f6a5..4834be32 100644 --- a/util/util.go +++ b/util/util.go @@ -24,7 +24,6 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" "net/http" "os" "path" @@ -39,6 +38,7 @@ import ( "garm/runner/common" "github.com/google/go-github/v48/github" + gorillaHandlers "github.com/gorilla/handlers" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" @@ -147,7 +147,7 @@ func GetLoggingWriter(cfg *config.Config) (io.Writer, error) { } func ConvertFileToBase64(file string) (string, error) { - bytes, err := ioutil.ReadFile(file) + bytes, err := os.ReadFile(file) if err != nil { return "", errors.Wrap(err, "reading file") } @@ -213,7 +213,7 @@ func GetCloudConfig(bootstrapParams params.BootstrapInstance, tools github.Runne FileName: *tools.Filename, DownloadURL: *tools.DownloadURL, TempDownloadToken: tempToken, - GithubToken: bootstrapParams.GithubRunnerAccessToken, + MetadataURL: bootstrapParams.MetadataURL, RunnerUsername: config.DefaultUser, RunnerGroup: config.DefaultUser, RepoURL: bootstrapParams.RepoURL, @@ -321,3 +321,9 @@ func PaswsordToBcrypt(password string) (string, error) { } return string(hashedPassword), nil } + +func NewLoggingMiddleware(writer io.Writer) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return gorillaHandlers.CombinedLoggingHandler(writer, next) + } +}