Add enterprise support

This change adds enterprise support throughout garm.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2022-10-13 16:09:28 +00:00
parent f40420bfb6
commit 296333412a
No known key found for this signature in database
GPG key ID: 7D073DCC2C074CB5
34 changed files with 2028 additions and 112 deletions

View file

@ -0,0 +1,282 @@
// Copyright 2022 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package controllers
import (
"encoding/json"
"log"
"net/http"
"garm/apiserver/params"
gErrors "garm/errors"
runnerParams "garm/params"
"github.com/gorilla/mux"
)
func (a *APIController) CreateEnterpriseHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var enterpriseData runnerParams.CreateEnterpriseParams
if err := json.NewDecoder(r.Body).Decode(&enterpriseData); err != nil {
handleError(w, gErrors.ErrBadRequest)
return
}
enterprise, err := a.r.CreateEnterprise(ctx, enterpriseData)
if err != nil {
log.Printf("error creating enterprise: %+v", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(enterprise)
}
func (a *APIController) ListEnterprisesHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
enterprise, err := a.r.ListEnterprises(ctx)
if err != nil {
log.Printf("listing enterprise: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(enterprise)
}
func (a *APIController) GetEnterpriseByIDHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, ok := vars["enterpriseID"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise ID specified",
})
return
}
enterprise, err := a.r.GetEnterpriseByID(ctx, enterpriseID)
if err != nil {
log.Printf("fetching enterprise: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(enterprise)
}
func (a *APIController) DeleteEnterpriseHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, ok := vars["enterpriseID"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise ID specified",
})
return
}
if err := a.r.DeleteEnterprise(ctx, enterpriseID); err != nil {
log.Printf("removing enterprise: %+v", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
}
func (a *APIController) UpdateEnterpriseHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, ok := vars["enterpriseID"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise ID specified",
})
return
}
var updatePayload runnerParams.UpdateRepositoryParams
if err := json.NewDecoder(r.Body).Decode(&updatePayload); err != nil {
handleError(w, gErrors.ErrBadRequest)
return
}
enterprise, err := a.r.UpdateEnterprise(ctx, enterpriseID, updatePayload)
if err != nil {
log.Printf("error updating enterprise: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(enterprise)
}
func (a *APIController) CreateEnterprisePoolHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, ok := vars["enterpriseID"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise ID specified",
})
return
}
var poolData runnerParams.CreatePoolParams
if err := json.NewDecoder(r.Body).Decode(&poolData); err != nil {
log.Printf("failed to decode: %s", err)
handleError(w, gErrors.ErrBadRequest)
return
}
pool, err := a.r.CreateEnterprisePool(ctx, enterpriseID, poolData)
if err != nil {
log.Printf("error creating enterprise pool: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(pool)
}
func (a *APIController) ListEnterprisePoolsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, ok := vars["enterpriseID"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise ID specified",
})
return
}
pools, err := a.r.ListEnterprisePools(ctx, enterpriseID)
if err != nil {
log.Printf("listing pools: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(pools)
}
func (a *APIController) GetEnterprisePoolHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, enterpriseOk := vars["enterpriseID"]
poolID, poolOk := vars["poolID"]
if !enterpriseOk || !poolOk {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise or pool ID specified",
})
return
}
pool, err := a.r.GetEnterprisePoolByID(ctx, enterpriseID, poolID)
if err != nil {
log.Printf("listing pools: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(pool)
}
func (a *APIController) DeleteEnterprisePoolHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, enterpriseOk := vars["enterpriseID"]
poolID, poolOk := vars["poolID"]
if !enterpriseOk || !poolOk {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise or pool ID specified",
})
return
}
if err := a.r.DeleteEnterprisePool(ctx, enterpriseID, poolID); err != nil {
log.Printf("removing pool: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
}
func (a *APIController) UpdateEnterprisePoolHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, enterpriseOk := vars["enterpriseID"]
poolID, poolOk := vars["poolID"]
if !enterpriseOk || !poolOk {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise or pool ID specified",
})
return
}
var poolData runnerParams.UpdatePoolParams
if err := json.NewDecoder(r.Body).Decode(&poolData); err != nil {
log.Printf("failed to decode: %s", err)
handleError(w, gErrors.ErrBadRequest)
return
}
pool, err := a.r.UpdateEnterprisePool(ctx, enterpriseID, poolID, poolData)
if err != nil {
log.Printf("error creating enterprise pool: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(pool)
}

View file

@ -136,7 +136,31 @@ func (a *APIController) ListOrgInstancesHandler(w http.ResponseWriter, r *http.R
instances, err := a.r.ListOrgInstances(ctx, orgID)
if err != nil {
log.Printf("listing pools: %s", err)
log.Printf("listing instances: %s", err)
handleError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(instances)
}
func (a *APIController) ListEnterpriseInstancesHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
enterpriseID, ok := vars["enterpriseID"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(params.APIErrorResponse{
Error: "Bad Request",
Details: "No enterprise ID specified",
})
return
}
instances, err := a.r.ListOrgInstances(ctx, enterpriseID)
if err != nil {
log.Printf("listing instances: %s", err)
handleError(w, err)
return
}

View file

@ -165,6 +165,45 @@ func NewAPIRouter(han *controllers.APIController, logWriter io.Writer, authMiddl
apiRouter.Handle("/organizations/", log(logWriter, http.HandlerFunc(han.CreateOrgHandler))).Methods("POST", "OPTIONS")
apiRouter.Handle("/organizations", log(logWriter, http.HandlerFunc(han.CreateOrgHandler))).Methods("POST", "OPTIONS")
/////////////////////////////
// Enterprises and pools //
/////////////////////////////
// Get pool
apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", log(os.Stdout, http.HandlerFunc(han.GetEnterprisePoolHandler))).Methods("GET", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", log(os.Stdout, http.HandlerFunc(han.GetEnterprisePoolHandler))).Methods("GET", "OPTIONS")
// Delete pool
apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", log(os.Stdout, http.HandlerFunc(han.DeleteEnterprisePoolHandler))).Methods("DELETE", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", log(os.Stdout, http.HandlerFunc(han.DeleteEnterprisePoolHandler))).Methods("DELETE", "OPTIONS")
// Update pool
apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}/", log(os.Stdout, http.HandlerFunc(han.UpdateEnterprisePoolHandler))).Methods("PUT", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}/pools/{poolID}", log(os.Stdout, http.HandlerFunc(han.UpdateEnterprisePoolHandler))).Methods("PUT", "OPTIONS")
// List pools
apiRouter.Handle("/enterprises/{enterpriseID}/pools/", log(os.Stdout, http.HandlerFunc(han.ListEnterprisePoolsHandler))).Methods("GET", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}/pools", log(os.Stdout, http.HandlerFunc(han.ListEnterprisePoolsHandler))).Methods("GET", "OPTIONS")
// Create pool
apiRouter.Handle("/enterprises/{enterpriseID}/pools/", log(os.Stdout, http.HandlerFunc(han.CreateEnterprisePoolHandler))).Methods("POST", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}/pools", log(os.Stdout, http.HandlerFunc(han.CreateEnterprisePoolHandler))).Methods("POST", "OPTIONS")
// Repo instances list
apiRouter.Handle("/enterprises/{enterpriseID}/instances/", log(os.Stdout, http.HandlerFunc(han.ListEnterpriseInstancesHandler))).Methods("GET", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}/instances", log(os.Stdout, http.HandlerFunc(han.ListEnterpriseInstancesHandler))).Methods("GET", "OPTIONS")
// Get org
apiRouter.Handle("/enterprises/{enterpriseID}/", log(os.Stdout, http.HandlerFunc(han.GetEnterpriseByIDHandler))).Methods("GET", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}", log(os.Stdout, http.HandlerFunc(han.GetEnterpriseByIDHandler))).Methods("GET", "OPTIONS")
// Update org
apiRouter.Handle("/enterprises/{enterpriseID}/", log(os.Stdout, http.HandlerFunc(han.UpdateEnterpriseHandler))).Methods("PUT", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}", log(os.Stdout, http.HandlerFunc(han.UpdateEnterpriseHandler))).Methods("PUT", "OPTIONS")
// Delete org
apiRouter.Handle("/enterprises/{enterpriseID}/", log(os.Stdout, http.HandlerFunc(han.DeleteEnterpriseHandler))).Methods("DELETE", "OPTIONS")
apiRouter.Handle("/enterprises/{enterpriseID}", log(os.Stdout, http.HandlerFunc(han.DeleteEnterpriseHandler))).Methods("DELETE", "OPTIONS")
// List orgs
apiRouter.Handle("/enterprises/", log(os.Stdout, http.HandlerFunc(han.ListEnterprisesHandler))).Methods("GET", "OPTIONS")
apiRouter.Handle("/enterprises", log(os.Stdout, http.HandlerFunc(han.ListEnterprisesHandler))).Methods("GET", "OPTIONS")
// Create org
apiRouter.Handle("/enterprises/", log(os.Stdout, http.HandlerFunc(han.CreateEnterpriseHandler))).Methods("POST", "OPTIONS")
apiRouter.Handle("/enterprises", log(os.Stdout, http.HandlerFunc(han.CreateEnterpriseHandler))).Methods("POST", "OPTIONS")
// Credentials and providers
apiRouter.Handle("/credentials/", log(logWriter, http.HandlerFunc(han.ListCredentials))).Methods("GET", "OPTIONS")
apiRouter.Handle("/credentials", log(logWriter, http.HandlerFunc(han.ListCredentials))).Methods("GET", "OPTIONS")

