Use watcher and get rid of RefreshState()

This change uses the database watcher to watch for changes to the
github entities, credentials and controller info.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2024-06-20 15:28:56 +00:00
parent 38127af747
commit daaca0bd8f
23 changed files with 452 additions and 462 deletions

View file

@ -46,7 +46,20 @@ type InstanceJWTClaims struct {
jwt.RegisteredClaims
}
func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) {
func NewInstanceTokenGetter(jwtSecret string) (InstanceTokenGetter, error) {
if jwtSecret == "" {
return nil, fmt.Errorf("jwt secret is required")
}
return &instanceToken{
jwtSecret: jwtSecret,
}, nil
}
type instanceToken struct {
jwtSecret string
}
func (i *instanceToken) NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) {
// Token expiration is equal to the bootstrap timeout set on the pool plus the polling
// interval garm uses to check for timed out runners. Runners that have not sent their info
// by the end of this interval are most likely failed and will be reaped by garm anyway.
@ -67,7 +80,7 @@ func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolTy
CreateAttempt: instance.CreateAttempt,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(secret))
tokenString, err := token.SignedString([]byte(i.jwtSecret))
if err != nil {
return "", errors.Wrap(err, "signing token")
}

View file

@ -14,9 +14,17 @@
package auth
import "net/http"
import (
"net/http"
"github.com/cloudbase/garm/params"
)
// Middleware defines an authentication middleware
type Middleware interface {
Middleware(next http.Handler) http.Handler
}
type InstanceTokenGetter interface {
NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error)
}

View file

@ -303,6 +303,10 @@ func parseCredentialsAddParams() (ret params.CreateGithubCredentialsParams, err
func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error) {
var updateParams params.UpdateGithubCredentialsParams
if credentialsAppInstallationID != 0 || credentialsAppID != 0 || credentialsPrivateKeyPath != "" {
updateParams.App = &params.GithubApp{}
}
if credentialsName != "" {
updateParams.Name = &credentialsName
}
@ -312,6 +316,9 @@ func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error
}
if credentialsOAuthToken != "" {
if updateParams.PAT == nil {
updateParams.PAT = &params.GithubPAT{}
}
updateParams.PAT.OAuth2Token = credentialsOAuthToken
}

View file

@ -132,7 +132,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e
}
func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching enterprise")
}
@ -206,17 +206,13 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string,
return errors.Wrap(q.Error, "saving enterprise")
}
if creds.ID != 0 {
enterprise.Credentials = creds
}
return nil
})
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}
enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}

View file

@ -123,7 +123,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio
}
func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) {
org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching org")
}
@ -198,17 +198,13 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
return errors.Wrap(q.Error, "saving org")
}
if creds.ID != 0 {
org.Credentials = creds
}
return nil
})
if err != nil {
return params.Organization{}, errors.Wrap(err, "saving org")
}
org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Organization{}, errors.Wrap(err, "updating enterprise")
}

View file

@ -122,7 +122,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository,
}
func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) {
repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching repo")
}
@ -197,16 +197,13 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
return errors.Wrap(q.Error, "saving repo")
}
if creds.ID != 0 {
repo.Credentials = creds
}
return nil
})
if err != nil {
return params.Repository{}, errors.Wrap(err, "saving repo")
}
repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Repository{}, errors.Wrap(err, "updating enterprise")
}

View file

@ -32,6 +32,18 @@ func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
}
}
// WithAll returns a filter function that returns true if all of the provided filters return true.
func WithAll(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
for _, filter := range filters {
if !filter(payload) {
return false
}
}
return true
}
}
// WithEntityTypeFilter returns a filter function that filters payloads by entity type.
// The filter function returns true if the payload's entity type matches the provided entity type.
func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc {
@ -139,3 +151,17 @@ func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFun
}
}
}
// WithGithubCredentialsFilter returns a filter function that filters payloads by Github credentials.
func WithGithubCredentialsFilter(creds params.GithubCredentials) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
if payload.EntityType != dbCommon.GithubCredentialsEntityType {
return false
}
credsPayload, ok := payload.Payload.(params.GithubCredentials)
if !ok {
return false
}
return credsPayload.ID == creds.ID
}
}

View file

