Add scaleset client

This change moves the github client to a subpackage in utils
and adds the scaleset github client code.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2025-04-04 20:44:57 +00:00
parent cc1470fe08
commit 79b9a1583c
13 changed files with 1432 additions and 432 deletions

471
util/github/client.go Normal file
View file

@ -0,0 +1,471 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package github
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/url"
"github.com/google/go-github/v57/github"
"github.com/pkg/errors"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/metrics"
"github.com/cloudbase/garm/params"
"github.com/cloudbase/garm/runner/common"
)
type githubClient struct {
*github.ActionsService
org *github.OrganizationsService
repo *github.RepositoriesService
enterprise *github.EnterpriseService
entity params.GithubEntity
cli *github.Client
}
func (g *githubClient) ListEntityHooks(ctx context.Context, opts *github.ListOptions) (ret []*github.Hook, response *github.Response, err error) {
metrics.GithubOperationCount.WithLabelValues(
"ListHooks", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"ListHooks", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, response, err = g.repo.ListHooks(ctx, g.entity.Owner, g.entity.Name, opts)
case params.GithubEntityTypeOrganization:
ret, response, err = g.org.ListHooks(ctx, g.entity.Owner, opts)
default:
return nil, nil, fmt.Errorf("invalid entity type: %s", g.entity.EntityType)
}
return ret, response, err
}
func (g *githubClient) GetEntityHook(ctx context.Context, id int64) (ret *github.Hook, err error) {
metrics.GithubOperationCount.WithLabelValues(
"GetHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"GetHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, _, err = g.repo.GetHook(ctx, g.entity.Owner, g.entity.Name, id)
case params.GithubEntityTypeOrganization:
ret, _, err = g.org.GetHook(ctx, g.entity.Owner, id)
default:
return nil, errors.New("invalid entity type")
}
return ret, err
}
func (g *githubClient) CreateEntityHook(ctx context.Context, hook *github.Hook) (ret *github.Hook, err error) {
metrics.GithubOperationCount.WithLabelValues(
"CreateHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"CreateHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, _, err = g.repo.CreateHook(ctx, g.entity.Owner, g.entity.Name, hook)
case params.GithubEntityTypeOrganization:
ret, _, err = g.org.CreateHook(ctx, g.entity.Owner, hook)
default:
return nil, errors.New("invalid entity type")
}
return ret, err
}
func (g *githubClient) DeleteEntityHook(ctx context.Context, id int64) (ret *github.Response, err error) {
metrics.GithubOperationCount.WithLabelValues(
"DeleteHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"DeleteHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, err = g.repo.DeleteHook(ctx, g.entity.Owner, g.entity.Name, id)
case params.GithubEntityTypeOrganization:
ret, err = g.org.DeleteHook(ctx, g.entity.Owner, id)
default:
return nil, errors.New("invalid entity type")
}
return ret, err
}
func (g *githubClient) PingEntityHook(ctx context.Context, id int64) (ret *github.Response, err error) {
metrics.GithubOperationCount.WithLabelValues(
"PingHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"PingHook", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, err = g.repo.PingHook(ctx, g.entity.Owner, g.entity.Name, id)
case params.GithubEntityTypeOrganization:
ret, err = g.org.PingHook(ctx, g.entity.Owner, id)
default:
return nil, errors.New("invalid entity type")
}
return ret, err
}
func (g *githubClient) ListEntityRunners(ctx context.Context, opts *github.ListOptions) (*github.Runners, *github.Response, error) {
var ret *github.Runners
var response *github.Response
var err error
metrics.GithubOperationCount.WithLabelValues(
"ListEntityRunners", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"ListEntityRunners", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, response, err = g.ListRunners(ctx, g.entity.Owner, g.entity.Name, opts)
case params.GithubEntityTypeOrganization:
ret, response, err = g.ListOrganizationRunners(ctx, g.entity.Owner, opts)
case params.GithubEntityTypeEnterprise:
ret, response, err = g.enterprise.ListRunners(ctx, g.entity.Owner, opts)
default:
return nil, nil, errors.New("invalid entity type")
}
return ret, response, err
}
func (g *githubClient) ListEntityRunnerApplicationDownloads(ctx context.Context) ([]*github.RunnerApplicationDownload, *github.Response, error) {
var ret []*github.RunnerApplicationDownload
var response *github.Response
var err error
metrics.GithubOperationCount.WithLabelValues(
"ListEntityRunnerApplicationDownloads", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"ListEntityRunnerApplicationDownloads", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, response, err = g.ListRunnerApplicationDownloads(ctx, g.entity.Owner, g.entity.Name)
case params.GithubEntityTypeOrganization:
ret, response, err = g.ListOrganizationRunnerApplicationDownloads(ctx, g.entity.Owner)
case params.GithubEntityTypeEnterprise:
ret, response, err = g.enterprise.ListRunnerApplicationDownloads(ctx, g.entity.Owner)
default:
return nil, nil, errors.New("invalid entity type")
}
return ret, response, err
}
func (g *githubClient) RemoveEntityRunner(ctx context.Context, runnerID int64) (*github.Response, error) {
var response *github.Response
var err error
metrics.GithubOperationCount.WithLabelValues(
"RemoveEntityRunner", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"RemoveEntityRunner", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
response, err = g.RemoveRunner(ctx, g.entity.Owner, g.entity.Name, runnerID)
case params.GithubEntityTypeOrganization:
response, err = g.RemoveOrganizationRunner(ctx, g.entity.Owner, runnerID)
case params.GithubEntityTypeEnterprise:
response, err = g.enterprise.RemoveRunner(ctx, g.entity.Owner, runnerID)
default:
return nil, errors.New("invalid entity type")
}
return response, err
}
func (g *githubClient) CreateEntityRegistrationToken(ctx context.Context) (*github.RegistrationToken, *github.Response, error) {
var ret *github.RegistrationToken
var response *github.Response
var err error
metrics.GithubOperationCount.WithLabelValues(
"CreateEntityRegistrationToken", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
defer func() {
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"CreateEntityRegistrationToken", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
}
}()
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, response, err = g.CreateRegistrationToken(ctx, g.entity.Owner, g.entity.Name)
case params.GithubEntityTypeOrganization:
ret, response, err = g.CreateOrganizationRegistrationToken(ctx, g.entity.Owner)
case params.GithubEntityTypeEnterprise:
ret, response, err = g.enterprise.CreateRegistrationToken(ctx, g.entity.Owner)
default:
return nil, nil, errors.New("invalid entity type")
}
return ret, response, err
}
func (g *githubClient) getOrganizationRunnerGroupIDByName(ctx context.Context, entity params.GithubEntity, rgName string) (int64, error) {
opts := github.ListOrgRunnerGroupOptions{
ListOptions: github.ListOptions{
PerPage: 100,
},
}
for {
metrics.GithubOperationCount.WithLabelValues(
"ListOrganizationRunnerGroups", // label: operation
entity.LabelScope(), // label: scope
).Inc()
runnerGroups, ghResp, err := g.ListOrganizationRunnerGroups(ctx, entity.Owner, &opts)
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"ListOrganizationRunnerGroups", // label: operation
entity.LabelScope(), // label: scope
).Inc()
if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized {
return 0, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runners")
}
return 0, errors.Wrap(err, "fetching runners")
}
for _, runnerGroup := range runnerGroups.RunnerGroups {
if runnerGroup.Name != nil && *runnerGroup.Name == rgName {
return *runnerGroup.ID, nil
}
}
if ghResp.NextPage == 0 {
break
}
opts.Page = ghResp.NextPage
}
return 0, runnerErrors.NewNotFoundError("runner group %s not found", rgName)
}
func (g *githubClient) getEnterpriseRunnerGroupIDByName(ctx context.Context, entity params.GithubEntity, rgName string) (int64, error) {
opts := github.ListEnterpriseRunnerGroupOptions{
ListOptions: github.ListOptions{
PerPage: 100,
},
}
for {
metrics.GithubOperationCount.WithLabelValues(
"ListRunnerGroups", // label: operation
entity.LabelScope(), // label: scope
).Inc()
runnerGroups, ghResp, err := g.enterprise.ListRunnerGroups(ctx, entity.Owner, &opts)
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"ListRunnerGroups", // label: operation
entity.LabelScope(), // label: scope
).Inc()
if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized {
return 0, errors.Wrap(runnerErrors.ErrUnauthorized, "fetching runners")
}
return 0, errors.Wrap(err, "fetching runners")
}
for _, runnerGroup := range runnerGroups.RunnerGroups {
if runnerGroup.Name != nil && *runnerGroup.Name == rgName {
return *runnerGroup.ID, nil
}
}
if ghResp.NextPage == 0 {
break
}
opts.Page = ghResp.NextPage
}
return 0, runnerErrors.NewNotFoundError("runner group not found")
}
func (g *githubClient) GetEntityJITConfig(ctx context.Context, instance string, pool params.Pool, labels []string) (jitConfigMap map[string]string, runner *github.Runner, err error) {
// If no runner group is set, use the default runner group ID. This is also the default for
// repository level runners.
var rgID int64 = 1
if pool.GitHubRunnerGroup != "" {
switch g.entity.EntityType {
case params.GithubEntityTypeOrganization:
rgID, err = g.getOrganizationRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup)
case params.GithubEntityTypeEnterprise:
rgID, err = g.getEnterpriseRunnerGroupIDByName(ctx, g.entity, pool.GitHubRunnerGroup)
}
if err != nil {
return nil, nil, fmt.Errorf("getting runner group ID: %w", err)
}
}
req := github.GenerateJITConfigRequest{
Name: instance,
RunnerGroupID: rgID,
Labels: labels,
// nolint:golangci-lint,godox
// TODO(gabriel-samfira): Should we make this configurable?
WorkFolder: github.String("_work"),
}
metrics.GithubOperationCount.WithLabelValues(
"GetEntityJITConfig", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
var ret *github.JITRunnerConfig
var response *github.Response
switch g.entity.EntityType {
case params.GithubEntityTypeRepository:
ret, response, err = g.GenerateRepoJITConfig(ctx, g.entity.Owner, g.entity.Name, &req)
case params.GithubEntityTypeOrganization:
ret, response, err = g.GenerateOrgJITConfig(ctx, g.entity.Owner, &req)
case params.GithubEntityTypeEnterprise:
ret, response, err = g.enterprise.GenerateEnterpriseJITConfig(ctx, g.entity.Owner, &req)
}
if err != nil {
metrics.GithubOperationFailedCount.WithLabelValues(
"GetEntityJITConfig", // label: operation
g.entity.LabelScope(), // label: scope
).Inc()
if response != nil && response.StatusCode == http.StatusUnauthorized {
return nil, nil, fmt.Errorf("failed to get JIT config: %w", err)
}
return nil, nil, fmt.Errorf("failed to get JIT config: %w", err)
}
defer func(run *github.Runner) {
if err != nil && run != nil {
_, innerErr := g.RemoveEntityRunner(ctx, run.GetID())
slog.With(slog.Any("error", innerErr)).ErrorContext(
ctx, "failed to remove runner",
"runner_id", run.GetID(), string(g.entity.EntityType), g.entity.String())
}
}(ret.Runner)
decoded, err := base64.StdEncoding.DecodeString(*ret.EncodedJITConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to decode JIT config: %w", err)
}
var jitConfig map[string]string
if err := json.Unmarshal(decoded, &jitConfig); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal JIT config: %w", err)
}
return jitConfig, ret.Runner, nil
}
func (g *githubClient) GetEntity() params.GithubEntity {
return g.entity
}
func (g *githubClient) GithubBaseURL() *url.URL {
return g.cli.BaseURL
}
func GithubClient(ctx context.Context, entity params.GithubEntity) (common.GithubClient, error) {
// func GithubClient(ctx context.Context, entity params.GithubEntity) (common.GithubClient, error) {
httpClient, err := entity.Credentials.GetHTTPClient(ctx)
if err != nil {
return nil, errors.Wrap(err, "fetching http client")
}
ghClient, err := github.NewClient(httpClient).WithEnterpriseURLs(
entity.Credentials.APIBaseURL, entity.Credentials.UploadBaseURL)
if err != nil {
return nil, errors.Wrap(err, "fetching github client")
}
cli := &githubClient{
ActionsService: ghClient.Actions,
org: ghClient.Organizations,
repo: ghClient.Repositories,
enterprise: ghClient.Enterprise,
cli: ghClient,
entity: entity,
}
return cli, nil
}

