diff --git a/database/watcher/watcher.go b/database/watcher/watcher.go index ec81d5bd..2ef1aeee 100644 --- a/database/watcher/watcher.go +++ b/database/watcher/watcher.go @@ -17,7 +17,7 @@ func InitWatcher(ctx context.Context) { if databaseWatcher != nil { return } - ctx = garmUtil.WithContext(ctx, slog.Any("watcher", "database")) + ctx = garmUtil.WithSlogContext(ctx, slog.Any("watcher", "database")) w := &watcher{ producers: make(map[string]*producer), consumers: make(map[string]*consumer), @@ -33,7 +33,7 @@ func RegisterProducer(ctx context.Context, id string) (common.Producer, error) { if databaseWatcher == nil { return nil, common.ErrWatcherNotInitialized } - ctx = garmUtil.WithContext(ctx, slog.Any("producer_id", id)) + ctx = garmUtil.WithSlogContext(ctx, slog.Any("producer_id", id)) return databaseWatcher.RegisterProducer(ctx, id) } @@ -41,7 +41,7 @@ func RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadF if databaseWatcher == nil { return nil, common.ErrWatcherNotInitialized } - ctx = garmUtil.WithContext(ctx, slog.Any("consumer_id", id)) + ctx = garmUtil.WithSlogContext(ctx, slog.Any("consumer_id", id)) return databaseWatcher.RegisterConsumer(ctx, id, filters...) } diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 7e2a6080..9e86c415 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -41,6 +41,7 @@ import ( "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner/common" garmUtil "github.com/cloudbase/garm/util" + ghClient "github.com/cloudbase/garm/util/github" ) var ( @@ -65,8 +66,8 @@ const ( ) func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, instanceTokenGetter auth.InstanceTokenGetter, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) { - ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", entity.EntityType)) - ghc, err := garmUtil.GithubClient(ctx, entity, entity.Credentials) + ctx = garmUtil.WithSlogContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", entity.EntityType)) + ghc, err := ghClient.GithubClient(ctx, entity) if err != nil { return nil, errors.Wrap(err, "getting github client") } diff --git a/runner/pool/watcher.go b/runner/pool/watcher.go index b17494d5..29950748 100644 --- a/runner/pool/watcher.go +++ b/runner/pool/watcher.go @@ -9,7 +9,7 @@ import ( "github.com/cloudbase/garm/database/common" "github.com/cloudbase/garm/params" runnerCommon "github.com/cloudbase/garm/runner/common" - garmUtil "github.com/cloudbase/garm/util" + ghClient "github.com/cloudbase/garm/util/github" ) // entityGetter is implemented by all github entities (repositories, organizations and enterprises) @@ -28,7 +28,7 @@ func (r *basePoolManager) handleControllerUpdateEvent(controllerInfo params.Cont func (r *basePoolManager) getClientOrStub() runnerCommon.GithubClient { var err error var ghc runnerCommon.GithubClient - ghc, err = garmUtil.GithubClient(r.ctx, r.entity, r.entity.Credentials) + ghc, err = ghClient.GithubClient(r.ctx, r.entity) if err != nil { slog.WarnContext(r.ctx, "failed to create github client", "error", err) ghc = &stubGithubClient{ diff --git a/util/github/client.go b/util/github/client.go new file mode 100644 index 00000000..800c5b00 --- /dev/null +++ b/util/github/client.go @@ -0,0 +1,471 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package github + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/url" + + "github.com/google/go-github/v57/github" + "github.com/pkg/errors" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/metrics" + "github.com/cloudbase/garm/params" + "github.com/cloudbase/garm/runner/common" +) + +type githubClient struct { + *github.ActionsService + org *github.OrganizationsService + repo *github.RepositoriesService + enterprise *github.EnterpriseService + + entity params.GithubEntity + cli *github.Client +} + +func (g *githubClient) ListEntityHooks(ctx context.Context, opts *github.ListOptions) (ret []*github.Hook, response *github.Response, err error) { + metrics.GithubOperationCount.WithLabelValues( + "ListHooks", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "ListHooks", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, response, err = g.repo.ListHooks(ctx, g.entity.Owner, g.entity.Name, opts) + case params.GithubEntityTypeOrganization: + ret, response, err = g.org.ListHooks(ctx, g.entity.Owner, opts) + default: + return nil, nil, fmt.Errorf("invalid entity type: %s", g.entity.EntityType) + } + return ret, response, err +} + +func (g *githubClient) GetEntityHook(ctx context.Context, id int64) (ret *github.Hook, err error) { + metrics.GithubOperationCount.WithLabelValues( + "GetHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "GetHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, _, err = g.repo.GetHook(ctx, g.entity.Owner, g.entity.Name, id) + case params.GithubEntityTypeOrganization: + ret, _, err = g.org.GetHook(ctx, g.entity.Owner, id) + default: + return nil, errors.New("invalid entity type") + } + return ret, err +} + +func (g *githubClient) CreateEntityHook(ctx context.Context, hook *github.Hook) (ret *github.Hook, err error) { + metrics.GithubOperationCount.WithLabelValues( + "CreateHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "CreateHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, _, err = g.repo.CreateHook(ctx, g.entity.Owner, g.entity.Name, hook) + case params.GithubEntityTypeOrganization: + ret, _, err = g.org.CreateHook(ctx, g.entity.Owner, hook) + default: + return nil, errors.New("invalid entity type") + } + return ret, err +} + +func (g *githubClient) DeleteEntityHook(ctx context.Context, id int64) (ret *github.Response, err error) { + metrics.GithubOperationCount.WithLabelValues( + "DeleteHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "DeleteHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, err = g.repo.DeleteHook(ctx, g.entity.Owner, g.entity.Name, id) + case params.GithubEntityTypeOrganization: + ret, err = g.org.DeleteHook(ctx, g.entity.Owner, id) + default: + return nil, errors.New("invalid entity type") + } + return ret, err +} + +func (g *githubClient) PingEntityHook(ctx context.Context, id int64) (ret *github.Response, err error) { + metrics.GithubOperationCount.WithLabelValues( + "PingHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "PingHook", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, err = g.repo.PingHook(ctx, g.entity.Owner, g.entity.Name, id) + case params.GithubEntityTypeOrganization: + ret, err = g.org.PingHook(ctx, g.entity.Owner, id) + default: + return nil, errors.New("invalid entity type") + } + return ret, err +} + +func (g *githubClient) ListEntityRunners(ctx context.Context, opts *github.ListOptions) (*github.Runners, *github.Response, error) { + var ret *github.Runners + var response *github.Response + var err error + + metrics.GithubOperationCount.WithLabelValues( + "ListEntityRunners", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "ListEntityRunners", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, response, err = g.ListRunners(ctx, g.entity.Owner, g.entity.Name, opts) + case params.GithubEntityTypeOrganization: + ret, response, err = g.ListOrganizationRunners(ctx, g.entity.Owner, opts) + case params.GithubEntityTypeEnterprise: + ret, response, err = g.enterprise.ListRunners(ctx, g.entity.Owner, opts) + default: + return nil, nil, errors.New("invalid entity type") + } + + return ret, response, err +} + +func (g *githubClient) ListEntityRunnerApplicationDownloads(ctx context.Context) ([]*github.RunnerApplicationDownload, *github.Response, error) { + var ret []*github.RunnerApplicationDownload + var response *github.Response + var err error + + metrics.GithubOperationCount.WithLabelValues( + "ListEntityRunnerApplicationDownloads", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "ListEntityRunnerApplicationDownloads", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, response, err = g.ListRunnerApplicationDownloads(ctx, g.entity.Owner, g.entity.Name) + case params.GithubEntityTypeOrganization: + ret, response, err = g.ListOrganizationRunnerApplicationDownloads(ctx, g.entity.Owner) + case params.GithubEntityTypeEnterprise: + ret, response, err = g.enterprise.ListRunnerApplicationDownloads(ctx, g.entity.Owner) + default: + return nil, nil, errors.New("invalid entity type") + } + + return ret, response, err +} + +func (g *githubClient) RemoveEntityRunner(ctx context.Context, runnerID int64) (*github.Response, error) { + var response *github.Response + var err error + + metrics.GithubOperationCount.WithLabelValues( + "RemoveEntityRunner", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "RemoveEntityRunner", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + response, err = g.RemoveRunner(ctx, g.entity.Owner, g.entity.Name, runnerID) + case params.GithubEntityTypeOrganization: + response, err = g.RemoveOrganizationRunner(ctx, g.entity.Owner, runnerID) + case params.GithubEntityTypeEnterprise: + response, err = g.enterprise.RemoveRunner(ctx, g.entity.Owner, runnerID) + default: + return nil, errors.New("invalid entity type") + } + + return response, err +} + +func (g *githubClient) CreateEntityRegistrationToken(ctx context.Context) (*github.RegistrationToken, *github.Response, error) { + var ret *github.RegistrationToken + var response *github.Response + var err error + + metrics.GithubOperationCount.WithLabelValues( + "CreateEntityRegistrationToken", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + defer func() { + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "CreateEntityRegistrationToken", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + } + }() + + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, response, err = g.CreateRegistrationToken(ctx, g.entity.Owner, g.entity.Name) + case params.GithubEntityTypeOrganization: + ret, response, err = g.CreateOrganizationRegistrationToken(ctx, g.entity.Owner) + case params.GithubEntityTypeEnterprise: + ret, response, err = g.enterprise.CreateRegistrationToken(ctx, g.entity.Owner) + default: + return nil, nil, errors.New("invalid entity type") + } + + return ret, response, err +} + +func (g *githubClient) getOrganizationRunnerGroupIDByName(ctx context.Context, entity params.GithubEntity, rgName string) (int64, error) { + opts := github.ListOrgRunnerGroupOptions{ + ListOptions: github.ListOptions{ + PerPage: 100, + }, + } + + for { + metrics.GithubOperationCount.WithLabelValues( + "ListOrganizationRunnerGroups", // label: operation + entity.LabelScope(), // label: scope + ).Inc() + runnerGroups, ghResp, err := g.ListOrganizationRunnerGroups(ctx, entity.Owner, &opts) + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "ListOrganizationRunnerGroups", // label: operation + entity.LabelScope(), // label: scope + ).Inc() + if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { + return 0, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runners") + } + return 0, errors.Wrap(err, "fetching runners") + } + for _, runnerGroup := range runnerGroups.RunnerGroups { + if runnerGroup.Name != nil && *runnerGroup.Name == rgName { + return *runnerGroup.ID, nil + } + } + if ghResp.NextPage == 0 { + break + } + opts.Page = ghResp.NextPage + } + return 0, runnerErrors.NewNotFoundError("runner group %s not found", rgName) +} + +func (g *githubClient) getEnterpriseRunnerGroupIDByName(ctx context.Context, entity params.GithubEntity, rgName string) (int64, error) { + opts := github.ListEnterpriseRunnerGroupOptions{ + ListOptions: github.ListOptions{ + PerPage: 100, + }, + } + + for { + metrics.GithubOperationCount.WithLabelValues( + "ListRunnerGroups", // label: operation + entity.LabelScope(), // label: scope + ).Inc() + runnerGroups, ghResp, err := g.enterprise.ListRunnerGroups(ctx, entity.Owner, &opts) + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "ListRunnerGroups", // label: operation + entity.LabelScope(), // label: scope + ).Inc() + if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { + return 0, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runners") + } + return 0, errors.Wrap(err, "fetching runners") + } + for _, runnerGroup := range runnerGroups.RunnerGroups { + if runnerGroup.Name != nil && *runnerGroup.Name == rgName { + return *runnerGroup.ID, nil + } + } + if ghResp.NextPage == 0 { + break + } + opts.Page = ghResp.NextPage + } + return 0, runnerErrors.NewNotFoundError("runner group not found") +} + +func (g *githubClient) GetEntityJITConfig(ctx context.Context, instance string, pool params.Pool, labels []string) (jitConfigMap map[string]string, runner *github.Runner, err error) { + // If no runner group is set, use the default runner group ID. This is also the default for + // repository level runners. + var rgID int64 = 1 + + if pool.GitHubRunnerGroup != "" { + switch g.entity.EntityType { + case params.GithubEntityTypeOrganization: + rgID, err = g.getOrganizationRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup) + case params.GithubEntityTypeEnterprise: + rgID, err = g.getEnterpriseRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup) + } + + if err != nil { + return nil, nil, fmt.Errorf("getting runner group ID: %w", err) + } + } + + req := github.GenerateJITConfigRequest{ + Name: instance, + RunnerGroupID: rgID, + Labels: labels, + // nolint:golangci-lint,godox + // TODO(gabriel-samfira): Should we make this configurable? + WorkFolder: github.String("_work"), + } + + metrics.GithubOperationCount.WithLabelValues( + "GetEntityJITConfig", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + + var ret *github.JITRunnerConfig + var response *github.Response + + switch g.entity.EntityType { + case params.GithubEntityTypeRepository: + ret, response, err = g.GenerateRepoJITConfig(ctx, g.entity.Owner, g.entity.Name, &req) + case params.GithubEntityTypeOrganization: + ret, response, err = g.GenerateOrgJITConfig(ctx, g.entity.Owner, &req) + case params.GithubEntityTypeEnterprise: + ret, response, err = g.enterprise.GenerateEnterpriseJITConfig(ctx, g.entity.Owner, &req) + } + if err != nil { + metrics.GithubOperationFailedCount.WithLabelValues( + "GetEntityJITConfig", // label: operation + g.entity.LabelScope(), // label: scope + ).Inc() + if response != nil && response.StatusCode == http.StatusUnauthorized { + return nil, nil, fmt.Errorf("failed to get JIT config: %w", err) + } + return nil, nil, fmt.Errorf("failed to get JIT config: %w", err) + } + + defer func(run *github.Runner) { + if err != nil && run != nil { + _, innerErr := g.RemoveEntityRunner(ctx, run.GetID()) + slog.With(slog.Any("error", innerErr)).ErrorContext( + ctx, "failed to remove runner", + "runner_id", run.GetID(), string(g.entity.EntityType), g.entity.String()) + } + }(ret.Runner) + + decoded, err := base64.StdEncoding.DecodeString(*ret.EncodedJITConfig) + if err != nil { + return nil, nil, fmt.Errorf("failed to decode JIT config: %w", err) + } + + var jitConfig map[string]string + if err := json.Unmarshal(decoded, &jitConfig); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal JIT config: %w", err) + } + + return jitConfig, ret.Runner, nil +} + +func (g *githubClient) GetEntity() params.GithubEntity { + return g.entity +} + +func (g *githubClient) GithubBaseURL() *url.URL { + return g.cli.BaseURL +} + +func GithubClient(ctx context.Context, entity params.GithubEntity) (common.GithubClient, error) { + // func GithubClient(ctx context.Context, entity params.GithubEntity) (common.GithubClient, error) { + httpClient, err := entity.Credentials.GetHTTPClient(ctx) + if err != nil { + return nil, errors.Wrap(err, "fetching http client") + } + + ghClient, err := github.NewClient(httpClient).WithEnterpriseURLs( + entity.Credentials.APIBaseURL, entity.Credentials.UploadBaseURL) + if err != nil { + return nil, errors.Wrap(err, "fetching github client") + } + + cli := &githubClient{ + ActionsService: ghClient.Actions, + org: ghClient.Organizations, + repo: ghClient.Repositories, + enterprise: ghClient.Enterprise, + cli: ghClient, + entity: entity, + } + + return cli, nil +} diff --git a/util/github/scalesets/client.go b/util/github/scalesets/client.go new file mode 100644 index 00000000..f0b2deac --- /dev/null +++ b/util/github/scalesets/client.go @@ -0,0 +1,95 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package scalesets + +import ( + "fmt" + "io" + "net/http" + "sync" + + "github.com/google/go-github/v57/github" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/params" + "github.com/cloudbase/garm/runner/common" +) + +func NewClient(cli common.GithubClient) (*ScaleSetClient, error) { + return &ScaleSetClient{ + ghCli: cli, + httpClient: &http.Client{}, + }, nil +} + +type ScaleSetClient struct { + ghCli common.GithubClient + httpClient *http.Client + + // scale sets are aparently available through the same security + // contex that a normal runner would use. We connect to the same + // API endpoint a runner would connect to, in order to fetch jobs. + // To do this, we use a runner registration token. + runnerRegistrationToken *github.RegistrationToken + // actionsServiceInfo holds the pipeline URL and the JWT token to + // access it. The pipeline URL is the base URL where we can access + // the scale set endpoints. + actionsServiceInfo *params.ActionsServiceAdminInfoResponse + + mux sync.Mutex +} + +func (s *ScaleSetClient) SetGithubClient(cli common.GithubClient) { + s.mux.Lock() + defer s.mux.Unlock() + s.ghCli = cli +} + +func (s *ScaleSetClient) Do(req *http.Request) (*http.Response, error) { + if s.httpClient == nil { + return nil, fmt.Errorf("http client is not initialized") + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to dispatch HTTP request: %w", err) + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return resp, nil + } + + var body []byte + if resp != nil { + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read body: %w", err) + } + } + + switch resp.StatusCode { + case 404: + return nil, runnerErrors.NewNotFoundError("resource %s not found: %q", req.URL.String(), string(body)) + case 400: + return nil, runnerErrors.NewBadRequestError("bad request while calling %s: %q", req.URL.String(), string(body)) + case 409: + return nil, runnerErrors.NewConflictError("conflict while calling %s: %q", req.URL.String(), string(body)) + case 401, 403: + return nil, runnerErrors.ErrUnauthorized + default: + return nil, fmt.Errorf("request to %s failed with status code %d: %q", req.URL.String(), resp.StatusCode, string(body)) + } +} diff --git a/util/github/scalesets/jobs.go b/util/github/scalesets/jobs.go new file mode 100644 index 00000000..b087ad63 --- /dev/null +++ b/util/github/scalesets/jobs.go @@ -0,0 +1,88 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package scalesets + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/cloudbase/garm/params" +) + +type acquireJobsResult struct { + Count int `json:"count"` + Value []int64 `json:"value"` +} + +func (s *ScaleSetClient) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) { + u := fmt.Sprintf("%s/%d/acquirejobs?api-version=6.0-preview", scaleSetEndpoint, runnerScaleSetId) + + body, err := json.Marshal(requestIds) + if err != nil { + return nil, err + } + + req, err := s.newActionsRequest(ctx, http.MethodPost, u, bytes.NewBuffer(body)) + if err != nil { + return nil, fmt.Errorf("failed to construct request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken)) + + resp, err := s.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed for %s: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + var acquiredJobs acquireJobsResult + err = json.NewDecoder(resp.Body).Decode(&acquiredJobs) + if err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return acquiredJobs.Value, nil +} + +func (s *ScaleSetClient) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (params.AcquirableJobList, error) { + path := fmt.Sprintf("%d/acquirablejobs", runnerScaleSetId) + + req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return params.AcquirableJobList{}, fmt.Errorf("failed to construct request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return params.AcquirableJobList{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNoContent { + return params.AcquirableJobList{Count: 0, Jobs: []params.AcquirableJob{}}, nil + } + + var acquirableJobList params.AcquirableJobList + err = json.NewDecoder(resp.Body).Decode(&acquirableJobList) + if err != nil { + return params.AcquirableJobList{}, fmt.Errorf("failed to decode response: %w", err) + } + + return acquirableJobList, nil +} diff --git a/util/github/scalesets/message_sessions.go b/util/github/scalesets/message_sessions.go new file mode 100644 index 00000000..ae70239e --- /dev/null +++ b/util/github/scalesets/message_sessions.go @@ -0,0 +1,265 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package scalesets + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "math/rand/v2" + "net/http" + "net/url" + "strconv" + "sync" + "time" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/params" +) + +const maxCapacityHeader = "X-ScaleSetMaxCapacity" + +func NewMessageSession(ctx context.Context, cli *ScaleSetClient, session *params.RunnerScaleSetSession) (*MessageSession, error) { + sess := &MessageSession{ + ssCli: cli, + session: session, + ctx: ctx, + done: make(chan struct{}), + closed: false, + } + go sess.loop() + return sess, nil +} + +type MessageSession struct { + ssCli *ScaleSetClient + session *params.RunnerScaleSetSession + ctx context.Context + + done chan struct{} + closed bool + lastErr error + + mux sync.Mutex +} + +func (m *MessageSession) Close() error { + m.mux.Lock() + defer m.mux.Unlock() + if m.closed { + return nil + } + close(m.done) + m.closed = true + return nil +} + +func (m *MessageSession) LastError() error { + return m.lastErr +} + +func (m *MessageSession) loop() { + timer := time.NewTimer(1 * time.Minute) + defer timer.Stop() + if m.closed { + return + } + for { + select { + case <-m.ctx.Done(): + return + case <-m.done: + return + case <-timer.C: + if err := m.maybeRefreshToken(m.ctx); err != nil { + // We endlessly retry. If it's a transient error, it should eventually + // work, if it's credentials issues, users can update them. + slog.With(slog.Any("error", err)).ErrorContext(m.ctx, "failed to refresh message queue token") + m.lastErr = err + } + } + } +} + +func (m *MessageSession) SessionsRelativeURL() (string, error) { + if m.session == nil { + return "", fmt.Errorf("session is nil") + } + if m.session.RunnerScaleSet == nil { + return "", fmt.Errorf("runner scale set is nil") + } + relativePath := fmt.Sprintf("%s/%d/sessions/%s", scaleSetEndpoint, m.session.RunnerScaleSet.Id, m.session.SessionId.String()) + return relativePath, nil +} + +func (m *MessageSession) Refresh(ctx context.Context) error { + m.mux.Lock() + defer m.mux.Unlock() + + relPath, err := m.SessionsRelativeURL() + if err != nil { + return fmt.Errorf("failed to get session URL: %w", err) + } + req, err := m.ssCli.newActionsRequest(ctx, http.MethodPatch, relPath, nil) + if err != nil { + return fmt.Errorf("failed to create message delete request: %w", err) + } + resp, err := m.ssCli.Do(req) + if err != nil { + return fmt.Errorf("failed to delete message session: %w", err) + } + + var refreshedSession params.RunnerScaleSetSession + if err := json.NewDecoder(resp.Body).Decode(&refreshedSession); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + + m.session = &refreshedSession + return nil +} + +func (m *MessageSession) maybeRefreshToken(ctx context.Context) error { + if m.session == nil { + return fmt.Errorf("session is nil") + } + // add some jitter + jitter := time.Duration(rand.IntN(10000)) * time.Millisecond + if m.session.ExpiresIn(2*time.Minute + jitter) { + if err := m.Refresh(ctx); err != nil { + return fmt.Errorf("failed to refresh message queue token: %w", err) + } + } + return nil +} + +func (m *MessageSession) GetMessage(ctx context.Context, lastMessageId int64, maxCapacity uint) (params.RunnerScaleSetMessage, error) { + u, err := url.Parse(m.session.MessageQueueUrl) + if err != nil { + return params.RunnerScaleSetMessage{}, err + } + + if lastMessageId > 0 { + q := u.Query() + q.Set("lastMessageId", strconv.FormatInt(lastMessageId, 10)) + u.RawQuery = q.Encode() + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json; api-version=6.0-preview") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken)) + req.Header.Set(maxCapacityHeader, fmt.Sprintf("%d", maxCapacity)) + + resp, err := m.ssCli.Do(req) + if err != nil { + return params.RunnerScaleSetMessage{}, fmt.Errorf("request to %s failed: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusAccepted { + return params.RunnerScaleSetMessage{}, nil + } + + var message params.RunnerScaleSetMessage + if err := json.NewDecoder(resp.Body).Decode(&message); err != nil { + return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to decode response: %w", err) + } + return message, nil +} + +func (m *MessageSession) DeleteMessage(ctx context.Context, messageId int64) error { + u, err := url.Parse(m.session.MessageQueueUrl) + if err != nil { + return err + } + + u.Path = fmt.Sprintf("%s/%d", u.Path, messageId) + + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.String(), nil) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken)) + + resp, err := m.ssCli.Do(req) + if err != nil { + return err + } + + resp.Body.Close() + return nil +} + +func (s *ScaleSetClient) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*MessageSession, error) { + path := fmt.Sprintf("%s/%d/sessions", scaleSetEndpoint, runnerScaleSetId) + + newSession := params.RunnerScaleSetSession{ + OwnerName: owner, + } + + requestData, err := json.Marshal(newSession) + if err != nil { + return nil, fmt.Errorf("failed to marshal session data: %w", err) + } + + req, err := s.newActionsRequest(ctx, http.MethodPost, path, bytes.NewBuffer(requestData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request to %s: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + var createdSession params.RunnerScaleSetSession + if err := json.NewDecoder(resp.Body).Decode(&createdSession); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return &MessageSession{ + ssCli: s, + session: &createdSession, + }, nil +} + +func (s *ScaleSetClient) DeleteMessageSession(ctx context.Context, session *MessageSession) error { + path, err := session.SessionsRelativeURL() + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + + req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil) + if err != nil { + return fmt.Errorf("failed to create message delete request: %w", err) + } + + _, err = s.Do(req) + if err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return fmt.Errorf("failed to delete message session: %w", err) + } + } + return nil +} diff --git a/util/github/scalesets/runners.go b/util/github/scalesets/runners.go new file mode 100644 index 00000000..2d1519dc --- /dev/null +++ b/util/github/scalesets/runners.go @@ -0,0 +1,129 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package scalesets + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/params" +) + +type scaleSetJitRunnerConfig struct { + Name string `json:"name"` + WorkFolder string `json:"workFolder"` +} + +func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName string, scaleSet params.RunnerScaleSet) (params.RunnerScaleSetJitRunnerConfig, error) { + runnerSettings := scaleSetJitRunnerConfig{ + Name: runnerName, + WorkFolder: "_work", + } + + body, err := json.Marshal(runnerSettings) + if err != nil { + return params.RunnerScaleSetJitRunnerConfig{}, err + } + + req, err := s.newActionsRequest(ctx, http.MethodPost, scaleSet.RunnerJitConfigUrl, bytes.NewBuffer(body)) + if err != nil { + return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + var runnerJitConfig params.RunnerScaleSetJitRunnerConfig + if err := json.NewDecoder(resp.Body).Decode(&runnerJitConfig); err != nil { + return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to decode response: %w", err) + } + return runnerJitConfig, nil +} + +func (s *ScaleSetClient) GetRunner(ctx context.Context, runnerId int64) (params.RunnerReference, error) { + path := fmt.Sprintf("%s/%d", runnerEndpoint, runnerId) + + req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return params.RunnerReference{}, fmt.Errorf("failed to construct request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerReference{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + var runnerReference params.RunnerReference + if err := json.NewDecoder(resp.Body).Decode(&runnerReference); err != nil { + return params.RunnerReference{}, fmt.Errorf("failed to decode response: %w", err) + } + + return runnerReference, nil +} + +func (s *ScaleSetClient) GetRunnerByName(ctx context.Context, runnerName string) (params.RunnerReference, error) { + path := fmt.Sprintf("%s?agentName=%s", runnerEndpoint, runnerName) + + req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return params.RunnerReference{}, fmt.Errorf("failed to construct request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerReference{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err) + } + defer resp.Body.Close() + + var runnerList params.RunnerReferenceList + if err := json.NewDecoder(resp.Body).Decode(&runnerList); err != nil { + return params.RunnerReference{}, fmt.Errorf("failed to decode response: %w", err) + } + + if runnerList.Count == 0 { + return params.RunnerReference{}, fmt.Errorf("could not find runner with name %q: %w", runnerName, runnerErrors.ErrNotFound) + } + + if runnerList.Count > 1 { + return params.RunnerReference{}, fmt.Errorf("failed to decode response: %w", err) + } + + return runnerList.RunnerReferences[0], nil +} + +func (s *ScaleSetClient) RemoveRunner(ctx context.Context, runnerId int64) error { + path := fmt.Sprintf("%s/%d", runnerEndpoint, runnerId) + + req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil) + if err != nil { + return fmt.Errorf("failed to construct request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return fmt.Errorf("request failed for %s: %w", req.URL.String(), err) + } + + resp.Body.Close() + return nil +} diff --git a/util/github/scalesets/scalesets.go b/util/github/scalesets/scalesets.go new file mode 100644 index 00000000..7c70daec --- /dev/null +++ b/util/github/scalesets/scalesets.go @@ -0,0 +1,204 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package scalesets + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httputil" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/params" +) + +const ( + runnerEndpoint = "_apis/distributedtask/pools/0/agents" + scaleSetEndpoint = "_apis/runtime/runnerscalesets" +) + +const ( + HeaderActionsActivityID = "ActivityId" + HeaderGitHubRequestID = "X-GitHub-Request-Id" +) + +func (s *ScaleSetClient) GetRunnerScaleSetByNameAndRunnerGroup(ctx context.Context, runnerGroupId int, name string) (params.RunnerScaleSet, error) { + path := fmt.Sprintf("%s?runnerGroupId=%d&name=%s", scaleSetEndpoint, runnerGroupId, name) + req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return params.RunnerScaleSet{}, err + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerScaleSet{}, err + } + + var runnerScaleSetList *params.RunnerScaleSetsResponse + if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSetList); err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err) + } + if runnerScaleSetList.Count == 0 { + return params.RunnerScaleSet{}, runnerErrors.NewNotFoundError("runner scale set with name %s and runner group ID %d was not found", name, runnerGroupId) + } + + // Runner scale sets must have a uniqe name. Attempting to create a runner scale set with the same name as + // an existing scale set will result in a Bad Request (400) error. + return runnerScaleSetList.RunnerScaleSets[0], nil +} + +func (s *ScaleSetClient) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int) (params.RunnerScaleSet, error) { + path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId) + req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return params.RunnerScaleSet{}, err + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to get runner scaleset with ID %d: %w", runnerScaleSetId, err) + } + + var runnerScaleSet params.RunnerScaleSet + if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSet); err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err) + } + return runnerScaleSet, nil +} + +// ListRunnerScaleSets lists all runner scale sets in a github entity. +func (s *ScaleSetClient) ListRunnerScaleSets(ctx context.Context) (*params.RunnerScaleSetsResponse, error) { + req, err := s.newActionsRequest(ctx, http.MethodGet, scaleSetEndpoint, nil) + if err != nil { + return nil, err + } + data, err := httputil.DumpRequest(req, false) + if err == nil { + fmt.Println(string(data)) + } + resp, err := s.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to list runner scale sets: %w", err) + } + + var runnerScaleSetList params.RunnerScaleSetsResponse + if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSetList); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return &runnerScaleSetList, nil +} + +// CreateRunnerScaleSet creates a new runner scale set in the target GitHub entity. +func (s *ScaleSetClient) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *params.RunnerScaleSet) (params.RunnerScaleSet, error) { + body, err := json.Marshal(runnerScaleSet) + if err != nil { + return params.RunnerScaleSet{}, err + } + + req, err := s.newActionsRequest(ctx, http.MethodPost, scaleSetEndpoint, bytes.NewReader(body)) + if err != nil { + return params.RunnerScaleSet{}, err + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to create runner scale set: %w", err) + } + + var createdRunnerScaleSet params.RunnerScaleSet + if err := json.NewDecoder(resp.Body).Decode(&createdRunnerScaleSet); err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err) + } + return createdRunnerScaleSet, nil +} + +func (s *ScaleSetClient) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, runnerScaleSet params.RunnerScaleSet) (params.RunnerScaleSet, error) { + path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId) + + body, err := json.Marshal(runnerScaleSet) + if err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := s.newActionsRequest(ctx, http.MethodPatch, path, bytes.NewReader(body)) + if err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to make request: %w", err) + } + + var ret params.RunnerScaleSet + if err := json.NewDecoder(resp.Body).Decode(&ret); err != nil { + return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err) + } + return ret, nil +} + +func (s *ScaleSetClient) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) error { + path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId) + req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil) + if err != nil { + return err + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("failed to delete scale set with code %d", resp.StatusCode) + } + + resp.Body.Close() + return nil +} + +func (s *ScaleSetClient) GetRunnerGroupByName(ctx context.Context, runnerGroup string) (params.RunnerGroup, error) { + path := fmt.Sprintf("_apis/runtime/runnergroups/?groupName=%s", runnerGroup) + req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return params.RunnerGroup{}, err + } + + resp, err := s.Do(req) + if err != nil { + return params.RunnerGroup{}, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + var runnerGroupList params.RunnerGroupList + err = json.NewDecoder(resp.Body).Decode(&runnerGroupList) + if err != nil { + return params.RunnerGroup{}, fmt.Errorf("failed to decode response: %w", err) + } + + if runnerGroupList.Count == 0 { + return params.RunnerGroup{}, runnerErrors.NewNotFoundError("runner group %s does not exist", runnerGroup) + } + + if runnerGroupList.Count > 1 { + return params.RunnerGroup{}, runnerErrors.NewConflictError("multiple runner groups exist with the same name (%s)", runnerGroup) + } + + return runnerGroupList.RunnerGroups[0], nil +} diff --git a/util/github/scalesets/token.go b/util/github/scalesets/token.go new file mode 100644 index 00000000..47aa764f --- /dev/null +++ b/util/github/scalesets/token.go @@ -0,0 +1,105 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package scalesets + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/cloudbase/garm/params" +) + +func (s *ScaleSetClient) getActionServiceInfo(ctx context.Context) (params.ActionsServiceAdminInfoResponse, error) { + regPath := "/actions/runner-registration" + baseURL := s.ghCli.GithubBaseURL() + url, err := baseURL.Parse(regPath) + if err != nil { + return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to parse url: %w", err) + } + + entity := s.ghCli.GetEntity() + body := params.ActionsServiceAdminInfoRequest{ + URL: entity.GithubURL(), + RunnerEvent: "register", + } + + buf := &bytes.Buffer{} + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + + if err := enc.Encode(body); err != nil { + return params.ActionsServiceAdminInfoResponse{}, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), buf) + if err != nil { + return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("RemoteAuth %s", *s.runnerRegistrationToken.Token)) + + resp, err := s.Do(req) + if err != nil { + return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to get actions service admin info: %w", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to read response body: %w", err) + } + data = bytes.TrimPrefix(data, []byte("\xef\xbb\xbf")) + + var info params.ActionsServiceAdminInfoResponse + if err := json.Unmarshal(data, &info); err != nil { + return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to decode response: %w", err) + } + + return info, nil +} + +func (s *ScaleSetClient) ensureAdminInfo(ctx context.Context) error { + s.mux.Lock() + defer s.mux.Unlock() + + var expiresAt time.Time + if s.runnerRegistrationToken != nil { + expiresAt = s.runnerRegistrationToken.GetExpiresAt().Time + } + + now := time.Now().UTC().Add(2 * time.Minute) + if now.After(expiresAt) || s.runnerRegistrationToken == nil { + token, _, err := s.ghCli.CreateEntityRegistrationToken(ctx) + if err != nil { + return fmt.Errorf("failed to fetch runner registration token: %w", err) + } + s.runnerRegistrationToken = token + } + + if s.actionsServiceInfo == nil || s.actionsServiceInfo.ExpiresIn(2*time.Minute) { + info, err := s.getActionServiceInfo(ctx) + if err != nil { + return fmt.Errorf("failed to get action service info: %w", err) + } + s.actionsServiceInfo = &info + } + + return nil +} diff --git a/util/github/scalesets/util.go b/util/github/scalesets/util.go new file mode 100644 index 00000000..4f79098b --- /dev/null +++ b/util/github/scalesets/util.go @@ -0,0 +1,54 @@ +// Copyright 2024 Cloudbase Solutions SRL +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package scalesets + +import ( + "context" + "fmt" + "io" + "net/http" +) + +func (s *ScaleSetClient) newActionsRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) { + if err := s.ensureAdminInfo(ctx); err != nil { + return nil, fmt.Errorf("failed to update token: %w", err) + } + + actionsUri, err := s.actionsServiceInfo.GetURL() + if err != nil { + return nil, fmt.Errorf("failed to get pipeline URL: %w", err) + } + + uri, err := actionsUri.Parse(path) + if err != nil { + return nil, fmt.Errorf("failed to parse path: %w", err) + } + + q := uri.Query() + if q.Get("api-version") == "" { + q.Set("api-version", "6.0-preview") + } + uri.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, method, uri.String(), body) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.actionsServiceInfo.Token)) + + return req, nil +} diff --git a/util/logging.go b/util/logging.go index ac35863b..bb7b0562 100644 --- a/util/logging.go +++ b/util/logging.go @@ -25,6 +25,6 @@ func (h ContextHandler) Handle(ctx context.Context, r slog.Record) error { return h.Handler.Handle(ctx, r) } -func WithContext(ctx context.Context, attrs ...slog.Attr) context.Context { +func WithSlogContext(ctx context.Context, attrs ...slog.Attr) context.Context { return context.WithValue(ctx, slogCtxFields, attrs) } diff --git a/util/util.go b/util/util.go index eb390743..da1264d2 100644 --- a/util/util.go +++ b/util/util.go @@ -16,442 +16,30 @@ package util import ( "context" - "encoding/base64" - "encoding/json" - "fmt" - "log/slog" "net/http" - "github.com/google/go-github/v57/github" "github.com/pkg/errors" runnerErrors "github.com/cloudbase/garm-provider-common/errors" - "github.com/cloudbase/garm/metrics" - "github.com/cloudbase/garm/params" + commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm/runner/common" ) -type githubClient struct { - *github.ActionsService - org *github.OrganizationsService - repo *github.RepositoriesService - enterprise *github.EnterpriseService - - entity params.GithubEntity -} - -func (g *githubClient) ListEntityHooks(ctx context.Context, opts *github.ListOptions) (ret []*github.Hook, response *github.Response, err error) { - metrics.GithubOperationCount.WithLabelValues( - "ListHooks", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "ListHooks", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, response, err = g.repo.ListHooks(ctx, g.entity.Owner, g.entity.Name, opts) - case params.GithubEntityTypeOrganization: - ret, response, err = g.org.ListHooks(ctx, g.entity.Owner, opts) - default: - return nil, nil, fmt.Errorf("invalid entity type: %s", g.entity.EntityType) - } - return ret, response, err -} - -func (g *githubClient) GetEntityHook(ctx context.Context, id int64) (ret *github.Hook, err error) { - metrics.GithubOperationCount.WithLabelValues( - "GetHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "GetHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, _, err = g.repo.GetHook(ctx, g.entity.Owner, g.entity.Name, id) - case params.GithubEntityTypeOrganization: - ret, _, err = g.org.GetHook(ctx, g.entity.Owner, id) - default: - return nil, errors.New("invalid entity type") - } - return ret, err -} - -func (g *githubClient) CreateEntityHook(ctx context.Context, hook *github.Hook) (ret *github.Hook, err error) { - metrics.GithubOperationCount.WithLabelValues( - "CreateHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "CreateHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, _, err = g.repo.CreateHook(ctx, g.entity.Owner, g.entity.Name, hook) - case params.GithubEntityTypeOrganization: - ret, _, err = g.org.CreateHook(ctx, g.entity.Owner, hook) - default: - return nil, errors.New("invalid entity type") - } - return ret, err -} - -func (g *githubClient) DeleteEntityHook(ctx context.Context, id int64) (ret *github.Response, err error) { - metrics.GithubOperationCount.WithLabelValues( - "DeleteHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "DeleteHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, err = g.repo.DeleteHook(ctx, g.entity.Owner, g.entity.Name, id) - case params.GithubEntityTypeOrganization: - ret, err = g.org.DeleteHook(ctx, g.entity.Owner, id) - default: - return nil, errors.New("invalid entity type") - } - return ret, err -} - -func (g *githubClient) PingEntityHook(ctx context.Context, id int64) (ret *github.Response, err error) { - metrics.GithubOperationCount.WithLabelValues( - "PingHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "PingHook", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, err = g.repo.PingHook(ctx, g.entity.Owner, g.entity.Name, id) - case params.GithubEntityTypeOrganization: - ret, err = g.org.PingHook(ctx, g.entity.Owner, id) - default: - return nil, errors.New("invalid entity type") - } - return ret, err -} - -func (g *githubClient) ListEntityRunners(ctx context.Context, opts *github.ListOptions) (*github.Runners, *github.Response, error) { - var ret *github.Runners - var response *github.Response - var err error - - metrics.GithubOperationCount.WithLabelValues( - "ListEntityRunners", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "ListEntityRunners", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, response, err = g.ListRunners(ctx, g.entity.Owner, g.entity.Name, opts) - case params.GithubEntityTypeOrganization: - ret, response, err = g.ListOrganizationRunners(ctx, g.entity.Owner, opts) - case params.GithubEntityTypeEnterprise: - ret, response, err = g.enterprise.ListRunners(ctx, g.entity.Owner, opts) - default: - return nil, nil, errors.New("invalid entity type") - } - - return ret, response, err -} - -func (g *githubClient) ListEntityRunnerApplicationDownloads(ctx context.Context) ([]*github.RunnerApplicationDownload, *github.Response, error) { - var ret []*github.RunnerApplicationDownload - var response *github.Response - var err error - - metrics.GithubOperationCount.WithLabelValues( - "ListEntityRunnerApplicationDownloads", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "ListEntityRunnerApplicationDownloads", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, response, err = g.ListRunnerApplicationDownloads(ctx, g.entity.Owner, g.entity.Name) - case params.GithubEntityTypeOrganization: - ret, response, err = g.ListOrganizationRunnerApplicationDownloads(ctx, g.entity.Owner) - case params.GithubEntityTypeEnterprise: - ret, response, err = g.enterprise.ListRunnerApplicationDownloads(ctx, g.entity.Owner) - default: - return nil, nil, errors.New("invalid entity type") - } - - return ret, response, err -} - -func (g *githubClient) RemoveEntityRunner(ctx context.Context, runnerID int64) (*github.Response, error) { - var response *github.Response - var err error - - metrics.GithubOperationCount.WithLabelValues( - "RemoveEntityRunner", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "RemoveEntityRunner", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - response, err = g.RemoveRunner(ctx, g.entity.Owner, g.entity.Name, runnerID) - case params.GithubEntityTypeOrganization: - response, err = g.RemoveOrganizationRunner(ctx, g.entity.Owner, runnerID) - case params.GithubEntityTypeEnterprise: - response, err = g.enterprise.RemoveRunner(ctx, g.entity.Owner, runnerID) - default: - return nil, errors.New("invalid entity type") - } - - return response, err -} - -func (g *githubClient) CreateEntityRegistrationToken(ctx context.Context) (*github.RegistrationToken, *github.Response, error) { - var ret *github.RegistrationToken - var response *github.Response - var err error - - metrics.GithubOperationCount.WithLabelValues( - "CreateEntityRegistrationToken", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - defer func() { - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "CreateEntityRegistrationToken", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - } - }() - - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, response, err = g.CreateRegistrationToken(ctx, g.entity.Owner, g.entity.Name) - case params.GithubEntityTypeOrganization: - ret, response, err = g.CreateOrganizationRegistrationToken(ctx, g.entity.Owner) - case params.GithubEntityTypeEnterprise: - ret, response, err = g.enterprise.CreateRegistrationToken(ctx, g.entity.Owner) - default: - return nil, nil, errors.New("invalid entity type") - } - - return ret, response, err -} - -func (g *githubClient) getOrganizationRunnerGroupIDByName(ctx context.Context, entity params.GithubEntity, rgName string) (int64, error) { - opts := github.ListOrgRunnerGroupOptions{ - ListOptions: github.ListOptions{ - PerPage: 100, - }, - } - - for { - metrics.GithubOperationCount.WithLabelValues( - "ListOrganizationRunnerGroups", // label: operation - entity.LabelScope(), // label: scope - ).Inc() - runnerGroups, ghResp, err := g.ListOrganizationRunnerGroups(ctx, entity.Owner, &opts) - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "ListOrganizationRunnerGroups", // label: operation - entity.LabelScope(), // label: scope - ).Inc() - if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { - return 0, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runners") - } - return 0, errors.Wrap(err, "fetching runners") - } - for _, runnerGroup := range runnerGroups.RunnerGroups { - if runnerGroup.Name != nil && *runnerGroup.Name == rgName { - return *runnerGroup.ID, nil - } - } - if ghResp.NextPage == 0 { - break - } - opts.Page = ghResp.NextPage - } - return 0, runnerErrors.NewNotFoundError("runner group %s not found", rgName) -} - -func (g *githubClient) getEnterpriseRunnerGroupIDByName(ctx context.Context, entity params.GithubEntity, rgName string) (int64, error) { - opts := github.ListEnterpriseRunnerGroupOptions{ - ListOptions: github.ListOptions{ - PerPage: 100, - }, - } - - for { - metrics.GithubOperationCount.WithLabelValues( - "ListRunnerGroups", // label: operation - entity.LabelScope(), // label: scope - ).Inc() - runnerGroups, ghResp, err := g.enterprise.ListRunnerGroups(ctx, entity.Owner, &opts) - if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "ListRunnerGroups", // label: operation - entity.LabelScope(), // label: scope - ).Inc() - if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { - return 0, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runners") - } - return 0, errors.Wrap(err, "fetching runners") - } - for _, runnerGroup := range runnerGroups.RunnerGroups { - if runnerGroup.Name != nil && *runnerGroup.Name == rgName { - return *runnerGroup.ID, nil - } - } - if ghResp.NextPage == 0 { - break - } - opts.Page = ghResp.NextPage - } - return 0, runnerErrors.NewNotFoundError("runner group not found") -} - -func (g *githubClient) GetEntityJITConfig(ctx context.Context, instance string, pool params.Pool, labels []string) (jitConfigMap map[string]string, runner *github.Runner, err error) { - // If no runner group is set, use the default runner group ID. This is also the default for - // repository level runners. - var rgID int64 = 1 - - if pool.GitHubRunnerGroup != "" { - switch g.entity.EntityType { - case params.GithubEntityTypeOrganization: - rgID, err = g.getOrganizationRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup) - case params.GithubEntityTypeEnterprise: - rgID, err = g.getEnterpriseRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup) - } - - if err != nil { - return nil, nil, fmt.Errorf("getting runner group ID: %w", err) - } - } - - req := github.GenerateJITConfigRequest{ - Name: instance, - RunnerGroupID: rgID, - Labels: labels, - // nolint:golangci-lint,godox - // TODO(gabriel-samfira): Should we make this configurable? - WorkFolder: github.String("_work"), - } - - metrics.GithubOperationCount.WithLabelValues( - "GetEntityJITConfig", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - - var ret *github.JITRunnerConfig - var response *github.Response - - switch g.entity.EntityType { - case params.GithubEntityTypeRepository: - ret, response, err = g.GenerateRepoJITConfig(ctx, g.entity.Owner, g.entity.Name, &req) - case params.GithubEntityTypeOrganization: - ret, response, err = g.GenerateOrgJITConfig(ctx, g.entity.Owner, &req) - case params.GithubEntityTypeEnterprise: - ret, response, err = g.enterprise.GenerateEnterpriseJITConfig(ctx, g.entity.Owner, &req) - } +func FetchTools(ctx context.Context, cli common.GithubClient) ([]commonParams.RunnerApplicationDownload, error) { + tools, ghResp, err := cli.ListEntityRunnerApplicationDownloads(ctx) if err != nil { - metrics.GithubOperationFailedCount.WithLabelValues( - "GetEntityJITConfig", // label: operation - g.entity.LabelScope(), // label: scope - ).Inc() - if response != nil && response.StatusCode == http.StatusUnauthorized { - return nil, nil, fmt.Errorf("failed to get JIT config: %w", err) + if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { + return nil, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching tools") } - return nil, nil, fmt.Errorf("failed to get JIT config: %w", err) + return nil, errors.Wrap(err, "fetching runner tools") } - defer func(run *github.Runner) { - if err != nil && run != nil { - _, innerErr := g.RemoveEntityRunner(ctx, run.GetID()) - slog.With(slog.Any("error", innerErr)).ErrorContext( - ctx, "failed to remove runner", - "runner_id", run.GetID(), string(g.entity.EntityType), g.entity.String()) + ret := []commonParams.RunnerApplicationDownload{} + for _, tool := range tools { + if tool == nil { + continue } - }(ret.Runner) - - decoded, err := base64.StdEncoding.DecodeString(*ret.EncodedJITConfig) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode JIT config: %w", err) + ret = append(ret, commonParams.RunnerApplicationDownload(*tool)) } - - var jitConfig map[string]string - if err := json.Unmarshal(decoded, &jitConfig); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal JIT config: %w", err) - } - - return jitConfig, ret.Runner, nil -} - -func GithubClient(ctx context.Context, entity params.GithubEntity, credsDetails params.GithubCredentials) (common.GithubClient, error) { - httpClient, err := credsDetails.GetHTTPClient(ctx) - if err != nil { - return nil, errors.Wrap(err, "fetching http client") - } - - ghClient, err := github.NewClient(httpClient).WithEnterpriseURLs(credsDetails.APIBaseURL, credsDetails.UploadBaseURL) - if err != nil { - return nil, errors.Wrap(err, "fetching github client") - } - - cli := &githubClient{ - ActionsService: ghClient.Actions, - org: ghClient.Organizations, - repo: ghClient.Repositories, - enterprise: ghClient.Enterprise, - entity: entity, - } - return cli, nil + return ret, nil }