View file

@ -65,12 +65,12 @@ const (
// of time and no new updates have been made to it's state, it will be removed.
DefaultRunnerBootstrapTimeout = 20
// DefaultGithubURL is the default URL where Github or Github Enterprise can be accessed
GithubBaseURL = "https://github.com"
// DefaultGithubURL is the default URL where Github or Github Enterprise can be accessed.
DefaultGithubURL = "https://github.com"
// defaultBaseURL is the default URL for the github API
// defaultBaseURL is the default URL for the github API.
defaultBaseURL = "https://api.github.com/"
// uploadBaseURL is the default URL for guthub uploads
// uploadBaseURL is the default URL for guthub uploads.
uploadBaseURL = "https://uploads.github.com/"
)
@ -250,7 +250,7 @@ func (g *Github) BaseEndpoint() string {
if g.BaseURL != "" {
return g.BaseURL
}
return GithubBaseURL
return DefaultGithubURL
}
func (g *Github) Validate() error {

View file

@ -19,7 +19,7 @@ import (
"garm/params"
)
type Store interface {
type RepoStore interface {
CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string) (params.Repository, error)
GetRepository(ctx context.Context, owner, name string) (params.Repository, error)
GetRepositoryByID(ctx context.Context, repoID string) (params.Repository, error)
@ -27,6 +27,18 @@ type Store interface {
DeleteRepository(ctx context.Context, repoID string) error
UpdateRepository(ctx context.Context, repoID string, param params.UpdateRepositoryParams) (params.Repository, error)
CreateRepositoryPool(ctx context.Context, repoId string, param params.CreatePoolParams) (params.Pool, error)
GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error)
DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error
UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error)
ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error)
ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error)
}
type OrgStore interface {
CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string) (params.Organization, error)
GetOrganization(ctx context.Context, name string) (params.Organization, error)
GetOrganizationByID(ctx context.Context, orgID string) (params.Organization, error)
@ -34,53 +46,77 @@ type Store interface {
DeleteOrganization(ctx context.Context, orgID string) error
UpdateOrganization(ctx context.Context, orgID string, param params.UpdateRepositoryParams) (params.Organization, error)
CreateRepositoryPool(ctx context.Context, repoId string, param params.CreatePoolParams) (params.Pool, error)
CreateOrganizationPool(ctx context.Context, orgId string, param params.CreatePoolParams) (params.Pool, error)
GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error)
GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error)
DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error
UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error)
FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error)
ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error)
ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error)
}
type EnterpriseStore interface {
CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string) (params.Enterprise, error)
GetEnterprise(ctx context.Context, name string) (params.Enterprise, error)
GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error)
ListEnterprises(ctx context.Context) ([]params.Enterprise, error)
DeleteEnterprise(ctx context.Context, enterpriseID string) error
UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateRepositoryParams) (params.Enterprise, error)
CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error)
GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error)
DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error
UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error)
ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error)
ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error)
}
type PoolStore interface {
// Probably a bad idea without some king of filter or at least pagination
// TODO: add filter/pagination
ListAllPools(ctx context.Context) ([]params.Pool, error)
GetPoolByID(ctx context.Context, poolID string) (params.Pool, error)
DeletePoolByID(ctx context.Context, poolID string) error
DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error
DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error
UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error)
FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error)
FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error)
CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error)
DeleteInstance(ctx context.Context, poolID string, instanceName string) error
UpdateInstance(ctx context.Context, instanceID string, param params.UpdateInstanceParams) (params.Instance, error)
ListPoolInstances(ctx context.Context, poolID string) ([]params.Instance, error)
ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error)
ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error)
PoolInstanceCount(ctx context.Context, poolID string) (int64, error)
// Probably a bad idea without some king of filter or at least pagination
// TODO: add filter/pagination
ListAllInstances(ctx context.Context) ([]params.Instance, error)
GetPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error)
GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error)
AddInstanceStatusMessage(ctx context.Context, instanceID string, statusMessage string) error
}
type UserStore interface {
GetUser(ctx context.Context, user string) (params.User, error)
GetUserByID(ctx context.Context, userID string) (params.User, error)
CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error)
UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error)
HasAdminUser(ctx context.Context) bool
}
type InstanceStore interface {
CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error)
DeleteInstance(ctx context.Context, poolID string, instanceName string) error
UpdateInstance(ctx context.Context, instanceID string, param params.UpdateInstanceParams) (params.Instance, error)
// Probably a bad idea without some king of filter or at least pagination
// TODO: add filter/pagination
ListAllInstances(ctx context.Context) ([]params.Instance, error)
GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error)
AddInstanceStatusMessage(ctx context.Context, instanceID string, statusMessage string) error
}
//go:generate mockery --name=Store
type Store interface {
RepoStore
OrgStore
EnterpriseStore
PoolStore
UserStore
InstanceStore
ControllerInfo() (params.ControllerInfo, error)
InitController() (params.ControllerInfo, error)

View file

@ -49,6 +49,48 @@ func (_m *Store) ControllerInfo() (params.ControllerInfo, error) {
return r0, r1
}
// CreateEnterprise provides a mock function with given fields: ctx, name, credentialsName, webhookSecret
func (_m *Store) CreateEnterprise(ctx context.Context, name string, credentialsName string, webhookSecret string) (params.Enterprise, error) {
ret := _m.Called(ctx, name, credentialsName, webhookSecret)
var r0 params.Enterprise
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) params.Enterprise); ok {
r0 = rf(ctx, name, credentialsName, webhookSecret)
} else {
r0 = ret.Get(0).(params.Enterprise)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, name, credentialsName, webhookSecret)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CreateEnterprisePool provides a mock function with given fields: ctx, enterpriseID, param
func (_m *Store) CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) {
ret := _m.Called(ctx, enterpriseID, param)
var r0 params.Pool
if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) params.Pool); ok {
r0 = rf(ctx, enterpriseID, param)
} else {
r0 = ret.Get(0).(params.Pool)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, params.CreatePoolParams) error); ok {
r1 = rf(ctx, enterpriseID, param)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CreateInstance provides a mock function with given fields: ctx, poolID, param
func (_m *Store) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) {
ret := _m.Called(ctx, poolID, param)
@ -175,6 +217,34 @@ func (_m *Store) CreateUser(ctx context.Context, user params.NewUserParams) (par
return r0, r1
}
// DeleteEnterprise provides a mock function with given fields: ctx, enterpriseID
func (_m *Store) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
ret := _m.Called(ctx, enterpriseID)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, enterpriseID)
} else {
r0 = ret.Error(0)
}
return r0
}
// DeleteEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID
func (_m *Store) DeleteEnterprisePool(ctx context.Context, enterpriseID string, poolID string) error {
ret := _m.Called(ctx, enterpriseID, poolID)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, enterpriseID, poolID)
} else {
r0 = ret.Error(0)
}
return r0
}
// DeleteInstance provides a mock function with given fields: ctx, poolID, instanceName
func (_m *Store) DeleteInstance(ctx context.Context, poolID string, instanceName string) error {
ret := _m.Called(ctx, poolID, instanceName)
@ -259,6 +329,27 @@ func (_m *Store) DeleteRepositoryPool(ctx context.Context, repoID string, poolID
return r0
}
// FindEnterprisePoolByTags provides a mock function with given fields: ctx, enterpriseID, tags
func (_m *Store) FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) {
ret := _m.Called(ctx, enterpriseID, tags)
var r0 params.Pool
if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok {
r0 = rf(ctx, enterpriseID, tags)
} else {
r0 = ret.Get(0).(params.Pool)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok {
r1 = rf(ctx, enterpriseID, tags)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindOrganizationPoolByTags provides a mock function with given fields: ctx, orgID, tags
func (_m *Store) FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) {
ret := _m.Called(ctx, orgID, tags)
@ -301,6 +392,69 @@ func (_m *Store) FindRepositoryPoolByTags(ctx context.Context, repoID string, ta
return r0, r1
}
// GetEnterprise provides a mock function with given fields: ctx, name
func (_m *Store) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) {
ret := _m.Called(ctx, name)
var r0 params.Enterprise
if rf, ok := ret.Get(0).(func(context.Context, string) params.Enterprise); ok {
r0 = rf(ctx, name)
} else {
r0 = ret.Get(0).(params.Enterprise)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, name)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetEnterpriseByID provides a mock function with given fields: ctx, enterpriseID
func (_m *Store) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) {
ret := _m.Called(ctx, enterpriseID)
var r0 params.Enterprise
if rf, ok := ret.Get(0).(func(context.Context, string) params.Enterprise); ok {
r0 = rf(ctx, enterpriseID)
} else {
r0 = ret.Get(0).(params.Enterprise)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, enterpriseID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID
func (_m *Store) GetEnterprisePool(ctx context.Context, enterpriseID string, poolID string) (params.Pool, error) {
ret := _m.Called(ctx, enterpriseID, poolID)
var r0 params.Pool
if rf, ok := ret.Get(0).(func(context.Context, string, string) params.Pool); ok {
r0 = rf(ctx, enterpriseID, poolID)
} else {
r0 = ret.Get(0).(params.Pool)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, enterpriseID, poolID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetInstanceByName provides a mock function with given fields: ctx, instanceName
func (_m *Store) GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error) {
ret := _m.Called(ctx, instanceName)
@ -613,6 +767,75 @@ func (_m *Store) ListAllPools(ctx context.Context) ([]params.Pool, error) {
return r0, r1
}
// ListEnterpriseInstances provides a mock function with given fields: ctx, enterpriseID
func (_m *Store) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) {
ret := _m.Called(ctx, enterpriseID)
var r0 []params.Instance
if rf, ok := ret.Get(0).(func(context.Context, string) []params.Instance); ok {
r0 = rf(ctx, enterpriseID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]params.Instance)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, enterpriseID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListEnterprisePools provides a mock function with given fields: ctx, enterpriseID
func (_m *Store) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) {
ret := _m.Called(ctx, enterpriseID)
var r0 []params.Pool
if rf, ok := ret.Get(0).(func(context.Context, string) []params.Pool); ok {
r0 = rf(ctx, enterpriseID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]params.Pool)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, enterpriseID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListEnterprises provides a mock function with given fields: ctx
func (_m *Store) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) {
ret := _m.Called(ctx)
var r0 []params.Enterprise
if rf, ok := ret.Get(0).(func(context.Context) []params.Enterprise); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]params.Enterprise)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListOrgInstances provides a mock function with given fields: ctx, orgID
func (_m *Store) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) {
ret := _m.Called(ctx, orgID)
@ -795,6 +1018,48 @@ func (_m *Store) PoolInstanceCount(ctx context.Context, poolID string) (int64, e
return r0, r1
}
// UpdateEnterprise provides a mock function with given fields: ctx, enterpriseID, param
func (_m *Store) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateRepositoryParams) (params.Enterprise, error) {
ret := _m.Called(ctx, enterpriseID, param)
var r0 params.Enterprise
if rf, ok := ret.Get(0).(func(context.Context, string, params.UpdateRepositoryParams) params.Enterprise); ok {
r0 = rf(ctx, enterpriseID, param)
} else {
r0 = ret.Get(0).(params.Enterprise)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, params.UpdateRepositoryParams) error); ok {
r1 = rf(ctx, enterpriseID, param)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID, param
func (_m *Store) UpdateEnterprisePool(ctx context.Context, enterpriseID string, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
ret := _m.Called(ctx, enterpriseID, poolID, param)
var r0 params.Pool
if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) params.Pool); ok {
r0 = rf(ctx, enterpriseID, poolID, param)
} else {
r0 = ret.Get(0).(params.Pool)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, string, params.UpdatePoolParams) error); ok {
r1 = rf(ctx, enterpriseID, poolID, param)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateInstance provides a mock function with given fields: ctx, instanceID, param
func (_m *Store) UpdateInstance(ctx context.Context, instanceID string, param params.UpdateInstanceParams) (params.Instance, error) {
ret := _m.Called(ctx, instanceID, param)

367
database/sql/enterprise.go Normal file
View file

@ -0,0 +1,367 @@
package sql
import (
"context"
"fmt"
runnerErrors "garm/errors"
"garm/params"
"garm/util"
"github.com/pkg/errors"
uuid "github.com/satori/go.uuid"
"gorm.io/gorm"
)
func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string) (params.Enterprise, error) {
secret := []byte{}
var err error
if webhookSecret != "" {
secret, err = util.Aes256EncodeString(webhookSecret, s.cfg.Passphrase)
if err != nil {
return params.Enterprise{}, fmt.Errorf("failed to encrypt string")
}
}
newEnterprise := Enterprise{
Name: name,
WebhookSecret: secret,
CredentialsName: credentialsName,
}
q := s.conn.Create(&newEnterprise)
if q.Error != nil {
return params.Enterprise{}, errors.Wrap(q.Error, "creating enterprise")
}
param := s.sqlToCommonEnterprise(newEnterprise)
param.WebhookSecret = webhookSecret
return param, nil
}
func (s *sqlDatabase) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) {
enterprise, err := s.getEnterprise(ctx, name)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
}
param := s.sqlToCommonEnterprise(enterprise)
secret, err := util.Aes256DecodeString(enterprise.WebhookSecret, s.cfg.Passphrase)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "decrypting secret")
}
param.WebhookSecret = secret
return param, nil
}
func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID, "Pools")
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
}
param := s.sqlToCommonEnterprise(enterprise)
secret, err := util.Aes256DecodeString(enterprise.WebhookSecret, s.cfg.Passphrase)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "decrypting secret")
}
param.WebhookSecret = secret
return param, nil
}
func (s *sqlDatabase) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) {
var enterprises []Enterprise
q := s.conn.Find(&enterprises)
if q.Error != nil {
return []params.Enterprise{}, errors.Wrap(q.Error, "fetching enterprise from database")
}
ret := make([]params.Enterprise, len(enterprises))
for idx, val := range enterprises {
ret[idx] = s.sqlToCommonEnterprise(val)
}
return ret, nil
}
func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil {
return errors.Wrap(err, "fetching enterprise")
}
q := s.conn.Unscoped().Delete(&enterprise)
if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) {
return errors.Wrap(q.Error, "deleting enterprise")
}
return nil
}
func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateRepositoryParams) (params.Enterprise, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
}
if param.CredentialsName != "" {
enterprise.CredentialsName = param.CredentialsName
}
if param.WebhookSecret != "" {
secret, err := util.Aes256EncodeString(param.WebhookSecret, s.cfg.Passphrase)
if err != nil {
return params.Enterprise{}, fmt.Errorf("failed to encrypt string")
}
enterprise.WebhookSecret = secret
}
q := s.conn.Save(&enterprise)
if q.Error != nil {
return params.Enterprise{}, errors.Wrap(q.Error, "saving enterprise")
}
newParams := s.sqlToCommonEnterprise(enterprise)
secret, err := util.Aes256DecodeString(enterprise.WebhookSecret, s.cfg.Passphrase)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "decrypting secret")
}
newParams.WebhookSecret = secret
return newParams, nil
}
func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) {
if len(param.Tags) == 0 {
return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified")
}
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching enterprise")
}
newPool := Pool{
ProviderName: param.ProviderName,
MaxRunners: param.MaxRunners,
MinIdleRunners: param.MinIdleRunners,
Image: param.Image,
Flavor: param.Flavor,
OSType: param.OSType,
OSArch: param.OSArch,
EnterpriseID: enterprise.ID,
Enabled: param.Enabled,
RunnerBootstrapTimeout: param.RunnerBootstrapTimeout,
}
_, err = s.getEnterprisePoolByUniqueFields(ctx, enterpriseID, newPool.ProviderName, newPool.Image, newPool.Flavor)
if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
return params.Pool{}, errors.Wrap(err, "creating pool")
}
} else {
return params.Pool{}, runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider")
}
tags := []Tag{}
for _, val := range param.Tags {
t, err := s.getOrCreateTag(val)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching tag")
}
tags = append(tags, t)
}
q := s.conn.Create(&newPool)
if q.Error != nil {
return params.Pool{}, errors.Wrap(q.Error, "adding pool")
}
for _, tt := range tags {
if err := s.conn.Model(&newPool).Association("Tags").Append(&tt); err != nil {
return params.Pool{}, errors.Wrap(err, "saving tag")
}
}
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return s.sqlToCommonPool(pool), nil
}
func (s *sqlDatabase) GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) {
pool, err := s.getEnterprisePool(ctx, enterpriseID, poolID, "Tags", "Instances")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return s.sqlToCommonPool(pool), nil
}
func (s *sqlDatabase) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error {
pool, err := s.getEnterprisePool(ctx, enterpriseID, poolID)
if err != nil {
return errors.Wrap(err, "looking up enterprise pool")
}
q := s.conn.Unscoped().Delete(&pool)
if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) {
return errors.Wrap(q.Error, "deleting pool")
}
return nil
}
func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
pool, err := s.getEnterprisePool(ctx, enterpriseID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return s.updatePool(pool, param)
}
func (s *sqlDatabase) FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) {
pool, err := s.findPoolByTags(enterpriseID, "enterprise_id", tags)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return pool, nil
}
func (s *sqlDatabase) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) {
pools, err := s.getEnterprisePools(ctx, enterpriseID, "Tags")
if err != nil {
return nil, errors.Wrap(err, "fetching pools")
}
ret := make([]params.Pool, len(pools))
for idx, pool := range pools {
ret[idx] = s.sqlToCommonPool(pool)
}
return ret, nil
}
func (s *sqlDatabase) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) {
pools, err := s.getEnterprisePools(ctx, enterpriseID, "Instances")
if err != nil {
return nil, errors.Wrap(err, "fetching enterprise")
}
ret := []params.Instance{}
for _, pool := range pools {
for _, instance := range pool.Instances {
ret = append(ret, s.sqlToParamsInstance(instance))
}
}
return ret, nil
}
func (s *sqlDatabase) getEnterprise(ctx context.Context, name string) (Enterprise, error) {
var enterprise Enterprise
q := s.conn.Where("name = ? COLLATE NOCASE", name)
q = q.First(&enterprise)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return Enterprise{}, runnerErrors.ErrNotFound
}
return Enterprise{}, errors.Wrap(q.Error, "fetching enterprise from database")
}
return enterprise, nil
}
func (s *sqlDatabase) getEnterpriseByID(ctx context.Context, id string, preload ...string) (Enterprise, error) {
u, err := uuid.FromString(id)
if err != nil {
return Enterprise{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var enterprise Enterprise
q := s.conn
if len(preload) > 0 {
for _, field := range preload {
q = q.Preload(field)
}
}
q = q.Where("id = ?", u).First(&enterprise)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return Enterprise{}, runnerErrors.ErrNotFound
}
return Enterprise{}, errors.Wrap(q.Error, "fetching enterprise from database")
}
return enterprise, nil
}
func (s *sqlDatabase) getEnterprisePoolByUniqueFields(ctx context.Context, enterpriseID string, provider, image, flavor string) (Pool, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching enterprise")
}
q := s.conn
var pool []Pool
err = q.Model(&enterprise).Association("Pools").Find(&pool, "provider_name = ? and image = ? and flavor = ?", provider, image, flavor)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching pool")
}
if len(pool) == 0 {
return Pool{}, runnerErrors.ErrNotFound
}
return pool[0], nil
}
func (s *sqlDatabase) getEnterprisePool(ctx context.Context, enterpriseID, poolID string, preload ...string) (Pool, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching enterprise")
}
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
q := s.conn
if len(preload) > 0 {
for _, item := range preload {
q = q.Preload(item)
}
}
var pool []Pool
err = q.Model(&enterprise).Association("Pools").Find(&pool, "id = ?", u)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching pool")
}
if len(pool) == 0 {
return Pool{}, runnerErrors.ErrNotFound
}
return pool[0], nil
}
func (s *sqlDatabase) getEnterprisePools(ctx context.Context, enterpriseID string, preload ...string) ([]Pool, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil {
return nil, errors.Wrap(err, "fetching enterprise")
}
var pools []Pool
q := s.conn.Model(&enterprise)
if len(preload) > 0 {
for _, item := range preload {
q = q.Preload(item)
}
}
err = q.Association("Pools").Find(&pools)
if err != nil {
return nil, errors.Wrap(err, "fetching pool")
}
return pools, nil
}