View file

@ -0,0 +1,95 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"fmt"
"io"
"net/http"
"sync"
"github.com/google/go-github/v57/github"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/params"
"github.com/cloudbase/garm/runner/common"
)
func NewClient(cli common.GithubClient) (*ScaleSetClient, error) {
return &ScaleSetClient{
ghCli: cli,
httpClient: &http.Client{},
}, nil
}
type ScaleSetClient struct {
ghCli common.GithubClient
httpClient *http.Client
// scale sets are aparently available through the same security
// contex that a normal runner would use. We connect to the same
// API endpoint a runner would connect to, in order to fetch jobs.
// To do this, we use a runner registration token.
runnerRegistrationToken *github.RegistrationToken
// actionsServiceInfo holds the pipeline URL and the JWT token to
// access it. The pipeline URL is the base URL where we can access
// the scale set endpoints.
actionsServiceInfo *params.ActionsServiceAdminInfoResponse
mux sync.Mutex
}
func (s *ScaleSetClient) SetGithubClient(cli common.GithubClient) {
s.mux.Lock()
defer s.mux.Unlock()
s.ghCli = cli
}
func (s *ScaleSetClient) Do(req *http.Request) (*http.Response, error) {
if s.httpClient == nil {
return nil, fmt.Errorf("http client is not initialized")
}
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to dispatch HTTP request: %w", err)
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return resp, nil
}
var body []byte
if resp != nil {
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read body: %w", err)
}
}
switch resp.StatusCode {
case 404:
return nil, runnerErrors.NewNotFoundError("resource %s not found: %q", req.URL.String(), string(body))
case 400:
return nil, runnerErrors.NewBadRequestError("bad request while calling %s: %q", req.URL.String(), string(body))
case 409:
return nil, runnerErrors.NewConflictError("conflict while calling %s: %q", req.URL.String(), string(body))
case 401, 403:
return nil, runnerErrors.ErrUnauthorized
default:
return nil, fmt.Errorf("request to %s failed with status code %d: %q", req.URL.String(), resp.StatusCode, string(body))
}
}

