Cache jobs in pool manager

This change caches jobs meant for an entity in the pool manager. This
allows us to avoid querying the db as much and allows us to better determine
when we should scale down.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2025-10-04 19:20:14 +00:00 committed by Gabriel
parent 0093393bc3
commit a36d01afd5
13 changed files with 331 additions and 73 deletions

View file

@ -568,7 +568,7 @@ func (s *SQLite) Validate() error {
}
func (s *SQLite) ConnectionString() (string, error) {
connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON", s.DBFile)
connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate", s.DBFile)
if s.BusyTimeoutSeconds > 0 {
timeout := s.BusyTimeoutSeconds * 1000
connectionString = fmt.Sprintf("%s&_busy_timeout=%d", connectionString, timeout)

View file

@ -387,13 +387,13 @@ func TestGormParams(t *testing.T) {
dbType, uri, err := cfg.GormParams()
require.Nil(t, err)
require.Equal(t, SQLiteBackend, dbType)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON"), uri)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate"), uri)
cfg.SQLite.BusyTimeoutSeconds = 5
dbType, uri, err = cfg.GormParams()
require.Nil(t, err)
require.Equal(t, SQLiteBackend, dbType)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000"), uri)
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_busy_timeout=5000"), uri)
cfg.DbBackend = MySQLBackend
cfg.MySQL = getMySQLDefaultConfig()

View file

@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"github.com/google/uuid"
"gorm.io/datatypes"
@ -31,54 +32,74 @@ import (
"github.com/cloudbase/garm/params"
)
func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) {
pool, err := s.getPoolByID(s.conn, poolID)
if err != nil {
return params.Instance{}, fmt.Errorf("error fetching pool: %w", err)
}
func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) {
defer func() {
if err == nil {
s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance)
}
}()
var labels datatypes.JSON
if len(param.AditionalLabels) > 0 {
labels, err = json.Marshal(param.AditionalLabels)
err = s.conn.Transaction(func(tx *gorm.DB) error {
pool, err := s.getPoolByID(tx, poolID)
if err != nil {
return params.Instance{}, fmt.Errorf("error marshalling labels: %w", err)
return fmt.Errorf("error fetching pool: %w", err)
}
}
var secret []byte
if len(param.JitConfiguration) > 0 {
secret, err = s.marshalAndSeal(param.JitConfiguration)
if err != nil {
return params.Instance{}, fmt.Errorf("error marshalling jit config: %w", err)
var cnt int64
q := s.conn.Model(&Instance{}).Where("pool_id = ?", pool.ID).Count(&cnt)
if q.Error != nil {
return fmt.Errorf("error fetching instance count: %w", q.Error)
}
var maxRunners int64
if pool.MaxRunners > math.MaxInt64 {
maxRunners = math.MaxInt64
} else {
maxRunners = int64(pool.MaxRunners)
}
if cnt >= maxRunners {
return runnerErrors.NewConflictError("max runners reached for pool %s", pool.ID)
}
var labels datatypes.JSON
if len(param.AditionalLabels) > 0 {
labels, err = json.Marshal(param.AditionalLabels)
if err != nil {
return fmt.Errorf("error marshalling labels: %w", err)
}
}
var secret []byte
if len(param.JitConfiguration) > 0 {
secret, err = s.marshalAndSeal(param.JitConfiguration)
if err != nil {
return fmt.Errorf("error marshalling jit config: %w", err)
}
}
newInstance := Instance{
Pool: pool,
Name: param.Name,
Status: param.Status,
RunnerStatus: param.RunnerStatus,
OSType: param.OSType,
OSArch: param.OSArch,
CallbackURL: param.CallbackURL,
MetadataURL: param.MetadataURL,
GitHubRunnerGroup: param.GitHubRunnerGroup,
JitConfiguration: secret,
AditionalLabels: labels,
AgentID: param.AgentID,
}
q = tx.Create(&newInstance)
if q.Error != nil {
return fmt.Errorf("error creating instance: %w", q.Error)
}
return nil
})
if err != nil {
return params.Instance{}, fmt.Errorf("error creating instance: %w", err)
}
newInstance := Instance{
Pool: pool,
Name: param.Name,
Status: param.Status,
RunnerStatus: param.RunnerStatus,
OSType: param.OSType,
OSArch: param.OSArch,
CallbackURL: param.CallbackURL,
MetadataURL: param.MetadataURL,
GitHubRunnerGroup: param.GitHubRunnerGroup,
JitConfiguration: secret,
AditionalLabels: labels,
AgentID: param.AgentID,
}
q := s.conn.Create(&newInstance)
if q.Error != nil {
return params.Instance{}, fmt.Errorf("error creating instance: %w", q.Error)
}
return s.sqlToParamsInstance(newInstance)
return s.GetInstance(ctx, param.Name)
}
func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) {

View file

@ -20,6 +20,7 @@ import (
"fmt"
"regexp"
"sort"
"sync"
"testing"
"github.com/stretchr/testify/suite"
@ -210,17 +211,182 @@ func (s *InstancesTestSuite) TestCreateInstance() {
func (s *InstancesTestSuite) TestCreateInstanceInvalidPoolID() {
_, err := s.Store.CreateInstance(s.adminCtx, "dummy-pool-id", params.CreateInstanceParams{})
s.Require().Equal("error fetching pool: error parsing id: invalid request", err.Error())
s.Require().Equal("error creating instance: error fetching pool: error parsing id: invalid request", err.Error())
}
func (s *InstancesTestSuite) TestCreateInstanceMaxRunnersReached() {
// Create a fourth instance (pool has max 4 runners, already has 3)
fourthInstanceParams := params.CreateInstanceParams{
Name: "test-instance-4",
OSType: "linux",
OSArch: "amd64",
CallbackURL: "https://garm.example.com/",
}
_, err := s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, fourthInstanceParams)
s.Require().Nil(err)
// Try to create a fifth instance, which should fail due to max runners limit
fifthInstanceParams := params.CreateInstanceParams{
Name: "test-instance-5",
OSType: "linux",
OSArch: "amd64",
CallbackURL: "https://garm.example.com/",
}
_, err = s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, fifthInstanceParams)
s.Require().NotNil(err)
s.Require().Contains(err.Error(), "max runners reached for pool")
}
func (s *InstancesTestSuite) TestCreateInstanceMaxRunnersReachedSpecificPool() {
// Create a new pool with max runners set to 3
createPoolParams := params.CreatePoolParams{
ProviderName: "test-provider",
MaxRunners: 3,
MinIdleRunners: 1,
Image: "test-image",
Flavor: "test-flavor",
OSType: "linux",
Tags: []string{"amd64", "linux"},
}
entity, err := s.Fixtures.Org.GetEntity()
s.Require().Nil(err)
testPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams)
s.Require().Nil(err)
// Create exactly 3 instances (max limit)
for i := 1; i <= 3; i++ {
instanceParams := params.CreateInstanceParams{
Name: fmt.Sprintf("max-test-instance-%d", i),
OSType: "linux",
OSArch: "amd64",
CallbackURL: "https://garm.example.com/",
}
_, err := s.Store.CreateInstance(s.adminCtx, testPool.ID, instanceParams)
s.Require().Nil(err)
}
// Try to create a fourth instance, which should fail
fourthInstanceParams := params.CreateInstanceParams{
Name: "max-test-instance-4",
OSType: "linux",
OSArch: "amd64",
CallbackURL: "https://garm.example.com/",
}
_, err = s.Store.CreateInstance(s.adminCtx, testPool.ID, fourthInstanceParams)
s.Require().NotNil(err)
s.Require().Contains(err.Error(), "max runners reached for pool")
// Verify instance count is still 3
count, err := s.Store.PoolInstanceCount(s.adminCtx, testPool.ID)
s.Require().Nil(err)
s.Require().Equal(int64(3), count)
}
func (s *InstancesTestSuite) TestCreateInstanceConcurrentMaxRunnersRaceCondition() {
// Create a new pool with max runners set to 15, starting from 0
createPoolParams := params.CreatePoolParams{
ProviderName: "test-provider",
MaxRunners: 15,
MinIdleRunners: 0,
Image: "test-image",
Flavor: "test-flavor",
OSType: "linux",
Tags: []string{"amd64", "linux"},
}
entity, err := s.Fixtures.Org.GetEntity()
s.Require().Nil(err)
raceTestPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams)
s.Require().Nil(err)
// Verify pool starts with 0 instances
initialCount, err := s.Store.PoolInstanceCount(s.adminCtx, raceTestPool.ID)
s.Require().Nil(err)
s.Require().Equal(int64(0), initialCount)
// Concurrently try to create 150 instances (should only allow 15)
var wg sync.WaitGroup
results := make([]error, 150)
for i := 0; i < 150; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
instanceParams := params.CreateInstanceParams{
Name: fmt.Sprintf("race-test-instance-%d", index),
OSType: "linux",
OSArch: "amd64",
CallbackURL: "https://garm.example.com/",
}
_, err := s.Store.CreateInstance(s.adminCtx, raceTestPool.ID, instanceParams)
results[index] = err
}(i)
}
wg.Wait()
// Count successful and failed creations
successCount := 0
conflictErrorCount := 0
databaseLockedCount := 0
otherErrorCount := 0
for i, err := range results {
if err == nil {
successCount++
continue
}
errStr := fmt.Sprintf("%v", err)
expectedConflictErr1 := "error creating instance: max runners reached for pool " + raceTestPool.ID
expectedConflictErr2 := "max runners reached for pool " + raceTestPool.ID
databaseLockedErr := "error creating instance: error creating instance: database is locked"
switch errStr {
case expectedConflictErr1, expectedConflictErr2:
conflictErrorCount++
case databaseLockedErr:
databaseLockedCount++
s.T().Logf("Got database locked error for goroutine %d: %v", i, err)
default:
otherErrorCount++
s.T().Logf("Got unexpected error for goroutine %d: %v", i, err)
}
}
s.T().Logf("Results: success=%d, conflict=%d, databaseLocked=%d, other=%d",
successCount, conflictErrorCount, databaseLockedCount, otherErrorCount)
// Verify final instance count is <= 15 (the main test - no more than max runners)
finalCount, err := s.Store.PoolInstanceCount(s.adminCtx, raceTestPool.ID)
s.Require().Nil(err)
s.Require().LessOrEqual(int64(successCount), int64(15), "Should not create more than max runners")
s.Require().Equal(int64(successCount), finalCount, "Final count should match successful creations")
// The key test: verify we never exceeded max runners despite concurrent attempts
s.Require().True(finalCount <= 15, "Pool should never exceed max runners limit of 15, got %d", finalCount)
// If there were database lock errors, that's a concurrency issue but not a max runners violation
if databaseLockedCount > 0 {
s.T().Logf("WARNING: Got %d database lock errors during concurrent testing - this indicates SQLite concurrency limitations", databaseLockedCount)
}
// The critical assertion: total successful attempts + database locked + conflicts should equal 150
s.Require().Equal(150, successCount+conflictErrorCount+databaseLockedCount+otherErrorCount,
"All 150 goroutines should have completed with some result")
}
func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() {
pool := s.Fixtures.Pool
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT ?")).
WithArgs(pool.ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pool.ID))
s.Fixtures.SQLMock.ExpectBegin()
WillReturnRows(sqlmock.NewRows([]string{"id", "max_runners"}).AddRow(pool.ID, 4))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT count(*) FROM `instances` WHERE pool_id = ? AND `instances`.`deleted_at` IS NULL")).
WithArgs(pool.ID).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
s.Fixtures.SQLMock.
ExpectExec("INSERT INTO `pools`").
WillReturnResult(sqlmock.NewResult(1, 1))
@ -233,7 +399,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() {
s.assertSQLMockExpectations()
s.Require().NotNil(err)
s.Require().Equal("error creating instance: mocked insert instance error", err.Error())
s.Require().Equal("error creating instance: error creating instance: mocked insert instance error", err.Error())
}
func (s *InstancesTestSuite) TestGetInstanceByName() {

View file

@ -221,6 +221,7 @@ func (s *PoolsTestSuite) TestEntityPoolOperations() {
createPoolParams := params.CreatePoolParams{
ProviderName: "test-provider",
MaxRunners: 5,
Image: "test-image",
Flavor: "test-flavor",
OSType: commonParams.Linux,
@ -301,6 +302,7 @@ func (s *PoolsTestSuite) TestListEntityInstances() {
createPoolParams := params.CreatePoolParams{
ProviderName: "test-provider",
MaxRunners: 5,
Image: "test-image",
Flavor: "test-flavor",
OSType: commonParams.Linux,

View file

@ -71,6 +71,7 @@ func newDBConn(dbCfg config.Database) (conn *gorm.DB, err error) {
if dbCfg.Debug {
conn = conn.Debug()
}
return conn, nil
}

View file

@ -183,25 +183,22 @@ func WithEntityJobFilter(ghEntity params.ForgeEntity) dbCommon.PayloadFilterFunc
switch ghEntity.EntityType {
case params.ForgeEntityTypeRepository:
if job.RepoID != nil && job.RepoID.String() != ghEntity.ID {
return false
if job.RepoID != nil && job.RepoID.String() == ghEntity.ID {
return true
}
case params.ForgeEntityTypeOrganization:
if job.OrgID != nil && job.OrgID.String() != ghEntity.ID {
return false
if job.OrgID != nil && job.OrgID.String() == ghEntity.ID {
return true
}
case params.ForgeEntityTypeEnterprise:
if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID {
return false
if job.EnterpriseID != nil && job.EnterpriseID.String() == ghEntity.ID {
return true
}
default:
return false
}
return true
default:
return false
}
return false
}
}

View file

@ -179,6 +179,7 @@ func (s *WatcherStoreTestSuite) TestInstanceWatcher() {
createPoolParams := params.CreatePoolParams{
ProviderName: "test-provider",
MaxRunners: 5,
Image: "test-image",
Flavor: "test-flavor",
OSType: commonParams.Linux,
@ -393,6 +394,7 @@ func (s *WatcherStoreTestSuite) TestPoolWatcher() {
createPoolParams := params.CreatePoolParams{
ProviderName: "test-provider",
MaxRunners: 5,
Image: "test-image",
Flavor: "test-flavor",
OSType: commonParams.Linux,

View file

@ -205,7 +205,8 @@ func GetTestSqliteDBConfig(t *testing.T) config.Database {
DbBackend: config.SQLiteBackend,
Passphrase: encryptionPassphrase,
SQLite: config.SQLite{
DBFile: filepath.Join(dir, "garm.db"),
DBFile: filepath.Join(dir, "garm.db"),
BusyTimeoutSeconds: 30, // 30 second timeout for concurrent transactions
},
}
}

View file

@ -1121,6 +1121,26 @@ type Job struct {
UpdatedAt time.Time `json:"updated_at,omitempty"`
}
func (j Job) BelongsTo(entity ForgeEntity) bool {
switch entity.EntityType {
case ForgeEntityTypeRepository:
if j.RepoID != nil {
return entity.ID == j.RepoID.String()
}
case ForgeEntityTypeEnterprise:
if j.EnterpriseID != nil {
return entity.ID == j.EnterpriseID.String()
}
case ForgeEntityTypeOrganization:
if j.OrgID != nil {
return entity.ID == j.OrgID.String()
}
default:
return false
}
return false
}
// swagger:model Jobs
// used by swagger client generated code
type Jobs []Job
@ -1144,13 +1164,13 @@ type CertificateBundle struct {
RootCertificates map[string][]byte `json:"root_certificates,omitempty"`
}
// swagger:model ForgeEntity
type UpdateSystemInfoParams struct {
OSName string `json:"os_name,omitempty"`
OSVersion string `json:"os_version,omitempty"`
AgentID *int64 `json:"agent_id,omitempty"`
}
// swagger:model ForgeEntity
type ForgeEntity struct {
Owner string `json:"owner,omitempty"`
Name string `json:"name,omitempty"`

View file

@ -124,6 +124,7 @@ func NewEntityPoolManager(ctx context.Context, entity params.ForgeEntity, instan
store: store,
providers: providers,
quit: make(chan struct{}),
jobs: make(map[int64]params.Job),
wg: wg,
backoff: backoff,
consumer: consumer,
@ -142,6 +143,7 @@ type basePoolManager struct {
consumer dbCommon.Consumer
store dbCommon.Store
jobs map[int64]params.Job
providers map[string]common.Provider
tools []commonParams.RunnerApplicationDownload
@ -1059,7 +1061,7 @@ func (r *basePoolManager) scaleDownOnePool(ctx context.Context, pool params.Pool
// consideration for scale-down. The 5 minute grace period prevents a situation where a
// "queued" workflow triggers the creation of a new idle runner, and this routine reaps
// an idle runner before they have a chance to pick up a job.
if inst.RunnerStatus == params.RunnerIdle && inst.Status == commonParams.InstanceRunning && time.Since(inst.UpdatedAt).Minutes() > 2 {
if inst.RunnerStatus == params.RunnerIdle && inst.Status == commonParams.InstanceRunning {
idleWorkers = append(idleWorkers, inst)
}
}
@ -1068,7 +1070,7 @@ func (r *basePoolManager) scaleDownOnePool(ctx context.Context, pool params.Pool
return nil
}
surplus := float64(len(idleWorkers) - pool.MinIdleRunnersAsInt())
surplus := float64(len(idleWorkers) - (pool.MinIdleRunnersAsInt() + len(r.getQueuedJobs())))
if surplus <= 0 {
return nil
@ -1143,17 +1145,8 @@ func (r *basePoolManager) addRunnerToPool(pool params.Pool, aditionalLabels []st
return fmt.Errorf("pool %s is disabled", pool.ID)
}
poolInstanceCount, err := r.store.PoolInstanceCount(r.ctx, pool.ID)
if err != nil {
return fmt.Errorf("failed to list pool instances: %w", err)
}
if poolInstanceCount >= int64(pool.MaxRunnersAsInt()) {
return fmt.Errorf("max workers (%d) reached for pool %s", pool.MaxRunners, pool.ID)
}
if err := r.AddRunner(r.ctx, pool.ID, aditionalLabels); err != nil {
return fmt.Errorf("failed to add new instance for pool %s: %s", pool.ID, err)
return fmt.Errorf("failed to add new instance for pool %s: %w", pool.ID, err)
}
return nil
}
@ -1760,10 +1753,7 @@ func (r *basePoolManager) DeleteRunner(runner params.Instance, forceRemove, bypa
// so those will trigger the creation of a runner. The jobs we don't know about will be dealt with by the idle runners.
// Once jobs are consumed, you can set min-idle-runners to 0 again.
func (r *basePoolManager) consumeQueuedJobs() error {
queued, err := r.store.ListEntityJobsByStatus(r.ctx, r.entity.EntityType, r.entity.ID, params.JobStatusQueued)
if err != nil {
return fmt.Errorf("error listing queued jobs: %w", err)
}
queued := r.getQueuedJobs()
poolsCache := poolsForTags{
poolCacheType: r.entity.GetPoolBalancerType(),

View file

@ -84,9 +84,32 @@ func composeWatcherFilters(entity params.ForgeEntity) dbCommon.PayloadFilterFunc
watcher.WithEntityFilter(entity),
// Watch for changes to the github credentials
watcher.WithForgeCredentialsFilter(entity.Credentials),
watcher.WithAll(
watcher.WithEntityJobFilter(entity),
watcher.WithAny(
watcher.WithOperationTypeFilter(dbCommon.UpdateOperation),
watcher.WithOperationTypeFilter(dbCommon.CreateOperation),
watcher.WithOperationTypeFilter(dbCommon.DeleteOperation),
),
),
)
}
func (r *basePoolManager) getQueuedJobs() []params.Job {
r.mux.Lock()
defer r.mux.Unlock()
ret := []params.Job{}
for _, job := range r.jobs {
slog.DebugContext(r.ctx, "considering job for processing", "job_id", job.ID, "job_status", job.Status)
if params.JobStatus(job.Status) == params.JobStatusQueued {
ret = append(ret, job)
}
}
return ret
}
func (r *basePoolManager) waitForToolsOrCancel() (hasTools, stopped bool) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()

View file

@ -162,11 +162,46 @@ func (r *basePoolManager) handleWatcherEvent(event common.ChangePayload) {
return
}
r.handleEntityUpdate(entityInfo, event.Operation)
case common.JobEntityType:
slog.DebugContext(r.ctx, "new job via watcher")
job, ok := event.Payload.(params.Job)
if !ok {
slog.ErrorContext(r.ctx, "failed to cast payload to job")
return
}
if !job.BelongsTo(r.entity) {
slog.InfoContext(r.ctx, "job does not belong to entity", "worklof_job_id", job.WorkflowJobID, "scaleset_job_id", job.ScaleSetJobID, "job_id", job.ID)
return
}
slog.DebugContext(r.ctx, "recording job", "job_id", job.ID, "job_status", job.Status)
r.mux.Lock()
switch event.Operation {
case common.CreateOperation, common.UpdateOperation:
if params.JobStatus(job.Status) != params.JobStatusCompleted {
slog.DebugContext(r.ctx, "adding job to map", "job_id", job.ID, "job_status", job.Status)
r.jobs[job.ID] = job
break
}
fallthrough
case common.DeleteOperation:
delete(r.jobs, job.ID)
}
r.mux.Unlock()
}
}
func (r *basePoolManager) runWatcher() {
defer r.consumer.Close()
queued, err := r.store.ListEntityJobsByStatus(r.ctx, r.entity.EntityType, r.entity.ID, params.JobStatusQueued)
if err != nil {
slog.ErrorContext(r.ctx, "failed to list jobs", "error", err)
}
r.mux.Lock()
for _, job := range queued {
r.jobs[job.ID] = job
}
r.mux.Unlock()
for {
select {
case <-r.quit:
@ -177,7 +212,7 @@ func (r *basePoolManager) runWatcher() {
if !ok {
return
}
go r.handleWatcherEvent(event)
r.handleWatcherEvent(event)
}
}
}