From 0883fcd5cd8caebf7d5cbafa68ab9d234186f463 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 28 Apr 2022 16:13:20 +0000 Subject: [PATCH] Add some basic auth --- apiserver/controllers/controllers.go | 108 ++++++- apiserver/params/params.go | 18 ++ apiserver/routers/routers.go | 79 ++++- auth/auth.go | 144 ++++++++++ auth/context.go | 111 ++++++++ auth/init_required.go | 34 +++ auth/interfaces.go | 8 + auth/jwt.go | 115 +++++++- cmd/runner-manager/dbcreate.go | 285 +++++++++++++------ cmd/runner-manager/main.go | 226 ++++++--------- config/config.go | 283 ++++-------------- database/common/common.go | 12 +- database/sql/models.go | 40 ++- database/sql/sql.go | 411 ++++++++++++++++++++++----- go.mod | 4 +- go.sum | 4 +- params/params.go | 97 +++---- params/requests.go | 94 +++++- runner/common/provider.go | 2 + runner/pool/repository.go | 11 +- runner/providers/lxd/lxd.go | 7 + runner/runner.go | 188 ++++++++++-- testdata/config.toml | 30 +- util/util.go | 50 +++- 24 files changed, 1687 insertions(+), 674 deletions(-) create mode 100644 auth/auth.go create mode 100644 auth/context.go create mode 100644 auth/init_required.go create mode 100644 auth/interfaces.go diff --git a/apiserver/controllers/controllers.go b/apiserver/controllers/controllers.go index b72596e4..5056ef55 100644 --- a/apiserver/controllers/controllers.go +++ b/apiserver/controllers/controllers.go @@ -8,6 +8,7 @@ import ( "net/http" "runner-manager/apiserver/params" + "runner-manager/auth" gErrors "runner-manager/errors" runnerParams "runner-manager/params" "runner-manager/runner" @@ -15,14 +16,16 @@ import ( "github.com/pkg/errors" ) -func NewAPIController(r *runner.Runner) (*APIController, error) { +func NewAPIController(r *runner.Runner, auth *auth.Authenticator) (*APIController, error) { return &APIController{ - r: r, + r: r, + auth: auth, }, nil } type APIController struct { - r *runner.Runner + r *runner.Runner + auth *auth.Authenticator } func handleError(w http.ResponseWriter, err error) { @@ -113,3 +116,102 @@ func (a *APIController) NotFoundHandler(w http.ResponseWriter, r *http.Request) w.WriteHeader(http.StatusNotFound) json.NewEncoder(w).Encode(apiErr) } + +// LoginHandler returns a jwt token +func (a *APIController) LoginHandler(w http.ResponseWriter, r *http.Request) { + var loginInfo runnerParams.PasswordLoginParams + if err := json.NewDecoder(r.Body).Decode(&loginInfo); err != nil { + handleError(w, gErrors.ErrBadRequest) + return + } + + if err := loginInfo.Validate(); err != nil { + handleError(w, err) + return + } + + ctx := r.Context() + ctx, err := a.auth.AuthenticateUser(ctx, loginInfo) + if err != nil { + handleError(w, err) + return + } + + tokenString, err := a.auth.GetJWTToken(ctx) + if err != nil { + handleError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(runnerParams.JWTResponse{Token: tokenString}) +} + +func (a *APIController) FirstRunHandler(w http.ResponseWriter, r *http.Request) { + if a.auth.IsInitialized() { + err := gErrors.NewConflictError("already initialized") + handleError(w, err) + return + } + + ctx := r.Context() + + var newUserParams runnerParams.NewUserParams + if err := json.NewDecoder(r.Body).Decode(&newUserParams); err != nil { + handleError(w, gErrors.ErrBadRequest) + return + } + + newUser, err := a.auth.InitController(ctx, newUserParams) + if err != nil { + handleError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(newUser) + +} + +func (a *APIController) ListCredentials(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + creds, err := a.r.ListCredentials(ctx) + if err != nil { + handleError(w, err) + return + } + + json.NewEncoder(w).Encode(creds) +} + +func (a *APIController) ListProviders(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + providers, err := a.r.ListProviders(ctx) + if err != nil { + handleError(w, err) + return + } + + json.NewEncoder(w).Encode(providers) +} + +func (a *APIController) CreateRepoHandler(w http.ResponseWriter, r *http.Request) { + // ctx := r.Context() + + // var repoData runnerParams.CreateRepoParams + // if err := json.NewDecoder(r.Body).Decode(&repoData); err != nil { + // handleError(w, gErrors.ErrBadRequest) + // return + // } + + // pasteInfo, err := p.paster.Create( + // ctx, pasteData.Data, pasteData.Name, + // pasteData.Language, pasteData.Description, + // pasteData.Expires, pasteData.Public, "", + // pasteData.Metadata) + // if err != nil { + // handleError(w, err) + // return + // } + // w.Header().Set("Content-Type", "application/json") + // json.NewEncoder(w).Encode(pasteInfo) + +} diff --git a/apiserver/params/params.go b/apiserver/params/params.go index e88e26ea..f367b824 100644 --- a/apiserver/params/params.go +++ b/apiserver/params/params.go @@ -5,3 +5,21 @@ type APIErrorResponse struct { Error string `json:"error"` Details string `json:"details"` } + +var ( + // NotFoundResponse is returned when a resource is not found + NotFoundResponse = APIErrorResponse{ + Error: "Not Found", + Details: "The resource you are looking for was not found", + } + // UnauthorizedResponse is a canned response for unauthorized access + UnauthorizedResponse = APIErrorResponse{ + Error: "Not Authorized", + Details: "You do not have the required permissions to access this resource", + } + // InitializationRequired is returned if gopherbin has not beed properly initialized + InitializationRequired = APIErrorResponse{ + Error: "init_required", + Details: "Missing superuser", + } +) diff --git a/apiserver/routers/routers.go b/apiserver/routers/routers.go index 1b02fe92..bc89503b 100644 --- a/apiserver/routers/routers.go +++ b/apiserver/routers/routers.go @@ -3,19 +3,92 @@ package routers import ( "io" "net/http" + "os" gorillaHandlers "github.com/gorilla/handlers" "github.com/gorilla/mux" "runner-manager/apiserver/controllers" + "runner-manager/auth" ) -func NewAPIRouter(han *controllers.APIController, logWriter io.Writer) *mux.Router { +func NewAPIRouter(han *controllers.APIController, logWriter io.Writer, authMiddleware, initMiddleware auth.Middleware) *mux.Router { router := mux.NewRouter() log := gorillaHandlers.CombinedLoggingHandler - apiRouter := router.PathPrefix("").Subrouter() - apiRouter.PathPrefix("/").Handler(log(logWriter, http.HandlerFunc(han.CatchAll))) + // Handles github webhooks + webhookRouter := router.PathPrefix("/webhooks").Subrouter() + webhookRouter.PathPrefix("/").Handler(log(logWriter, http.HandlerFunc(han.CatchAll))) + + // Handles API calls + apiSubRouter := router.PathPrefix("/api/v1").Subrouter() + + // FirstRunHandler + firstRunRouter := apiSubRouter.PathPrefix("/first-run").Subrouter() + firstRunRouter.Handle("/", log(os.Stdout, http.HandlerFunc(han.FirstRunHandler))).Methods("POST", "OPTIONS") + + // Login + authRouter := apiSubRouter.PathPrefix("/auth").Subrouter() + authRouter.Handle("/{login:login\\/?}", log(os.Stdout, http.HandlerFunc(han.LoginHandler))).Methods("POST", "OPTIONS") + authRouter.Use(initMiddleware.Middleware) + + apiRouter := apiSubRouter.PathPrefix("").Subrouter() + apiRouter.Use(initMiddleware.Middleware) + apiRouter.Use(authMiddleware.Middleware) + + ///////////////////// + // Repos and pools // + ///////////////////// + // Get pool + apiRouter.Handle("/repositories/{repoID}/pools/{poolID:poolID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Delete pool + apiRouter.Handle("/repositories/{repoID}/pools/{poolID:poolID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("DELETE", "OPTIONS") + // List pools + apiRouter.Handle("/repositories/{repoID}/pools/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Create pool + apiRouter.Handle("/repositories/{repoID}/pools/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + apiRouter.Handle("/repositories/{repoID}/pools", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + + // Get repo + apiRouter.Handle("/repositories/{repoID:repoID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Delete repo + apiRouter.Handle("/repositories/{repoID:repoID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("DELETE", "OPTIONS") + // List repos + apiRouter.Handle("/repositories/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + apiRouter.Handle("/repositories", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Create repo + apiRouter.Handle("/repositories/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + apiRouter.Handle("/repositories", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + + ///////////////////////////// + // Organizations and pools // + ///////////////////////////// + // Get pool + apiRouter.Handle("/organizations/{repoID}/pools/{poolID:poolID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Delete pool + apiRouter.Handle("/organizations/{repoID}/pools/{poolID:poolID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("DELETE", "OPTIONS") + // List pools + apiRouter.Handle("/organizations/{repoID}/pools/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations/{repoID}/pools", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Create pool + apiRouter.Handle("/organizations/{repoID}/pools/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + apiRouter.Handle("/organizations/{repoID}/pools", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + + // Get repo + apiRouter.Handle("/organizations/{repoID:repoID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Delete repo + apiRouter.Handle("/organizations/{repoID:repoID\\/?}", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("DELETE", "OPTIONS") + // List repos + apiRouter.Handle("/organizations/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + apiRouter.Handle("/organizations", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("GET", "OPTIONS") + // Create repo + apiRouter.Handle("/organizations/", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + apiRouter.Handle("/organizations", log(os.Stdout, http.HandlerFunc(han.CatchAll))).Methods("POST", "OPTIONS") + + // Credentials and providers + apiRouter.Handle("/credentials", log(os.Stdout, http.HandlerFunc(han.ListCredentials))).Methods("GET", "OPTIONS") + apiRouter.Handle("/providers", log(os.Stdout, http.HandlerFunc(han.ListProviders))).Methods("GET", "OPTIONS") return router } diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 00000000..ae4b6cc0 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,144 @@ +package auth + +import ( + "context" + "runner-manager/config" + "runner-manager/database/common" + runnerErrors "runner-manager/errors" + "runner-manager/params" + "runner-manager/util" + "time" + + "github.com/golang-jwt/jwt" + "github.com/nbutton23/zxcvbn-go" + "github.com/pkg/errors" + "golang.org/x/crypto/bcrypt" +) + +func NewAuthenticator(cfg config.JWTAuth, store common.Store) *Authenticator { + return &Authenticator{ + cfg: cfg, + store: store, + } +} + +type Authenticator struct { + store common.Store + cfg config.JWTAuth +} + +func (a *Authenticator) IsInitialized() bool { + info, err := a.store.ControllerInfo() + if err != nil { + return false + } + + if info.ControllerID.String() == "" { + return false + } + + return true +} + +func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) { + tokenID, err := util.GetRandomString(16) + if err != nil { + return "", errors.Wrap(err, "generating random string") + } + expireToken := time.Now().Add(a.cfg.TimeToLive.Duration()).Unix() + claims := JWTClaims{ + StandardClaims: jwt.StandardClaims{ + ExpiresAt: expireToken, + // TODO: make this configurable + Issuer: "runner-manager", + }, + UserID: UserID(ctx), + TokenID: tokenID, + IsAdmin: IsAdmin(ctx), + FullName: FullName(ctx), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte(a.cfg.Secret)) + if err != nil { + return "", errors.Wrap(err, "fetching token string") + } + + return tokenString, nil +} + +func (a *Authenticator) InitController(ctx context.Context, param params.NewUserParams) (params.User, error) { + _, err := a.store.ControllerInfo() + if err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return params.User{}, errors.Wrap(err, "initializing controller") + } + } + if a.store.HasAdminUser(ctx) { + return params.User{}, runnerErrors.ErrNotFound + } + + if param.Email == "" || param.Username == "" { + return params.User{}, runnerErrors.NewBadRequestError("missing username or email") + } + + if !util.IsValidEmail(param.Email) { + return params.User{}, runnerErrors.NewBadRequestError("invalid email address") + } + + // username is varchar(64) + if len(param.Username) > 64 || !util.IsAlphanumeric(param.Username) { + return params.User{}, runnerErrors.NewBadRequestError("invalid username") + } + + param.IsAdmin = true + param.Enabled = true + + passwordStenght := zxcvbn.PasswordStrength(param.Password, nil) + if passwordStenght.Score < 4 { + return params.User{}, runnerErrors.NewBadRequestError("password is too weak") + } + + hashed, err := util.PaswsordToBcrypt(param.Password) + if err != nil { + return params.User{}, errors.Wrap(err, "creating user") + } + + param.Password = hashed + + if _, err := a.store.InitController(); err != nil { + return params.User{}, errors.Wrap(err, "initializing controller") + } + return a.store.CreateUser(ctx, param) +} + +func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.PasswordLoginParams) (context.Context, error) { + if info.Username == "" { + return ctx, runnerErrors.ErrUnauthorized + } + + if info.Password == "" { + return ctx, runnerErrors.ErrUnauthorized + } + + user, err := a.store.GetUser(ctx, info.Username) + + if err != nil { + if err == runnerErrors.ErrNotFound { + return ctx, runnerErrors.ErrUnauthorized + } + return ctx, errors.Wrap(err, "authenticating") + } + if !user.Enabled { + return ctx, runnerErrors.ErrUnauthorized + } + + if user.Password == "" { + return ctx, runnerErrors.ErrUnauthorized + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(info.Password)); err != nil { + return ctx, runnerErrors.ErrUnauthorized + } + + return PopulateContext(ctx, user), nil +} diff --git a/auth/context.go b/auth/context.go new file mode 100644 index 00000000..9e99f425 --- /dev/null +++ b/auth/context.go @@ -0,0 +1,111 @@ +package auth + +import ( + "context" + + "runner-manager/params" +) + +type contextFlags string + +const ( + isAdminKey contextFlags = "is_admin" + fullNameKey contextFlags = "full_name" + // UserIDFlag is the User ID flag we set in the context + UserIDFlag contextFlags = "user_id" + isEnabledFlag contextFlags = "is_enabled" + jwtTokenFlag contextFlags = "jwt_token" +) + +// PopulateContext sets the appropriate fields in the context, based on +// the user object +func PopulateContext(ctx context.Context, user params.User) context.Context { + ctx = SetUserID(ctx, user.ID) + ctx = SetAdmin(ctx, user.IsAdmin) + ctx = SetIsEnabled(ctx, user.Enabled) + ctx = SetFullName(ctx, user.FullName) + return ctx +} + +// SetFullName sets the user full name in the context +func SetFullName(ctx context.Context, fullName string) context.Context { + return context.WithValue(ctx, fullNameKey, fullName) +} + +// FullName returns the full name from context +func FullName(ctx context.Context) string { + name := ctx.Value(fullNameKey) + if name == nil { + return "" + } + return name.(string) +} + +// SetJWTClaim will set the JWT claim in the context +func SetJWTClaim(ctx context.Context, claim JWTClaims) context.Context { + return context.WithValue(ctx, jwtTokenFlag, claim) +} + +// JWTClaim returns the JWT claim saved in the context +func JWTClaim(ctx context.Context) JWTClaims { + jwtClaim := ctx.Value(jwtTokenFlag) + if jwtClaim == nil { + return JWTClaims{} + } + return jwtClaim.(JWTClaims) +} + +// SetIsEnabled sets a flag indicating if account is enabled +func SetIsEnabled(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, isEnabledFlag, enabled) +} + +// IsEnabled returns the a boolean indicating if the enabled flag is +// set and is true or false +func IsEnabled(ctx context.Context) bool { + elem := ctx.Value(isEnabledFlag) + if elem == nil { + return false + } + return elem.(bool) +} + +// SetAdmin sets the isAdmin flag on the context +func SetAdmin(ctx context.Context, isAdmin bool) context.Context { + return context.WithValue(ctx, isAdminKey, isAdmin) +} + +// IsAdmin returns a boolean indicating whether +// or not the context belongs to a logged in user +// and if that context has the admin flag set +func IsAdmin(ctx context.Context) bool { + elem := ctx.Value(isAdminKey) + if elem == nil { + return false + } + return elem.(bool) +} + +// SetUserID sets the userID in the context +func SetUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, UserIDFlag, userID) +} + +// UserID returns the userID from the context +func UserID(ctx context.Context) string { + userID := ctx.Value(UserIDFlag) + if userID == nil { + return "" + } + return userID.(string) +} + +// GetAdminContext will return an admin context. This can be used internally +// when fetching users. +func GetAdminContext() context.Context { + ctx := context.Background() + ctx = SetUserID(ctx, "") + ctx = SetAdmin(ctx, true) + ctx = SetIsEnabled(ctx, true) + return ctx +} diff --git a/auth/init_required.go b/auth/init_required.go new file mode 100644 index 00000000..a88f18d7 --- /dev/null +++ b/auth/init_required.go @@ -0,0 +1,34 @@ +package auth + +import ( + "encoding/json" + "net/http" + "runner-manager/apiserver/params" + "runner-manager/database/common" +) + +// NewjwtMiddleware returns a populated jwtMiddleware +func NewInitRequiredMiddleware(store common.Store) (Middleware, error) { + return &initRequired{ + store: store, + }, nil +} + +type initRequired struct { + store common.Store +} + +// Middleware implements the middleware interface +func (i *initRequired) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctrlInfo, err := i.store.ControllerInfo() + if err != nil || ctrlInfo.ControllerID.String() == "" { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusConflict) + json.NewEncoder(w).Encode(params.InitializationRequired) + return + } + ctx := r.Context() + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/auth/interfaces.go b/auth/interfaces.go new file mode 100644 index 00000000..5ae25cd8 --- /dev/null +++ b/auth/interfaces.go @@ -0,0 +1,8 @@ +package auth + +import "net/http" + +// Middleware defines an authentication middleware +type Middleware interface { + Middleware(next http.Handler) http.Handler +} diff --git a/auth/jwt.go b/auth/jwt.go index 932eb7ae..699072e7 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -1,9 +1,19 @@ package auth import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + apiParams "runner-manager/apiserver/params" + "runner-manager/config" + dbCommon "runner-manager/database/common" + runnerErrors "runner-manager/errors" "runner-manager/params" "runner-manager/runner/common" - "time" "github.com/golang-jwt/jwt" "github.com/pkg/errors" @@ -21,6 +31,15 @@ type InstanceJWTClaims struct { jwt.StandardClaims } +// JWTClaims holds JWT claims +type JWTClaims struct { + UserID string `json:"user"` + TokenID string `json:"token_id"` + FullName string `json:"full_name"` + IsAdmin bool `json:"is_admin"` + jwt.StandardClaims +} + func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolType common.PoolType) (string, error) { // make TTL configurable? expireToken := time.Now().Add(3 * time.Hour).Unix() @@ -43,3 +62,97 @@ func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolTy return tokenString, nil } + +// jwtMiddleware is the authentication middleware +// used with gorilla +type jwtMiddleware struct { + store dbCommon.Store + auth *Authenticator + cfg config.JWTAuth +} + +// NewjwtMiddleware returns a populated jwtMiddleware +func NewjwtMiddleware(store dbCommon.Store, cfg config.JWTAuth) (Middleware, error) { + return &jwtMiddleware{ + store: store, + cfg: cfg, + }, nil +} + +func (amw *jwtMiddleware) claimsToContext(ctx context.Context, claims *JWTClaims) (context.Context, error) { + if claims == nil { + return ctx, runnerErrors.ErrUnauthorized + } + + if claims.UserID == "" { + return nil, runnerErrors.ErrUnauthorized + } + + userInfo, err := amw.store.GetUser(ctx, claims.UserID) + if err != nil { + return ctx, runnerErrors.ErrUnauthorized + } + + ctx = PopulateContext(ctx, userInfo) + return ctx, nil +} + +func invalidAuthResponse(w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode( + apiParams.APIErrorResponse{ + Error: "Authentication failed", + Details: "Invalid authentication token", + }) +} + +// Middleware implements the middleware interface +func (amw *jwtMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TODO: Log error details when authentication fails + ctx := r.Context() + authorizationHeader := r.Header.Get("authorization") + if authorizationHeader == "" { + invalidAuthResponse(w) + return + } + + bearerToken := strings.Split(authorizationHeader, " ") + if len(bearerToken) != 2 { + invalidAuthResponse(w) + return + } + + claims := &JWTClaims{} + token, err := jwt.ParseWithClaims(bearerToken[1], claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("invalid signing method") + } + return []byte(amw.cfg.Secret), nil + }) + + if err != nil { + invalidAuthResponse(w) + return + } + + if !token.Valid { + invalidAuthResponse(w) + return + } + + ctx, err = amw.claimsToContext(ctx, claims) + if err != nil { + invalidAuthResponse(w) + return + } + if !IsEnabled(ctx) { + invalidAuthResponse(w) + return + } + + ctx = SetJWTClaim(ctx, *claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/cmd/runner-manager/dbcreate.go b/cmd/runner-manager/dbcreate.go index 8a7b6b3a..2477b56b 100644 --- a/cmd/runner-manager/dbcreate.go +++ b/cmd/runner-manager/dbcreate.go @@ -1,107 +1,214 @@ package main -import ( - "context" - "flag" - "fmt" - "log" - "os/signal" - "runner-manager/config" - "runner-manager/database/sql" - "runner-manager/params" - "runner-manager/util" -) +// import ( +// "context" +// "flag" +// "fmt" +// "log" +// "os/signal" +// "runner-manager/config" +// "runner-manager/database/sql" +// "runner-manager/params" +// "runner-manager/util" +// ) -var ( - conf = flag.String("config", config.DefaultConfigFilePath, "runner-manager config file") - version = flag.Bool("version", false, "prints version") -) +// var ( +// conf = flag.String("config", config.DefaultConfigFilePath, "runner-manager config file") +// version = flag.Bool("version", false, "prints version") +// ) -var Version string +// var Version string -func main() { - flag.Parse() - if *version { - fmt.Println(Version) - return - } - ctx, stop := signal.NotifyContext(context.Background(), signals...) - defer stop() - fmt.Println(ctx) +// func main() { +// flag.Parse() +// if *version { +// fmt.Println(Version) +// return +// } +// ctx, stop := signal.NotifyContext(context.Background(), signals...) +// defer stop() +// fmt.Println(ctx) - cfg, err := config.NewConfig(*conf) - if err != nil { - log.Fatalf("Fetching config: %+v", err) - } +// cfg, err := config.NewConfig(*conf) +// if err != nil { +// log.Fatalf("Fetching config: %+v", err) +// } - db, err := sql.NewSQLDatabase(ctx, cfg.Database) - if err != nil { - log.Fatal(err) - } +// db, err := sql.NewSQLDatabase(ctx, cfg.Database) +// if err != nil { +// log.Fatal(err) +// } - fmt.Println(db) +// fmt.Println(db) - txt := "ana are mere prune și alune" +// txt := "ana are mere prune și alune" - enc, err := util.Aes256EncodeString(txt, "pamkotepAyksemfeghoibidEwCivbaut") - if err != nil { - log.Fatal(err) - } +// enc, err := util.Aes256EncodeString(txt, "pamkotepAyksemfeghoibidEwCivbaut") +// if err != nil { +// log.Fatal(err) +// } - fmt.Printf("encrypted: %d\n", len(enc)) +// fmt.Printf("encrypted: %d\n", len(enc)) - dec, err := util.Aes256DecodeString(enc, "pamkotepAyksemfeghoibidEwCivbaut") - if err != nil { - log.Fatal(err) - } +// dec, err := util.Aes256DecodeString(enc, "pamkotepAyksemfeghoibidEwCivbaut") +// if err != nil { +// log.Fatal(err) +// } - fmt.Println(dec) +// fmt.Println(dec) - repo, err := db.CreateRepository(ctx, "gabriel-samfira", "scripts", "") - if err != nil { - log.Fatal(err) - } +// repo, err := db.CreateRepository(ctx, "gabriel-samfira", "", "scripts", "") +// if err != nil { +// log.Fatal(err) +// } - pool, err := db.CreateRepositoryPool(ctx, repo.ID, params.CreatePoolParams{ - ProviderName: "lxd_local", - MaxRunners: 10, - MinIdleRunners: 1, - Image: "ubuntu:20.04", - Flavor: "default", - Tags: []string{ - "myrunner", - "superAwesome", - }, - OSType: config.Linux, - OSArch: config.Amd64, - }) - if err != nil { - log.Fatal(err) - } - fmt.Println(pool) +// pool, err := db.CreateRepositoryPool(ctx, repo.ID, params.CreatePoolParams{ +// ProviderName: "lxd_local", +// MaxRunners: 10, +// MinIdleRunners: 1, +// Image: "ubuntu:20.04", +// Flavor: "default", +// Tags: []string{ +// "myrunner", +// "superAwesome", +// }, +// OSType: config.Linux, +// OSArch: config.Amd64, +// }) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Println(pool) - pool2, err := db.CreateRepositoryPool(ctx, repo.ID, params.CreatePoolParams{ - ProviderName: "lxd_local2", - MaxRunners: 10, - MinIdleRunners: 1, - Image: "ubuntu:20.04", - Flavor: "default", - Tags: []string{ - "myrunner", - "superAwesome2", - }, - OSType: config.Linux, - OSArch: config.Amd64, - }) - if err != nil { - log.Fatal(err) - } - fmt.Println(pool2) +// pool2, err := db.CreateRepositoryPool(ctx, repo.ID, params.CreatePoolParams{ +// ProviderName: "lxd_local2",package main - pool3, err := db.FindRepositoryPoolByTags(ctx, repo.ID, []string{"myrunner", "superAwesome2"}) - if err != nil { - log.Fatal(err) - } +// import ( +// "context" +// "flag" +// "fmt" +// "log" +// "os/signal" +// "runner-manager/config" +// "runner-manager/database/sql" +// "runner-manager/params" +// "runner-manager/util" +// ) - fmt.Println(pool3) -} +// var ( +// conf = flag.String("config", config.DefaultConfigFilePath, "runner-manager config file") +// version = flag.Bool("version", false, "prints version") +// ) + +// var Version string + +// func main() { +// flag.Parse() +// if *version { +// fmt.Println(Version) +// return +// } +// ctx, stop := signal.NotifyContext(context.Background(), signals...) +// defer stop() +// fmt.Println(ctx) + +// cfg, err := config.NewConfig(*conf) +// if err != nil { +// log.Fatalf("Fetching config: %+v", err) +// } + +// db, err := sql.NewSQLDatabase(ctx, cfg.Database) +// if err != nil { +// log.Fatal(err) +// } + +// fmt.Println(db) + +// txt := "ana are mere prune și alune" + +// enc, err := util.Aes256EncodeString(txt, "pamkotepAyksemfeghoibidEwCivbaut") +// if err != nil { +// log.Fatal(err) +// } + +// fmt.Printf("encrypted: %d\n", len(enc)) + +// dec, err := util.Aes256DecodeString(enc, "pamkotepAyksemfeghoibidEwCivbaut") +// if err != nil { +// log.Fatal(err) +// } + +// fmt.Println(dec) + +// repo, err := db.CreateRepository(ctx, "gabriel-samfira", "", "scripts", "") +// if err != nil { +// log.Fatal(err) +// } + +// pool, err := db.CreateRepositoryPool(ctx, repo.ID, params.CreatePoolParams{ +// ProviderName: "lxd_local", +// MaxRunners: 10, +// MinIdleRunners: 1, +// Image: "ubuntu:20.04", +// Flavor: "default", +// Tags: []string{ +// "myrunner", +// "superAwesome", +// }, +// OSType: config.Linux, +// OSArch: config.Amd64, +// }) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Println(pool) + +// pool2, err := db.CreateRepositoryPool(ctx, repo.ID, params.CreatePoolParams{ +// ProviderName: "lxd_local2", +// MaxRunners: 10, +// MinIdleRunners: 1, +// Image: "ubuntu:20.04", +// Flavor: "default", +// Tags: []string{ +// "myrunner", +// "superAwesome2", +// }, +// OSType: config.Linux, +// OSArch: config.Amd64, +// }) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Println(pool2) + +// pool3, err := db.FindRepositoryPoolByTags(ctx, repo.ID, []string{"myrunner", "superAwesome2"}) +// if err != nil { +// log.Fatal(err) +// } + +// fmt.Println(pool3) +// } + +// MaxRunners: 10, +// MinIdleRunners: 1, +// Image: "ubuntu:20.04", +// Flavor: "default", +// Tags: []string{ +// "myrunner", +// "superAwesome2", +// }, +// OSType: config.Linux, +// OSArch: config.Amd64, +// }) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Println(pool2) + +// pool3, err := db.FindRepositoryPoolByTags(ctx, repo.ID, []string{"myrunner", "superAwesome2"}) +// if err != nil { +// log.Fatal(err) +// } + +// fmt.Println(pool3) +// } diff --git a/cmd/runner-manager/main.go b/cmd/runner-manager/main.go index 02e96954..e7e2e694 100644 --- a/cmd/runner-manager/main.go +++ b/cmd/runner-manager/main.go @@ -1,160 +1,106 @@ package main -// import ( -// "context" -// "flag" -// "fmt" -// "log" -// "net" -// "net/http" -// "os/signal" +import ( + "context" + "flag" + "fmt" + "log" + "net" + "net/http" + "os/signal" -// "runner-manager/apiserver/controllers" -// "runner-manager/apiserver/routers" -// "runner-manager/config" -// "runner-manager/util" -// // "github.com/google/go-github/v43/github" -// // "golang.org/x/oauth2" -// // "gopkg.in/yaml.v3" -// ) + "runner-manager/apiserver/controllers" + "runner-manager/apiserver/routers" + "runner-manager/auth" + "runner-manager/config" + "runner-manager/database" + "runner-manager/runner" + "runner-manager/util" + // "github.com/google/go-github/v43/github" + // "golang.org/x/oauth2" + // "gopkg.in/yaml.v3" +) -// var ( -// conf = flag.String("config", config.DefaultConfigFilePath, "runner-manager config file") -// version = flag.Bool("version", false, "prints version") -// ) +var ( + conf = flag.String("config", config.DefaultConfigFilePath, "runner-manager config file") + version = flag.Bool("version", false, "prints version") +) -// var Version string +var Version string -// // var token = "super secret token" +// var token = "super secret token" -// func main() { -// flag.Parse() -// if *version { -// fmt.Println(Version) -// return -// } -// ctx, stop := signal.NotifyContext(context.Background(), signals...) -// defer stop() -// fmt.Println(ctx) +func main() { + flag.Parse() + if *version { + fmt.Println(Version) + return + } + ctx, stop := signal.NotifyContext(context.Background(), signals...) + defer stop() + fmt.Println(ctx) -// cfg, err := config.NewConfig(*conf) -// if err != nil { -// log.Fatalf("Fetching config: %+v", err) -// } + cfg, err := config.NewConfig(*conf) + if err != nil { + log.Fatalf("Fetching config: %+v", err) + } -// // ts := oauth2.StaticTokenSource( -// // &oauth2.Token{AccessToken: cfg.Github.OAuth2Token}, -// // ) + logWriter, err := util.GetLoggingWriter(cfg) + if err != nil { + log.Fatalf("fetching log writer: %+v", err) + } + log.SetOutput(logWriter) -// // tc := oauth2.NewClient(ctx, ts) + runner, err := runner.NewRunner(ctx, *cfg) + if err != nil { + log.Fatalf("failed to create controller: %+v", err) + } -// // ghClient := github.NewClient(tc) + db, err := database.NewDatabase(ctx, cfg.Database) + if err != nil { + log.Fatal(err) + } -// // // list all repositories for the authenticated user -// // repos, _, err := client.Repositories.List(ctx, "", nil) + authenticator := auth.NewAuthenticator(cfg.JWTAuth, db) + controller, err := controllers.NewAPIController(runner, authenticator) + if err != nil { + log.Fatalf("failed to create controller: %+v", err) + } -// // fmt.Println(repos, err) + jwtMiddleware, err := auth.NewjwtMiddleware(db, cfg.JWTAuth) + if err != nil { + log.Fatal(err) + } -// logWriter, err := util.GetLoggingWriter(cfg) -// if err != nil { -// log.Fatalf("fetching log writer: %+v", err) -// } -// log.SetOutput(logWriter) + initMiddleware, err := auth.NewInitRequiredMiddleware(db) + if err != nil { + log.Fatal(err) + } -// controller, err := controllers.NewAPIController() -// if err != nil { -// log.Fatalf("failed to create controller: %+v", err) -// } + router := routers.NewAPIRouter(controller, logWriter, jwtMiddleware, initMiddleware) -// router := routers.NewAPIRouter(controller, logWriter) + tlsCfg, err := cfg.APIServer.APITLSConfig() + if err != nil { + log.Fatalf("failed to get TLS config: %q", err) + } -// tlsCfg, err := cfg.APIServer.APITLSConfig() -// if err != nil { -// log.Fatalf("failed to get TLS config: %q", err) -// } + srv := &http.Server{ + Addr: cfg.APIServer.BindAddress(), + TLSConfig: tlsCfg, + // Pass our instance of gorilla/mux in. + Handler: router, + } -// srv := &http.Server{ -// Addr: cfg.APIServer.BindAddress(), -// TLSConfig: tlsCfg, -// // Pass our instance of gorilla/mux in. -// Handler: router, -// } + listener, err := net.Listen("tcp", srv.Addr) + if err != nil { + log.Fatalf("creating listener: %q", err) + } -// listener, err := net.Listen("tcp", srv.Addr) -// if err != nil { -// log.Fatalf("creating listener: %q", err) -// } + go func() { + if err := srv.Serve(listener); err != nil { + log.Fatalf("Listening: %+v", err) + } + }() -// go func() { -// if err := srv.Serve(listener); err != nil { -// log.Fatalf("Listening: %+v", err) -// } -// }() - -// <-ctx.Done() - -// // runner, err := runner.NewRunner(ctx, *cfg) -// // if err != nil { -// // log.Fatal(err) -// // } - -// // fmt.Println(runner) -// // controllerID := "026d374d-6a8a-4241-8ed9-a246fff6762f" -// // provider, err := lxd.NewProvider(ctx, &cfg.Providers[0], controllerID) -// // if err != nil { -// // log.Fatal(err) -// // } - -// // if err := provider.RemoveAllInstances(ctx); err != nil { -// // log.Fatal(err) -// // } - -// // fmt.Println(provider) - -// // if err := provider.DeleteInstance(ctx, "runner-manager-2fbe5354-be28-4e00-95a8-11479912368d"); err != nil { -// // log.Fatal(err) -// // } - -// // instances, err := provider.ListInstances(ctx) - -// // asJs, err := json.MarshalIndent(instances, "", " ") -// // fmt.Println(string(asJs), err) - -// // log.Print("Fetching tools") -// // tools, _, err := ghClient.Actions.ListRunnerApplicationDownloads(ctx, cfg.Repositories[0].Owner, cfg.Repositories[0].Name) -// // if err != nil { -// // log.Fatal(err) -// // } - -// // toolsAsYaml, err := yaml.Marshal(tools) -// // if err != nil { -// // log.Fatal(err) -// // } -// // log.Printf("got tools:\n%s\n", string(toolsAsYaml)) - -// // log.Print("fetching runner token") -// // ghRunnerToken, _, err := ghClient.Actions.CreateRegistrationToken(ctx, cfg.Repositories[0].Owner, cfg.Repositories[0].Name) -// // if err != nil { -// // log.Fatal(err) -// // } -// // log.Printf("got token %v", ghRunnerToken) - -// // bootstrapArgs := params.BootstrapInstance{ -// // Tools: tools, -// // RepoURL: cfg.Repositories[0].String(), -// // GithubRunnerAccessToken: *ghRunnerToken.Token, -// // RunnerType: cfg.Repositories[0].Pool.Runners[0].Name, -// // CallbackURL: "", -// // InstanceToken: "", -// // SSHKeys: []string{ -// // "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQC2oT7j/+elHY9U2ibgk2RYJgCvqIwewYKJTtHslTQFDWlHLeDam93BBOFlQJm9/wKX/qjC8d26qyzjeeeVf2EEAztp+jQfEq9OU+EtgQUi589jxtVmaWuYED8KVNbzLuP79SrBtEZD4xqgmnNotPhRshh3L6eYj4XzLWDUuOD6kzNdsJA2QOKeMOIFpBN6urKJHRHYD+oUPUX1w5QMv1W1Srlffl4m5uE+0eJYAMr02980PG4+jS4bzM170wYdWwUI0pSZsEDC8Fn7jef6QARU2CgHJYlaTem+KWSXislOUTaCpR0uhakP1ezebW20yuuc3bdRNgSlZi9B7zAPALGZpOshVqwF+KmLDi6XiFwG+NnwAFa6zaQfhOxhw/rF5Jk/wVjHIHkNNvYewycZPbKui0E3QrdVtR908N3VsPtLhMQ59BEMl3xlURSi0fiOU3UjnwmOkOoFDy/WT8qk//gFD93tUxlf4eKXDgNfME3zNz8nVi2uCPvG5NT/P/VWR8NMqW6tZcmWyswM/GgL6Y84JQ3ESZq/7WvAetdc1gVIDQJ2ejYbSHBcQpWvkocsiuMTCwiEvQ0sr+UE5jmecQvLPUyXOhuMhw43CwxnLk1ZSeYeCorxbskyqIXH71o8zhbPoPiEbwgB+i9WEoq02u7c8CmCmO8Y9aOnh8MzTKxIgQ==", -// // }, -// // } - -// // instance, err := provider.CreateInstance(ctx, bootstrapArgs) -// // if err != nil { -// // log.Fatal(err) -// // } - -// // fmt.Println(instance) -// } + <-ctx.Done() +} diff --git a/config/config.go b/config/config.go index 2d0cf5b7..672259e8 100644 --- a/config/config.go +++ b/config/config.go @@ -87,15 +87,12 @@ func NewConfig(cfgFile string) (*Config, error) { } type Config struct { - Default Default `toml:"default" json:"default"` - APIServer APIServer `toml:"apiserver,omitempty" json:"apiserver,omitempty"` - Database Database `toml:"database,omitempty" json:"database,omitempty"` - Repositories []Repository `toml:"repository,omitempty" json:"repository,omitempty"` - Organizations []Organization `toml:"organization,omitempty" json:"organization,omitempty"` - Providers []Provider `toml:"provider,omitempty" json:"provider,omitempty"` - Github []Github `toml:"github,omitempty"` - // LogFile is the location of the log file. - LogFile string `toml:"log_file,omitempty"` + Default Default `toml:"default" json:"default"` + APIServer APIServer `toml:"apiserver,omitempty" json:"apiserver,omitempty"` + Database Database `toml:"database,omitempty" json:"database,omitempty"` + Providers []Provider `toml:"provider,omitempty" json:"provider,omitempty"` + Github []Github `toml:"github,omitempty"` + JWTAuth JWTAuth `toml:"jwt_auth" json:"jwt-auth"` } // Validate validates the config @@ -117,54 +114,16 @@ func (c *Config) Validate() error { } } + if err := c.JWTAuth.Validate(); err != nil { + return errors.Wrap(err, "validating jwt config") + } + for _, provider := range c.Providers { if err := provider.Validate(); err != nil { return errors.Wrap(err, "validating provider") } } - for _, repo := range c.Repositories { - if err := repo.Validate(); err != nil { - return errors.Wrap(err, "validating repository") - } - - // We also need to validate that the provider used for this - // repo, has been defined in the providers section. Multiple - // repos can use the same provider. - found := false - for _, provider := range c.Providers { - if provider.Name == repo.Pool.ProviderName { - found = true - break - } - } - - if !found { - return fmt.Errorf("provider %s defined in repo %s/%s is not defined", repo.Pool.ProviderName, repo.Owner, repo.Name) - } - } - - for _, org := range c.Organizations { - if err := org.Validate(); err != nil { - return errors.Wrap(err, "validating organization") - } - - // We also need to validate that the provider used for this - // repo, has been defined in the providers section. Multiple - // repos can use the same provider. - found := false - for _, provider := range c.Providers { - if provider.Name == org.Pool.ProviderName { - found = true - break - } - } - - if !found { - return fmt.Errorf("provider %s defined in org %s is not defined", org.Pool.ProviderName, org.Name) - } - } - return nil } @@ -174,10 +133,8 @@ type Default struct { // may be used to access the runner instances. ConfigDir string `toml:"config_dir,omitempty" json:"config-dir,omitempty"` CallbackURL string `toml:"callback_url" json:"callback-url"` - - // JWTSecret is used to sign JWT tokens that will be used by instances to - // call home. - JWTSecret string `toml:"jwt_secret" json:"jwt-secret"` + // LogFile is the location of the log file. + LogFile string `toml:"log_file,omitempty" json:"log-file"` } func (d *Default) Validate() error { @@ -185,47 +142,9 @@ func (d *Default) Validate() error { return fmt.Errorf("missing callback_url") } - if d.JWTSecret == "" { - return fmt.Errorf("missing jwt secret") - } - - passwordStenght := zxcvbn.PasswordStrength(d.JWTSecret, nil) - if passwordStenght.Score < 4 { - return fmt.Errorf("jwt_secret is too weak") - } - return nil } -// Organization represents a Github organization for which we can manage runners. -type Organization struct { - // Name is the name of the organization. - Name string `toml:"name" json:"name"` - // WebsocketSecret is the shared secret used to create the hash of - // the webhook body. We use this to validate that the webhook message - // came in from the correct repo. - WebhookSecret string `toml:"webhook_secret" json:"webhook-secret"` - - // Pool is the pool defined for this repository. - Pool Pool `toml:"pool" json:"pool"` -} - -func (o *Organization) Validate() error { - if o.Name == "" { - return fmt.Errorf("missing org name") - } - - if err := o.Pool.Validate(); err != nil { - return errors.Wrap(err, "validating org pool") - } - - return nil -} - -func (o *Organization) String() string { - return fmt.Sprintf("https://github.com/%s", o.Name) -} - // Github hold configuration options specific to interacting with github. // Currently that is just a OAuth2 personal token. type Github struct { @@ -265,142 +184,6 @@ func (p *Provider) Validate() error { return nil } -// Runner represents a runner type. The runner type is defined by the labels -// it has, the image it runs on and the size of the compute system that was -// requested. -type Runner struct { - // Name is the name of this runner. The name needs to be unique within a provider, - // and is used as an ID. If you wish to change the name, you must make sure all - // runners of this type are deleted. - Name string `toml:"name" json:"name"` - // Labels is a list of labels that will be set for this runner in github. - // The labels will be used in workflows to request a particular kind of - // runner. - Labels []string `toml:"labels" json:"labels"` - // MaxRunners is the maximum number of self hosted action runners - // of any type that are spun up for this repo. If current worker count - // is not enough to handle jobs comming in, a new runner will be spun up, - // until MaxWorkers count is hit. Set this to 0 to disable MaxRunners. - MaxRunners int `toml:"max_runners" json:"max-runners"` - // MinIdleRunners is the minimum number of idle self hosted runners that will - // be maintained for this repo. Ensuring a few idle runners, speeds up jobs, especially - // on providers where cold boot takes a long time. The pool will attempt to maintain at - // least this many idle workers, unless MaxRunners is hit. Set this to 0, for on-demand. - MinIdleRunners int `toml:"min_idle_runners" json:"min-runners"` - - // Flavor is the size of the VM that will be spun up. - Flavor string `toml:"flavor" json:"flavor"` - // Image is the image that the VM will run. Each - Image string `toml:"image" json:"image"` - - // OSType overrides the OS type that comes in from the Image. If the image - // on a particular provider does not have this information set within it's metadata - // you must set this option, so the runner-manager knows how to configure - // the worker. - OSType OSType `toml:"os_type" json:"os-type"` - // OSArch overrides the OS architecture that comes in from the Image. - // If the image metadata does not include information about the OS architecture, - // you must set this option, so the runner-manager knows how to configure the worker. - OSArch OSArch `toml:"os_arch" json:"os-arch"` -} - -func (r *Runner) HasAllLabels(labels []string) bool { - hashed := map[string]struct{}{} - for _, val := range r.Labels { - hashed[val] = struct{}{} - } - - for _, val := range labels { - if _, ok := hashed[val]; !ok { - return false - } - } - - return true -} - -// TODO: validate rest -func (r *Runner) Validate() error { - if len(r.Labels) == 0 { - return fmt.Errorf("missing labels") - } - - if r.Name == "" { - return fmt.Errorf("name is not set") - } - - return nil -} - -type Pool struct { - // ProviderName is the name of the provider that will be used for this pool. - // A provider with the name specified in this setting, must be defined in - // the Providers array in the main config. - ProviderName string `toml:"provider_name" json:"provider-name"` - - // QueueSize defines the number of jobs this pool can handle simultaneously. - QueueSize uint `toml:"queue_size" json:"queue-size"` - - // Runners represents a list of runner types defined for this pool. - Runners []Runner `toml:"runners" json:"runners"` -} - -func (p *Pool) Validate() error { - if p.ProviderName == "" { - return fmt.Errorf("missing provider_name") - } - - if len(p.Runners) == 0 { - return fmt.Errorf("no runners defined for pool") - } - - for _, runner := range p.Runners { - if err := runner.Validate(); err != nil { - return errors.Wrap(err, "validating runner for pool") - } - } - return nil -} - -// Repository defines the settings for a pool associated with a particular repository. -type Repository struct { - // Owner is the user under which the repo is created - Owner string `toml:"owner" json:"owner"` - // Name is the name of the repo. - Name string `toml:"name" json:"name"` - // WebsocketSecret is the shared secret used to create the hash of - // the webhook body. We use this to validate that the webhook message - // came in from the correct repo. - WebhookSecret string `toml:"webhook_secret" json:"webhook-secret"` - - // Pool is the pool defined for this repository. - Pool Pool `toml:"pool" json:"pool"` -} - -func (r *Repository) String() string { - return fmt.Sprintf("https://github.com/%s/%s", r.Owner, r.Name) -} - -func (r *Repository) Validate() error { - if r.Owner == "" { - return fmt.Errorf("missing owner") - } - - if r.Name == "" { - return fmt.Errorf("missing repo name") - } - - if r.WebhookSecret == "" { - return fmt.Errorf("missing webhook_secret") - } - - if err := r.Pool.Validate(); err != nil { - return errors.Wrapf(err, "validating pool for %s", r) - } - - return nil -} - // Database is the database config entry type Database struct { Debug bool `toml:"debug" json:"debug"` @@ -606,3 +389,45 @@ func (a *APIServer) Validate() error { } return nil } + +type timeToLive string + +func (d *timeToLive) Duration() time.Duration { + duration, err := time.ParseDuration(string(*d)) + if err != nil { + return DefaultJWTTTL + } + return duration +} + +func (d *timeToLive) UnmarshalText(text []byte) error { + _, err := time.ParseDuration(string(text)) + if err != nil { + return errors.Wrap(err, "parsing time_to_live") + } + + *d = timeToLive(text) + return nil +} + +// JWTAuth holds settings used to generate JWT tokens +type JWTAuth struct { + Secret string `toml:"secret" json:"secret"` + TimeToLive timeToLive `toml:"time_to_live" json:"time-to-live"` +} + +// Validate validates the JWTAuth config +func (j *JWTAuth) Validate() error { + // TODO: Set defaults somewhere else. + if j.TimeToLive.Duration() < DefaultJWTTTL { + j.TimeToLive = timeToLive(DefaultJWTTTL.String()) + } + if j.Secret == "" { + return fmt.Errorf("invalid JWT secret") + } + passwordStenght := zxcvbn.PasswordStrength(j.Secret, nil) + if passwordStenght.Score < 4 { + return fmt.Errorf("jwt_secret is too weak") + } + return nil +} diff --git a/database/common/common.go b/database/common/common.go index 5e294073..7a4e597d 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -6,12 +6,12 @@ import ( ) type Store interface { - CreateRepository(ctx context.Context, owner, name, webhookSecret string) (params.Repository, error) + CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string) (params.Repository, error) GetRepository(ctx context.Context, owner, name string) (params.Repository, error) ListRepositories(ctx context.Context) ([]params.Repository, error) DeleteRepository(ctx context.Context, owner, name string) error - CreateOrganization(ctx context.Context, name, webhookSecret string) (params.Organization, error) + CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string) (params.Organization, error) GetOrganization(ctx context.Context, name string) (params.Organization, error) ListOrganizations(ctx context.Context) ([]params.Organization, error) DeleteOrganization(ctx context.Context, name string) error @@ -41,4 +41,12 @@ type Store interface { // GetInstance(ctx context.Context, poolID string, instanceID string) (params.Instance, error) GetInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error) + + CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error) + GetUser(ctx context.Context, user string) (params.User, error) + UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error) + HasAdminUser(ctx context.Context) bool + + ControllerInfo() (params.ControllerInfo, error) + InitController() (params.ControllerInfo, error) } diff --git a/database/sql/models.go b/database/sql/models.go index 6a87973f..df6efa46 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -5,6 +5,7 @@ import ( "runner-manager/runner/providers/common" "time" + "github.com/pkg/errors" uuid "github.com/satori/go.uuid" "gorm.io/gorm" ) @@ -21,7 +22,11 @@ func (b *Base) BeforeCreate(tx *gorm.DB) error { if b.ID != emptyId { return nil } - b.ID = uuid.NewV4() + newID, err := uuid.NewV4() + if err != nil { + return errors.Wrap(err, "generating id") + } + b.ID = newID return nil } @@ -57,18 +62,20 @@ type Pool struct { type Repository struct { Base - Owner string `gorm:"index:idx_owner,unique"` - Name string `gorm:"index:idx_owner,unique"` - WebhookSecret []byte - Pools []Pool `gorm:"foreignKey:RepoID"` + CredentialsName string + Owner string `gorm:"index:idx_owner,unique"` + Name string `gorm:"index:idx_owner,unique"` + WebhookSecret []byte + Pools []Pool `gorm:"foreignKey:RepoID"` } type Organization struct { Base - Name string `gorm:"uniqueIndex"` - WebhookSecret []byte - Pools []Pool `gorm:"foreignKey:OrgID"` + CredentialsName string + Name string `gorm:"uniqueIndex"` + WebhookSecret []byte + Pools []Pool `gorm:"foreignKey:OrgID"` } type Address struct { @@ -95,3 +102,20 @@ type Instance struct { PoolID uuid.UUID Pool Pool `gorm:"foreignKey:PoolID"` } + +type User struct { + Base + + Username string `gorm:"uniqueIndex;varchar(64)"` + FullName string `gorm:"type:varchar(254)"` + Email string `gorm:"type:varchar(254);unique;index:idx_email"` + Password string `gorm:"type:varchar(60)"` + IsAdmin bool + Enabled bool +} + +type ControllerInfo struct { + Base + + ControllerID uuid.UUID +} diff --git a/database/sql/sql.go b/database/sql/sql.go index 6dd95d27..8dbbbc79 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -9,8 +9,8 @@ import ( "runner-manager/params" "runner-manager/util" - "github.com/pborman/uuid" "github.com/pkg/errors" + uuid "github.com/satori/go.uuid" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -41,12 +41,13 @@ type sqlDatabase struct { func (s *sqlDatabase) migrateDB() error { if err := s.conn.AutoMigrate( &Tag{}, - // &Runner{}, &Pool{}, &Repository{}, &Organization{}, &Address{}, &Instance{}, + &ControllerInfo{}, + &User{}, ); err != nil { return err } @@ -85,10 +86,11 @@ func (s *sqlDatabase) sqlToCommonPool(pool Pool) params.Pool { func (s *sqlDatabase) sqlToCommonRepository(repo Repository) params.Repository { ret := params.Repository{ - ID: repo.ID.String(), - Name: repo.Name, - Owner: repo.Owner, - Pools: make([]params.Pool, len(repo.Pools)), + ID: repo.ID.String(), + Name: repo.Name, + Owner: repo.Owner, + CredentialsName: repo.CredentialsName, + Pools: make([]params.Pool, len(repo.Pools)), } for idx, pool := range repo.Pools { @@ -100,15 +102,16 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository) params.Repository { func (s *sqlDatabase) sqlToCommonOrganization(org Organization) params.Organization { ret := params.Organization{ - ID: org.ID.String(), - Name: org.Name, - Pools: make([]params.Pool, len(org.Pools)), + ID: org.ID.String(), + Name: org.Name, + CredentialsName: org.CredentialsName, + Pools: make([]params.Pool, len(org.Pools)), } return ret } -func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, webhookSecret string) (params.Repository, error) { +func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string) (params.Repository, error) { secret := []byte{} var err error if webhookSecret != "" { @@ -118,9 +121,10 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, webhook } } newRepo := Repository{ - Name: name, - Owner: owner, - WebhookSecret: secret, + Name: name, + Owner: owner, + WebhookSecret: secret, + CredentialsName: credentialsName, } q := s.conn.Create(&newRepo) @@ -134,12 +138,18 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, webhook return param, nil } -func (s *sqlDatabase) getRepo(ctx context.Context, owner, name string) (Repository, error) { +func (s *sqlDatabase) getRepo(ctx context.Context, owner, name string, preloadAll bool) (Repository, error) { var repo Repository - q := s.conn.Preload(clause.Associations). - Where("name = ? and owner = ?", name, owner). + + q := s.conn.Where("name = ? and owner = ?", name, owner). First(&repo) + if preloadAll { + q = q.Preload(clause.Associations) + } + + q = q.First(&repo) + if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return Repository{}, runnerErrors.ErrNotFound @@ -150,12 +160,12 @@ func (s *sqlDatabase) getRepo(ctx context.Context, owner, name string) (Reposito } func (s *sqlDatabase) getRepoByID(ctx context.Context, id string) (Repository, error) { - u := uuid.Parse(id) - if u == nil { - return Repository{}, errors.Wrap(runnerErrors.NewBadRequestError(""), "parsing id") + u, err := uuid.FromString(id) + if err != nil { + return Repository{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var repo Repository - q := s.conn.Preload(clause.Associations). + q := s.conn. Where("id = ?", u). First(&repo) @@ -169,7 +179,7 @@ func (s *sqlDatabase) getRepoByID(ctx context.Context, id string) (Repository, e } func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (params.Repository, error) { - repo, err := s.getRepo(ctx, owner, name) + repo, err := s.getRepo(ctx, owner, name, false) if err != nil { return params.Repository{}, errors.Wrap(err, "fetching repo") } @@ -200,7 +210,7 @@ func (s *sqlDatabase) ListRepositories(ctx context.Context) ([]params.Repository } func (s *sqlDatabase) DeleteRepository(ctx context.Context, owner, name string) error { - repo, err := s.getRepo(ctx, owner, name) + repo, err := s.getRepo(ctx, owner, name, false) if err != nil { if err == runnerErrors.ErrNotFound { return nil @@ -216,7 +226,7 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, owner, name string) return nil } -func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, webhookSecret string) (params.Organization, error) { +func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string) (params.Organization, error) { secret := []byte{} var err error if webhookSecret != "" { @@ -226,8 +236,9 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, webhookSecre } } newOrg := Organization{ - Name: name, - WebhookSecret: secret, + Name: name, + WebhookSecret: secret, + CredentialsName: credentialsName, } q := s.conn.Create(&newOrg) @@ -241,9 +252,15 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, webhookSecre return param, nil } -func (s *sqlDatabase) getOrg(ctx context.Context, name string) (Organization, error) { +func (s *sqlDatabase) getOrg(ctx context.Context, name string, preloadAll bool) (Organization, error) { var org Organization - q := s.conn.Preload(clause.Associations).Where("name = ?", name).First(&org) + + q := s.conn.Where("name = ?", name) + if preloadAll { + q = q.Preload(clause.Associations) + } + + q = q.First(&org) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return Organization{}, runnerErrors.ErrNotFound @@ -253,13 +270,19 @@ func (s *sqlDatabase) getOrg(ctx context.Context, name string) (Organization, er return org, nil } -func (s *sqlDatabase) getOrgByID(ctx context.Context, id string) (Organization, error) { - u := uuid.Parse(id) - if u == nil { - return Organization{}, errors.Wrap(runnerErrors.NewBadRequestError(""), "parsing id") +func (s *sqlDatabase) getOrgByID(ctx context.Context, id string, preloadAll bool) (Organization, error) { + u, err := uuid.FromString(id) + if err != nil { + return Organization{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } + + q := s.conn.Where("id = ?", u) + if preloadAll { + q = q.Preload(clause.Associations) + } + var org Organization - q := s.conn.Preload(clause.Associations).Where("id = ?", u).First(&org) + q = q.First(&org) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return Organization{}, runnerErrors.ErrNotFound @@ -270,7 +293,7 @@ func (s *sqlDatabase) getOrgByID(ctx context.Context, id string) (Organization, } func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.Organization, error) { - org, err := s.getOrg(ctx, name) + org, err := s.getOrg(ctx, name, false) if err != nil { return params.Organization{}, errors.Wrap(err, "fetching repo") } @@ -301,7 +324,7 @@ func (s *sqlDatabase) ListOrganizations(ctx context.Context) ([]params.Organizat } func (s *sqlDatabase) DeleteOrganization(ctx context.Context, name string) error { - org, err := s.getOrg(ctx, name) + org, err := s.getOrg(ctx, name, false) if err != nil { if err == runnerErrors.ErrNotFound { return nil @@ -377,11 +400,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoId string, p s.conn.Model(&newPool).Association("Tags").Append(&tt) } - repo, err = s.getRepoByID(ctx, repoId) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching repo") - } - return s.sqlToCommonPool(repo.Pools[0]), nil + return s.sqlToCommonPool(newPool), nil } func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string, param params.CreatePoolParams) (params.Pool, error) { @@ -389,7 +408,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string, return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") } - org, err := s.getOrgByID(ctx, orgId) + org, err := s.getOrgByID(ctx, orgId, false) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching org") } @@ -422,14 +441,53 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string, return s.sqlToCommonPool(newPool), nil } +func (s *sqlDatabase) getRepoPools(ctx context.Context, repoID string, preloadAll bool) ([]Pool, error) { + repo, err := s.getRepoByID(ctx, repoID) + if err != nil { + return nil, errors.Wrap(err, "fetching repo") + } + + var pools []Pool + q := s.conn.Model(&repo) + if preloadAll { + q = q.Preload(clause.Associations) + } + err = q.Association("Pools").Find(&pools) + if err != nil { + return nil, errors.Wrap(err, "fetching pool") + } + + return pools, nil +} + +func (s *sqlDatabase) getOrgPools(ctx context.Context, orgID string, preloadAll bool) ([]Pool, error) { + org, err := s.getOrgByID(ctx, orgID, preloadAll) + if err != nil { + return nil, errors.Wrap(err, "fetching repo") + } + + var pools []Pool + q := s.conn.Model(&org) + if preloadAll { + q = q.Preload(clause.Associations) + } + err = q.Association("Pools").Find(&pools) + if err != nil { + return nil, errors.Wrap(err, "fetching pool") + } + + return pools, nil +} + func (s *sqlDatabase) getRepoPool(ctx context.Context, repoID, poolID string) (Pool, error) { repo, err := s.getRepoByID(ctx, repoID) if err != nil { return Pool{}, errors.Wrap(err, "fetching repo") } - u := uuid.Parse(poolID) - if u == nil { - return Pool{}, fmt.Errorf("invalid pool id") + + u, err := uuid.FromString(poolID) + if err != nil { + return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var pool []Pool err = s.conn.Model(&repo).Association("Pools").Find(&pool, "id = ?", u) @@ -451,22 +509,24 @@ func (s *sqlDatabase) GetRepositoryPool(ctx context.Context, repoID, poolID stri return s.sqlToCommonPool(pool), nil } -func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string) (Pool, error) { - org, err := s.getOrgByID(ctx, orgID) +func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string, preloadAll bool) (Pool, error) { + org, err := s.getOrgByID(ctx, orgID, preloadAll) if err != nil { return Pool{}, errors.Wrap(err, "fetching repo") } - u := uuid.Parse(poolID) - if u == nil { - return Pool{}, fmt.Errorf("invalid pool id") + u, err := uuid.FromString(poolID) + if err != nil { + return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var pool []Pool - err = s.conn.Model(&org). - Association(clause.Associations). - Find(&pool, "id = ?", u) + q := s.conn.Model(&org) + if preloadAll { + q = q.Preload(clause.Associations) + } + q = q.Find(&pool, "id = ?", u) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching pool") + if q.Error != nil { + return Pool{}, errors.Wrap(q.Error, "fetching pool") } if len(pool) == 0 { return Pool{}, runnerErrors.ErrNotFound @@ -475,15 +535,18 @@ func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string) (Poo return pool[0], nil } -func (s *sqlDatabase) getPoolByID(ctx context.Context, poolID string) (Pool, error) { - u := uuid.Parse(poolID) - if u == nil { +func (s *sqlDatabase) getPoolByID(ctx context.Context, poolID string, preloadAll bool) (Pool, error) { + u, err := uuid.FromString(poolID) + if err != nil { return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var pool Pool - q := s.conn.Model(&Pool{}). - Preload(clause.Associations). - Where("id = ?", u).First(&pool) + q := s.conn.Model(&Pool{}) + if preloadAll { + q = q.Preload(clause.Associations) + } + + q = q.Where("id = ?", u).First(&pool) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { @@ -495,7 +558,7 @@ func (s *sqlDatabase) getPoolByID(ctx context.Context, poolID string) (Pool, err } func (s *sqlDatabase) GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) { - pool, err := s.getOrgPool(ctx, orgID, poolID) + pool, err := s.getOrgPool(ctx, orgID, poolID, false) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -518,7 +581,7 @@ func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID s } func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error { - pool, err := s.getOrgPool(ctx, orgID, poolID) + pool, err := s.getOrgPool(ctx, orgID, poolID, false) if err != nil { if errors.Is(err, runnerErrors.ErrNotFound) { return nil @@ -536,9 +599,9 @@ func (s *sqlDatabase) findPoolByTags(id, poolType string, tags []string) (params if len(tags) == 0 { return params.Pool{}, runnerErrors.NewBadRequestError("missing tags") } - u := uuid.Parse(id) - if u == nil { - return params.Pool{}, errors.Wrap(runnerErrors.NewBadRequestError(""), "parsing id") + u, err := uuid.FromString(id) + if err != nil { + return params.Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var pool Pool @@ -548,7 +611,7 @@ func (s *sqlDatabase) findPoolByTags(id, poolType string, tags []string) (params Group("pools.id"). Preload("Tags"). Having("count(1) = ?", len(tags)). - Where(where, tags, id).First(&pool) + Where(where, tags, u).First(&pool) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { @@ -605,7 +668,7 @@ func (s *sqlDatabase) sqlToParamsInstance(instance Instance) params.Instance { } func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { - pool, err := s.getPoolByID(ctx, param.Pool) + pool, err := s.getPoolByID(ctx, param.Pool, false) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") } @@ -631,8 +694,8 @@ func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param p // } func (s *sqlDatabase) getInstanceByID(ctx context.Context, instanceID string) (Instance, error) { - u := uuid.Parse(instanceID) - if u == nil { + u, err := uuid.FromString(instanceID) + if err != nil { return Instance{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } var instance Instance @@ -647,7 +710,7 @@ func (s *sqlDatabase) getInstanceByID(ctx context.Context, instanceID string) (I } func (s *sqlDatabase) getInstanceByName(ctx context.Context, poolID string, instanceName string) (Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(ctx, poolID, false) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return Instance{}, errors.Wrap(runnerErrors.ErrNotFound, "fetching instance") @@ -738,7 +801,7 @@ func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceID string, par } func (s *sqlDatabase) ListInstances(ctx context.Context, poolID string) ([]params.Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) + pool, err := s.getPoolByID(ctx, poolID, true) if err != nil { return nil, errors.Wrap(err, "fetching pool") } @@ -751,13 +814,13 @@ func (s *sqlDatabase) ListInstances(ctx context.Context, poolID string) ([]param } func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) { - repo, err := s.getRepoByID(ctx, repoID) + pools, err := s.getRepoPools(ctx, repoID, true) if err != nil { return nil, errors.Wrap(err, "fetching repo") } ret := []params.Instance{} - for _, pool := range repo.Pools { + for _, pool := range pools { for _, instance := range pool.Instances { ret = append(ret, s.sqlToParamsInstance(instance)) } @@ -766,7 +829,7 @@ func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]p } func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) { - org, err := s.getOrgByID(ctx, orgID) + org, err := s.getOrgByID(ctx, orgID, true) if err != nil { return nil, errors.Wrap(err, "fetching org") } @@ -779,10 +842,208 @@ func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]par return ret, nil } +func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (params.Pool, error) { + if param.Enabled != nil && pool.Enabled != *param.Enabled { + pool.Enabled = *param.Enabled + } + + if param.Flavor != "" { + pool.Flavor = param.Flavor + } + + if param.Image != "" { + pool.Image = param.Image + } + + if param.MaxRunners != nil { + pool.MaxRunners = *param.MaxRunners + } + + if param.MinIdleRunners != nil { + pool.MinIdleRunners = *param.MinIdleRunners + } + + if param.OSArch != "" { + pool.OSArch = param.OSArch + } + + if param.OSType != "" { + pool.OSType = param.OSType + } + + if q := s.conn.Save(&pool); q.Error != nil { + return params.Pool{}, errors.Wrap(q.Error, "saving database entry") + } + + if len(param.Tags) > 0 { + tags := make([]Tag, len(param.Tags)) + for idx, t := range param.Tags { + tags[idx] = Tag{ + Name: t.Name, + } + } + + if err := s.conn.Model(&pool).Association("Tags").Replace(&tags); err != nil { + return params.Pool{}, errors.Wrap(err, "replacing tags") + } + } + + return s.sqlToCommonPool(pool), nil +} + func (s *sqlDatabase) UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - return params.Pool{}, nil + pool, err := s.getRepoPool(ctx, repoID, poolID) + if err != nil { + return params.Pool{}, errors.Wrap(err, "fetching pool") + } + + return s.updatePool(pool, param) } func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - return params.Pool{}, nil + pool, err := s.getOrgPool(ctx, orgID, poolID, true) + if err != nil { + return params.Pool{}, errors.Wrap(err, "fetching pool") + } + + return s.updatePool(pool, param) +} + +func (s *sqlDatabase) sqlToParamsUser(user User) params.User { + return params.User{ + ID: user.ID.String(), + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + Email: user.Email, + Username: user.Username, + FullName: user.FullName, + Password: user.Password, + Enabled: user.Enabled, + IsAdmin: user.IsAdmin, + } +} + +func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) { + field := "username" + if util.IsValidEmail(user) { + field = "email" + } + query := fmt.Sprintf("%s = ?", field) + + var dbUser User + q := s.conn.Model(&User{}).Where(query, user).First(&dbUser) + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return User{}, runnerErrors.ErrNotFound + } + return User{}, errors.Wrap(q.Error, "fetching user") + } + return dbUser, nil +} + +func (s *sqlDatabase) CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error) { + if user.Username == "" || user.Email == "" { + return params.User{}, runnerErrors.NewBadRequestError("missing username or email") + } + if _, err := s.getUserByUsernameOrEmail(user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) { + return params.User{}, runnerErrors.NewConflictError("username already exists") + } + if _, err := s.getUserByUsernameOrEmail(user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) { + return params.User{}, runnerErrors.NewConflictError("email already exists") + } + + newUser := User{ + Username: user.Username, + Password: user.Password, + FullName: user.FullName, + Enabled: user.Enabled, + Email: user.Email, + IsAdmin: user.IsAdmin, + } + + q := s.conn.Save(&newUser) + if q.Error != nil { + return params.User{}, errors.Wrap(q.Error, "creating user") + } + return params.User{}, nil +} + +func (s *sqlDatabase) HasAdminUser(ctx context.Context) bool { + var user User + q := s.conn.Model(&User{}).Where("is_admin = ?", true).First(&user) + if q.Error != nil { + return false + } + return true +} + +func (s *sqlDatabase) GetUser(ctx context.Context, user string) (params.User, error) { + dbUser, err := s.getUserByUsernameOrEmail(user) + if err != nil { + return params.User{}, errors.Wrap(err, "fetching user") + } + return s.sqlToParamsUser(dbUser), nil +} + +func (s *sqlDatabase) UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error) { + dbUser, err := s.getUserByUsernameOrEmail(user) + if err != nil { + return params.User{}, errors.Wrap(err, "fetching user") + } + + if param.FullName != "" { + dbUser.FullName = param.FullName + } + + if param.Enabled != nil { + dbUser.Enabled = *param.Enabled + } + + if param.Password != "" { + dbUser.Password = param.Password + } + + if q := s.conn.Save(&dbUser); q.Error != nil { + return params.User{}, errors.Wrap(q.Error, "saving user") + } + + return s.sqlToParamsUser(dbUser), nil +} + +func (s *sqlDatabase) ControllerInfo() (params.ControllerInfo, error) { + var info ControllerInfo + q := s.conn.Model(&ControllerInfo{}).First(&info) + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return params.ControllerInfo{}, errors.Wrap(runnerErrors.ErrNotFound, "fetching controller info") + } + return params.ControllerInfo{}, errors.Wrap(q.Error, "fetching controller info") + } + return params.ControllerInfo{ + ControllerID: info.ControllerID, + }, nil +} + +func (s *sqlDatabase) InitController() (params.ControllerInfo, error) { + if _, err := s.ControllerInfo(); err == nil { + return params.ControllerInfo{}, runnerErrors.NewConflictError("controller already initialized") + } + + newID, err := uuid.NewV4() + if err != nil { + return params.ControllerInfo{}, errors.Wrap(err, "generating UUID") + } + + newInfo := ControllerInfo{ + ControllerID: newID, + } + + q := s.conn.Save(&newInfo) + if q.Error != nil { + return params.ControllerInfo{}, errors.Wrap(q.Error, "saving controller info") + } + + return params.ControllerInfo{ + ControllerID: newInfo.ControllerID, + }, nil } diff --git a/go.mod b/go.mod index 3d642309..de115722 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,8 @@ require ( github.com/gorilla/mux v1.8.0 github.com/lxc/lxd v0.0.0-20220415052741-1170f2806124 github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 - github.com/pborman/uuid v1.2.1 github.com/pkg/errors v0.9.1 - github.com/satori/go.uuid v1.2.0 + github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be gopkg.in/natefinch/lumberjack.v2 v2.0.0 @@ -38,6 +37,7 @@ require ( github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/fs v0.1.0 // indirect github.com/mattn/go-sqlite3 v1.14.12 // indirect + github.com/pborman/uuid v1.2.1 // indirect github.com/pkg/sftp v1.13.4 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/rogpeppe/fastuuid v1.2.0 // indirect diff --git a/go.sum b/go.sum index 8085af31..e6019bff 100644 --- a/go.sum +++ b/go.sum @@ -117,8 +117,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0 h1:Ppwyp6VYCF1nvBTXL3trRso7mXMlRrw9ooo375wvi2s= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b h1:gQZ0qzfKHQIybLANtM3mBXNUtOfsCFXeTsnBqCsx1KM= +github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/params/params.go b/params/params.go index 8b0e11ce..6992cd5f 100644 --- a/params/params.go +++ b/params/params.go @@ -3,8 +3,10 @@ package params import ( "runner-manager/config" "runner-manager/runner/providers/common" + "time" "github.com/google/go-github/v43/github" + uuid "github.com/satori/go.uuid" ) type AddressType string @@ -19,20 +21,6 @@ type Address struct { Type AddressType `json:"type"` } -type UpdateInstanceParams struct { - ProviderID string `json:"provider_id,omitempty"` - // OSName is the name of the OS. Eg: ubuntu, centos, etc. - OSName string `json:"os_name,omitempty"` - // OSVersion is the version of the operating system. - OSVersion string `json:"os_version,omitempty"` - // Addresses is a list of IP addresses the provider reports - // for this instance. - Addresses []Address `json:"addresses,omitempty"` - // Status is the status of the instance inside the provider (eg: running, stopped, etc) - Status common.InstanceStatus `json:"status"` - RunnerStatus common.RunnerStatus `json:"runner_status"` -} - type Instance struct { // ID is the database ID of this instance. ID string `json:"id"` @@ -117,66 +105,55 @@ type Internal struct { } type Repository struct { - ID string `json:"id"` - Owner string `json:"owner"` - Name string `json:"name"` - Pools []Pool `json:"pool,omitempty"` + ID string `json:"id"` + Owner string `json:"owner"` + Name string `json:"name"` + Pools []Pool `json:"pool,omitempty"` + CredentialsName string `json:"credentials_name"` // Do not serialize sensitive info. WebhookSecret string `json:"-"` Internal Internal `json:"-"` } type Organization struct { - ID string `json:"id"` - Name string `json:"name"` - Pools []Pool `json:"pool,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + Pools []Pool `json:"pool,omitempty"` + CredentialsName string `json:"credentials_name"` // Do not serialize sensitive info. WebhookSecret string `json:"-"` Internal Internal `json:"-"` } -type CreatePoolParams struct { - ProviderName string `json:"provider_name"` - MaxRunners uint `json:"max_runners"` - MinIdleRunners uint `json:"min_idle_runners"` - Image string `json:"image"` - Flavor string `json:"flavor"` - OSType config.OSType `json:"os_type"` - OSArch config.OSArch `json:"os_arch"` - Tags []string `json:"tags"` - Enabled bool `json:"enabled"` +// Users holds information about a particular user +type User struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Email string `json:"email"` + Username string `json:"username"` + FullName string `json:"full_name"` + Password string `json:"-"` + Enabled bool `json:"enabled"` + IsAdmin bool `json:"is_admin"` } -type CreateInstanceParams struct { - Name string - OSType config.OSType - OSArch config.OSArch - Status common.InstanceStatus - RunnerStatus common.RunnerStatus - CallbackURL string - - Pool string +// JWTResponse holds the JWT token returned as a result of a +// successful auth +type JWTResponse struct { + Token string `json:"token"` } -/* -type Pool struct { - ID string `json:"id"` - ProviderName string `json:"provider_name"` - MaxRunners uint `json:"max_runners"` - MinIdleRunners uint `json:"min_idle_runners"` - Image string `json:"image"` - Flavor string `json:"flavor"` - OSType config.OSType `json:"os_type"` - OSArch config.OSArch `json:"os_arch"` - Tags []Tag `json:"tags"` - Enabled bool `json:"enabled"` +type ControllerInfo struct { + ControllerID uuid.UUID `json:"controller_id"` } -*/ -type UpdatePoolParams struct { - Tags []Tag `json:"tags"` - Enabled bool `json:"enabled"` - MaxRunners uint `json:"max_runners"` - MinIdleRunners uint `json:"min_idle_runners"` - Image string `json:"image"` - Flavor string `json:"flavor"` + +type GithubCredentials struct { + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` +} + +type Provider struct { + Name string `json:"name"` + ProviderType config.ProviderType `json:"type"` } diff --git a/params/requests.go b/params/requests.go index 8e8bac92..c3113024 100644 --- a/params/requests.go +++ b/params/requests.go @@ -1,9 +1,101 @@ package params -import "runner-manager/config" +import ( + "runner-manager/config" + "runner-manager/errors" + "runner-manager/runner/providers/common" +) type InstanceRequest struct { Name string `json:"name"` OSType config.OSType `json:"os_type"` OSVersion string `json:"os_version"` } + +type CreateRepoParams struct { + Owner string `json:"owner"` + Name string `json:"name"` + CredentialsName string `json:"credentials_name"` + WebhookSecret string `json:"webhook_secret"` +} + +// NewUserParams holds the needed information to create +// a new user +type NewUserParams struct { + Email string `json:"email"` + Username string `json:"username"` + FullName string `json:"full_name"` + Password string `json:"password"` + IsAdmin bool `json:"-"` + Enabled bool `json:"-"` +} + +type UpdatePoolParams struct { + Tags []Tag `json:"tags"` + Enabled *bool `json:"enabled"` + MaxRunners *uint `json:"max_runners"` + MinIdleRunners *uint `json:"min_idle_runners"` + Image string `json:"image"` + Flavor string `json:"flavor"` + OSType config.OSType `json:"os_type"` + OSArch config.OSArch `json:"os_arch"` +} + +type CreateInstanceParams struct { + Name string + OSType config.OSType + OSArch config.OSArch + Status common.InstanceStatus + RunnerStatus common.RunnerStatus + CallbackURL string + + Pool string +} + +type CreatePoolParams struct { + ProviderName string `json:"provider_name"` + MaxRunners uint `json:"max_runners"` + MinIdleRunners uint `json:"min_idle_runners"` + Image string `json:"image"` + Flavor string `json:"flavor"` + OSType config.OSType `json:"os_type"` + OSArch config.OSArch `json:"os_arch"` + Tags []string `json:"tags"` + Enabled bool `json:"enabled"` +} + +type UpdateInstanceParams struct { + ProviderID string `json:"provider_id,omitempty"` + // OSName is the name of the OS. Eg: ubuntu, centos, etc. + OSName string `json:"os_name,omitempty"` + // OSVersion is the version of the operating system. + OSVersion string `json:"os_version,omitempty"` + // Addresses is a list of IP addresses the provider reports + // for this instance. + Addresses []Address `json:"addresses,omitempty"` + // Status is the status of the instance inside the provider (eg: running, stopped, etc) + Status common.InstanceStatus `json:"status"` + RunnerStatus common.RunnerStatus `json:"runner_status"` +} + +type UpdateUserParams struct { + FullName string `json:"full_name"` + Password string `json:"password"` + Enabled *bool `json:"enabled"` +} + +// PasswordLoginParams holds information used during +// password authentication, that will be passed to a +// password login function +type PasswordLoginParams struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// Validate checks if the username and password are set +func (p PasswordLoginParams) Validate() error { + if p.Username == "" || p.Password == "" { + return errors.ErrUnauthorized + } + return nil +} diff --git a/runner/common/provider.go b/runner/common/provider.go index ab7e3734..203ae7cc 100644 --- a/runner/common/provider.go +++ b/runner/common/provider.go @@ -20,4 +20,6 @@ type Provider interface { Stop(ctx context.Context, instance string, force bool) error // Start boots up an instance. Start(ctx context.Context, instance string) error + + AsParams() params.Provider } diff --git a/runner/pool/repository.go b/runner/pool/repository.go index 8ec1a246..5545cc39 100644 --- a/runner/pool/repository.go +++ b/runner/pool/repository.go @@ -46,10 +46,6 @@ func NewRepositoryPoolManager(ctx context.Context, cfg params.Repository, provid quit: make(chan struct{}), done: make(chan struct{}), } - - if err := repo.fetchTools(); err != nil { - return nil, errors.Wrap(err, "initializing tools") - } return repo, nil } @@ -96,6 +92,10 @@ func (r *Repository) getProviderInstances() ([]params.Instance, error) { } func (r *Repository) Start() error { + if err := r.fetchTools(); err != nil { + return errors.Wrap(err, "initializing tools") + } + runners, err := r.getGithubRunners() if err != nil { return errors.Wrap(err, "fetching github runners") @@ -334,7 +334,8 @@ func (r *Repository) ensureMinIdleRunners() { projectedInstanceCount := len(existingInstances) + required if uint(projectedInstanceCount) > pool.MaxRunners { // ensure we don't go above max workers - required = (len(existingInstances) + required) - int(pool.MaxRunners) + delta := projectedInstanceCount - int(pool.MaxRunners) + required = required - delta } } diff --git a/runner/providers/lxd/lxd.go b/runner/providers/lxd/lxd.go index 7f6a72d3..fadb2dbe 100644 --- a/runner/providers/lxd/lxd.go +++ b/runner/providers/lxd/lxd.go @@ -231,6 +231,13 @@ func (l *LXD) getCreateInstanceArgs(bootstrapParams params.BootstrapInstance) (a return args, nil } +func (l *LXD) AsParams() params.Provider { + return params.Provider{ + Name: l.cfg.Name, + ProviderType: l.cfg.ProviderType, + } +} + func (l *LXD) launchInstance(createArgs api.InstancesPost) error { // Get LXD to create the instance (background operation) op, err := l.cli.CreateInstance(createArgs) diff --git a/runner/runner.go b/runner/runner.go index 93c63a98..c0f93b9c 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -21,6 +21,7 @@ import ( gErrors "runner-manager/errors" "runner-manager/params" "runner-manager/runner/common" + "runner-manager/runner/pool" "runner-manager/runner/providers" "runner-manager/util" @@ -46,15 +47,21 @@ func NewRunner(ctx context.Context, cfg config.Config) (*Runner, error) { runner := &Runner{ ctx: ctx, config: cfg, - db: db, + store: db, // ghc: ghc, - providers: providers, + repositories: map[string]common.PoolManager{}, + organizations: map[string]common.PoolManager{}, + providers: providers, } if err := runner.ensureSSHKeys(); err != nil { return nil, errors.Wrap(err, "ensuring SSH keys") } + if err := runner.loadReposAndOrgs(); err != nil { + return nil, errors.Wrap(err, "loading pool managers") + } + return runner, nil } @@ -63,7 +70,7 @@ type Runner struct { ctx context.Context // ghc *github.Client - db dbCommon.Store + store dbCommon.Store controllerID string @@ -73,35 +80,170 @@ type Runner struct { providers map[string]common.Provider } -func (r *Runner) loadPools() error { +func (r *Runner) CreateRepository(ctx context.Context) error { + return nil +} + +func (r *Runner) ListRepositories(ctx context.Context) error { + return nil +} + +func (r *Runner) GetRepository(ctx context.Context) error { + return nil +} + +func (r *Runner) DeleteRepository(ctx context.Context) error { + return nil +} + +func (r *Runner) UpdateRepository(ctx context.Context) error { + return nil +} + +func (r *Runner) CreateRepoPool(ctx context.Context) error { + return nil +} + +func (r *Runner) DeleteRepoPool(ctx context.Context) error { + return nil +} + +func (r *Runner) ListRepoPools(ctx context.Context) error { + return nil +} + +func (r *Runner) UpdateRepoPool(ctx context.Context) error { + return nil +} + +func (r *Runner) ListPoolInstances(ctx context.Context) error { + return nil +} + +func (r *Runner) ListCredentials(ctx context.Context) ([]params.GithubCredentials, error) { + ret := []params.GithubCredentials{} + + for _, val := range r.config.Github { + ret = append(ret, params.GithubCredentials{ + Name: val.Name, + Description: val.Description, + }) + } + return ret, nil +} + +func (r *Runner) ListProviders(ctx context.Context) ([]params.Provider, error) { + ret := []params.Provider{} + + for _, val := range r.providers { + ret = append(ret, val.AsParams()) + } + return ret, nil +} + +func (r *Runner) loadReposAndOrgs() error { r.mux.Lock() defer r.mux.Unlock() - // repos, err := r.db.ListRepositories(r.ctx) - // if err != nil { - // return errors.Wrap(err, "fetching repositories") - // } + repos, err := r.store.ListRepositories(r.ctx) + if err != nil { + return errors.Wrap(err, "fetching repositories") + } + + for _, repo := range repos { + log.Printf("creating pool manager for %s/%s", repo.Owner, repo.Name) + poolManager, err := pool.NewRepositoryPoolManager(r.ctx, repo, r.providers, r.store) + if err != nil { + return errors.Wrap(err, "creating pool manager") + } + r.repositories[repo.ID] = poolManager + } return nil } -func (r *Runner) findRepoPool(owner, name string) (common.PoolManager, error) { +func (r *Runner) Start() error { + for _, repo := range r.repositories { + if err := repo.Start(); err != nil { + return errors.Wrap(err, "starting repo pool manager") + } + } + + for _, org := range r.organizations { + if err := org.Start(); err != nil { + return errors.Wrap(err, "starting org pool manager") + } + } + return nil +} + +func (r *Runner) Stop() error { + for _, repo := range r.repositories { + if err := repo.Stop(); err != nil { + return errors.Wrap(err, "starting repo pool manager") + } + } + + for _, org := range r.organizations { + if err := org.Stop(); err != nil { + return errors.Wrap(err, "starting org pool manager") + } + } + return nil +} + +func (r *Runner) Wait() error { + var wg sync.WaitGroup + + for poolId, repo := range r.repositories { + wg.Add(1) + go func(id string, poolMgr common.PoolManager) { + defer wg.Done() + if err := poolMgr.Wait(); err != nil { + log.Printf("timed out waiting for pool manager %s to exit", id) + } + }(poolId, repo) + } + + for poolId, org := range r.organizations { + wg.Add(1) + go func(id string, poolMgr common.PoolManager) { + defer wg.Done() + if err := poolMgr.Wait(); err != nil { + log.Printf("timed out waiting for pool manager %s to exit", id) + } + }(poolId, org) + } + wg.Wait() + return nil +} + +func (r *Runner) findRepoPoolManager(owner, name string) (common.PoolManager, error) { r.mux.Lock() defer r.mux.Unlock() - // key := fmt.Sprintf("%s/%s", owner, name) - // if repo, ok := r.repositories[key]; ok { - // return pool, nil - // } + repo, err := r.store.GetRepository(r.ctx, owner, name) + if err != nil { + return nil, errors.Wrap(err, "fetching repo") + } - // repo, err := r.db.GetRepository(r.ctx, owner, name) - // r.repositories[key] = repo - return nil, errors.Wrapf(gErrors.ErrNotFound, "repository %s not configured", name) + if repo, ok := r.repositories[repo.ID]; ok { + return repo, nil + } + return nil, errors.Wrapf(gErrors.ErrNotFound, "repository %s/%s not configured", owner, name) } -func (r *Runner) findOrgPool(name string) (common.PoolManager, error) { - if pool, ok := r.organizations[name]; ok { - return pool, nil +func (r *Runner) findOrgPoolManager(name string) (common.PoolManager, error) { + r.mux.Lock() + defer r.mux.Unlock() + + org, err := r.store.GetOrganization(r.ctx, name) + if err != nil { + return nil, errors.Wrap(err, "fetching repo") + } + + if orgPoolMgr, ok := r.organizations[org.ID]; ok { + return orgPoolMgr, nil } return nil, errors.Wrapf(gErrors.ErrNotFound, "organization %s not configured", name) } @@ -165,9 +307,9 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ switch HookTargetType(hookTargetType) { case RepoHook: - poolManager, err = r.findRepoPool(job.Repository.Owner.Login, job.Repository.Name) + poolManager, err = r.findRepoPoolManager(job.Repository.Owner.Login, job.Repository.Name) case OrganizationHook: - poolManager, err = r.findOrgPool(job.Organization.Login) + poolManager, err = r.findOrgPoolManager(job.Organization.Login) default: return gErrors.NewBadRequestError("cannot handle hook target type %s", hookTargetType) } @@ -185,6 +327,10 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ return errors.Wrap(err, "validating webhook data") } + if err := poolManager.HandleWorkflowJob(job); err != nil { + return errors.Wrap(err, "handling workflow job") + } + return nil } diff --git a/testdata/config.toml b/testdata/config.toml index 3fdb2f8e..9165586f 100644 --- a/testdata/config.toml +++ b/testdata/config.toml @@ -1,5 +1,12 @@ +[default] +config_dir = "/etc/runner-manager" +callback_url = "https://webhooks.samfira.com/api/v1/instances/status" # log_file = "/tmp/runner-manager.log" +[jwt_auth] +secret = "L&CGG?%VaV;Zs5CnGqaWINDBhx 254 || !rxEmail.MatchString(email) { + return false + } + return true +} + +func IsAlphanumeric(s string) bool { + for _, r := range s { + if !unicode.IsLetter(r) && !unicode.IsNumber(r) { + return false + } + } + return true +} + // GetLoggingWriter returns a new io.Writer suitable for logging. func GetLoggingWriter(cfg *config.Config) (io.Writer, error) { var writer io.Writer = os.Stdout - if cfg.LogFile != "" { - dirname := path.Dir(cfg.LogFile) + if cfg.Default.LogFile != "" { + dirname := path.Dir(cfg.Default.LogFile) if _, err := os.Stat(dirname); err != nil { if !os.IsNotExist(err) { return nil, fmt.Errorf("failed to create log folder") @@ -61,7 +83,7 @@ func GetLoggingWriter(cfg *config.Config) (io.Writer, error) { } } writer = &lumberjack.Logger{ - Filename: cfg.LogFile, + Filename: cfg.Default.LogFile, MaxSize: 500, // megabytes MaxBackups: 3, MaxAge: 28, //days @@ -71,16 +93,6 @@ func GetLoggingWriter(cfg *config.Config) (io.Writer, error) { return writer, nil } -func FindRunnerType(runnerType string, runners []config.Runner) (config.Runner, error) { - for _, runner := range runners { - if runner.Name == runnerType { - return runner, nil - } - } - - return config.Runner{}, runnerErrors.ErrNotFound -} - func ConvertFileToBase64(file string) (string, error) { bytes, err := ioutil.ReadFile(file) if err != nil { @@ -256,3 +268,13 @@ func Aes256DecodeString(target []byte, passphrase string) (string, error) { } return string(plaintext), nil } + +// PaswsordToBcrypt returns a bcrypt hash of the specified password using the default cost +func PaswsordToBcrypt(password string) (string, error) { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + // TODO: make this a fatal error, that should return a 500 error to user + return "", fmt.Errorf("failed to hash password") + } + return string(hashedPassword), nil +}