View file

@ -0,0 +1,88 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/cloudbase/garm/params"
)
type acquireJobsResult struct {
Count int `json:"count"`
Value []int64 `json:"value"`
}
func (s *ScaleSetClient) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) {
u := fmt.Sprintf("%s/%d/acquirejobs?api-version=6.0-preview", scaleSetEndpoint, runnerScaleSetId)
body, err := json.Marshal(requestIds)
if err != nil {
return nil, err
}
req, err := s.newActionsRequest(ctx, http.MethodPost, u, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to construct request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken))
resp, err := s.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed for %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
var acquiredJobs acquireJobsResult
err = json.NewDecoder(resp.Body).Decode(&acquiredJobs)
if err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return acquiredJobs.Value, nil
}
func (s *ScaleSetClient) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (params.AcquirableJobList, error) {
path := fmt.Sprintf("%d/acquirablejobs", runnerScaleSetId)
req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return params.AcquirableJobList{}, fmt.Errorf("failed to construct request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return params.AcquirableJobList{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent {
return params.AcquirableJobList{Count: 0, Jobs: []params.AcquirableJob{}}, nil
}
var acquirableJobList params.AcquirableJobList
err = json.NewDecoder(resp.Body).Decode(&acquirableJobList)
if err != nil {
return params.AcquirableJobList{}, fmt.Errorf("failed to decode response: %w", err)
}
return acquirableJobList, nil
}