View file

@ -71,6 +71,9 @@ type Pool struct {
OrgID uuid.UUID `gorm:"index"`
Organization Organization `gorm:"foreignKey:OrgID"`
EnterpriseID uuid.UUID `gorm:"index"`
Enterprise Enterprise `gorm:"foreignKey:EnterpriseID"`
Instances []Instance `gorm:"foreignKey:PoolID"`
}
@ -93,6 +96,15 @@ type Organization struct {
Pools []Pool `gorm:"foreignKey:OrgID"`
}
type Enterprise struct {
Base
CredentialsName string
Name string `gorm:"index:idx_ent_name_nocase,collate:nocase"`
WebhookSecret []byte
Pools []Pool `gorm:"foreignKey:EnterpriseID"`
}
type Address struct {
Base

View file

@ -72,7 +72,7 @@ func (s *sqlDatabase) ListOrganizations(ctx context.Context) ([]params.Organizat
var orgs []Organization
q := s.conn.Find(&orgs)
if q.Error != nil {
return []params.Organization{}, errors.Wrap(q.Error, "fetching user from database")
return []params.Organization{}, errors.Wrap(q.Error, "fetching org from database")
}
ret := make([]params.Organization, len(orgs))
@ -197,7 +197,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string,
}
}
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Organization", "Repository")
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
@ -262,7 +262,7 @@ func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]par
}
func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
pool, err := s.getOrgPool(ctx, orgID, poolID, "Tags", "Instances", "Organization", "Repository")
pool, err := s.getOrgPool(ctx, orgID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}

