Handle scale up and down; add provider worker

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2025-04-20 17:39:52 +00:00
parent 7376a5fe74
commit 020210d6ad
14 changed files with 372 additions and 39 deletions

View file

@ -46,6 +46,7 @@ import (
"github.com/cloudbase/garm/params"
"github.com/cloudbase/garm/runner" //nolint:typecheck
runnerMetrics "github.com/cloudbase/garm/runner/metrics"
"github.com/cloudbase/garm/runner/providers"
garmUtil "github.com/cloudbase/garm/util"
"github.com/cloudbase/garm/util/appdefaults"
"github.com/cloudbase/garm/websocket"
@ -62,16 +63,17 @@ var signals = []os.Signal{
syscall.SIGTERM,
}
func maybeInitController(db common.Store) error {
if _, err := db.ControllerInfo(); err == nil {
return nil
func maybeInitController(db common.Store) (params.ControllerInfo, error) {
if info, err := db.ControllerInfo(); err == nil {
return info, nil
}
if _, err := db.InitController(); err != nil {
return errors.Wrap(err, "initializing controller")
info, err := db.InitController()
if err != nil {
return params.ControllerInfo{}, errors.Wrap(err, "initializing controller")
}
return nil
return info, nil
}
func setupLogging(ctx context.Context, logCfg config.Logging, hub *websocket.Hub) {
@ -212,7 +214,8 @@ func main() {
log.Fatal(err)
}
if err := maybeInitController(db); err != nil {
controllerInfo, err := maybeInitController(db)
if err != nil {
log.Fatal(err)
}
@ -231,7 +234,12 @@ func main() {
log.Fatal(err)
}
entityController, err := entity.NewController(ctx, db, *cfg)
providers, err := providers.LoadProvidersFromConfig(ctx, *cfg, controllerInfo.ControllerID.String())
if err != nil {
log.Fatalf("loading providers: %+v", err)
}
entityController, err := entity.NewController(ctx, db, providers)
if err != nil {
log.Fatalf("failed to create entity controller: %+v", err)
}

View file

@ -92,6 +92,7 @@ type UserStore interface {
type InstanceStore interface {
CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error)
DeleteInstance(ctx context.Context, poolID string, instanceName string) error
DeleteInstanceByName(ctx context.Context, instanceName string) error
UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error)
// Probably a bad idea without some king of filter or at least pagination

View file

@ -177,6 +177,39 @@ func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceN
return nil
}
func (s *sqlDatabase) DeleteInstanceByName(ctx context.Context, instanceName string) error {
instance, err := s.getInstanceByName(ctx, instanceName)
if err != nil {
return errors.Wrap(err, "deleting instance")
}
defer func() {
if err == nil {
var providerID string
if instance.ProviderID != nil {
providerID = *instance.ProviderID
}
if notifyErr := s.sendNotify(common.InstanceEntityType, common.DeleteOperation, params.Instance{
ID: instance.ID.String(),
Name: instance.Name,
ProviderID: providerID,
AgentID: instance.AgentID,
PoolID: instance.PoolID.String(),
}); notifyErr != nil {
slog.With(slog.Any("error", notifyErr)).Error("failed to send notify")
}
}
}()
if q := s.conn.Unscoped().Delete(&instance); q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return nil
}
return errors.Wrap(q.Error, "deleting instance")
}
return nil
}
func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error {
instance, err := s.getInstanceByName(ctx, instanceName)
if err != nil {
@ -293,7 +326,7 @@ func (s *sqlDatabase) ListPoolInstances(_ context.Context, poolID string) ([]par
func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, error) {
var instances []Instance
q := s.conn.Model(&Instance{}).Preload("Job", "Pool", "ScaleSet").Find(&instances)
q := s.conn.Model(&Instance{}).Preload("Job").Find(&instances)
if q.Error != nil {
return nil, errors.Wrap(q.Error, "fetching instances")
}

View file

@ -277,7 +277,7 @@ type Instance struct {
GitHubRunnerGroup string
AditionalLabels datatypes.JSON
PoolID uuid.UUID
PoolID *uuid.UUID
Pool Pool `gorm:"foreignKey:PoolID"`
ScaleSetFkID *uint

View file

@ -51,7 +51,7 @@ func (s *sqlDatabase) CreateScaleSetInstance(_ context.Context, scaleSetID uint,
func (s *sqlDatabase) ListScaleSetInstances(_ context.Context, scalesetID uint) ([]params.Instance, error) {
var instances []Instance
query := s.conn.Model(&Instance{}).Preload("Job", "ScaleSet").Where("scale_set_fk_id = ?", scalesetID)
query := s.conn.Model(&Instance{}).Preload("Job").Where("scale_set_fk_id = ?", scalesetID)
if err := query.Find(&instances); err.Error != nil {
return nil, errors.Wrap(err.Error, "fetching instances")

View file

@ -79,7 +79,7 @@ func (s *sqlDatabase) sqlToParamsInstance(instance Instance) (params.Instance, e
ret.RunnerBootstrapTimeout = instance.ScaleSet.RunnerBootstrapTimeout
}
if instance.PoolID != uuid.Nil {
if instance.PoolID != nil {
ret.PoolID = instance.PoolID.String()
ret.ProviderName = instance.Pool.ProviderName
ret.RunnerBootstrapTimeout = instance.Pool.RunnerBootstrapTimeout

View file

@ -1,6 +1,8 @@
package watcher
import (
commonParams "github.com/cloudbase/garm-provider-common/params"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/params"
)
@ -281,3 +283,24 @@ func WithEntityTypeAndCallbackFilter(entityType dbCommon.DatabaseEntityType, cal
return ok
}
}
func WithInstanceStatusFilter(statuses ...commonParams.InstanceStatus) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
if payload.EntityType != dbCommon.InstanceEntityType {
return false
}
instance, ok := payload.Payload.(params.Instance)
if !ok {
return false
}
if len(statuses) == 0 {
return false
}
for _, status := range statuses {
if instance.Status == status {
return true
}
}
return false
}
}

View file

@ -419,18 +419,18 @@ func (r RunnerScaleSetMessage) GetJobsFromBody() ([]ScaleSetJobMessage, error) {
}
type RunnerReference struct {
ID int `json:"id"`
Name string `json:"name"`
RunnerScaleSetID int `json:"runnerScaleSetId"`
CreatedOn time.Time `json:"createdOn"`
RunnerGroupID uint64 `json:"runnerGroupId"`
RunnerGroupName string `json:"runnerGroupName"`
Version string `json:"version"`
Enabled bool `json:"enabled"`
Ephemeral bool `json:"ephemeral"`
Status RunnerStatus `json:"status"`
DisableUpdate bool `json:"disableUpdate"`
ProvisioningState string `json:"provisioningState"`
ID int64 `json:"id"`
Name string `json:"name"`
RunnerScaleSetID int `json:"runnerScaleSetId"`
CreatedOn interface{} `json:"createdOn"`
RunnerGroupID uint64 `json:"runnerGroupId"`
RunnerGroupName string `json:"runnerGroupName"`
Version string `json:"version"`
Enabled bool `json:"enabled"`
Ephemeral bool `json:"ephemeral"`
Status interface{} `json:"status"`
DisableUpdate bool `json:"disableUpdate"`
ProvisioningState string `json:"provisioningState"`
}
type RunnerScaleSetJitRunnerConfig struct {

View file

@ -30,7 +30,7 @@ type scaleSetJitRunnerConfig struct {
WorkFolder string `json:"workFolder"`
}
func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName string, scaleSet params.RunnerScaleSet) (params.RunnerScaleSetJitRunnerConfig, error) {
func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName string, scaleSetID int) (params.RunnerScaleSetJitRunnerConfig, error) {
runnerSettings := scaleSetJitRunnerConfig{
Name: runnerName,
WorkFolder: "_work",
@ -41,7 +41,14 @@ func (s *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName
return params.RunnerScaleSetJitRunnerConfig{}, err
}
req, err := s.newActionsRequest(ctx, http.MethodPost, scaleSet.RunnerJitConfigURL, bytes.NewBuffer(body))
serviceUrl, err := s.actionsServiceInfo.GetURL()
if err != nil {
return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to get pipeline URL: %w", err)
}
jitConfigPath := fmt.Sprintf("/%s/%d/generatejitconfig", scaleSetEndpoint, scaleSetID)
jitConfigURL := serviceUrl.JoinPath(jitConfigPath)
req, err := s.newActionsRequest(ctx, http.MethodPost, jitConfigURL.String(), bytes.NewBuffer(body))
if err != nil {
return params.RunnerScaleSetJitRunnerConfig{}, fmt.Errorf("failed to create request: %w", err)
}
@ -81,6 +88,26 @@ func (s *ScaleSetClient) GetRunner(ctx context.Context, runnerID int64) (params.
return runnerReference, nil
}
func (s *ScaleSetClient) ListAllRunners(ctx context.Context) (params.RunnerReferenceList, error) {
req, err := s.newActionsRequest(ctx, http.MethodGet, runnerEndpoint, nil)
if err != nil {
return params.RunnerReferenceList{}, fmt.Errorf("failed to construct request: %w", err)
}
resp, err := s.Do(req)
if err != nil {
return params.RunnerReferenceList{}, fmt.Errorf("request failed for %s: %w", req.URL.String(), err)
}
defer resp.Body.Close()
var runnerList params.RunnerReferenceList
if err := json.NewDecoder(resp.Body).Decode(&runnerList); err != nil {
return params.RunnerReferenceList{}, fmt.Errorf("failed to decode response: %w", err)
}
return runnerList, nil
}
func (s *ScaleSetClient) GetRunnerByName(ctx context.Context, runnerName string) (params.RunnerReference, error) {
path := fmt.Sprintf("%s?agentName=%s", runnerEndpoint, runnerName)

View file

@ -18,6 +18,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
)
@ -50,5 +51,7 @@ func (s *ScaleSetClient) newActionsRequest(ctx context.Context, method, path str
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.actionsServiceInfo.Token))
slog.DebugContext(ctx, "newActionsRequest", "method", method, "url", uri.String(), "body", body, "headers", req.Header)
return req, nil
}

View file

@ -7,31 +7,19 @@ import (
"sync"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/config"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/runner/common"
"github.com/cloudbase/garm/runner/providers"
garmUtil "github.com/cloudbase/garm/util"
)
func NewController(ctx context.Context, store dbCommon.Store, cfg config.Config) (*Controller, error) {
func NewController(ctx context.Context, store dbCommon.Store, providers map[string]common.Provider) (*Controller, error) {
consumerID := "entity-controller"
ctrlID, err := store.ControllerInfo()
if err != nil {
return nil, fmt.Errorf("getting controller info: %w", err)
}
ctx = garmUtil.WithSlogContext(
ctx,
slog.Any("worker", consumerID))
ctx = auth.GetAdminContext(ctx)
providers, err := providers.LoadProvidersFromConfig(ctx, cfg, ctrlID.ControllerID.String())
if err != nil {
return nil, fmt.Errorf("loading providers: %w", err)
}
return &Controller{
consumerID: consumerID,
ctx: ctx,

View file

@ -0,0 +1,73 @@
package provider
import (
"context"
"fmt"
"sync"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/runner/common"
)
func NewWorker(ctx context.Context, store dbCommon.Store, providers map[string]common.Provider) (*provider, error) {
consumerID := "provider-worker"
return &provider{
ctx: context.Background(),
store: store,
consumerID: consumerID,
providers: providers,
}, nil
}
type provider struct {
ctx context.Context
consumerID string
consumer dbCommon.Consumer
// TODO: not all workers should have access to the store.
// We need to implement way to RPC from workers to controllers
// and abstract that into something we can use to eventually
// scale out.
store dbCommon.Store
providers map[string]common.Provider
mux sync.Mutex
running bool
quit chan struct{}
}
func (p *provider) Start() error {
p.mux.Lock()
defer p.mux.Unlock()
if p.running {
return nil
}
consumer, err := watcher.RegisterConsumer(
p.ctx, p.consumerID, composeProviderWatcher())
if err != nil {
return fmt.Errorf("registering consumer: %w", err)
}
p.consumer = consumer
p.quit = make(chan struct{})
p.running = true
return nil
}
func (p *provider) Stop() error {
p.mux.Lock()
defer p.mux.Unlock()
if !p.running {
return nil
}
p.consumer.Close()
close(p.quit)
p.running = false
return nil
}

18
workers/provider/util.go Normal file
View file

@ -0,0 +1,18 @@
package provider
import (
commonParams "github.com/cloudbase/garm-provider-common/params"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
)
func composeProviderWatcher() dbCommon.PayloadFilterFunc {
return watcher.WithAny(
watcher.WithInstanceStatusFilter(
commonParams.InstancePendingCreate,
commonParams.InstancePendingDelete,
commonParams.InstancePendingForceDelete,
),
)
}

View file

@ -2,13 +2,19 @@ package scaleset
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
"time"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
commonParams "github.com/cloudbase/garm-provider-common/params"
"github.com/cloudbase/garm-provider-common/util"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/locking"
"github.com/cloudbase/garm/params"
"github.com/cloudbase/garm/runner/common"
"github.com/cloudbase/garm/util/github/scalesets"
@ -188,6 +194,17 @@ func (w *Worker) handleScaleSetEvent(event dbCommon.ChangePayload) {
}
}
func (w *Worker) handleInstanceCleanup(instance params.Instance) error {
if instance.Status == commonParams.InstanceDeleted {
if err := w.store.DeleteInstanceByName(w.ctx, instance.Name); err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
return fmt.Errorf("deleting instance %s: %w", instance.ID, err)
}
}
}
return nil
}
func (w *Worker) handleInstanceEntityEvent(event dbCommon.ChangePayload) {
instance, ok := event.Payload.(params.Instance)
if !ok {
@ -319,6 +336,138 @@ func (w *Worker) keepListenerAlive() {
}
}
func (w *Worker) handleScaleUp(target, current uint) {
if !w.scaleSet.Enabled {
slog.DebugContext(w.ctx, "scale set is disabled; not scaling up")
return
}
if target <= current {
slog.DebugContext(w.ctx, "target is less than or equal to current; not scaling up")
return
}
controllerConfig, err := w.store.ControllerInfo()
if err != nil {
slog.ErrorContext(w.ctx, "error getting controller config", "error", err)
return
}
for i := current; i < target; i++ {
newRunnerName := fmt.Sprintf("%s-%s", w.scaleSet.GetRunnerPrefix(), util.NewID())
jitConfig, err := w.scaleSetCli.GenerateJitRunnerConfig(w.ctx, newRunnerName, w.scaleSet.ScaleSetID)
if err != nil {
slog.ErrorContext(w.ctx, "error generating jit config", "error", err)
continue
}
slog.DebugContext(w.ctx, "creating new runner", "runner_name", newRunnerName)
decodedJit, err := jitConfig.DecodedJITConfig()
if err != nil {
slog.ErrorContext(w.ctx, "error decoding jit config", "error", err)
continue
}
runnerParams := params.CreateInstanceParams{
Name: newRunnerName,
Status: commonParams.InstancePendingCreate,
RunnerStatus: params.RunnerPending,
OSArch: w.scaleSet.OSArch,
OSType: w.scaleSet.OSType,
CallbackURL: controllerConfig.CallbackURL,
MetadataURL: controllerConfig.MetadataURL,
CreateAttempt: 1,
GitHubRunnerGroup: w.scaleSet.GitHubRunnerGroup,
JitConfiguration: decodedJit,
AgentID: int64(jitConfig.Runner.ID),
}
if _, err := w.store.CreateScaleSetInstance(w.ctx, w.scaleSet.ID, runnerParams); err != nil {
slog.ErrorContext(w.ctx, "error creating instance", "error", err)
if err := w.scaleSetCli.RemoveRunner(w.ctx, jitConfig.Runner.ID); err != nil {
slog.ErrorContext(w.ctx, "error deleting runner", "error", err)
}
continue
}
runnerDetails, err := w.scaleSetCli.GetRunner(w.ctx, jitConfig.Runner.ID)
if err != nil {
slog.ErrorContext(w.ctx, "error getting runner details", "error", err)
continue
}
slog.DebugContext(w.ctx, "runner details", "runner_details", runnerDetails)
}
}
func (w *Worker) handleScaleDown(target, current uint) {
delta := current - target
if delta <= 0 {
return
}
w.mux.Lock()
defer w.mux.Unlock()
removed := 0
for _, runner := range w.runners {
if removed >= int(delta) {
break
}
locked, err := locking.TryLock(runner.Name)
if err != nil || !locked {
slog.DebugContext(w.ctx, "runner is locked; skipping", "runner_name", runner.Name)
continue
}
switch runner.Status {
case commonParams.InstancePendingCreate, commonParams.InstanceRunning:
case commonParams.InstancePendingDelete, commonParams.InstancePendingForceDelete:
removed++
locking.Unlock(runner.Name, true)
continue
default:
slog.DebugContext(w.ctx, "runner is not in a valid state; skipping", "runner_name", runner.Name, "runner_status", runner.Status)
locking.Unlock(runner.Name, false)
continue
}
switch runner.RunnerStatus {
case params.RunnerTerminated, params.RunnerActive:
slog.DebugContext(w.ctx, "runner is not in a valid state; skipping", "runner_name", runner.Name, "runner_status", runner.RunnerStatus)
locking.Unlock(runner.Name, false)
continue
}
slog.DebugContext(w.ctx, "removing runner", "runner_name", runner.Name)
if err := w.scaleSetCli.RemoveRunner(w.ctx, runner.AgentID); err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) {
slog.ErrorContext(w.ctx, "error removing runner", "runner_name", runner.Name, "error", err)
locking.Unlock(runner.Name, false)
continue
}
}
runnerUpdateParams := params.UpdateInstanceParams{
Status: commonParams.InstancePendingDelete,
}
if _, err := w.store.UpdateInstance(w.ctx, runner.Name, runnerUpdateParams); err != nil {
if errors.Is(err, runnerErrors.ErrNotFound) {
// The error seems to be that the instance was removed from the database. We still had it in our
// state, so either the update never came from the watcher or something else happened.
// Remove it from the local cache.
delete(w.runners, runner.ID)
removed++
locking.Unlock(runner.Name, true)
continue
}
// TODO: This should not happen, unless there is some issue with the database.
// The UpdateInstance() function should add tenacity, but even in that case, if it
// still errors out, we need to handle it somehow.
slog.ErrorContext(w.ctx, "error updating runner", "runner_name", runner.Name, "error", err)
locking.Unlock(runner.Name, false)
continue
}
removed++
locking.Unlock(runner.Name, false)
}
}
func (w *Worker) handleAutoScale() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
@ -337,6 +486,14 @@ func (w *Worker) handleAutoScale() {
case <-w.ctx.Done():
return
case <-ticker.C:
w.mux.Lock()
for _, instance := range w.runners {
if err := w.handleInstanceCleanup(instance); err != nil {
slog.ErrorContext(w.ctx, "error cleaning up instance", "instance_id", instance.ID, "error", err)
}
}
w.mux.Unlock()
var desiredRunners uint
if w.scaleSet.DesiredRunnerCount > 0 {
desiredRunners = uint(w.scaleSet.DesiredRunnerCount)
@ -351,8 +508,10 @@ func (w *Worker) handleAutoScale() {
if currentRunners < targetRunners {
lastMsgDebugLog("scaling up", targetRunners, currentRunners)
w.handleScaleUp(targetRunners, currentRunners)
} else {
lastMsgDebugLog("attempting to scale down", targetRunners, currentRunners)
w.handleScaleDown(targetRunners, currentRunners)
}
}
}