@ -419,10 +419,13 @@ func (r Repository) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("repository has no ID")
}
return GithubEntity{
ID: r.ID,
EntityType: GithubEntityTypeRepository,
Owner: r.Owner,
Name: r.Name,
ID: r.ID,
EntityType: GithubEntityTypeRepository,
Owner: r.Owner,
Name: r.Name,
PoolBalancerType: r.PoolBalancerType,
Credentials: r.Credentials,
WebhookSecret: r.WebhookSecret,
}, nil
}
@ -470,10 +473,12 @@ func (o Organization) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("organization has no ID")
}
return GithubEntity{
ID: o.ID,
EntityType: GithubEntityTypeOrganization,
Owner: o.Name,
WebhookSecret: o.WebhookSecret,
ID: o.ID,
EntityType: GithubEntityTypeOrganization,
Owner: o.Name,
WebhookSecret: o.WebhookSecret,
PoolBalancerType: o.PoolBalancerType,
Credentials: o.Credentials,
}, nil
}
@ -517,10 +522,12 @@ func (e Enterprise) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("enterprise has no ID")
}
return GithubEntity{
ID: e.ID,
EntityType: GithubEntityTypeEnterprise,
Owner: e.Name,
WebhookSecret: e.WebhookSecret,
ID: e.ID,
EntityType: GithubEntityTypeEnterprise,
Owner: e.Name,
WebhookSecret: e.WebhookSecret,
PoolBalancerType: e.PoolBalancerType,
Credentials: e.Credentials,
}, nil
}
@ -685,11 +692,6 @@ type Provider struct {
// used by swagger client generated code
type Providers []Provider
type UpdatePoolStateParams struct {
WebhookSecret string
InternalConfig *Internal
}
type PoolManagerStatus struct {
IsRunning bool `json:"running"`
FailureReason string `json:"failure_reason,omitempty"`
@ -788,15 +790,23 @@ type UpdateSystemInfoParams struct {
}
type GithubEntity struct {
Owner string `json:"owner"`
Name string `json:"name"`
ID string `json:"id"`
EntityType GithubEntityType `json:"entity_type"`
Credentials GithubCredentials `json:"credentials"`
Owner string `json:"owner"`
Name string `json:"name"`
ID string `json:"id"`
EntityType GithubEntityType `json:"entity_type"`
Credentials GithubCredentials `json:"credentials"`
PoolBalancerType PoolBalancerType `json:"pool_balancing_type"`
WebhookSecret string `json:"-"`
}
func (g GithubEntity) GetPoolBalancerType() PoolBalancerType {
if g.PoolBalancerType == "" {
return PoolBalancerTypeRoundRobin
}
return g.PoolBalancerType
}
func (g GithubEntity) LabelScope() string {
switch g.EntityType {
case GithubEntityTypeRepository:

View file

@ -152,24 +152,6 @@ func (_m *PoolManager) InstallWebhook(ctx context.Context, param params.InstallW
return r0, r1
}
// RefreshState provides a mock function with given fields: param
func (_m *PoolManager) RefreshState(param params.UpdatePoolStateParams) error {
ret := _m.Called(param)
if len(ret) == 0 {
panic("no return value specified for RefreshState")
}
var r0 error
if rf, ok := ret.Get(0).(func(params.UpdatePoolStateParams) error); ok {
r0 = rf(param)
} else {
r0 = ret.Error(0)
}
return r0
}
// RootCABundle provides a mock function with given fields:
func (_m *PoolManager) RootCABundle() (params.CertificateBundle, error) {
ret := _m.Called()

View file

@ -53,8 +53,6 @@ type PoolManager interface {
// a repo, org or enterprise, we determine the destination of that webhook, retrieve the pool manager
// for it and call this function with the WorkflowJob as a parameter.
HandleWorkflowJob(job params.WorkflowJob) error
// RefreshState allows us to update webhook secrets and configuration for a pool manager.
RefreshState(param params.UpdatePoolStateParams) error
// DeleteRunner will attempt to remove a runner from the pool. If forceRemove is true, any error
// received from the provider will be ignored and we will proceed to remove the runner from the database.

View file

@ -174,11 +174,9 @@ func (r *Runner) UpdateEnterprise(ctx context.Context, enterpriseID string, para
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}
// Use the admin context in the pool manager. Any access control is already done above when
// updating the store.
poolMgr, err := r.poolManagerCtrl.UpdateEnterprisePoolManager(r.ctx, enterprise)
poolMgr, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise)
if err != nil {
return params.Enterprise{}, fmt.Errorf("failed to update enterprise pool manager: %w", err)
return params.Enterprise{}, fmt.Errorf("failed to get enterprise pool manager: %w", err)
}
enterprise.PoolManagerStatus = poolMgr.Status()

View file

@ -45,7 +45,6 @@ type EnterpriseTestFixtures struct {
CreateInstanceParams params.CreateInstanceParams
UpdateRepoParams params.UpdateEntityParams
UpdatePoolParams params.UpdatePoolParams
UpdatePoolStateParams params.UpdatePoolStateParams
ErrMock error
ProviderMock *runnerCommonMocks.Provider
PoolMgrMock *runnerCommonMocks.PoolManager
@ -138,9 +137,6 @@ func (s *EnterpriseTestSuite) SetupTest() {
Image: "test-images-updated",
Flavor: "test-flavor-updated",
},
UpdatePoolStateParams: params.UpdatePoolStateParams{
WebhookSecret: "test-update-repo-webhook-secret",
},
ErrMock: fmt.Errorf("mock error"),
ProviderMock: providerMock,
PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()),
@ -298,7 +294,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolMgrFailed() {
}
func (s *EnterpriseTestSuite) TestUpdateEnterprise() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil)
param := s.Fixtures.UpdateRepoParams
@ -330,21 +326,21 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidCreds() {
}
func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
_, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams)
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}
func (s *EnterpriseTestSuite) TestUpdateEnterpriseCreateEnterprisePoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
_, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams)
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}
func (s *EnterpriseTestSuite) TestCreateEnterprisePool() {

View file

@ -24,7 +24,6 @@ import (
type RepoPoolManager interface {
CreateRepoPoolManager(ctx context.Context, repo params.Repository, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error)
UpdateRepoPoolManager(ctx context.Context, repo params.Repository) (common.PoolManager, error)
GetRepoPoolManager(repo params.Repository) (common.PoolManager, error)
DeleteRepoPoolManager(repo params.Repository) error
GetRepoPoolManagers() (map[string]common.PoolManager, error)
@ -32,7 +31,6 @@ type RepoPoolManager interface {
type OrgPoolManager interface {
CreateOrgPoolManager(ctx context.Context, org params.Organization, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error)
UpdateOrgPoolManager(ctx context.Context, org params.Organization) (common.PoolManager, error)
GetOrgPoolManager(org params.Organization) (common.PoolManager, error)
DeleteOrgPoolManager(org params.Organization) error
GetOrgPoolManagers() (map[string]common.PoolManager, error)
@ -40,7 +38,6 @@ type OrgPoolManager interface {
type EnterprisePoolManager interface {
CreateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error)
UpdateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise) (common.PoolManager, error)
GetEnterprisePoolManager(enterprise params.Enterprise) (common.PoolManager, error)
DeleteEnterprisePoolManager(enterprise params.Enterprise) error
GetEnterprisePoolManagers() (map[string]common.PoolManager, error)

View file

@ -343,96 +343,6 @@ func (_m *PoolManagerController) GetRepoPoolManagers() (map[string]common.PoolMa
return r0, r1
}
// UpdateEnterprisePoolManager provides a mock function with given fields: ctx, enterprise
func (_m *PoolManagerController) UpdateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise) (common.PoolManager, error) {
ret := _m.Called(ctx, enterprise)
if len(ret) == 0 {
panic("no return value specified for UpdateEnterprisePoolManager")
}
var r0 common.PoolManager
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, params.Enterprise) (common.PoolManager, error)); ok {
return rf(ctx, enterprise)
}
if rf, ok := ret.Get(0).(func(context.Context, params.Enterprise) common.PoolManager); ok {
r0 = rf(ctx, enterprise)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(common.PoolManager)
}
}
if rf, ok := ret.Get(1).(func(context.Context, params.Enterprise) error); ok {
r1 = rf(ctx, enterprise)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateOrgPoolManager provides a mock function with given fields: ctx, org
func (_m *PoolManagerController) UpdateOrgPoolManager(ctx context.Context, org params.Organization) (common.PoolManager, error) {
ret := _m.Called(ctx, org)
if len(ret) == 0 {
panic("no return value specified for UpdateOrgPoolManager")
}
var r0 common.PoolManager
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, params.Organization) (common.PoolManager, error)); ok {
return rf(ctx, org)
}
if rf, ok := ret.Get(0).(func(context.Context, params.Organization) common.PoolManager); ok {
r0 = rf(ctx, org)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(common.PoolManager)
}
}
if rf, ok := ret.Get(1).(func(context.Context, params.Organization) error); ok {
r1 = rf(ctx, org)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateRepoPoolManager provides a mock function with given fields: ctx, repo
func (_m *PoolManagerController) UpdateRepoPoolManager(ctx context.Context, repo params.Repository) (common.PoolManager, error) {
ret := _m.Called(ctx, repo)
if len(ret) == 0 {
panic("no return value specified for UpdateRepoPoolManager")
}
var r0 common.PoolManager
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, params.Repository) (common.PoolManager, error)); ok {
return rf(ctx, repo)
}
if rf, ok := ret.Get(0).(func(context.Context, params.Repository) common.PoolManager); ok {
r0 = rf(ctx, repo)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(common.PoolManager)
}
}
if rf, ok := ret.Get(1).(func(context.Context, params.Repository) error); ok {
r1 = rf(ctx, repo)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewPoolManagerController creates a new instance of PoolManagerController. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewPoolManagerController(t interface {

View file

@ -203,11 +203,9 @@ func (r *Runner) UpdateOrganization(ctx context.Context, orgID string, param par
return params.Organization{}, errors.Wrap(err, "updating org")
}
// Use the admin context in the pool manager. Any access control is already done above when
// updating the store.
poolMgr, err := r.poolManagerCtrl.UpdateOrgPoolManager(r.ctx, org)
poolMgr, err := r.poolManagerCtrl.GetOrgPoolManager(org)
if err != nil {
return params.Organization{}, fmt.Errorf("updating org pool manager: %w", err)
return params.Organization{}, fmt.Errorf("failed to get org pool manager: %w", err)
}
org.PoolManagerStatus = poolMgr.Status()

View file

@ -34,22 +34,21 @@ import (
)
type OrgTestFixtures struct {
AdminContext context.Context
DBFile string
Store dbCommon.Store
StoreOrgs map[string]params.Organization
Providers map[string]common.Provider
Credentials map[string]params.GithubCredentials
CreateOrgParams params.CreateOrgParams
CreatePoolParams params.CreatePoolParams
CreateInstanceParams params.CreateInstanceParams
UpdateRepoParams params.UpdateEntityParams
UpdatePoolParams params.UpdatePoolParams
UpdatePoolStateParams params.UpdatePoolStateParams
ErrMock error
ProviderMock *runnerCommonMocks.Provider
PoolMgrMock *runnerCommonMocks.PoolManager
PoolMgrCtrlMock *runnerMocks.PoolManagerController
AdminContext context.Context
DBFile string
Store dbCommon.Store
StoreOrgs map[string]params.Organization
Providers map[string]common.Provider
Credentials map[string]params.GithubCredentials
CreateOrgParams params.CreateOrgParams
CreatePoolParams params.CreatePoolParams
CreateInstanceParams params.CreateInstanceParams
UpdateRepoParams params.UpdateEntityParams
UpdatePoolParams params.UpdatePoolParams
ErrMock error
ProviderMock *runnerCommonMocks.Provider
PoolMgrMock *runnerCommonMocks.PoolManager
PoolMgrCtrlMock *runnerMocks.PoolManagerController
}
type OrgTestSuite struct {
@ -139,9 +138,6 @@ func (s *OrgTestSuite) SetupTest() {
Image: "test-images-updated",
Flavor: "test-flavor-updated",
},
UpdatePoolStateParams: params.UpdatePoolStateParams{
WebhookSecret: "test-update-repo-webhook-secret",
},
ErrMock: fmt.Errorf("mock error"),
ProviderMock: providerMock,
PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()),
@ -312,7 +308,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolMgrFailed() {
}
func (s *OrgTestSuite) TestUpdateOrganization() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil)
org, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams)
@ -326,7 +322,7 @@ func (s *OrgTestSuite) TestUpdateOrganization() {
func (s *OrgTestSuite) TestUpdateRepositoryBalancingType() {
s.Fixtures.UpdateRepoParams.PoolBalancerType = params.PoolBalancerTypePack
s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil)
param := s.Fixtures.UpdateRepoParams
@ -355,21 +351,21 @@ func (s *OrgTestSuite) TestUpdateOrganizationInvalidCreds() {
}
func (s *OrgTestSuite) TestUpdateOrganizationPoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
_, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams)
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("updating org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}
func (s *OrgTestSuite) TestUpdateOrganizationCreateOrgPoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
_, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams)
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("updating org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}
func (s *OrgTestSuite) TestCreateOrgPool() {

View file

@ -35,6 +35,7 @@ import (
"github.com/cloudbase/garm-provider-common/util"
"github.com/cloudbase/garm/auth"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/params"
"github.com/cloudbase/garm/runner/common"
garmUtil "github.com/cloudbase/garm/util"
@ -61,16 +62,9 @@ const (
maxCreateAttempts = 5
)
type urls struct {
callbackURL string
metadataURL string
webhookURL string
controllerWebhookURL string
}
func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, instanceTokenGetter auth.InstanceTokenGetter, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) {
ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", entity.EntityType))
ghc, err := garmUtil.GithubClient(ctx, entity, cfgInternal.GithubCredentialsDetails)
ghc, err := garmUtil.GithubClient(ctx, entity, entity.Credentials)
if err != nil {
return nil, errors.Wrap(err, "getting github client")
}
@ -79,38 +73,47 @@ func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, cfgIn
return nil, errors.New("webhook secret is empty")
}
controllerInfo, err := store.ControllerInfo()
if err != nil {
return nil, errors.Wrap(err, "getting controller info")
}
consumerID := fmt.Sprintf("pool-manager-%s", entity.String())
consumer, err := watcher.RegisterConsumer(
ctx, consumerID,
composeWatcherFilters(entity),
)
if err != nil {
return nil, errors.Wrap(err, "registering consumer")
}
wg := &sync.WaitGroup{}
keyMuxes := &keyMutex{}
repo := &basePoolManager{
ctx: ctx,
cfgInternal: cfgInternal,
entity: entity,
ghcli: ghc,
ctx: ctx,
entity: entity,
ghcli: ghc,
controllerInfo: controllerInfo,
instanceTokenGetter: instanceTokenGetter,
store: store,
providers: providers,
controllerID: cfgInternal.ControllerID,
urls: urls{
webhookURL: cfgInternal.BaseWebhookURL,
callbackURL: cfgInternal.InstanceCallbackURL,
metadataURL: cfgInternal.InstanceMetadataURL,
controllerWebhookURL: cfgInternal.ControllerWebhookURL,
},
quit: make(chan struct{}),
credsDetails: cfgInternal.GithubCredentialsDetails,
wg: wg,
keyMux: keyMuxes,
store: store,
providers: providers,
quit: make(chan struct{}),
wg: wg,
keyMux: keyMuxes,
consumer: consumer,
}
return repo, nil
}
type basePoolManager struct {
ctx context.Context
controllerID string
entity params.GithubEntity
ghcli common.GithubClient
cfgInternal params.Internal
ctx context.Context
entity params.GithubEntity
ghcli common.GithubClient
controllerInfo params.ControllerInfo
instanceTokenGetter auth.InstanceTokenGetter
consumer dbCommon.Consumer
store dbCommon.Store
@ -118,13 +121,9 @@ type basePoolManager struct {
tools []commonParams.RunnerApplicationDownload
quit chan struct{}
credsDetails params.GithubCredentials
managerIsRunning bool
managerErrorReason string
urls urls
mux sync.Mutex
wg *sync.WaitGroup
keyMux *keyMutex
@ -353,9 +352,9 @@ func (r *basePoolManager) updateTools() error {
tools, err := r.FetchTools()
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
r.ctx, "failed to update tools for repo")
r.ctx, "failed to update tools for entity", "entity", r.entity.String())
r.setPoolRunningState(false, err.Error())
return fmt.Errorf("failed to update tools for repo %s: %w", r.entity.String(), err)
return fmt.Errorf("failed to update tools for entity %s: %w", r.entity.String(), err)
}
r.mux.Lock()
r.tools = tools
@ -381,7 +380,7 @@ func (r *basePoolManager) cleanupOrphanedProviderRunners(runners []*github.Runne
runnerNames := map[string]bool{}
for _, run := range runners {
if !isManagedRunner(labelsFromRunner(run), r.controllerID) {
if !isManagedRunner(labelsFromRunner(run), r.controllerInfo.ControllerID.String()) {
slog.DebugContext(
r.ctx, "runner is not managed by a pool we manage",
"runner_name", run.GetName())
@ -457,7 +456,7 @@ func (r *basePoolManager) reapTimedOutRunners(runners []*github.Runner) error {
runnersByName := map[string]*github.Runner{}
for _, run := range runners {
if !isManagedRunner(labelsFromRunner(run), r.controllerID) {
if !isManagedRunner(labelsFromRunner(run), r.controllerInfo.ControllerID.String()) {
slog.DebugContext(
r.ctx, "runner is not managed by a pool we manage",
"runner_name", run.GetName())
@ -515,7 +514,7 @@ func (r *basePoolManager) cleanupOrphanedGithubRunners(runners []*github.Runner)
poolInstanceCache := map[string][]commonParams.ProviderInstance{}
g, ctx := errgroup.WithContext(r.ctx)
for _, runner := range runners {
if !isManagedRunner(labelsFromRunner(runner), r.controllerID) {
if !isManagedRunner(labelsFromRunner(runner), r.controllerInfo.ControllerID.String()) {
slog.DebugContext(
r.ctx, "runner is not managed by a pool we manage",
"runner_name", runner.GetName())
@ -741,8 +740,8 @@ func (r *basePoolManager) AddRunner(ctx context.Context, poolID string, aditiona
RunnerStatus: params.RunnerPending,
OSArch: pool.OSArch,
OSType: pool.OSType,
CallbackURL: r.urls.callbackURL,
MetadataURL: r.urls.metadataURL,
CallbackURL: r.controllerInfo.CallbackURL,
MetadataURL: r.controllerInfo.MetadataURL,
CreateAttempt: 1,
GitHubRunnerGroup: pool.GitHubRunnerGroup,
AditionalLabels: aditionalLabels,
@ -832,7 +831,7 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error
jwtValidity := pool.RunnerTimeout()
entity := r.entity.String()
jwtToken, err := auth.NewInstanceJWTToken(instance, r.cfgInternal.JWTSecret, entity, pool.PoolType(), jwtValidity)
jwtToken, err := r.instanceTokenGetter.NewInstanceJWTToken(instance, entity, pool.PoolType(), jwtValidity)
if err != nil {
return errors.Wrap(err, "fetching instance jwt token")
}
@ -852,7 +851,7 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error
Image: pool.Image,
ExtraSpecs: pool.ExtraSpecs,
PoolID: instance.PoolID,
CACertBundle: r.credsDetails.CABundle,
CACertBundle: r.entity.Credentials.CABundle,
GitHubRunnerGroup: instance.GitHubRunnerGroup,
JitConfigEnabled: hasJITConfig,
}
@ -954,7 +953,7 @@ func (r *basePoolManager) poolLabel(poolID string) string {
}
func (r *basePoolManager) controllerLabel() string {
return fmt.Sprintf("%s%s", controllerLabelPrefix, r.controllerID)
return fmt.Sprintf("%s%s", controllerLabelPrefix, r.controllerInfo.ControllerID.String())
}
func (r *basePoolManager) updateArgsFromProviderInstance(providerInstance commonParams.ProviderInstance) params.UpdateInstanceParams {
@ -1525,6 +1524,7 @@ func (r *basePoolManager) Start() error {
initialToolUpdate <- struct{}{}
}()
go r.runWatcher()
go func() {
select {
case <-r.quit:
@ -1552,37 +1552,6 @@ func (r *basePoolManager) Stop() error {
return nil
}
func (r *basePoolManager) RefreshState(param params.UpdatePoolStateParams) error {
r.mux.Lock()
if param.WebhookSecret != "" {
r.entity.WebhookSecret = param.WebhookSecret
}
if param.InternalConfig != nil {
r.cfgInternal = *param.InternalConfig
r.urls = urls{
webhookURL: r.cfgInternal.BaseWebhookURL,
callbackURL: r.cfgInternal.InstanceCallbackURL,
metadataURL: r.cfgInternal.InstanceMetadataURL,
controllerWebhookURL: r.cfgInternal.ControllerWebhookURL,
}
}
ghc, err := garmUtil.GithubClient(r.ctx, r.entity, r.cfgInternal.GithubCredentialsDetails)
if err != nil {
return errors.Wrap(err, "getting github client")
}
r.ghcli = ghc
r.mux.Unlock()
// Update the tools as soon as state is updated. This should revive a stopped pool manager
// or stop one if the supplied credentials are not okay.
if err := r.updateTools(); err != nil {
return fmt.Errorf("failed to update tools: %w", err)
}
return nil
}
func (r *basePoolManager) WebhookSecret() string {
return r.entity.WebhookSecret
}
@ -1688,7 +1657,7 @@ func (r *basePoolManager) consumeQueuedJobs() error {
}
poolsCache := poolsForTags{
poolCacheType: r.PoolBalancerType(),
poolCacheType: r.entity.GetPoolBalancerType(),
}
slog.DebugContext(
@ -1812,7 +1781,7 @@ func (r *basePoolManager) consumeQueuedJobs() error {
}
func (r *basePoolManager) UninstallWebhook(ctx context.Context) error {
if r.urls.controllerWebhookURL == "" {
if r.controllerInfo.ControllerWebhookURL == "" {
return errors.Wrap(runnerErrors.ErrBadRequest, "controller webhook url is empty")
}
@ -1823,8 +1792,8 @@ func (r *basePoolManager) UninstallWebhook(ctx context.Context) error {
var controllerHookID int64
var baseHook string
trimmedBase := strings.TrimRight(r.urls.webhookURL, "/")
trimmedController := strings.TrimRight(r.urls.controllerWebhookURL, "/")
trimmedBase := strings.TrimRight(r.controllerInfo.WebhookURL, "/")
trimmedController := strings.TrimRight(r.controllerInfo.ControllerWebhookURL, "/")
for _, hook := range allHooks {
hookInfo := hookToParamsHookInfo(hook)
@ -1859,7 +1828,7 @@ func (r *basePoolManager) InstallHook(ctx context.Context, req *github.Hook) (pa
return params.HookInfo{}, errors.Wrap(err, "listing hooks")
}
if err := validateHookRequest(r.cfgInternal.ControllerID, r.cfgInternal.BaseWebhookURL, allHooks, req); err != nil {
if err := validateHookRequest(r.controllerInfo.ControllerID.String(), r.controllerInfo.WebhookURL, allHooks, req); err != nil {
return params.HookInfo{}, errors.Wrap(err, "validating hook request")
}
@ -1879,7 +1848,7 @@ func (r *basePoolManager) InstallHook(ctx context.Context, req *github.Hook) (pa
}
func (r *basePoolManager) InstallWebhook(ctx context.Context, param params.InstallWebhookParams) (params.HookInfo, error) {
if r.urls.controllerWebhookURL == "" {
if r.controllerInfo.ControllerWebhookURL == "" {
return params.HookInfo{}, errors.Wrap(runnerErrors.ErrBadRequest, "controller webhook url is empty")
}
@ -1890,7 +1859,7 @@ func (r *basePoolManager) InstallWebhook(ctx context.Context, param params.Insta
req := &github.Hook{
Active: github.Bool(true),
Config: map[string]interface{}{
"url": r.urls.controllerWebhookURL,
"url": r.controllerInfo.ControllerWebhookURL,
"content_type": "json",
"insecure_ssl": insecureSSL,
"secret": r.WebhookSecret(),
@ -1978,21 +1947,14 @@ func (r *basePoolManager) GetGithubRunners() ([]*github.Runner, error) {
return allRunners, nil
}
func (r *basePoolManager) PoolBalancerType() params.PoolBalancerType {
if r.cfgInternal.PoolBalancerType == "" {
return params.PoolBalancerTypeRoundRobin
}
return r.cfgInternal.PoolBalancerType
}
func (r *basePoolManager) GithubURL() string {
switch r.entity.EntityType {
case params.GithubEntityTypeRepository:
return fmt.Sprintf("%s/%s/%s", r.cfgInternal.GithubCredentialsDetails.BaseURL, r.entity.Owner, r.entity.Name)
return fmt.Sprintf("%s/%s/%s", r.entity.Credentials.BaseURL, r.entity.Owner, r.entity.Name)
case params.GithubEntityTypeOrganization:
return fmt.Sprintf("%s/%s", r.cfgInternal.GithubCredentialsDetails.BaseURL, r.entity.Owner)
return fmt.Sprintf("%s/%s", r.entity.Credentials.BaseURL, r.entity.Owner)
case params.GithubEntityTypeEnterprise:
return fmt.Sprintf("%s/enterprises/%s", r.cfgInternal.GithubCredentialsDetails.BaseURL, r.entity.Owner)
return fmt.Sprintf("%s/enterprises/%s", r.entity.Credentials.BaseURL, r.entity.Owner)
}
return ""
}
@ -2002,8 +1964,8 @@ func (r *basePoolManager) GetWebhookInfo(ctx context.Context) (params.HookInfo,
if err != nil {
return params.HookInfo{}, errors.Wrap(err, "listing hooks")
}
trimmedBase := strings.TrimRight(r.urls.webhookURL, "/")
trimmedController := strings.TrimRight(r.urls.controllerWebhookURL, "/")
trimmedBase := strings.TrimRight(r.controllerInfo.WebhookURL, "/")
trimmedController := strings.TrimRight(r.controllerInfo.ControllerWebhookURL, "/")
var controllerHookInfo *params.HookInfo
var baseHookInfo *params.HookInfo
@ -2034,5 +1996,5 @@ func (r *basePoolManager) GetWebhookInfo(ctx context.Context) (params.HookInfo,
}
func (r *basePoolManager) RootCABundle() (params.CertificateBundle, error) {
return r.credsDetails.RootCertificateBundle()
return r.entity.Credentials.RootCertificateBundle()
}

View file

@ -0,0 +1,57 @@
package pool
import (
"context"
"github.com/google/go-github/v57/github"
"github.com/cloudbase/garm/params"
)
type stubGithubClient struct {
err error
}
func (s *stubGithubClient) ListEntityHooks(_ context.Context, _ *github.ListOptions) ([]*github.Hook, *github.Response, error) {
return nil, nil, s.err
}
func (s *stubGithubClient) GetEntityHook(_ context.Context, _ int64) (*github.Hook, error) {
return nil, s.err
}
func (s *stubGithubClient) CreateEntityHook(_ context.Context, _ *github.Hook) (*github.Hook, error) {
return nil, s.err
}
func (s *stubGithubClient) DeleteEntityHook(_ context.Context, _ int64) (*github.Response, error) {
return nil, s.err
}
func (s *stubGithubClient) PingEntityHook(_ context.Context, _ int64) (*github.Response, error) {
return nil, s.err
}
func (s *stubGithubClient) ListEntityRunners(_ context.Context, _ *github.ListOptions) (*github.Runners, *github.Response, error) {
return nil, nil, s.err
}
func (s *stubGithubClient) ListEntityRunnerApplicationDownloads(_ context.Context) ([]*github.RunnerApplicationDownload, *github.Response, error) {
return nil, nil, s.err
}
func (s *stubGithubClient) RemoveEntityRunner(_ context.Context, _ int64) (*github.Response, error) {
return nil, s.err
}
func (s *stubGithubClient) CreateEntityRegistrationToken(_ context.Context) (*github.RegistrationToken, *github.Response, error) {
return nil, nil, s.err
}
func (s *stubGithubClient) GetEntityJITConfig(_ context.Context, _ string, _ params.Pool, _ []string) (map[string]string, *github.Runner, error) {
return nil, nil, s.err
}
func (s *stubGithubClient) GetWorkflowJobByID(_ context.Context, _, _ string, _ int64) (*github.WorkflowJob, *github.Response, error) {
return nil, nil, s.err
}

View file

@ -10,6 +10,8 @@ import (
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
commonParams "github.com/cloudbase/garm-provider-common/params"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/params"
)
@ -116,3 +118,19 @@ func isManagedRunner(labels []string, controllerID string) bool {
runnerControllerID := controllerIDFromLabels(labels)
return runnerControllerID == controllerID
}
func composeWatcherFilters(entity params.GithubEntity) dbCommon.PayloadFilterFunc {
// We want to watch for changes in either the controller or the
// entity itself.
return watcher.WithAny(
watcher.WithAll(
// Updates to the controller
watcher.WithEntityTypeFilter(dbCommon.ControllerEntityType),
watcher.WithOperationTypeFilter(dbCommon.UpdateOperation),
),
// Any operation on the entity we're managing the pool for.
watcher.WithEntityFilter(entity),
// Watch for changes to the github credentials
watcher.WithGithubCredentialsFilter(entity.Credentials),
)
}

154
runner/pool/watcher.go Normal file
View file

@ -0,0 +1,154 @@
package pool
import (
"log/slog"
"github.com/pkg/errors"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/params"
runnerCommon "github.com/cloudbase/garm/runner/common"
garmUtil "github.com/cloudbase/garm/util"
)
// entityGetter is implemented by all github entities (repositories, organizations and enterprises)
type entityGetter interface {
GetEntity() (params.GithubEntity, error)
}
func (r *basePoolManager) handleControllerUpdateEvent(controllerInfo params.ControllerInfo) {
r.mux.Lock()
defer r.mux.Unlock()
slog.DebugContext(r.ctx, "updating controller info", "controller_info", controllerInfo)
r.controllerInfo = controllerInfo
}
func (r *basePoolManager) getClientOrStub() runnerCommon.GithubClient {
var err error
var ghc runnerCommon.GithubClient
ghc, err = garmUtil.GithubClient(r.ctx, r.entity, r.entity.Credentials)
if err != nil {
slog.WarnContext(r.ctx, "failed to create github client", "error", err)
ghc = &stubGithubClient{
err: errors.Wrapf(runnerErrors.ErrUnauthorized, "failed to create github client; please update credentials: %v", err),
}
}
return ghc
}
func (r *basePoolManager) handleEntityUpdate(entity params.GithubEntity) {
slog.DebugContext(r.ctx, "received entity update", "entity", entity.ID)
credentialsUpdate := r.entity.Credentials.ID != entity.Credentials.ID
defer func() {
slog.DebugContext(r.ctx, "deferred tools update", "credentials_update", credentialsUpdate)
if !credentialsUpdate {
return
}
slog.DebugContext(r.ctx, "updating tools", "entity", entity.ID)
if err := r.updateTools(); err != nil {
slog.ErrorContext(r.ctx, "failed to update tools", "error", err)
}
}()
slog.DebugContext(r.ctx, "updating entity", "entity", entity.ID)
r.mux.Lock()
slog.DebugContext(r.ctx, "lock acquired", "entity", entity.ID)
r.entity = entity
if credentialsUpdate {
if r.consumer != nil {
filters := composeWatcherFilters(r.entity)
r.consumer.SetFilters(filters)
}
slog.DebugContext(r.ctx, "credentials update", "entity", entity.ID)
r.ghcli = r.getClientOrStub()
}
r.mux.Unlock()
slog.DebugContext(r.ctx, "lock released", "entity", entity.ID)
}
func (r *basePoolManager) handleCredentialsUpdate(credentials params.GithubCredentials) {
// when we switch credentials on an entity (like from one app to another or from an app
// to a PAT), we may still get events for the previous credentials as the channel is buffered.
// The watcher will watch for changes to the entity itself, which includes events that
// change the credentials name on the entity, but we also watch for changes to the credentials
// themselves, like an updated PAT token set on existing credentials entity.
// The handleCredentialsUpdate function handles situations where we have changes on the
// credentials entity itself, not on the entity that the credentials are set on.
// For example, we may have a credentials entity called org_pat set on a repo called
// test-repo. This function would handle situations where "org_pat" is updated.
// If "test-repo" is updated with new credentials, that event is handled above in
// handleEntityUpdate.
shouldUpdateTools := r.entity.Credentials.ID == credentials.ID
defer func() {
if !shouldUpdateTools {
return
}
slog.DebugContext(r.ctx, "deferred tools update", "credentials_id", credentials.ID)
if err := r.updateTools(); err != nil {
slog.ErrorContext(r.ctx, "failed to update tools", "error", err)
}
}()
r.mux.Lock()
if !shouldUpdateTools {
slog.InfoContext(r.ctx, "credential ID mismatch; stale event?", "credentials_id", credentials.ID)
r.mux.Unlock()
return
}
slog.DebugContext(r.ctx, "updating credentials", "credentials_id", credentials.ID)
r.entity.Credentials = credentials
r.ghcli = r.getClientOrStub()
r.mux.Unlock()
}
func (r *basePoolManager) handleWatcherEvent(event common.ChangePayload) {
dbEntityType := common.DatabaseEntityType(r.entity.EntityType)
switch event.EntityType {
case common.GithubCredentialsEntityType:
credentials, ok := event.Payload.(params.GithubCredentials)
if !ok {
slog.ErrorContext(r.ctx, "failed to cast payload to github credentials")
return
}
r.handleCredentialsUpdate(credentials)
case common.ControllerEntityType:
controllerInfo, ok := event.Payload.(params.ControllerInfo)
if !ok {
slog.ErrorContext(r.ctx, "failed to cast payload to controller info")
return
}
r.handleControllerUpdateEvent(controllerInfo)
case dbEntityType:
entity, ok := event.Payload.(entityGetter)
if !ok {
slog.ErrorContext(r.ctx, "failed to cast payload to entity")
return
}
entityInfo, err := entity.GetEntity()
if err != nil {
slog.ErrorContext(r.ctx, "failed to get entity", "error", err)
return
}
r.handleEntityUpdate(entityInfo)
}
}
func (r *basePoolManager) runWatcher() {
for {
select {
case <-r.quit:
return
case <-r.ctx.Done():
return
case event, ok := <-r.consumer.Watch():
if !ok {
return
}
go r.handleWatcherEvent(event)
}
}
}

View file

@ -197,16 +197,15 @@ func (r *Runner) UpdateRepository(ctx context.Context, repoID string, param para
return params.Repository{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType)
}
slog.InfoContext(ctx, "updating repository", "repo_id", repoID, "param", param)
repo, err := r.store.UpdateRepository(ctx, repoID, param)
if err != nil {
return params.Repository{}, errors.Wrap(err, "updating repo")
}
// Use the admin context in the pool manager. Any access control is already done above when
// updating the store.
poolMgr, err := r.poolManagerCtrl.UpdateRepoPoolManager(r.ctx, repo)
poolMgr, err := r.poolManagerCtrl.GetRepoPoolManager(repo)
if err != nil {
return params.Repository{}, fmt.Errorf("failed to update pool manager: %w", err)
return params.Repository{}, fmt.Errorf("failed to get pool manager: %w", err)
}
repo.PoolManagerStatus = poolMgr.Status()

View file

@ -35,21 +35,20 @@ import (
)
type RepoTestFixtures struct {
AdminContext context.Context
Store dbCommon.Store
StoreRepos map[string]params.Repository
Providers map[string]common.Provider
Credentials map[string]params.GithubCredentials
CreateRepoParams params.CreateRepoParams
CreatePoolParams params.CreatePoolParams
CreateInstanceParams params.CreateInstanceParams
UpdateRepoParams params.UpdateEntityParams
UpdatePoolParams params.UpdatePoolParams
UpdatePoolStateParams params.UpdatePoolStateParams
ErrMock error
ProviderMock *runnerCommonMocks.Provider
PoolMgrMock *runnerCommonMocks.PoolManager
PoolMgrCtrlMock *runnerMocks.PoolManagerController
AdminContext context.Context
Store dbCommon.Store
StoreRepos map[string]params.Repository
Providers map[string]common.Provider
Credentials map[string]params.GithubCredentials
CreateRepoParams params.CreateRepoParams
CreatePoolParams params.CreatePoolParams
CreateInstanceParams params.CreateInstanceParams
UpdateRepoParams params.UpdateEntityParams
UpdatePoolParams params.UpdatePoolParams
ErrMock error
ProviderMock *runnerCommonMocks.Provider
PoolMgrMock *runnerCommonMocks.PoolManager
PoolMgrCtrlMock *runnerMocks.PoolManagerController
}
func init() {
@ -143,9 +142,6 @@ func (s *RepoTestSuite) SetupTest() {
Image: "test-images-updated",
Flavor: "test-flavor-updated",
},
UpdatePoolStateParams: params.UpdatePoolStateParams{
WebhookSecret: "test-update-repo-webhook-secret",
},
ErrMock: fmt.Errorf("mock error"),
ProviderMock: providerMock,
PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()),
@ -327,7 +323,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolMgrFailed() {
}
func (s *RepoTestSuite) TestUpdateRepository() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil)
repo, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams)
@ -341,7 +337,7 @@ func (s *RepoTestSuite) TestUpdateRepository() {
}
func (s *RepoTestSuite) TestUpdateRepositoryBalancingType() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil)
updateRepoParams := s.Fixtures.UpdateRepoParams
@ -372,21 +368,21 @@ func (s *RepoTestSuite) TestUpdateRepositoryInvalidCreds() {
}
func (s *RepoTestSuite) TestUpdateRepositoryPoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
_, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams)
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}
func (s *RepoTestSuite) TestUpdateRepositoryCreateRepoPoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
_, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams)
s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}
func (s *RepoTestSuite) TestCreateRepoPool() {

View file

@ -100,23 +100,16 @@ func (p *poolManagerCtrl) CreateRepoPoolManager(ctx context.Context, repo params
p.mux.Lock()
defer p.mux.Unlock()
creds, err := p.store.GetGithubCredentials(ctx, repo.CredentialsID, true)
entity, err := repo.GetEntity()
if err != nil {
return nil, errors.Wrap(err, "fetching credentials")
return nil, errors.Wrap(err, "getting entity")
}
cfgInternal, err := p.getInternalConfig(ctx, creds, repo.GetBalancerType())
instanceTokenGetter, err := auth.NewInstanceTokenGetter(p.config.JWTAuth.Secret)
if err != nil {
return nil, errors.Wrap(err, "fetching internal config")
return nil, errors.Wrap(err, "creating instance token getter")
}
entity := params.GithubEntity{
Owner: repo.Owner,
Name: repo.Name,
ID: repo.ID,
WebhookSecret: repo.WebhookSecret,
EntityType: params.GithubEntityTypeRepository,
}
poolManager, err := pool.NewEntityPoolManager(ctx, entity, cfgInternal, providers, store)
poolManager, err := pool.NewEntityPoolManager(ctx, entity, instanceTokenGetter, providers, store)
if err != nil {
return nil, errors.Wrap(err, "creating repo pool manager")
}
@ -124,36 +117,6 @@ func (p *poolManagerCtrl) CreateRepoPoolManager(ctx context.Context, repo params
return poolManager, nil
}
func (p *poolManagerCtrl) UpdateRepoPoolManager(ctx context.Context, repo params.Repository) (common.PoolManager, error) {
p.mux.Lock()
defer p.mux.Unlock()
poolMgr, ok := p.repositories[repo.ID]
if !ok {
return nil, errors.Wrapf(runnerErrors.ErrNotFound, "repository %s/%s pool manager not loaded", repo.Owner, repo.Name)
}
creds, err := p.store.GetGithubCredentials(ctx, repo.CredentialsID, true)
if err != nil {
return nil, errors.Wrap(err, "fetching credentials")
}
internalCfg, err := p.getInternalConfig(ctx, creds, repo.GetBalancerType())
if err != nil {
return nil, errors.Wrap(err, "fetching internal config")
}
newState := params.UpdatePoolStateParams{
WebhookSecret: repo.WebhookSecret,
InternalConfig: &internalCfg,
}
if err := poolMgr.RefreshState(newState); err != nil {
return nil, errors.Wrap(err, "updating repo pool manager")
}
return poolMgr, nil
}
func (p *poolManagerCtrl) GetRepoPoolManager(repo params.Repository) (common.PoolManager, error) {
if repoPoolMgr, ok := p.repositories[repo.ID]; ok {
return repoPoolMgr, nil
@ -183,21 +146,16 @@ func (p *poolManagerCtrl) CreateOrgPoolManager(ctx context.Context, org params.O
p.mux.Lock()
defer p.mux.Unlock()
creds, err := p.store.GetGithubCredentials(ctx, org.CredentialsID, true)
entity, err := org.GetEntity()
if err != nil {
return nil, errors.Wrap(err, "fetching credentials")
return nil, errors.Wrap(err, "getting entity")
}
cfgInternal, err := p.getInternalConfig(ctx, creds, org.GetBalancerType())
instanceTokenGetter, err := auth.NewInstanceTokenGetter(p.config.JWTAuth.Secret)
if err != nil {
return nil, errors.Wrap(err, "fetching internal config")
return nil, errors.Wrap(err, "creating instance token getter")
}
entity := params.GithubEntity{
Owner: org.Name,
ID: org.ID,
WebhookSecret: org.WebhookSecret,
EntityType: params.GithubEntityTypeOrganization,
}
poolManager, err := pool.NewEntityPoolManager(ctx, entity, cfgInternal, providers, store)
poolManager, err := pool.NewEntityPoolManager(ctx, entity, instanceTokenGetter, providers, store)
if err != nil {
return nil, errors.Wrap(err, "creating org pool manager")
}
@ -205,35 +163,6 @@ func (p *poolManagerCtrl) CreateOrgPoolManager(ctx context.Context, org params.O
return poolManager, nil
}
func (p *poolManagerCtrl) UpdateOrgPoolManager(ctx context.Context, org params.Organization) (common.PoolManager, error) {
p.mux.Lock()
defer p.mux.Unlock()
poolMgr, ok := p.organizations[org.ID]
if !ok {
return nil, errors.Wrapf(runnerErrors.ErrNotFound, "org %s pool manager not loaded", org.Name)
}
creds, err := p.store.GetGithubCredentials(ctx, org.CredentialsID, true)
if err != nil {
return nil, errors.Wrap(err, "fetching credentials")
}
internalCfg, err := p.getInternalConfig(ctx, creds, org.GetBalancerType())
if err != nil {
return nil, errors.Wrap(err, "fetching internal config")
}
newState := params.UpdatePoolStateParams{
WebhookSecret: org.WebhookSecret,
InternalConfig: &internalCfg,
}
if err := poolMgr.RefreshState(newState); err != nil {
return nil, errors.Wrap(err, "updating repo pool manager")
}
return poolMgr, nil
}
func (p *poolManagerCtrl) GetOrgPoolManager(org params.Organization) (common.PoolManager, error) {
if orgPoolMgr, ok := p.organizations[org.ID]; ok {
return orgPoolMgr, nil
@ -263,22 +192,16 @@ func (p *poolManagerCtrl) CreateEnterprisePoolManager(ctx context.Context, enter
p.mux.Lock()
defer p.mux.Unlock()
creds, err := p.store.GetGithubCredentials(ctx, enterprise.CredentialsID, true)
entity, err := enterprise.GetEntity()
if err != nil {
return nil, errors.Wrap(err, "fetching credentials")
}
cfgInternal, err := p.getInternalConfig(ctx, creds, enterprise.GetBalancerType())
if err != nil {
return nil, errors.Wrap(err, "fetching internal config")
return nil, errors.Wrap(err, "getting entity")
}
entity := params.GithubEntity{
Owner: enterprise.Name,
ID: enterprise.ID,
WebhookSecret: enterprise.WebhookSecret,
EntityType: params.GithubEntityTypeEnterprise,
instanceTokenGetter, err := auth.NewInstanceTokenGetter(p.config.JWTAuth.Secret)
if err != nil {
return nil, errors.Wrap(err, "creating instance token getter")
}
poolManager, err := pool.NewEntityPoolManager(ctx, entity, cfgInternal, providers, store)
poolManager, err := pool.NewEntityPoolManager(ctx, entity, instanceTokenGetter, providers, store)
if err != nil {
return nil, errors.Wrap(err, "creating enterprise pool manager")
}
@ -286,35 +209,6 @@ func (p *poolManagerCtrl) CreateEnterprisePoolManager(ctx context.Context, enter
return poolManager, nil
}
func (p *poolManagerCtrl) UpdateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise) (common.PoolManager, error) {
p.mux.Lock()
defer p.mux.Unlock()
poolMgr, ok := p.enterprises[enterprise.ID]
if !ok {
return nil, errors.Wrapf(runnerErrors.ErrNotFound, "enterprise %s pool manager not loaded", enterprise.Name)
}
creds, err := p.store.GetGithubCredentials(ctx, enterprise.CredentialsID, true)
if err != nil {
return nil, errors.Wrap(err, "fetching credentials")
}
internalCfg, err := p.getInternalConfig(ctx, creds, enterprise.GetBalancerType())
if err != nil {
return nil, errors.Wrap(err, "fetching internal config")
}
newState := params.UpdatePoolStateParams{
WebhookSecret: enterprise.WebhookSecret,
InternalConfig: &internalCfg,
}
if err := poolMgr.RefreshState(newState); err != nil {
return nil, errors.Wrap(err, "updating repo pool manager")
}
return poolMgr, nil
}
func (p *poolManagerCtrl) GetEnterprisePoolManager(enterprise params.Enterprise) (common.PoolManager, error) {
if enterprisePoolMgr, ok := p.enterprises[enterprise.ID]; ok {
return enterprisePoolMgr, nil
@ -340,24 +234,6 @@ func (p *poolManagerCtrl) GetEnterprisePoolManagers() (map[string]common.PoolMan
return p.enterprises, nil
}
func (p *poolManagerCtrl) getInternalConfig(_ context.Context, creds params.GithubCredentials, poolBalancerType params.PoolBalancerType) (params.Internal, error) {
controllerInfo, err := p.store.ControllerInfo()
if err != nil {
return params.Internal{}, errors.Wrap(err, "fetching controller info")
}
return params.Internal{
ControllerID: controllerInfo.ControllerID.String(),
InstanceCallbackURL: controllerInfo.CallbackURL,
InstanceMetadataURL: controllerInfo.MetadataURL,
BaseWebhookURL: controllerInfo.WebhookURL,
ControllerWebhookURL: controllerInfo.ControllerWebhookURL,
JWTSecret: p.config.JWTAuth.Secret,
PoolBalancerType: poolBalancerType,
GithubCredentialsDetails: creds,
}, nil
}
type Runner struct {
mux sync.Mutex