View file

@ -42,7 +42,7 @@ func (s *sqlDatabase) ListAllPools(ctx context.Context) ([]params.Pool, error) {
}
func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) {
pool, err := s.getPoolByID(ctx, poolID, "Tags", "Instances", "Organization", "Repository")
pool, err := s.getPoolByID(ctx, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool by ID")
}

View file

@ -205,7 +205,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoId string, p
}
}
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Organization", "Repository")
pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
@ -271,7 +271,7 @@ func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]p
}
func (s *sqlDatabase) UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
pool, err := s.getRepoPool(ctx, repoID, poolID, "Tags", "Instances", "Organization", "Repository")
pool, err := s.getRepoPool(ctx, repoID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}

View file

@ -96,13 +96,14 @@ func (s *sqlDatabase) migrateDB() error {
&Pool{},
&Repository{},
&Organization{},
&Enterprise{},
&Address{},
&InstanceStatusUpdate{},
&Instance{},
&ControllerInfo{},
&User{},
); err != nil {
return err
return errors.Wrap(err, "running auto migrate")
}
return nil

View file

@ -85,6 +85,21 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization) params.Organizat
return ret
}
func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise) params.Enterprise {
ret := params.Enterprise{
ID: enterprise.ID.String(),
Name: enterprise.Name,
CredentialsName: enterprise.CredentialsName,
Pools: make([]params.Pool, len(enterprise.Pools)),
}
for idx, pool := range enterprise.Pools {
ret.Pools[idx] = s.sqlToCommonPool(pool)
}
return ret
}
func (s *sqlDatabase) sqlToCommonPool(pool Pool) params.Pool {
ret := params.Pool{
ID: pool.ID.String(),

4
go.mod
View file

@ -6,7 +6,7 @@ require (
github.com/BurntSushi/toml v0.4.1
github.com/go-resty/resty/v2 v2.7.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/go-github/v43 v43.0.0
github.com/google/go-github/v47 v47.1.0
github.com/google/uuid v1.3.0
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0
@ -71,3 +71,5 @@ require (
gopkg.in/macaroon-bakery.v2 v2.3.0 // indirect
gopkg.in/macaroon.v2 v2.1.0 // indirect
)
replace github.com/google/go-github/v47 => github.com/gabriel-samfira/go-github/v47 v47.1.1-0.20221013145953-21e3b4d7b0c1

7
go.sum
View file

@ -64,6 +64,8 @@ github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03D
github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y=
github.com/frankban/quicktest v1.11.3 h1:8sXhOn0uLys67V8EsXLc6eszDs8VXWxL3iRvebPhedY=
github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k=
github.com/gabriel-samfira/go-github/v47 v47.1.1-0.20221013145953-21e3b4d7b0c1 h1:CNZ1asZM2ABO6DLFPS86CkGMEp5nFSQnpAECOOhYBGo=
github.com/gabriel-samfira/go-github/v47 v47.1.1-0.20221013145953-21e3b4d7b0c1/go.mod h1:VPZBXNbFSJGjyjFRUKo9vZGawTajnWzC/YjGw/oFKi0=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
@ -116,9 +118,7 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-github/v43 v43.0.0 h1:y+GL7LIsAIF2NZlJ46ZoC/D1W1ivZasT0lnWHMYPZ+U=
github.com/google/go-github/v43 v43.0.0/go.mod h1:ZkTvvmCXBvsfPpTHXnH/d2hP9Y0cTbvN9kr5xqyXOIc=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
@ -438,7 +438,6 @@ golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=

View file

@ -19,7 +19,7 @@ import (
"garm/runner/providers/common"
"time"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
uuid "github.com/satori/go.uuid"
)
@ -167,6 +167,15 @@ type Organization struct {
WebhookSecret string `json:"-"`
}
type Enterprise struct {
ID string `json:"id"`
Name string `json:"name"`
Pools []Pool `json:"pool,omitempty"`
CredentialsName string `json:"credentials_name"`
// Do not serialize sensitive info.
WebhookSecret string `json:"-"`
}
// Users holds information about a particular user
type User struct {
ID string `json:"id"`

View file

@ -66,6 +66,23 @@ func (c *CreateOrgParams) Validate() error {
return nil
}
type CreateEnterpriseParams struct {
Name string `json:"name"`
CredentialsName string `json:"credentials_name"`
WebhookSecret string `json:"webhook_secret"`
}
func (c *CreateEnterpriseParams) Validate() error {
if c.Name == "" {
return errors.NewBadRequestError("missing org name")
}
if c.CredentialsName == "" {
return errors.NewBadRequestError("missing credentials name")
}
return nil
}
// NewUserParams holds the needed information to create
// a new user
type NewUserParams struct {

View file

@ -5,7 +5,7 @@ package mocks
import (
context "context"
github "github.com/google/go-github/v43/github"
github "github.com/google/go-github/v47/github"
mock "github.com/stretchr/testify/mock"
)
@ -78,6 +78,38 @@ func (_m *GithubClient) CreateRegistrationToken(ctx context.Context, owner strin
return r0, r1, r2
}
// GetWorkflowJobByID provides a mock function with given fields: ctx, owner, repo, jobID
func (_m *GithubClient) GetWorkflowJobByID(ctx context.Context, owner string, repo string, jobID int64) (*github.WorkflowJob, *github.Response, error) {
ret := _m.Called(ctx, owner, repo, jobID)
var r0 *github.WorkflowJob
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *github.WorkflowJob); ok {
r0 = rf(ctx, owner, repo, jobID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*github.WorkflowJob)
}
}
var r1 *github.Response
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) *github.Response); ok {
r1 = rf(ctx, owner, repo, jobID)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*github.Response)
}
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, string, string, int64) error); ok {
r2 = rf(ctx, owner, repo, jobID)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// ListOrganizationRunnerApplicationDownloads provides a mock function with given fields: ctx, owner
func (_m *GithubClient) ListOrganizationRunnerApplicationDownloads(ctx context.Context, owner string) ([]*github.RunnerApplicationDownload, *github.Response, error) {
ret := _m.Called(ctx, owner)

View file

@ -0,0 +1,149 @@
// Code generated by mockery v2.14.0. DO NOT EDIT.
package mocks
import (
context "context"
github "github.com/google/go-github/v47/github"
mock "github.com/stretchr/testify/mock"
)
// GithubEnterpriseClient is an autogenerated mock type for the GithubEnterpriseClient type
type GithubEnterpriseClient struct {
mock.Mock
}
// CreateRegistrationToken provides a mock function with given fields: ctx, enterprise
func (_m *GithubEnterpriseClient) CreateRegistrationToken(ctx context.Context, enterprise string) (*github.RegistrationToken, *github.Response, error) {
ret := _m.Called(ctx, enterprise)
var r0 *github.RegistrationToken
if rf, ok := ret.Get(0).(func(context.Context, string) *github.RegistrationToken); ok {
r0 = rf(ctx, enterprise)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*github.RegistrationToken)
}
}
var r1 *github.Response
if rf, ok := ret.Get(1).(func(context.Context, string) *github.Response); ok {
r1 = rf(ctx, enterprise)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*github.Response)
}
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, string) error); ok {
r2 = rf(ctx, enterprise)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// ListRunnerApplicationDownloads provides a mock function with given fields: ctx, enterprise
func (_m *GithubEnterpriseClient) ListRunnerApplicationDownloads(ctx context.Context, enterprise string) ([]*github.RunnerApplicationDownload, *github.Response, error) {
ret := _m.Called(ctx, enterprise)
var r0 []*github.RunnerApplicationDownload
if rf, ok := ret.Get(0).(func(context.Context, string) []*github.RunnerApplicationDownload); ok {
r0 = rf(ctx, enterprise)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*github.RunnerApplicationDownload)
}
}
var r1 *github.Response
if rf, ok := ret.Get(1).(func(context.Context, string) *github.Response); ok {
r1 = rf(ctx, enterprise)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*github.Response)
}
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, string) error); ok {
r2 = rf(ctx, enterprise)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// ListRunners provides a mock function with given fields: ctx, enterprise, opts
func (_m *GithubEnterpriseClient) ListRunners(ctx context.Context, enterprise string, opts *github.ListOptions) (*github.Runners, *github.Response, error) {
ret := _m.Called(ctx, enterprise, opts)
var r0 *github.Runners
if rf, ok := ret.Get(0).(func(context.Context, string, *github.ListOptions) *github.Runners); ok {
r0 = rf(ctx, enterprise, opts)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*github.Runners)
}
}
var r1 *github.Response
if rf, ok := ret.Get(1).(func(context.Context, string, *github.ListOptions) *github.Response); ok {
r1 = rf(ctx, enterprise, opts)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*github.Response)
}
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, string, *github.ListOptions) error); ok {
r2 = rf(ctx, enterprise, opts)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// RemoveRunner provides a mock function with given fields: ctx, enterprise, runnerID
func (_m *GithubEnterpriseClient) RemoveRunner(ctx context.Context, enterprise string, runnerID int64) (*github.Response, error) {
ret := _m.Called(ctx, enterprise, runnerID)
var r0 *github.Response
if rf, ok := ret.Get(0).(func(context.Context, string, int64) *github.Response); ok {
r0 = rf(ctx, enterprise, runnerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*github.Response)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, int64) error); ok {
r1 = rf(ctx, enterprise, runnerID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
type mockConstructorTestingTNewGithubEnterpriseClient interface {
mock.TestingT
Cleanup(func())
}
// NewGithubEnterpriseClient creates a new instance of GithubEnterpriseClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewGithubEnterpriseClient(t mockConstructorTestingTNewGithubEnterpriseClient) *GithubEnterpriseClient {
mock := &GithubEnterpriseClient{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View file

@ -32,6 +32,7 @@ const (
PoolToolUpdateInterval = 50 * time.Minute
)
//go:generate mockery --all
type PoolManager interface {
ID() string
WebhookSecret() string

View file

@ -19,6 +19,7 @@ import (
"garm/params"
)
//go:generate mockery --all
type Provider interface {
// CreateInstance creates a new compute instance in the provider.
CreateInstance(ctx context.Context, bootstrapParams params.BootstrapInstance) (params.Instance, error)

View file

@ -3,11 +3,13 @@ package common
import (
"context"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
)
// GithubClient that describes the minimum list of functions we need to interact with github.
// Allows for easier testing.
//
//go:generate mockery --all
type GithubClient interface {
// GetWorkflowJobByID gets details about a single workflow job.
GetWorkflowJobByID(ctx context.Context, owner, repo string, jobID int64) (*github.WorkflowJob, *github.Response, error)
@ -31,3 +33,15 @@ type GithubClient interface {
// CreateOrganizationRegistrationToken creates a runner registration token for an organization.
CreateOrganizationRegistrationToken(ctx context.Context, owner string) (*github.RegistrationToken, *github.Response, error)
}
type GithubEnterpriseClient interface {
// ListRunners lists all runners within a repository.
ListRunners(ctx context.Context, enterprise string, opts *github.ListOptions) (*github.Runners, *github.Response, error)
// RemoveRunner removes one runner from an enterprise.
RemoveRunner(ctx context.Context, enterprise string, runnerID int64) (*github.Response, error)
// CreateRegistrationToken creates a runner registration token for an enterprise.
CreateRegistrationToken(ctx context.Context, enterprise string) (*github.RegistrationToken, *github.Response, error)
// ListRunnerApplicationDownloads returns a list of github runner application downloads for the
// various supported operating systems and architectures.
ListRunnerApplicationDownloads(ctx context.Context, enterprise string) ([]*github.RunnerApplicationDownload, *github.Response, error)
}

313
runner/enterprises.go Normal file
View file

@ -0,0 +1,313 @@
package runner
import (
"context"
"garm/auth"
"garm/config"
runnerErrors "garm/errors"
"garm/params"
"garm/runner/common"
"log"
"strings"
"github.com/pkg/errors"
)
func (r *Runner) CreateEnterprise(ctx context.Context, param params.CreateEnterpriseParams) (enterprise params.Enterprise, err error) {
if !auth.IsAdmin(ctx) {
return enterprise, runnerErrors.ErrUnauthorized
}
err = param.Validate()
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "validating params")
}
creds, ok := r.credentials[param.CredentialsName]
if !ok {
return params.Enterprise{}, runnerErrors.NewBadRequestError("credentials %s not defined", param.CredentialsName)
}
_, err = r.store.GetEnterprise(ctx, param.Name)
if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
}
} else {
return params.Enterprise{}, runnerErrors.NewConflictError("enterprise %s already exists", param.Name)
}
enterprise, err = r.store.CreateEnterprise(ctx, param.Name, creds.Name, param.WebhookSecret)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "creating enterprise")
}
defer func() {
if err != nil {
r.store.DeleteEnterprise(ctx, enterprise.ID)
}
}()
var poolMgr common.PoolManager
poolMgr, err = r.poolManagerCtrl.CreateEnterprisePoolManager(r.ctx, enterprise, r.providers, r.store)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "creating enterprise pool manager")
}
if err := poolMgr.Start(); err != nil {
if deleteErr := r.poolManagerCtrl.DeleteEnterprisePoolManager(enterprise); deleteErr != nil {
log.Printf("failed to cleanup pool manager for enterprise %s", enterprise.ID)
}
return params.Enterprise{}, errors.Wrap(err, "starting enterprise pool manager")
}
return enterprise, nil
}
func (r *Runner) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) {
if !auth.IsAdmin(ctx) {
return nil, runnerErrors.ErrUnauthorized
}
enterprises, err := r.store.ListEnterprises(ctx)
if err != nil {
return nil, errors.Wrap(err, "listing enterprises")
}
return enterprises, nil
}
func (r *Runner) GetEnterpriseByID(ctx context.Context, enterpriseID string) (params.Enterprise, error) {
if !auth.IsAdmin(ctx) {
return params.Enterprise{}, runnerErrors.ErrUnauthorized
}
enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
}
return enterprise, nil
}
func (r *Runner) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
if !auth.IsAdmin(ctx) {
return runnerErrors.ErrUnauthorized
}
enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID)
if err != nil {
return errors.Wrap(err, "fetching enterprise")
}
pools, err := r.store.ListEnterprisePools(ctx, enterpriseID)
if err != nil {
return errors.Wrap(err, "fetching enterprise pools")
}
if len(pools) > 0 {
poolIds := []string{}
for _, pool := range pools {
poolIds = append(poolIds, pool.ID)
}
return runnerErrors.NewBadRequestError("enterprise has pools defined (%s)", strings.Join(poolIds, ", "))
}
if err := r.poolManagerCtrl.DeleteEnterprisePoolManager(enterprise); err != nil {
return errors.Wrap(err, "deleting enterprise pool manager")
}
if err := r.store.DeleteEnterprise(ctx, enterpriseID); err != nil {
return errors.Wrapf(err, "removing enterprise %s", enterpriseID)
}
return nil
}
func (r *Runner) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateRepositoryParams) (params.Enterprise, error) {
if !auth.IsAdmin(ctx) {
return params.Enterprise{}, runnerErrors.ErrUnauthorized
}
r.mux.Lock()
defer r.mux.Unlock()
enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "fetching enterprise")
}
if param.CredentialsName != "" {
// Check that credentials are set before saving to db
if _, ok := r.credentials[param.CredentialsName]; !ok {
return params.Enterprise{}, runnerErrors.NewBadRequestError("invalid credentials (%s) for enterprise %s", param.CredentialsName, enterprise.Name)
}
}
enterprise, err = r.store.UpdateEnterprise(ctx, enterpriseID, param)
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}
poolMgr, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise)
if err != nil {
newState := params.UpdatePoolStateParams{
WebhookSecret: enterprise.WebhookSecret,
}
// stop the pool mgr
if err := poolMgr.RefreshState(newState); err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise pool manager")
}
} else {
if _, err := r.poolManagerCtrl.CreateEnterprisePoolManager(r.ctx, enterprise, r.providers, r.store); err != nil {
return params.Enterprise{}, errors.Wrap(err, "creating enterprise pool manager")
}
}
return enterprise, nil
}
func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) {
if !auth.IsAdmin(ctx) {
return params.Pool{}, runnerErrors.ErrUnauthorized
}
r.mux.Lock()
defer r.mux.Unlock()
enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching enterprise")
}
if _, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise); err != nil {
return params.Pool{}, runnerErrors.ErrNotFound
}
createPoolParams, err := r.appendTagsToCreatePoolParams(param)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool params")
}
if param.RunnerBootstrapTimeout == 0 {
param.RunnerBootstrapTimeout = config.DefaultRunnerBootstrapTimeout
}
pool, err := r.store.CreateEnterprisePool(ctx, enterpriseID, createPoolParams)
if err != nil {
return params.Pool{}, errors.Wrap(err, "creating pool")
}
return pool, nil
}
func (r *Runner) GetEnterprisePoolByID(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) {
if !auth.IsAdmin(ctx) {
return params.Pool{}, runnerErrors.ErrUnauthorized
}
pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return pool, nil
}
func (r *Runner) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error {
if !auth.IsAdmin(ctx) {
return runnerErrors.ErrUnauthorized
}
// TODO: dedup instance count verification
pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID)
if err != nil {
return errors.Wrap(err, "fetching pool")
}
instances, err := r.store.ListPoolInstances(ctx, pool.ID)
if err != nil {
return errors.Wrap(err, "fetching instances")
}
// TODO: implement a count function
if len(instances) > 0 {
runnerIDs := []string{}
for _, run := range instances {
runnerIDs = append(runnerIDs, run.ID)
}
return runnerErrors.NewBadRequestError("pool has runners: %s", strings.Join(runnerIDs, ", "))
}
if err := r.store.DeleteEnterprisePool(ctx, enterpriseID, poolID); err != nil {
return errors.Wrap(err, "deleting pool")
}
return nil
}
func (r *Runner) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) {
if !auth.IsAdmin(ctx) {
return []params.Pool{}, runnerErrors.ErrUnauthorized
}
pools, err := r.store.ListEnterprisePools(ctx, enterpriseID)
if err != nil {
return nil, errors.Wrap(err, "fetching pools")
}
return pools, nil
}
func (r *Runner) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
if !auth.IsAdmin(ctx) {
return params.Pool{}, runnerErrors.ErrUnauthorized
}
pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
maxRunners := pool.MaxRunners
minIdleRunners := pool.MinIdleRunners
if param.MaxRunners != nil {
maxRunners = *param.MaxRunners
}
if param.MinIdleRunners != nil {
minIdleRunners = *param.MinIdleRunners
}
if minIdleRunners > maxRunners {
return params.Pool{}, runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners")
}
newPool, err := r.store.UpdateEnterprisePool(ctx, enterpriseID, poolID, param)
if err != nil {
return params.Pool{}, errors.Wrap(err, "updating pool")
}
return newPool, nil
}
func (r *Runner) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) {
if !auth.IsAdmin(ctx) {
return nil, runnerErrors.ErrUnauthorized
}
instances, err := r.store.ListEnterpriseInstances(ctx, enterpriseID)
if err != nil {
return []params.Instance{}, errors.Wrap(err, "fetching instances")
}
return instances, nil
}
func (r *Runner) findEnterprisePoolManager(name string) (common.PoolManager, error) {
r.mux.Lock()
defer r.mux.Unlock()
enterprise, err := r.store.GetEnterprise(r.ctx, name)
if err != nil {
return nil, errors.Wrap(err, "fetching enterprise")
}
poolManager, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise)
if err != nil {
return nil, errors.Wrap(err, "fetching pool manager for enterprise")
}
return poolManager, nil
}