View file

@ -0,0 +1,265 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math/rand/v2"
"net/http"
"net/url"
"strconv"
"sync"
"time"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/params"
)
const maxCapacityHeader = "X-ScaleSetMaxCapacity"
func NewMessageSession(ctx context.Context, cli *ScaleSetClient, session *params.RunnerScaleSetSession) (*MessageSession, error) {
sess := &MessageSession{
ssCli: cli,
session: session,
ctx: ctx,
done: make(chan struct{}),
closed: false,
}
go sess.loop()
return sess, nil
}
type MessageSession struct {
ssCli *ScaleSetClient
session *params.RunnerScaleSetSession
ctx context.Context
done chan struct{}
closed bool
lastErr error
mux sync.Mutex
}
func (m *MessageSession) Close() error {
m.mux.Lock()
defer m.mux.Unlock()
if m.closed {
return nil
}
close(m.done)
m.closed = true
return nil
}
func (m *MessageSession) LastError() error {
return m.lastErr
}
func (m *MessageSession) loop() {
timer := time.NewTimer(1 * time.Minute)
defer timer.Stop()
if m.closed {
return
}
for {
select {
case <-m.ctx.Done():
return
case <-m.done:
return
case <-timer.C:
if err := m.maybeRefreshToken(m.ctx); err != nil {
// We endlessly retry. If it's a transient error, it should eventually
// work, if it's credentials issues, users can update them.
slog.With(slog.Any("error", err)).ErrorContext(m.ctx, "failed to refresh message queue token")
m.lastErr = err
}
}
}
}
func (m *MessageSession) SessionsRelativeURL() (string, error) {
if m.session == nil {
return "", fmt.Errorf("session is nil")
}
if m.session.RunnerScaleSet == nil {
return "", fmt.Errorf("runner scale set is nil")
}
relativePath := fmt.Sprintf("%s/%d/sessions/%s", scaleSetEndpoint, m.session.RunnerScaleSet.Id, m.session.SessionId.String())
return relativePath, nil
}
func (m *MessageSession) Refresh(ctx context.Context) error {
m.mux.Lock()
defer m.mux.Unlock()
relPath, err := m.SessionsRelativeURL()
if err != nil {
return fmt.Errorf("failed to get session URL: %w", err)
}
req, err := m.ssCli.newActionsRequest(ctx, http.MethodPatch, relPath, nil)
if err != nil {
return fmt.Errorf("failed to create message delete request: %w", err)
}
resp, err := m.ssCli.Do(req)
if err != nil {
return fmt.Errorf("failed to delete message session: %w", err)
}
var refreshedSession params.RunnerScaleSetSession
if err := json.NewDecoder(resp.Body).Decode(&refreshedSession); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
m.session = &refreshedSession
return nil
}
func (m *MessageSession) maybeRefreshToken(ctx context.Context) error {
if m.session == nil {
return fmt.Errorf("session is nil")
}
// add some jitter
jitter := time.Duration(rand.IntN(10000)) * time.Millisecond
if m.session.ExpiresIn(2*time.Minute + jitter) {
if err := m.Refresh(ctx); err != nil {
return fmt.Errorf("failed to refresh message queue token: %w", err)
}
}
return nil
}
func (m *MessageSession) GetMessage(ctx context.Context, lastMessageId int64, maxCapacity uint) (params.RunnerScaleSetMessage, error) {
u, err := url.Parse(m.session.MessageQueueUrl)
if err != nil {
return params.RunnerScaleSetMessage{}, err
}
if lastMessageId > 0 {
q := u.Query()
q.Set("lastMessageId", strconv.FormatInt(lastMessageId, 10))
u.RawQuery = q.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json; api-version=6.0-preview")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken))
req.Header.Set(maxCapacityHeader, fmt.Sprintf("%d", maxCapacity))
resp, err := m.ssCli.Do(req)
if err != nil {
return params.RunnerScaleSetMessage{}, fmt.Errorf("request to %s failed: %w", req.URL.String(), err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusAccepted {
return params.RunnerScaleSetMessage{}, nil
}
var message params.RunnerScaleSetMessage
if err := json.NewDecoder(resp.Body).Decode(&message); err != nil {
return params.RunnerScaleSetMessage{}, fmt.Errorf("failed to decode response: %w", err)
}
return message, nil
}
func (m *MessageSession) DeleteMessage(ctx context.Context, messageId int64) error {
u, err := url.Parse(m.session.MessageQueueUrl)
if err != nil {
return err
}
u.Path = fmt.Sprintf("%s/%d", u.Path, messageId)
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.String(), nil)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", m.session.MessageQueueAccessToken))
resp, err := m.ssCli.Do(req)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
func (s *ScaleSetClient) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*MessageSession, error) {
path := fmt.Sprintf("%s/%d/sessions", scaleSetEndpoint, runnerScaleSetId)
newSession := params.RunnerScaleSetSession{
OwnerName: owner,
}
requestData, err := json.Marshal(newSession)
if err != nil {
return nil, fmt.Errorf("failed to marshal session data: %w", err)
}
req, err := s.newActionsRequest(ctx, http.MethodPost, path, bytes.NewBuffer(requestData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request to %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
var createdSession params.RunnerScaleSetSession
if err := json.NewDecoder(resp.Body).Decode(&createdSession); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &MessageSession{
ssCli: s,
session: &createdSession,
}, nil
}
func (s *ScaleSetClient) DeleteMessageSession(ctx context.Context, session *MessageSession) error {
path, err := session.SessionsRelativeURL()
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil)
if err != nil {
return fmt.Errorf("failed to create message delete request: %w", err)
}
_, err = s.Do(req)
if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
return fmt.Errorf("failed to delete message session: %w", err)
}
}
return nil
}

View file

@ -0,0 +1,129 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/params"
)
type scaleSetJitRunnerConfig struct {
Name string `json:"name"`
WorkFolder string `json:"workFolder"`
}
func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName string, scaleSet params.RunnerScaleSet) (params.RunnerScaleSetJitRunnerConfig, error) {
runnerSettings := scaleSetJitRunnerConfig{
Name: runnerName,
WorkFolder: "_work",
}
body, err := json.Marshal(runnerSettings)
if err != nil {
return params.RunnerScaleSetJitRunnerConfig{}, err
}
req, err := s.newActionsRequest(ctx, http.MethodPost, scaleSet.RunnerJitConfigUrl, bytes.NewBuffer(body))
if err != nil {
return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
var runnerJitConfig params.RunnerScaleSetJitRunnerConfig
if err := json.NewDecoder(resp.Body).Decode(&runnerJitConfig); err != nil {
return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to decode response: %w", err)
}
return runnerJitConfig, nil
}
func (s *ScaleSetClient) GetRunner(ctx context.Context, runnerId int64) (params.RunnerReference, error) {
path := fmt.Sprintf("%s/%d", runnerEndpoint, runnerId)
req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return params.RunnerReference{}, fmt.Errorf("failed to construct request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerReference{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
var runnerReference params.RunnerReference
if err := json.NewDecoder(resp.Body).Decode(&runnerReference); err != nil {
return params.RunnerReference{}, fmt.Errorf("failed to decode response: %w", err)
}
return runnerReference, nil
}
func (s *ScaleSetClient) GetRunnerByName(ctx context.Context, runnerName string) (params.RunnerReference, error) {
path := fmt.Sprintf("%s?agentName=%s", runnerEndpoint, runnerName)
req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return params.RunnerReference{}, fmt.Errorf("failed to construct request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerReference{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
var runnerList params.RunnerReferenceList
if err := json.NewDecoder(resp.Body).Decode(&runnerList); err != nil {
return params.RunnerReference{}, fmt.Errorf("failed to decode response: %w", err)
}
if runnerList.Count == 0 {
return params.RunnerReference{}, fmt.Errorf("could not find runner with name %q: %w", runnerName, runnerErrors.ErrNotFound)
}
if runnerList.Count > 1 {
return params.RunnerReference{}, fmt.Errorf("failed to decode response: %w", err)
}
return runnerList.RunnerReferences[0], nil
}
func (s *ScaleSetClient) RemoveRunner(ctx context.Context, runnerId int64) error {
path := fmt.Sprintf("%s/%d", runnerEndpoint, runnerId)
req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil)
if err != nil {
return fmt.Errorf("failed to construct request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return fmt.Errorf("request failed for %s: %w", req.URL.String(), err)
}
resp.Body.Close()
return nil
}

View file

@ -0,0 +1,204 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httputil"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/params"
)
const (
runnerEndpoint = "_apis/distributedtask/pools/0/agents"
scaleSetEndpoint = "_apis/runtime/runnerscalesets"
)
const (
HeaderActionsActivityID = "ActivityId"
HeaderGitHubRequestID = "X-GitHub-Request-Id"
)
func (s *ScaleSetClient) GetRunnerScaleSetByNameAndRunnerGroup(ctx context.Context, runnerGroupId int, name string) (params.RunnerScaleSet, error) {
path := fmt.Sprintf("%s?runnerGroupId=%d&name=%s", scaleSetEndpoint, runnerGroupId, name)
req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return params.RunnerScaleSet{}, err
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerScaleSet{}, err
}
var runnerScaleSetList *params.RunnerScaleSetsResponse
if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSetList); err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err)
}
if runnerScaleSetList.Count == 0 {
return params.RunnerScaleSet{}, runnerErrors.NewNotFoundError("runner scale set with name %s and runner group ID %d was not found", name, runnerGroupId)
}
// Runner scale sets must have a uniqe name. Attempting to create a runner scale set with the same name as
// an existing scale set will result in a Bad Request (400) error.
return runnerScaleSetList.RunnerScaleSets[0], nil
}
func (s *ScaleSetClient) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int) (params.RunnerScaleSet, error) {
path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId)
req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return params.RunnerScaleSet{}, err
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to get runner scaleset with ID %d: %w", runnerScaleSetId, err)
}
var runnerScaleSet params.RunnerScaleSet
if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSet); err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err)
}
return runnerScaleSet, nil
}
// ListRunnerScaleSets lists all runner scale sets in a github entity.
func (s *ScaleSetClient) ListRunnerScaleSets(ctx context.Context) (*params.RunnerScaleSetsResponse, error) {
req, err := s.newActionsRequest(ctx, http.MethodGet, scaleSetEndpoint, nil)
if err != nil {
return nil, err
}
data, err := httputil.DumpRequest(req, false)
if err == nil {
fmt.Println(string(data))
}
resp, err := s.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to list runner scale sets: %w", err)
}
var runnerScaleSetList params.RunnerScaleSetsResponse
if err := json.NewDecoder(resp.Body).Decode(&runnerScaleSetList); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &runnerScaleSetList, nil
}
// CreateRunnerScaleSet creates a new runner scale set in the target GitHub entity.
func (s *ScaleSetClient) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *params.RunnerScaleSet) (params.RunnerScaleSet, error) {
body, err := json.Marshal(runnerScaleSet)
if err != nil {
return params.RunnerScaleSet{}, err
}
req, err := s.newActionsRequest(ctx, http.MethodPost, scaleSetEndpoint, bytes.NewReader(body))
if err != nil {
return params.RunnerScaleSet{}, err
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to create runner scale set: %w", err)
}
var createdRunnerScaleSet params.RunnerScaleSet
if err := json.NewDecoder(resp.Body).Decode(&createdRunnerScaleSet); err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err)
}
return createdRunnerScaleSet, nil
}
func (s *ScaleSetClient) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, runnerScaleSet params.RunnerScaleSet) (params.RunnerScaleSet, error) {
path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId)
body, err := json.Marshal(runnerScaleSet)
if err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := s.newActionsRequest(ctx, http.MethodPatch, path, bytes.NewReader(body))
if err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to make request: %w", err)
}
var ret params.RunnerScaleSet
if err := json.NewDecoder(resp.Body).Decode(&ret); err != nil {
return params.RunnerScaleSet{}, fmt.Errorf("failed to decode response: %w", err)
}
return ret, nil
}
func (s *ScaleSetClient) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) error {
path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId)
req, err := s.newActionsRequest(ctx, http.MethodDelete, path, nil)
if err != nil {
return err
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("failed to delete scale set with code %d", resp.StatusCode)
}
resp.Body.Close()
return nil
}
func (s *ScaleSetClient) GetRunnerGroupByName(ctx context.Context, runnerGroup string) (params.RunnerGroup, error) {
path := fmt.Sprintf("_apis/runtime/runnergroups/?groupName=%s", runnerGroup)
req, err := s.newActionsRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return params.RunnerGroup{}, err
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerGroup{}, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
var runnerGroupList params.RunnerGroupList
err = json.NewDecoder(resp.Body).Decode(&runnerGroupList)
if err != nil {
return params.RunnerGroup{}, fmt.Errorf("failed to decode response: %w", err)
}
if runnerGroupList.Count == 0 {
return params.RunnerGroup{}, runnerErrors.NewNotFoundError("runner group %s does not exist", runnerGroup)
}
if runnerGroupList.Count > 1 {
return params.RunnerGroup{}, runnerErrors.NewConflictError("multiple runner groups exist with the same name (%s)", runnerGroup)
}
return runnerGroupList.RunnerGroups[0], nil
}

