From 3e60a48ca8e8b4bad045d6a224e6d121b0811a53 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Wed, 17 Apr 2024 08:05:06 +0000 Subject: [PATCH] Preload credentials endpoint and remove extra code Signed-off-by: Gabriel Adrian Samfira --- Makefile | 2 +- database/sql/enterprise.go | 3 ++- database/sql/github.go | 14 ++++++++++++++ database/sql/organizations.go | 3 ++- database/sql/repositories.go | 3 ++- params/params.go | 9 ++++----- runner/runner.go | 31 ++++++++++--------------------- util/util.go | 10 ++++++---- 8 files changed, 41 insertions(+), 34 deletions(-) diff --git a/Makefile b/Makefile index ce21ce4b..3dcab4a9 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ export SHELLOPTS:=$(if $(SHELLOPTS),$(SHELLOPTS):)pipefail:errexit .ONESHELL: -GEN_PASSWORD=$(shell (apg -n1 -m32)) +GEN_PASSWORD=$(shell (/usr/bin/apg -n1 -m32)) IMAGE_TAG = garm-build USER_ID=$(shell ((docker --version | grep -q podman) && echo "0" || id -u)) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index e8efcf8b..9fa3c73c 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -96,7 +96,7 @@ func (s *sqlDatabase) GetEnterpriseByID(ctx context.Context, enterpriseID string func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, error) { var enterprises []Enterprise - q := s.conn.Preload("Credentials").Find(&enterprises) + q := s.conn.Preload("Credentials").Preload("Credentials.Endpoint").Find(&enterprises) if q.Error != nil { return []params.Enterprise{}, errors.Wrap(q.Error, "fetching enterprises") } @@ -183,6 +183,7 @@ func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, q := s.conn.Where("name = ? COLLATE NOCASE", name). Preload("Credentials"). + Preload("Credentials.Endpoint"). Preload("Endpoint"). First(&enterprise) if q.Error != nil { diff --git a/database/sql/github.go b/database/sql/github.go index 300c0ef4..08e22f62 100644 --- a/database/sql/github.go +++ b/database/sql/github.go @@ -1,3 +1,17 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + package sql import ( diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 0019ad39..419f81b5 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -87,7 +87,7 @@ func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params. func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organization, error) { var orgs []Organization - q := s.conn.Preload("Credentials").Find(&orgs) + q := s.conn.Preload("Credentials").Preload("Credentials.Endpoint").Find(&orgs) if q.Error != nil { return []params.Organization{}, errors.Wrap(q.Error, "fetching org from database") } @@ -213,6 +213,7 @@ func (s *sqlDatabase) getOrg(_ context.Context, name string) (Organization, erro q := s.conn.Where("name = ? COLLATE NOCASE", name). Preload("Credentials"). + Preload("Credentials.Endpoint"). Preload("Endpoint"). First(&org) if q.Error != nil { diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 18284e1b..26e0d0bc 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -85,7 +85,7 @@ func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (pa func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, error) { var repos []Repository - q := s.conn.Preload("Credentials").Find(&repos) + q := s.conn.Preload("Credentials").Preload("Credentials.Endpoint").Find(&repos) if q.Error != nil { return []params.Repository{}, errors.Wrap(q.Error, "fetching user from database") } @@ -186,6 +186,7 @@ func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository q := s.conn.Where("name = ? COLLATE NOCASE and owner = ? COLLATE NOCASE", name, owner). Preload("Credentials"). + Preload("Credentials.Endpoint"). Preload("Endpoint"). First(&repo) diff --git a/params/params.go b/params/params.go index e6bf8fb6..8409855d 100644 --- a/params/params.go +++ b/params/params.go @@ -566,8 +566,7 @@ type GithubCredentials struct { Enterprises []Enterprise `json:"enterprises,omitempty"` Endpoint string `json:"endpoint"` - CredentialsPayload []byte `json:"-"` - HTTPClient *http.Client `json:"-"` + CredentialsPayload []byte `json:"-"` } func (g GithubCredentials) GetHTTPClient(ctx context.Context) (*http.Client, error) { @@ -579,11 +578,11 @@ func (g GithubCredentials) GetHTTPClient(ctx context.Context) (*http.Client, err return nil, fmt.Errorf("failed to parse CA cert") } } - // nolint:golangci-lint,gosec,godox - // TODO: set TLS MinVersion + httpTransport := &http.Transport{ TLSClientConfig: &tls.Config{ - RootCAs: roots, + RootCAs: roots, + MinVersion: tls.VersionTLS12, }, } diff --git a/runner/runner.go b/runner/runner.go index 78e298d3..bc2c3676 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -316,32 +316,21 @@ func (p *poolManagerCtrl) GetEnterprisePoolManagers() (map[string]common.PoolMan return p.enterprises, nil } -func (p *poolManagerCtrl) getInternalConfig(ctx context.Context, creds params.GithubCredentials, poolBalancerType params.PoolBalancerType) (params.Internal, error) { +func (p *poolManagerCtrl) getInternalConfig(_ context.Context, creds params.GithubCredentials, poolBalancerType params.PoolBalancerType) (params.Internal, error) { var controllerWebhookURL string if p.config.Default.WebhookURL != "" { controllerWebhookURL = fmt.Sprintf("%s/%s", p.config.Default.WebhookURL, p.controllerID) } - httpClient, err := creds.GetHTTPClient(ctx) - if err != nil { - return params.Internal{}, fmt.Errorf("fetching http client for creds: %w", err) - } + return params.Internal{ - ControllerID: p.controllerID, - InstanceCallbackURL: p.config.Default.CallbackURL, - InstanceMetadataURL: p.config.Default.MetadataURL, - BaseWebhookURL: p.config.Default.WebhookURL, - ControllerWebhookURL: controllerWebhookURL, - JWTSecret: p.config.JWTAuth.Secret, - PoolBalancerType: poolBalancerType, - GithubCredentialsDetails: params.GithubCredentials{ - Name: creds.Name, - Description: creds.Description, - BaseURL: creds.BaseURL, - APIBaseURL: creds.APIBaseURL, - UploadBaseURL: creds.UploadBaseURL, - CABundle: creds.CABundle, - HTTPClient: httpClient, - }, + ControllerID: p.controllerID, + InstanceCallbackURL: p.config.Default.CallbackURL, + InstanceMetadataURL: p.config.Default.MetadataURL, + BaseWebhookURL: p.config.Default.WebhookURL, + ControllerWebhookURL: controllerWebhookURL, + JWTSecret: p.config.JWTAuth.Secret, + PoolBalancerType: poolBalancerType, + GithubCredentialsDetails: creds, }, nil } diff --git a/util/util.go b/util/util.go index a75be33f..56f2150a 100644 --- a/util/util.go +++ b/util/util.go @@ -435,11 +435,13 @@ func (g *githubClient) GetEntityJITConfig(ctx context.Context, instance string, return jitConfig, ret.Runner, nil } -func GithubClient(_ context.Context, entity params.GithubEntity, credsDetails params.GithubCredentials) (common.GithubClient, error) { - if credsDetails.HTTPClient == nil { - return nil, errors.New("http client is nil") +func GithubClient(ctx context.Context, entity params.GithubEntity, credsDetails params.GithubCredentials) (common.GithubClient, error) { + httpClient, err := credsDetails.GetHTTPClient(ctx) + if err != nil { + return nil, errors.Wrap(err, "fetching http client") } - ghClient, err := github.NewClient(credsDetails.HTTPClient).WithEnterpriseURLs(credsDetails.APIBaseURL, credsDetails.UploadBaseURL) + + ghClient, err := github.NewClient(httpClient).WithEnterpriseURLs(credsDetails.APIBaseURL, credsDetails.UploadBaseURL) if err != nil { return nil, errors.Wrap(err, "fetching github client") }