Fix findEndpointForJob

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2025-05-14 21:09:02 +00:00
parent 56be5eb698
commit f66b651b59
13 changed files with 287 additions and 113 deletions

View file

@ -16,7 +16,7 @@ import (
"github.com/cloudbase/garm/params"
)
var systemdUnitTemplate = `[Unit]
var githubSystemdUnitTemplate = `[Unit]
Description=GitHub Actions Runner ({{.ServiceName}})
After=network.target
@ -32,11 +32,24 @@ TimeoutStopSec=5min
WantedBy=multi-user.target
`
func validateInstanceState(ctx context.Context) (params.Instance, error) {
if !auth.InstanceHasJITConfig(ctx) {
return params.Instance{}, fmt.Errorf("instance not configured for JIT: %w", runnerErrors.ErrNotFound)
}
var giteaSystemdUnitTemplate = `[Unit]
Description=Act Runner ({{.ServiceName}})
After=network.target
[Service]
ExecStart=/home/{{.RunAsUser}}/act-runner/act_runner daemon --once
User={{.RunAsUser}}
WorkingDirectory=/home/{{.RunAsUser}}/act-runner
KillMode=process
KillSignal=SIGTERM
TimeoutStopSec=5min
Restart=always
[Install]
WantedBy=multi-user.target
`
func validateInstanceState(ctx context.Context) (params.Instance, error) {
status := auth.InstanceRunnerStatus(ctx)
if status != params.RunnerPending && status != params.RunnerInstalling {
return params.Instance{}, runnerErrors.ErrUnauthorized
@ -49,6 +62,56 @@ func validateInstanceState(ctx context.Context) (params.Instance, error) {
return instance, nil
}
func (r *Runner) getForgeEntityFromInstance(ctx context.Context, instance params.Instance) (params.ForgeEntity, error) {
var entityGetter params.EntityGetter
var err error
switch {
case instance.PoolID != "":
entityGetter, err = r.store.GetPoolByID(r.ctx, instance.PoolID)
case instance.ScaleSetID != 0:
entityGetter, err = r.store.GetScaleSetByID(r.ctx, instance.ScaleSetID)
default:
return params.ForgeEntity{}, errors.New("instance not associated with a pool or scale set")
}
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get entity getter",
"instance", instance.Name)
return params.ForgeEntity{}, errors.Wrap(err, "fetching entity getter")
}
poolEntity, err := entityGetter.GetEntity()
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get entity",
"instance", instance.Name)
return params.ForgeEntity{}, errors.Wrap(err, "fetching entity")
}
entity, err := r.store.GetForgeEntity(r.ctx, poolEntity.EntityType, poolEntity.ID)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get entity",
"instance", instance.Name)
return params.ForgeEntity{}, errors.Wrap(err, "fetching entity")
}
return entity, nil
}
func (r *Runner) getServiceNameForEntity(entity params.ForgeEntity) (string, error) {
switch entity.EntityType {
case params.ForgeEntityTypeEnterprise:
return fmt.Sprintf("actions.runner.%s.%s", entity.Owner, entity.Name), nil
case params.ForgeEntityTypeOrganization:
return fmt.Sprintf("actions.runner.%s.%s", entity.Owner, entity.Name), nil
case params.ForgeEntityTypeRepository:
return fmt.Sprintf("actions.runner.%s-%s.%s", entity.Owner, entity.Name, entity.Name), nil
default:
return "", errors.New("unknown entity type")
}
}
func (r *Runner) GetRunnerServiceName(ctx context.Context) (string, error) {
instance, err := validateInstanceState(ctx)
if err != nil {
@ -56,64 +119,51 @@ func (r *Runner) GetRunnerServiceName(ctx context.Context) (string, error) {
ctx, "failed to get instance params")
return "", runnerErrors.ErrUnauthorized
}
var entity params.ForgeEntity
switch {
case instance.PoolID != "":
pool, err := r.store.GetPoolByID(r.ctx, instance.PoolID)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get pool",
"pool_id", instance.PoolID)
return "", errors.Wrap(err, "fetching pool")
}
entity, err = pool.GetEntity()
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get pool entity",
"pool_id", instance.PoolID)
return "", errors.Wrap(err, "fetching pool entity")
}
case instance.ScaleSetID != 0:
scaleSet, err := r.store.GetScaleSetByID(r.ctx, instance.ScaleSetID)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get scale set",
"scale_set_id", instance.ScaleSetID)
return "", errors.Wrap(err, "fetching scale set")
}
entity, err = scaleSet.GetEntity()
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get scale set entity",
"scale_set_id", instance.ScaleSetID)
return "", errors.Wrap(err, "fetching scale set entity")
}
default:
return "", errors.New("instance not associated with a pool or scale set")
entity, err := r.getForgeEntityFromInstance(ctx, instance)
if err != nil {
slog.ErrorContext(r.ctx, "failed to get entity", "error", err)
return "", errors.Wrap(err, "fetching entity")
}
tpl := "actions.runner.%s.%s"
var serviceName string
switch entity.EntityType {
case params.ForgeEntityTypeEnterprise:
serviceName = fmt.Sprintf(tpl, entity.Owner, instance.Name)
case params.ForgeEntityTypeOrganization:
serviceName = fmt.Sprintf(tpl, entity.Owner, instance.Name)
case params.ForgeEntityTypeRepository:
serviceName = fmt.Sprintf(tpl, fmt.Sprintf("%s-%s", entity.Owner, entity.Name), instance.Name)
serviceName, err := r.getServiceNameForEntity(entity)
if err != nil {
slog.ErrorContext(r.ctx, "failed to get service name", "error", err)
return "", errors.Wrap(err, "fetching service name")
}
return serviceName, nil
}
func (r *Runner) GenerateSystemdUnitFile(ctx context.Context, runAsUser string) ([]byte, error) {
serviceName, err := r.GetRunnerServiceName(ctx)
instance, err := validateInstanceState(ctx)
if err != nil {
return nil, errors.Wrap(err, "fetching runner service name")
slog.With(slog.Any("error", err)).ErrorContext(
ctx, "failed to get instance params")
return nil, runnerErrors.ErrUnauthorized
}
entity, err := r.getForgeEntityFromInstance(ctx, instance)
if err != nil {
slog.ErrorContext(r.ctx, "failed to get entity", "error", err)
return nil, errors.Wrap(err, "fetching entity")
}
unitTemplate, err := template.New("").Parse(systemdUnitTemplate)
serviceName, err := r.getServiceNameForEntity(entity)
if err != nil {
slog.ErrorContext(r.ctx, "failed to get service name", "error", err)
return nil, errors.Wrap(err, "fetching service name")
}
var unitTemplate *template.Template
switch entity.Credentials.ForgeType {
case params.GithubEndpointType:
unitTemplate, err = template.New("").Parse(githubSystemdUnitTemplate)
case params.GiteaEndpointType:
unitTemplate, err = template.New("").Parse(giteaSystemdUnitTemplate)
default:
slog.ErrorContext(r.ctx, "unknown forge type", "forge_type", entity.Credentials.ForgeType)
return nil, errors.New("unknown forge type")
}
if err != nil {
slog.ErrorContext(r.ctx, "failed to parse template", "error", err)
return nil, errors.Wrap(err, "parsing template")
}
@ -131,12 +181,17 @@ func (r *Runner) GenerateSystemdUnitFile(ctx context.Context, runAsUser string)
var unitFile bytes.Buffer
if err := unitTemplate.Execute(&unitFile, data); err != nil {
slog.ErrorContext(r.ctx, "failed to execute template", "error", err)
return nil, errors.Wrap(err, "executing template")
}
return unitFile.Bytes(), nil
}
func (r *Runner) GetJITConfigFile(ctx context.Context, file string) ([]byte, error) {
if !auth.InstanceHasJITConfig(ctx) {
return nil, fmt.Errorf("instance not configured for JIT: %w", runnerErrors.ErrNotFound)
}
instance, err := validateInstanceState(ctx)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(

View file

@ -47,15 +47,15 @@ import (
)
var (
poolIDLabelprefix = "runner-pool-id:"
controllerLabelPrefix = "runner-controller-id:"
poolIDLabelprefix = "runner-pool-id"
controllerLabelPrefix = "runner-controller-id"
// We tag runners that have been spawned as a result of a queued job with the job ID
// that spawned them. There is no way to guarantee that the runner spawned in response to a particular
// job, will be picked up by that job. We mark them so as in the very likely event that the runner
// has picked up a different job, we can clear the lock on the job that spaned it.
// The job it picked up would already be transitioned to in_progress so it will be ignored by the
// consume loop.
jobLabelPrefix = "in_response_to_job:"
jobLabelPrefix = "in_response_to_job"
)
const (
@ -296,7 +296,8 @@ func (r *basePoolManager) HandleWorkflowJob(job params.WorkflowJob) error {
func jobIDFromLabels(labels []string) int64 {
for _, lbl := range labels {
if strings.HasPrefix(lbl, jobLabelPrefix) {
jobID, err := strconv.ParseInt(lbl[len(jobLabelPrefix):], 10, 64)
trimLength := min(len(jobLabelPrefix)+1, len(lbl))
jobID, err := strconv.ParseInt(lbl[trimLength:], 10, 64)
if err != nil {
return 0
}
@ -361,21 +362,21 @@ func (r *basePoolManager) startLoopForFunction(f func() error, interval time.Dur
}
func (r *basePoolManager) updateTools() error {
// Update tools cache.
tools, err := r.FetchTools()
tools, err := cache.GetGithubToolsCache(r.entity.ID)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(
r.ctx, "failed to update tools for entity", "entity", r.entity.String())
r.SetPoolRunningState(false, err.Error())
return fmt.Errorf("failed to update tools for entity %s: %w", r.entity.String(), err)
}
r.mux.Lock()
r.tools = tools
r.mux.Unlock()
slog.DebugContext(r.ctx, "successfully updated tools")
r.SetPoolRunningState(true, "")
return err
return nil
}
// cleanupOrphanedProviderRunners compares runners in github with local runners and removes
@ -995,11 +996,11 @@ func (r *basePoolManager) paramsWorkflowJobToParamsJob(job params.WorkflowJob) (
}
func (r *basePoolManager) poolLabel(poolID string) string {
return fmt.Sprintf("%s%s", poolIDLabelprefix, poolID)
return fmt.Sprintf("%s=%s", poolIDLabelprefix, poolID)
}
func (r *basePoolManager) controllerLabel() string {
return fmt.Sprintf("%s%s", controllerLabelPrefix, r.controllerInfo.ControllerID.String())
return fmt.Sprintf("%s=%s", controllerLabelPrefix, r.controllerInfo.ControllerID.String())
}
func (r *basePoolManager) updateArgsFromProviderInstance(providerInstance commonParams.ProviderInstance) params.UpdateInstanceParams {
@ -1613,6 +1614,16 @@ func (r *basePoolManager) Start() error {
initialToolUpdate := make(chan struct{}, 1)
go func() {
slog.Info("running initial tool update")
for {
slog.DebugContext(r.ctx, "waiting for tools to be available")
hasTools, stopped := r.waitForToolsOrCancel()
if stopped {
return
}
if hasTools {
break
}
}
if err := r.updateTools(); err != nil {
slog.With(slog.Any("error", err)).Error("failed to update tools")
}
@ -1804,7 +1815,7 @@ func (r *basePoolManager) consumeQueuedJobs() error {
}
jobLabels := []string{
fmt.Sprintf("%s%d", jobLabelPrefix, job.ID),
fmt.Sprintf("%s=%d", jobLabelPrefix, job.ID),
}
for i := 0; i < poolRR.Len(); i++ {
pool, err := poolRR.Next()

View file

@ -5,11 +5,13 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/go-github/v71/github"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
commonParams "github.com/cloudbase/garm-provider-common/params"
"github.com/cloudbase/garm/cache"
dbCommon "github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/params"
@ -91,7 +93,8 @@ func instanceInList(instanceName string, instances []commonParams.ProviderInstan
func controllerIDFromLabels(labels []string) string {
for _, lbl := range labels {
if strings.HasPrefix(lbl, controllerLabelPrefix) {
return lbl[len(controllerLabelPrefix):]
trimLength := min(len(controllerLabelPrefix)+1, len(lbl))
return lbl[trimLength:]
}
}
return ""
@ -134,3 +137,19 @@ func composeWatcherFilters(entity params.ForgeEntity) dbCommon.PayloadFilterFunc
watcher.WithForgeCredentialsFilter(entity.Credentials),
)
}
func (r *basePoolManager) waitForToolsOrCancel() (hasTools, stopped bool) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
select {
case <-ticker.C:
if _, err := cache.GetGithubToolsCache(r.entity.ID); err != nil {
return false, false
}
return true, false
case <-r.quit:
return false, true
case <-r.ctx.Done():
return false, true
}
}

View file

@ -602,7 +602,7 @@ func (r *Runner) validateHookBody(signature, secret string, body []byte) error {
return nil
}
func (r *Runner) findEndpointForJob(job params.WorkflowJob) (params.ForgeEndpoint, error) {
func (r *Runner) findEndpointForJob(job params.WorkflowJob, forgeType params.EndpointType) (params.ForgeEndpoint, error) {
uri, err := url.ParseRequestURI(job.WorkflowJob.HTMLURL)
if err != nil {
return params.ForgeEndpoint{}, errors.Wrap(err, "parsing job URL")
@ -614,12 +614,23 @@ func (r *Runner) findEndpointForJob(job params.WorkflowJob) (params.ForgeEndpoin
// a GHES involved, those users will have just one extra endpoint or 2 (if they also have a
// test env). But there should be a relatively small number, regardless. So we don't really care
// that much about the performance of this function.
endpoints, err := r.store.ListGithubEndpoints(r.ctx)
var endpoints []params.ForgeEndpoint
switch forgeType {
case params.GithubEndpointType:
endpoints, err = r.store.ListGithubEndpoints(r.ctx)
case params.GiteaEndpointType:
endpoints, err = r.store.ListGiteaEndpoints(r.ctx)
default:
return params.ForgeEndpoint{}, runnerErrors.NewBadRequestError("unknown forge type %s", forgeType)
}
if err != nil {
return params.ForgeEndpoint{}, errors.Wrap(err, "fetching github endpoints")
}
for _, ep := range endpoints {
if ep.BaseURL == baseURI {
slog.DebugContext(r.ctx, "checking endpoint", "base_uri", baseURI, "endpoint", ep.BaseURL)
epBaseURI := strings.TrimSuffix(ep.BaseURL, "/")
if epBaseURI == baseURI {
return ep, nil
}
}
@ -627,18 +638,21 @@ func (r *Runner) findEndpointForJob(job params.WorkflowJob) (params.ForgeEndpoin
return params.ForgeEndpoint{}, runnerErrors.NewNotFoundError("no endpoint found for job")
}
func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData []byte) error {
func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, forgeType params.EndpointType, jobData []byte) error {
if len(jobData) == 0 {
slog.ErrorContext(r.ctx, "missing job data")
return runnerErrors.NewBadRequestError("missing job data")
}
var job params.WorkflowJob
if err := json.Unmarshal(jobData, &job); err != nil {
slog.ErrorContext(r.ctx, "failed to unmarshal job data", "error", err)
return errors.Wrapf(runnerErrors.ErrBadRequest, "invalid job data: %s", err)
}
endpoint, err := r.findEndpointForJob(job)
endpoint, err := r.findEndpointForJob(job, forgeType)
if err != nil {
slog.ErrorContext(r.ctx, "failed to find endpoint for job", "error", err)
return errors.Wrap(err, "finding endpoint for job")
}
@ -867,15 +881,17 @@ func (r *Runner) DeleteRunner(ctx context.Context, instanceName string, forceDel
}
if err != nil {
if errors.Is(err, runnerErrors.ErrUnauthorized) && instance.PoolID != "" {
poolMgr, err := r.getPoolManagerFromInstance(ctx, instance)
if err != nil {
return errors.Wrap(err, "fetching pool manager for instance")
if !errors.Is(err, runnerErrors.ErrNotFound) {
if errors.Is(err, runnerErrors.ErrUnauthorized) && instance.PoolID != "" {
poolMgr, err := r.getPoolManagerFromInstance(ctx, instance)
if err != nil {
return errors.Wrap(err, "fetching pool manager for instance")
}
poolMgr.SetPoolRunningState(false, fmt.Sprintf("failed to remove runner: %q", err))
}
if !bypassGithubUnauthorized {
return errors.Wrap(err, "removing runner from github")
}
poolMgr.SetPoolRunningState(false, fmt.Sprintf("failed to remove runner: %q", err))
}
if !bypassGithubUnauthorized {
return errors.Wrap(err, "removing runner from github")
}
}
}