View file

@ -21,16 +21,31 @@ import (
"garm/runner/common"
)
//go:generate mockery --name=PoolManagerController
type PoolManagerController interface {
type RepoPoolManager interface {
CreateRepoPoolManager(ctx context.Context, repo params.Repository, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error)
GetRepoPoolManager(repo params.Repository) (common.PoolManager, error)
DeleteRepoPoolManager(repo params.Repository) error
GetRepoPoolManagers() (map[string]common.PoolManager, error)
}
type OrgPoolManager interface {
CreateOrgPoolManager(ctx context.Context, org params.Organization, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error)
GetOrgPoolManager(org params.Organization) (common.PoolManager, error)
DeleteOrgPoolManager(org params.Organization) error
GetOrgPoolManagers() (map[string]common.PoolManager, error)
}
type EnterprisePoolManager interface {
CreateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error)
GetEnterprisePoolManager(enterprise params.Enterprise) (common.PoolManager, error)
DeleteEnterprisePoolManager(enterprise params.Enterprise) error
GetEnterprisePoolManagers() (map[string]common.PoolManager, error)
}
//go:generate mockery --name=PoolManagerController
type PoolManagerController interface {
RepoPoolManager
OrgPoolManager
EnterprisePoolManager
}

View file

@ -18,6 +18,29 @@ type PoolManagerController struct {
mock.Mock
}
// CreateEnterprisePoolManager provides a mock function with given fields: ctx, enterprise, providers, store
func (_m *PoolManagerController) CreateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise, providers map[string]common.Provider, store databasecommon.Store) (common.PoolManager, error) {
ret := _m.Called(ctx, enterprise, providers, store)
var r0 common.PoolManager
if rf, ok := ret.Get(0).(func(context.Context, params.Enterprise, map[string]common.Provider, databasecommon.Store) common.PoolManager); ok {
r0 = rf(ctx, enterprise, providers, store)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(common.PoolManager)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, params.Enterprise, map[string]common.Provider, databasecommon.Store) error); ok {
r1 = rf(ctx, enterprise, providers, store)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CreateOrgPoolManager provides a mock function with given fields: ctx, org, providers, store
func (_m *PoolManagerController) CreateOrgPoolManager(ctx context.Context, org params.Organization, providers map[string]common.Provider, store databasecommon.Store) (common.PoolManager, error) {
ret := _m.Called(ctx, org, providers, store)
@ -64,6 +87,20 @@ func (_m *PoolManagerController) CreateRepoPoolManager(ctx context.Context, repo
return r0, r1
}
// DeleteEnterprisePoolManager provides a mock function with given fields: enterprise
func (_m *PoolManagerController) DeleteEnterprisePoolManager(enterprise params.Enterprise) error {
ret := _m.Called(enterprise)
var r0 error
if rf, ok := ret.Get(0).(func(params.Enterprise) error); ok {
r0 = rf(enterprise)
} else {
r0 = ret.Error(0)
}
return r0
}
// DeleteOrgPoolManager provides a mock function with given fields: org
func (_m *PoolManagerController) DeleteOrgPoolManager(org params.Organization) error {
ret := _m.Called(org)
@ -92,6 +129,52 @@ func (_m *PoolManagerController) DeleteRepoPoolManager(repo params.Repository) e
return r0
}
// GetEnterprisePoolManager provides a mock function with given fields: enterprise
func (_m *PoolManagerController) GetEnterprisePoolManager(enterprise params.Enterprise) (common.PoolManager, error) {
ret := _m.Called(enterprise)
var r0 common.PoolManager
if rf, ok := ret.Get(0).(func(params.Enterprise) common.PoolManager); ok {
r0 = rf(enterprise)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(common.PoolManager)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(params.Enterprise) error); ok {
r1 = rf(enterprise)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetEnterprisePoolManagers provides a mock function with given fields:
func (_m *PoolManagerController) GetEnterprisePoolManagers() (map[string]common.PoolManager, error) {
ret := _m.Called()
var r0 map[string]common.PoolManager
if rf, ok := ret.Get(0).(func() map[string]common.PoolManager); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]common.PoolManager)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetOrgPoolManager provides a mock function with given fields: org
func (_m *PoolManagerController) GetOrgPoolManager(org params.Organization) (common.PoolManager, error) {
ret := _m.Called(org)

201
runner/pool/enterprise.go Normal file
View file

@ -0,0 +1,201 @@
package pool
import (
"context"
"fmt"
"math"
"strings"
"sync"
dbCommon "garm/database/common"
runnerErrors "garm/errors"
"garm/params"
"garm/runner/common"
"garm/util"
"github.com/google/go-github/v47/github"
"github.com/pkg/errors"
)
// test that we implement PoolManager
var _ poolHelper = &organization{}
func NewEnterprisePoolManager(ctx context.Context, cfg params.Enterprise, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
ghc, ghEnterpriseClient, err := util.GithubClient(ctx, cfgInternal.OAuth2Token, cfgInternal.GithubCredentialsDetails)
if err != nil {
return nil, errors.Wrap(err, "getting github client")
}
helper := &enterprise{
cfg: cfg,
cfgInternal: cfgInternal,
ctx: ctx,
ghcli: ghc,
ghcEnterpriseCli: ghEnterpriseClient,
id: cfg.ID,
store: store,
}
repo := &basePool{
ctx: ctx,
store: store,
providers: providers,
controllerID: cfgInternal.ControllerID,
quit: make(chan struct{}),
done: make(chan struct{}),
helper: helper,
credsDetails: cfgInternal.GithubCredentialsDetails,
}
return repo, nil
}
type enterprise struct {
cfg params.Enterprise
cfgInternal params.Internal
ctx context.Context
ghcli common.GithubClient
ghcEnterpriseCli common.GithubEnterpriseClient
id string
store dbCommon.Store
mux sync.Mutex
}
func (r *enterprise) GetRunnerNameFromWorkflow(job params.WorkflowJob) (string, error) {
workflow, _, err := r.ghcli.GetWorkflowJobByID(r.ctx, job.Repository.Owner.Login, job.Repository.Name, job.WorkflowJob.ID)
if err != nil {
return "", errors.Wrap(err, "fetching workflow info")
}
if workflow.RunnerName != nil {
return *workflow.RunnerName, nil
}
return "", fmt.Errorf("failed to find runner name from workflow")
}
func (r *enterprise) UpdateState(param params.UpdatePoolStateParams) error {
r.mux.Lock()
defer r.mux.Unlock()
r.cfg.WebhookSecret = param.WebhookSecret
ghc, ghcEnterprise, err := util.GithubClient(r.ctx, r.GetGithubToken(), r.cfgInternal.GithubCredentialsDetails)
if err != nil {
return errors.Wrap(err, "getting github client")
}
r.ghcli = ghc
r.ghcEnterpriseCli = ghcEnterprise
return nil
}
func (r *enterprise) GetGithubToken() string {
return r.cfgInternal.OAuth2Token
}
func (r *enterprise) GetGithubRunners() ([]*github.Runner, error) {
opts := github.ListOptions{
PerPage: 100,
Page: 1,
}
runners, _, err := r.ghcEnterpriseCli.ListRunners(r.ctx, r.cfg.Name, &opts)
if err != nil {
return nil, errors.Wrap(err, "fetching runners")
}
ret := []*github.Runner{}
ret = append(ret, runners.Runners...)
pages := math.Ceil(float64(runners.TotalCount) / float64(100))
if pages > 1 {
for i := 2; i <= int(pages); i++ {
opts.Page = i
runners, _, err = r.ghcEnterpriseCli.ListRunners(r.ctx, r.cfg.Name, &opts)
if err != nil {
return nil, errors.Wrap(err, "fetching runners")
}
ret = append(ret, runners.Runners...)
}
}
return ret, nil
}
func (r *enterprise) FetchTools() ([]*github.RunnerApplicationDownload, error) {
r.mux.Lock()
defer r.mux.Unlock()
tools, _, err := r.ghcEnterpriseCli.ListRunnerApplicationDownloads(r.ctx, r.cfg.Name)
if err != nil {
return nil, errors.Wrap(err, "fetching runner tools")
}
return tools, nil
}
func (r *enterprise) FetchDbInstances() ([]params.Instance, error) {
return r.store.ListEnterpriseInstances(r.ctx, r.id)
}
func (r *enterprise) RemoveGithubRunner(runnerID int64) (*github.Response, error) {
return r.ghcEnterpriseCli.RemoveRunner(r.ctx, r.cfg.Name, runnerID)
}
func (r *enterprise) ListPools() ([]params.Pool, error) {
pools, err := r.store.ListEnterprisePools(r.ctx, r.id)
if err != nil {
return nil, errors.Wrap(err, "fetching pools")
}
return pools, nil
}
func (r *enterprise) GithubURL() string {
return fmt.Sprintf("%s/enterprises/%s", r.cfgInternal.GithubCredentialsDetails.BaseURL, r.cfg.Name)
}
func (r *enterprise) JwtToken() string {
return r.cfgInternal.JWTSecret
}
func (r *enterprise) GetGithubRegistrationToken() (string, error) {
tk, _, err := r.ghcEnterpriseCli.CreateRegistrationToken(r.ctx, r.cfg.Name)
if err != nil {
return "", errors.Wrap(err, "creating runner token")
}
return *tk.Token, nil
}
func (r *enterprise) String() string {
return r.cfg.Name
}
func (r *enterprise) WebhookSecret() string {
return r.cfg.WebhookSecret
}
func (r *enterprise) GetCallbackURL() string {
return r.cfgInternal.InstanceCallbackURL
}
func (r *enterprise) FindPoolByTags(labels []string) (params.Pool, error) {
pool, err := r.store.FindEnterprisePoolByTags(r.ctx, r.id, labels)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching suitable pool")
}
return pool, nil
}
func (r *enterprise) GetPoolByID(poolID string) (params.Pool, error) {
pool, err := r.store.GetEnterprisePool(r.ctx, r.id, poolID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return pool, nil
}
func (r *enterprise) ValidateOwner(job params.WorkflowJob) error {
if !strings.EqualFold(job.Enterprise.Slug, r.cfg.Name) {
return runnerErrors.NewBadRequestError("job not meant for this pool manager")
}
return nil
}
func (r *enterprise) ID() string {
return r.id
}

View file

@ -17,7 +17,7 @@ package pool
import (
"garm/params"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
)
type poolHelper interface {

View file

@ -17,6 +17,7 @@ package pool
import (
"context"
"fmt"
"math"
"strings"
"sync"
@ -26,7 +27,7 @@ import (
"garm/runner/common"
"garm/util"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
"github.com/pkg/errors"
)
@ -34,7 +35,7 @@ import (
var _ poolHelper = &organization{}
func NewOrganizationPoolManager(ctx context.Context, cfg params.Organization, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
ghc, err := util.GithubClient(ctx, cfgInternal.OAuth2Token, cfgInternal.GithubCredentialsDetails)
ghc, _, err := util.GithubClient(ctx, cfgInternal.OAuth2Token, cfgInternal.GithubCredentialsDetails)
if err != nil {
return nil, errors.Wrap(err, "getting github client")
}
@ -89,7 +90,7 @@ func (r *organization) UpdateState(param params.UpdatePoolStateParams) error {
r.cfg.WebhookSecret = param.WebhookSecret
ghc, err := util.GithubClient(r.ctx, r.GetGithubToken(), r.cfgInternal.GithubCredentialsDetails)
ghc, _, err := util.GithubClient(r.ctx, r.GetGithubToken(), r.cfgInternal.GithubCredentialsDetails)
if err != nil {
return errors.Wrap(err, "getting github client")
}
@ -102,12 +103,30 @@ func (r *organization) GetGithubToken() string {
}
func (r *organization) GetGithubRunners() ([]*github.Runner, error) {
runners, _, err := r.ghcli.ListOrganizationRunners(r.ctx, r.cfg.Name, nil)
opts := github.ListOptions{
PerPage: 100,
Page: 1,
}
runners, _, err := r.ghcli.ListOrganizationRunners(r.ctx, r.cfg.Name, &opts)
if err != nil {
return nil, errors.Wrap(err, "fetching runners")
}
return runners.Runners, nil
ret := []*github.Runner{}
ret = append(ret, runners.Runners...)
pages := math.Ceil(float64(runners.TotalCount) / float64(100))
if pages > 1 {
for i := 2; i <= int(pages); i++ {
opts.Page = i
runners, _, err = r.ghcli.ListOrganizationRunners(r.ctx, r.cfg.Name, &opts)
if err != nil {
return nil, errors.Wrap(err, "fetching runners")
}
ret = append(ret, runners.Runners...)
}
}
return ret, nil
}
func (r *organization) FetchTools() ([]*github.RunnerApplicationDownload, error) {

View file

@ -30,7 +30,7 @@ import (
"garm/runner/common"
providerCommon "garm/runner/providers/common"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
"github.com/google/uuid"
"github.com/pkg/errors"
)

View file

@ -17,6 +17,7 @@ package pool
import (
"context"
"fmt"
"math"
"strings"
"sync"
@ -26,7 +27,7 @@ import (
"garm/runner/common"
"garm/util"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
"github.com/pkg/errors"
)
@ -34,7 +35,7 @@ import (
var _ poolHelper = &repository{}
func NewRepositoryPoolManager(ctx context.Context, cfg params.Repository, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
ghc, err := util.GithubClient(ctx, cfgInternal.OAuth2Token, cfgInternal.GithubCredentialsDetails)
ghc, _, err := util.GithubClient(ctx, cfgInternal.OAuth2Token, cfgInternal.GithubCredentialsDetails)
if err != nil {
return nil, errors.Wrap(err, "getting github client")
}
@ -91,7 +92,7 @@ func (r *repository) UpdateState(param params.UpdatePoolStateParams) error {
r.cfg.WebhookSecret = param.WebhookSecret
ghc, err := util.GithubClient(r.ctx, r.GetGithubToken(), r.cfgInternal.GithubCredentialsDetails)
ghc, _, err := util.GithubClient(r.ctx, r.GetGithubToken(), r.cfgInternal.GithubCredentialsDetails)
if err != nil {
return errors.Wrap(err, "getting github client")
}
@ -104,12 +105,29 @@ func (r *repository) GetGithubToken() string {
}
func (r *repository) GetGithubRunners() ([]*github.Runner, error) {
runners, _, err := r.ghcli.ListRunners(r.ctx, r.cfg.Owner, r.cfg.Name, nil)
opts := github.ListOptions{
PerPage: 100,
Page: 1,
}
runners, _, err := r.ghcli.ListRunners(r.ctx, r.cfg.Owner, r.cfg.Name, &opts)
if err != nil {
return nil, errors.Wrap(err, "fetching runners")
}
ret := []*github.Runner{}
ret = append(ret, runners.Runners...)
pages := math.Ceil(float64(runners.TotalCount) / float64(100))
if pages > 1 {
for i := 2; i <= int(pages); i++ {
opts.Page = i
runners, _, err = r.ghcli.ListRunners(r.ctx, r.cfg.Owner, r.cfg.Name, &opts)
if err != nil {
return nil, errors.Wrap(err, "fetching runners")
}
ret = append(ret, runners.Runners...)
}
}
return runners.Runners, nil
return ret, nil
}
func (r *repository) FetchTools() ([]*github.RunnerApplicationDownload, error) {

View file

@ -25,7 +25,7 @@ import (
"garm/runner/common"
"garm/util"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
lxd "github.com/lxc/lxd/client"
"github.com/lxc/lxd/shared/api"
"github.com/pkg/errors"

View file

@ -23,9 +23,7 @@ import (
"encoding/json"
"fmt"
"hash"
"io/ioutil"
"log"
"path/filepath"
"strings"
"sync"
"time"
@ -43,7 +41,6 @@ import (
"garm/util"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
)
func NewRunner(ctx context.Context, cfg config.Config) (*Runner, error) {
@ -100,6 +97,7 @@ type poolManagerCtrl struct {
repositories map[string]common.PoolManager
organizations map[string]common.PoolManager
enterprises map[string]common.PoolManager
}
func (p *poolManagerCtrl) CreateRepoPoolManager(ctx context.Context, repo params.Repository, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
@ -184,6 +182,47 @@ func (p *poolManagerCtrl) GetOrgPoolManagers() (map[string]common.PoolManager, e
return p.organizations, nil
}
func (p *poolManagerCtrl) CreateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
p.mux.Lock()
defer p.mux.Unlock()
cfgInternal, err := p.getInternalConfig(enterprise.CredentialsName)
if err != nil {
return nil, errors.Wrap(err, "fetching internal config")
}
poolManager, err := pool.NewEnterprisePoolManager(ctx, enterprise, cfgInternal, providers, store)
if err != nil {
return nil, errors.Wrap(err, "creating enterprise pool manager")
}
p.enterprises[enterprise.ID] = poolManager
return poolManager, nil
}
func (p *poolManagerCtrl) GetEnterprisePoolManager(enterprise params.Enterprise) (common.PoolManager, error) {
if enterprisePoolMgr, ok := p.enterprises[enterprise.ID]; ok {
return enterprisePoolMgr, nil
}
return nil, errors.Wrapf(runnerErrors.ErrNotFound, "enterprise %s pool manager not loaded", enterprise.Name)
}
func (p *poolManagerCtrl) DeleteEnterprisePoolManager(enterprise params.Enterprise) error {
p.mux.Lock()
defer p.mux.Unlock()
poolMgr, ok := p.enterprises[enterprise.ID]
if ok {
if err := poolMgr.Stop(); err != nil {
return errors.Wrap(err, "stopping enterprise pool manager")
}
delete(p.enterprises, enterprise.ID)
}
return nil
}
func (p *poolManagerCtrl) GetEnterprisePoolManagers() (map[string]common.PoolManager, error) {
return p.enterprises, nil
}
func (p *poolManagerCtrl) getInternalConfig(credsName string) (params.Internal, error) {
creds, ok := p.credentials[credsName]
if !ok {
@ -472,6 +511,8 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [
poolManager, err = r.findRepoPoolManager(job.Repository.Owner.Login, job.Repository.Name)
case OrganizationHook:
poolManager, err = r.findOrgPoolManager(job.Organization.Login)
case EnterpriseHook:
poolManager, err = r.findEnterprisePoolManager(job.Enterprise.Slug)
default:
return runnerErrors.NewBadRequestError("cannot handle hook target type %s", hookTargetType)
}
@ -496,45 +537,6 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [
return nil
}
func (r *Runner) sshDir() string {
return filepath.Join(r.config.Default.ConfigDir, "ssh")
}
func (r *Runner) sshKeyPath() string {
keyPath := filepath.Join(r.sshDir(), "runner_rsa_key")
return keyPath
}
func (r *Runner) sshPubKeyPath() string {
keyPath := filepath.Join(r.sshDir(), "runner_rsa_key.pub")
return keyPath
}
func (r *Runner) parseSSHKey() (ssh.Signer, error) {
r.mux.Lock()
defer r.mux.Unlock()
key, err := ioutil.ReadFile(r.sshKeyPath())
if err != nil {
return nil, errors.Wrapf(err, "reading private key %s", r.sshKeyPath())
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, errors.Wrapf(err, "parsing private key %s", r.sshKeyPath())
}
return signer, nil
}
func (r *Runner) sshPubKey() ([]byte, error) {
key, err := ioutil.ReadFile(r.sshPubKeyPath())
if err != nil {
return nil, errors.Wrapf(err, "reading public key %s", r.sshPubKeyPath())
}
return key, nil
}
func (r *Runner) appendTagsToCreatePoolParams(param params.CreatePoolParams) (params.CreatePoolParams, error) {
if err := param.Validate(); err != nil {
return params.CreatePoolParams{}, errors.Wrapf(runnerErrors.ErrBadRequest, "validating params: %s", err)

View file

@ -21,6 +21,7 @@ type HookTargetType string
const (
RepoHook HookTargetType = "repository"
OrganizationHook HookTargetType = "organization"
EnterpriseHook HookTargetType = "business"
)
var (

View file

@ -38,7 +38,7 @@ import (
"garm/params"
"garm/runner/common"
"github.com/google/go-github/v43/github"
"github.com/google/go-github/v47/github"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
@ -163,13 +163,13 @@ func OSToOSType(os string) (config.OSType, error) {
return osType, nil
}
func GithubClient(ctx context.Context, token string, credsDetails params.GithubCredentials) (common.GithubClient, error) {
func GithubClient(ctx context.Context, token string, credsDetails params.GithubCredentials) (common.GithubClient, common.GithubEnterpriseClient, error) {
var roots *x509.CertPool
if credsDetails.CABundle != nil && len(credsDetails.CABundle) > 0 {
roots = x509.NewCertPool()
ok := roots.AppendCertsFromPEM(credsDetails.CABundle)
if !ok {
return nil, fmt.Errorf("failed to parse CA cert")
return nil, nil, fmt.Errorf("failed to parse CA cert")
}
}
httpTransport := &http.Transport{
@ -185,13 +185,12 @@ func GithubClient(ctx context.Context, token string, credsDetails params.GithubC
)
tc := oauth2.NewClient(ctx, ts)
// ghClient := github.NewClient(tc)
ghClient, err := github.NewEnterpriseClient(credsDetails.APIBaseURL, credsDetails.UploadBaseURL, tc)
if err != nil {
return nil, errors.Wrap(err, "fetching github client")
return nil, nil, errors.Wrap(err, "fetching github client")
}
return ghClient.Actions, nil
return ghClient.Actions, ghClient.Enterprise, nil
}
func GetCloudConfig(bootstrapParams params.BootstrapInstance, tools github.RunnerApplicationDownload, runnerName string) (string, error) {