View file

@ -0,0 +1,105 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/cloudbase/garm/params"
)
func (s *ScaleSetClient) getActionServiceInfo(ctx context.Context) (params.ActionsServiceAdminInfoResponse, error) {
regPath := "/actions/runner-registration"
baseURL := s.ghCli.GithubBaseURL()
url, err := baseURL.Parse(regPath)
if err != nil {
return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to parse url: %w", err)
}
entity := s.ghCli.GetEntity()
body := params.ActionsServiceAdminInfoRequest{
URL: entity.GithubURL(),
RunnerEvent: "register",
}
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(body); err != nil {
return params.ActionsServiceAdminInfoResponse{}, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), buf)
if err != nil {
return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("RemoteAuth %s", *s.runnerRegistrationToken.Token))
resp, err := s.Do(req)
if err != nil {
return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to get actions service admin info: %w", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to read response body: %w", err)
}
data = bytes.TrimPrefix(data, []byte("\xef\xbb\xbf"))
var info params.ActionsServiceAdminInfoResponse
if err := json.Unmarshal(data, &info); err != nil {
return params.ActionsServiceAdminInfoResponse{}, fmt.Errorf("failed to decode response: %w", err)
}
return info, nil
}
func (s *ScaleSetClient) ensureAdminInfo(ctx context.Context) error {
s.mux.Lock()
defer s.mux.Unlock()
var expiresAt time.Time
if s.runnerRegistrationToken != nil {
expiresAt = s.runnerRegistrationToken.GetExpiresAt().Time
}
now := time.Now().UTC().Add(2 * time.Minute)
if now.After(expiresAt) || s.runnerRegistrationToken == nil {
token, _, err := s.ghCli.CreateEntityRegistrationToken(ctx)
if err != nil {
return fmt.Errorf("failed to fetch runner registration token: %w", err)
}
s.runnerRegistrationToken = token
}
if s.actionsServiceInfo == nil || s.actionsServiceInfo.ExpiresIn(2*time.Minute) {
info, err := s.getActionServiceInfo(ctx)
if err != nil {
return fmt.Errorf("failed to get action service info: %w", err)
}
s.actionsServiceInfo = &info
}
return nil
}

View file

@ -0,0 +1,54 @@
// Copyright 2024 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package scalesets
import (
"context"
"fmt"
"io"
"net/http"
)
func (s *ScaleSetClient) newActionsRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) {
if err := s.ensureAdminInfo(ctx); err != nil {
return nil, fmt.Errorf("failed to update token: %w", err)
}
actionsUri, err := s.actionsServiceInfo.GetURL()
if err != nil {
return nil, fmt.Errorf("failed to get pipeline URL: %w", err)
}
uri, err := actionsUri.Parse(path)
if err != nil {
return nil, fmt.Errorf("failed to parse path: %w", err)
}
q := uri.Query()
if q.Get("api-version") == "" {
q.Set("api-version", "6.0-preview")
}
uri.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, method, uri.String(), body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.actionsServiceInfo.Token))
return req, nil
}