From a36d01afd53728cc6d54b8752bb121a6d582aaef Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Sat, 4 Oct 2025 19:20:14 +0000 Subject: [PATCH] Cache jobs in pool manager This change caches jobs meant for an entity in the pool manager. This allows us to avoid querying the db as much and allows us to better determine when we should scale down. Signed-off-by: Gabriel Adrian Samfira --- config/config.go | 2 +- config/config_test.go | 4 +- database/sql/instances.go | 95 ++++++++------ database/sql/instances_test.go | 174 ++++++++++++++++++++++++- database/sql/pools_test.go | 2 + database/sql/sql.go | 1 + database/watcher/filters.go | 17 +-- database/watcher/watcher_store_test.go | 2 + internal/testing/testing.go | 3 +- params/params.go | 22 +++- runner/pool/pool.go | 22 +--- runner/pool/util.go | 23 ++++ runner/pool/watcher.go | 37 +++++- 13 files changed, 331 insertions(+), 73 deletions(-) diff --git a/config/config.go b/config/config.go index 31a16ae2..c62f314e 100644 --- a/config/config.go +++ b/config/config.go @@ -568,7 +568,7 @@ func (s *SQLite) Validate() error { } func (s *SQLite) ConnectionString() (string, error) { - connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON", s.DBFile) + connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate", s.DBFile) if s.BusyTimeoutSeconds > 0 { timeout := s.BusyTimeoutSeconds * 1000 connectionString = fmt.Sprintf("%s&_busy_timeout=%d", connectionString, timeout) diff --git a/config/config_test.go b/config/config_test.go index bbf9e299..d6b7e88a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -387,13 +387,13 @@ func TestGormParams(t *testing.T) { dbType, uri, err := cfg.GormParams() require.Nil(t, err) require.Equal(t, SQLiteBackend, dbType) - require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON"), uri) + require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate"), uri) cfg.SQLite.BusyTimeoutSeconds = 5 dbType, uri, err = cfg.GormParams() require.Nil(t, err) require.Equal(t, SQLiteBackend, dbType) - require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000"), uri) + require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_busy_timeout=5000"), uri) cfg.DbBackend = MySQLBackend cfg.MySQL = getMySQLDefaultConfig() diff --git a/database/sql/instances.go b/database/sql/instances.go index fa36726a..6a3d18c0 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "log/slog" + "math" "github.com/google/uuid" "gorm.io/datatypes" @@ -31,54 +32,74 @@ import ( "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) { - pool, err := s.getPoolByID(s.conn, poolID) - if err != nil { - return params.Instance{}, fmt.Errorf("error fetching pool: %w", err) - } - +func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) { defer func() { if err == nil { s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance) } }() - var labels datatypes.JSON - if len(param.AditionalLabels) > 0 { - labels, err = json.Marshal(param.AditionalLabels) + err = s.conn.Transaction(func(tx *gorm.DB) error { + pool, err := s.getPoolByID(tx, poolID) if err != nil { - return params.Instance{}, fmt.Errorf("error marshalling labels: %w", err) + return fmt.Errorf("error fetching pool: %w", err) } - } - - var secret []byte - if len(param.JitConfiguration) > 0 { - secret, err = s.marshalAndSeal(param.JitConfiguration) - if err != nil { - return params.Instance{}, fmt.Errorf("error marshalling jit config: %w", err) + var cnt int64 + q := s.conn.Model(&Instance{}).Where("pool_id = ?", pool.ID).Count(&cnt) + if q.Error != nil { + return fmt.Errorf("error fetching instance count: %w", q.Error) } + var maxRunners int64 + if pool.MaxRunners > math.MaxInt64 { + maxRunners = math.MaxInt64 + } else { + maxRunners = int64(pool.MaxRunners) + } + if cnt >= maxRunners { + return runnerErrors.NewConflictError("max runners reached for pool %s", pool.ID) + } + + var labels datatypes.JSON + if len(param.AditionalLabels) > 0 { + labels, err = json.Marshal(param.AditionalLabels) + if err != nil { + return fmt.Errorf("error marshalling labels: %w", err) + } + } + + var secret []byte + if len(param.JitConfiguration) > 0 { + secret, err = s.marshalAndSeal(param.JitConfiguration) + if err != nil { + return fmt.Errorf("error marshalling jit config: %w", err) + } + } + + newInstance := Instance{ + Pool: pool, + Name: param.Name, + Status: param.Status, + RunnerStatus: param.RunnerStatus, + OSType: param.OSType, + OSArch: param.OSArch, + CallbackURL: param.CallbackURL, + MetadataURL: param.MetadataURL, + GitHubRunnerGroup: param.GitHubRunnerGroup, + JitConfiguration: secret, + AditionalLabels: labels, + AgentID: param.AgentID, + } + q = tx.Create(&newInstance) + if q.Error != nil { + return fmt.Errorf("error creating instance: %w", q.Error) + } + return nil + }) + if err != nil { + return params.Instance{}, fmt.Errorf("error creating instance: %w", err) } - newInstance := Instance{ - Pool: pool, - Name: param.Name, - Status: param.Status, - RunnerStatus: param.RunnerStatus, - OSType: param.OSType, - OSArch: param.OSArch, - CallbackURL: param.CallbackURL, - MetadataURL: param.MetadataURL, - GitHubRunnerGroup: param.GitHubRunnerGroup, - JitConfiguration: secret, - AditionalLabels: labels, - AgentID: param.AgentID, - } - q := s.conn.Create(&newInstance) - if q.Error != nil { - return params.Instance{}, fmt.Errorf("error creating instance: %w", q.Error) - } - - return s.sqlToParamsInstance(newInstance) + return s.GetInstance(ctx, param.Name) } func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) { diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 92a18720..8891da72 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -20,6 +20,7 @@ import ( "fmt" "regexp" "sort" + "sync" "testing" "github.com/stretchr/testify/suite" @@ -210,17 +211,182 @@ func (s *InstancesTestSuite) TestCreateInstance() { func (s *InstancesTestSuite) TestCreateInstanceInvalidPoolID() { _, err := s.Store.CreateInstance(s.adminCtx, "dummy-pool-id", params.CreateInstanceParams{}) - s.Require().Equal("error fetching pool: error parsing id: invalid request", err.Error()) + s.Require().Equal("error creating instance: error fetching pool: error parsing id: invalid request", err.Error()) +} + +func (s *InstancesTestSuite) TestCreateInstanceMaxRunnersReached() { + // Create a fourth instance (pool has max 4 runners, already has 3) + fourthInstanceParams := params.CreateInstanceParams{ + Name: "test-instance-4", + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err := s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, fourthInstanceParams) + s.Require().Nil(err) + + // Try to create a fifth instance, which should fail due to max runners limit + fifthInstanceParams := params.CreateInstanceParams{ + Name: "test-instance-5", + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err = s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, fifthInstanceParams) + s.Require().NotNil(err) + s.Require().Contains(err.Error(), "max runners reached for pool") +} + +func (s *InstancesTestSuite) TestCreateInstanceMaxRunnersReachedSpecificPool() { + // Create a new pool with max runners set to 3 + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + MaxRunners: 3, + MinIdleRunners: 1, + Image: "test-image", + Flavor: "test-flavor", + OSType: "linux", + Tags: []string{"amd64", "linux"}, + } + entity, err := s.Fixtures.Org.GetEntity() + s.Require().Nil(err) + testPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams) + s.Require().Nil(err) + + // Create exactly 3 instances (max limit) + for i := 1; i <= 3; i++ { + instanceParams := params.CreateInstanceParams{ + Name: fmt.Sprintf("max-test-instance-%d", i), + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err := s.Store.CreateInstance(s.adminCtx, testPool.ID, instanceParams) + s.Require().Nil(err) + } + + // Try to create a fourth instance, which should fail + fourthInstanceParams := params.CreateInstanceParams{ + Name: "max-test-instance-4", + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err = s.Store.CreateInstance(s.adminCtx, testPool.ID, fourthInstanceParams) + s.Require().NotNil(err) + s.Require().Contains(err.Error(), "max runners reached for pool") + + // Verify instance count is still 3 + count, err := s.Store.PoolInstanceCount(s.adminCtx, testPool.ID) + s.Require().Nil(err) + s.Require().Equal(int64(3), count) +} + +func (s *InstancesTestSuite) TestCreateInstanceConcurrentMaxRunnersRaceCondition() { + // Create a new pool with max runners set to 15, starting from 0 + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + MaxRunners: 15, + MinIdleRunners: 0, + Image: "test-image", + Flavor: "test-flavor", + OSType: "linux", + Tags: []string{"amd64", "linux"}, + } + entity, err := s.Fixtures.Org.GetEntity() + s.Require().Nil(err) + raceTestPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams) + s.Require().Nil(err) + + // Verify pool starts with 0 instances + initialCount, err := s.Store.PoolInstanceCount(s.adminCtx, raceTestPool.ID) + s.Require().Nil(err) + s.Require().Equal(int64(0), initialCount) + + // Concurrently try to create 150 instances (should only allow 15) + var wg sync.WaitGroup + results := make([]error, 150) + + for i := 0; i < 150; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + instanceParams := params.CreateInstanceParams{ + Name: fmt.Sprintf("race-test-instance-%d", index), + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err := s.Store.CreateInstance(s.adminCtx, raceTestPool.ID, instanceParams) + results[index] = err + }(i) + } + + wg.Wait() + + // Count successful and failed creations + successCount := 0 + conflictErrorCount := 0 + databaseLockedCount := 0 + otherErrorCount := 0 + + for i, err := range results { + if err == nil { + successCount++ + continue + } + + errStr := fmt.Sprintf("%v", err) + expectedConflictErr1 := "error creating instance: max runners reached for pool " + raceTestPool.ID + expectedConflictErr2 := "max runners reached for pool " + raceTestPool.ID + databaseLockedErr := "error creating instance: error creating instance: database is locked" + + switch errStr { + case expectedConflictErr1, expectedConflictErr2: + conflictErrorCount++ + case databaseLockedErr: + databaseLockedCount++ + s.T().Logf("Got database locked error for goroutine %d: %v", i, err) + default: + otherErrorCount++ + s.T().Logf("Got unexpected error for goroutine %d: %v", i, err) + } + } + + s.T().Logf("Results: success=%d, conflict=%d, databaseLocked=%d, other=%d", + successCount, conflictErrorCount, databaseLockedCount, otherErrorCount) + + // Verify final instance count is <= 15 (the main test - no more than max runners) + finalCount, err := s.Store.PoolInstanceCount(s.adminCtx, raceTestPool.ID) + s.Require().Nil(err) + s.Require().LessOrEqual(int64(successCount), int64(15), "Should not create more than max runners") + s.Require().Equal(int64(successCount), finalCount, "Final count should match successful creations") + + // The key test: verify we never exceeded max runners despite concurrent attempts + s.Require().True(finalCount <= 15, "Pool should never exceed max runners limit of 15, got %d", finalCount) + + // If there were database lock errors, that's a concurrency issue but not a max runners violation + if databaseLockedCount > 0 { + s.T().Logf("WARNING: Got %d database lock errors during concurrent testing - this indicates SQLite concurrency limitations", databaseLockedCount) + } + + // The critical assertion: total successful attempts + database locked + conflicts should equal 150 + s.Require().Equal(150, successCount+conflictErrorCount+databaseLockedCount+otherErrorCount, + "All 150 goroutines should have completed with some result") } func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() { pool := s.Fixtures.Pool + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT ?")). WithArgs(pool.ID, 1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pool.ID)) - s.Fixtures.SQLMock.ExpectBegin() + WillReturnRows(sqlmock.NewRows([]string{"id", "max_runners"}).AddRow(pool.ID, 4)) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT count(*) FROM `instances` WHERE pool_id = ? AND `instances`.`deleted_at` IS NULL")). + WithArgs(pool.ID). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) s.Fixtures.SQLMock. ExpectExec("INSERT INTO `pools`"). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -233,7 +399,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() { s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("error creating instance: mocked insert instance error", err.Error()) + s.Require().Equal("error creating instance: error creating instance: mocked insert instance error", err.Error()) } func (s *InstancesTestSuite) TestGetInstanceByName() { diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index fa5d178e..e3b83e8d 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -221,6 +221,7 @@ func (s *PoolsTestSuite) TestEntityPoolOperations() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, @@ -301,6 +302,7 @@ func (s *PoolsTestSuite) TestListEntityInstances() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, diff --git a/database/sql/sql.go b/database/sql/sql.go index 9f4c37ab..eb97a3c8 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -71,6 +71,7 @@ func newDBConn(dbCfg config.Database) (conn *gorm.DB, err error) { if dbCfg.Debug { conn = conn.Debug() } + return conn, nil } diff --git a/database/watcher/filters.go b/database/watcher/filters.go index acf79ba8..6b920ec3 100644 --- a/database/watcher/filters.go +++ b/database/watcher/filters.go @@ -183,25 +183,22 @@ func WithEntityJobFilter(ghEntity params.ForgeEntity) dbCommon.PayloadFilterFunc switch ghEntity.EntityType { case params.ForgeEntityTypeRepository: - if job.RepoID != nil && job.RepoID.String() != ghEntity.ID { - return false + if job.RepoID != nil && job.RepoID.String() == ghEntity.ID { + return true } case params.ForgeEntityTypeOrganization: - if job.OrgID != nil && job.OrgID.String() != ghEntity.ID { - return false + if job.OrgID != nil && job.OrgID.String() == ghEntity.ID { + return true } case params.ForgeEntityTypeEnterprise: - if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID { - return false + if job.EnterpriseID != nil && job.EnterpriseID.String() == ghEntity.ID { + return true } - default: - return false } - - return true default: return false } + return false } } diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index 97fc8a9d..2315078d 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -179,6 +179,7 @@ func (s *WatcherStoreTestSuite) TestInstanceWatcher() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, @@ -393,6 +394,7 @@ func (s *WatcherStoreTestSuite) TestPoolWatcher() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, diff --git a/internal/testing/testing.go b/internal/testing/testing.go index 38725882..8d8b941a 100644 --- a/internal/testing/testing.go +++ b/internal/testing/testing.go @@ -205,7 +205,8 @@ func GetTestSqliteDBConfig(t *testing.T) config.Database { DbBackend: config.SQLiteBackend, Passphrase: encryptionPassphrase, SQLite: config.SQLite{ - DBFile: filepath.Join(dir, "garm.db"), + DBFile: filepath.Join(dir, "garm.db"), + BusyTimeoutSeconds: 30, // 30 second timeout for concurrent transactions }, } } diff --git a/params/params.go b/params/params.go index 8e51fa23..3f7b24b5 100644 --- a/params/params.go +++ b/params/params.go @@ -1121,6 +1121,26 @@ type Job struct { UpdatedAt time.Time `json:"updated_at,omitempty"` } +func (j Job) BelongsTo(entity ForgeEntity) bool { + switch entity.EntityType { + case ForgeEntityTypeRepository: + if j.RepoID != nil { + return entity.ID == j.RepoID.String() + } + case ForgeEntityTypeEnterprise: + if j.EnterpriseID != nil { + return entity.ID == j.EnterpriseID.String() + } + case ForgeEntityTypeOrganization: + if j.OrgID != nil { + return entity.ID == j.OrgID.String() + } + default: + return false + } + return false +} + // swagger:model Jobs // used by swagger client generated code type Jobs []Job @@ -1144,13 +1164,13 @@ type CertificateBundle struct { RootCertificates map[string][]byte `json:"root_certificates,omitempty"` } -// swagger:model ForgeEntity type UpdateSystemInfoParams struct { OSName string `json:"os_name,omitempty"` OSVersion string `json:"os_version,omitempty"` AgentID *int64 `json:"agent_id,omitempty"` } +// swagger:model ForgeEntity type ForgeEntity struct { Owner string `json:"owner,omitempty"` Name string `json:"name,omitempty"` diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 36a7de03..446f2afb 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -124,6 +124,7 @@ func NewEntityPoolManager(ctx context.Context, entity params.ForgeEntity, instan store: store, providers: providers, quit: make(chan struct{}), + jobs: make(map[int64]params.Job), wg: wg, backoff: backoff, consumer: consumer, @@ -142,6 +143,7 @@ type basePoolManager struct { consumer dbCommon.Consumer store dbCommon.Store + jobs map[int64]params.Job providers map[string]common.Provider tools []commonParams.RunnerApplicationDownload @@ -1059,7 +1061,7 @@ func (r *basePoolManager) scaleDownOnePool(ctx context.Context, pool params.Pool // consideration for scale-down. The 5 minute grace period prevents a situation where a // "queued" workflow triggers the creation of a new idle runner, and this routine reaps // an idle runner before they have a chance to pick up a job. - if inst.RunnerStatus == params.RunnerIdle && inst.Status == commonParams.InstanceRunning && time.Since(inst.UpdatedAt).Minutes() > 2 { + if inst.RunnerStatus == params.RunnerIdle && inst.Status == commonParams.InstanceRunning { idleWorkers = append(idleWorkers, inst) } } @@ -1068,7 +1070,7 @@ func (r *basePoolManager) scaleDownOnePool(ctx context.Context, pool params.Pool return nil } - surplus := float64(len(idleWorkers) - pool.MinIdleRunnersAsInt()) + surplus := float64(len(idleWorkers) - (pool.MinIdleRunnersAsInt() + len(r.getQueuedJobs()))) if surplus <= 0 { return nil @@ -1143,17 +1145,8 @@ func (r *basePoolManager) addRunnerToPool(pool params.Pool, aditionalLabels []st return fmt.Errorf("pool %s is disabled", pool.ID) } - poolInstanceCount, err := r.store.PoolInstanceCount(r.ctx, pool.ID) - if err != nil { - return fmt.Errorf("failed to list pool instances: %w", err) - } - - if poolInstanceCount >= int64(pool.MaxRunnersAsInt()) { - return fmt.Errorf("max workers (%d) reached for pool %s", pool.MaxRunners, pool.ID) - } - if err := r.AddRunner(r.ctx, pool.ID, aditionalLabels); err != nil { - return fmt.Errorf("failed to add new instance for pool %s: %s", pool.ID, err) + return fmt.Errorf("failed to add new instance for pool %s: %w", pool.ID, err) } return nil } @@ -1760,10 +1753,7 @@ func (r *basePoolManager) DeleteRunner(runner params.Instance, forceRemove, bypa // so those will trigger the creation of a runner. The jobs we don't know about will be dealt with by the idle runners. // Once jobs are consumed, you can set min-idle-runners to 0 again. func (r *basePoolManager) consumeQueuedJobs() error { - queued, err := r.store.ListEntityJobsByStatus(r.ctx, r.entity.EntityType, r.entity.ID, params.JobStatusQueued) - if err != nil { - return fmt.Errorf("error listing queued jobs: %w", err) - } + queued := r.getQueuedJobs() poolsCache := poolsForTags{ poolCacheType: r.entity.GetPoolBalancerType(), diff --git a/runner/pool/util.go b/runner/pool/util.go index d58f90a3..3fed1478 100644 --- a/runner/pool/util.go +++ b/runner/pool/util.go @@ -84,9 +84,32 @@ func composeWatcherFilters(entity params.ForgeEntity) dbCommon.PayloadFilterFunc watcher.WithEntityFilter(entity), // Watch for changes to the github credentials watcher.WithForgeCredentialsFilter(entity.Credentials), + watcher.WithAll( + watcher.WithEntityJobFilter(entity), + watcher.WithAny( + watcher.WithOperationTypeFilter(dbCommon.UpdateOperation), + watcher.WithOperationTypeFilter(dbCommon.CreateOperation), + watcher.WithOperationTypeFilter(dbCommon.DeleteOperation), + ), + ), ) } +func (r *basePoolManager) getQueuedJobs() []params.Job { + r.mux.Lock() + defer r.mux.Unlock() + + ret := []params.Job{} + + for _, job := range r.jobs { + slog.DebugContext(r.ctx, "considering job for processing", "job_id", job.ID, "job_status", job.Status) + if params.JobStatus(job.Status) == params.JobStatusQueued { + ret = append(ret, job) + } + } + return ret +} + func (r *basePoolManager) waitForToolsOrCancel() (hasTools, stopped bool) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() diff --git a/runner/pool/watcher.go b/runner/pool/watcher.go index 999b52c6..16764a2c 100644 --- a/runner/pool/watcher.go +++ b/runner/pool/watcher.go @@ -162,11 +162,46 @@ func (r *basePoolManager) handleWatcherEvent(event common.ChangePayload) { return } r.handleEntityUpdate(entityInfo, event.Operation) + case common.JobEntityType: + slog.DebugContext(r.ctx, "new job via watcher") + job, ok := event.Payload.(params.Job) + if !ok { + slog.ErrorContext(r.ctx, "failed to cast payload to job") + return + } + if !job.BelongsTo(r.entity) { + slog.InfoContext(r.ctx, "job does not belong to entity", "worklof_job_id", job.WorkflowJobID, "scaleset_job_id", job.ScaleSetJobID, "job_id", job.ID) + return + } + slog.DebugContext(r.ctx, "recording job", "job_id", job.ID, "job_status", job.Status) + r.mux.Lock() + switch event.Operation { + case common.CreateOperation, common.UpdateOperation: + if params.JobStatus(job.Status) != params.JobStatusCompleted { + slog.DebugContext(r.ctx, "adding job to map", "job_id", job.ID, "job_status", job.Status) + r.jobs[job.ID] = job + break + } + fallthrough + case common.DeleteOperation: + delete(r.jobs, job.ID) + } + r.mux.Unlock() } } func (r *basePoolManager) runWatcher() { defer r.consumer.Close() + queued, err := r.store.ListEntityJobsByStatus(r.ctx, r.entity.EntityType, r.entity.ID, params.JobStatusQueued) + if err != nil { + slog.ErrorContext(r.ctx, "failed to list jobs", "error", err) + } + + r.mux.Lock() + for _, job := range queued { + r.jobs[job.ID] = job + } + r.mux.Unlock() for { select { case <-r.quit: @@ -177,7 +212,7 @@ func (r *basePoolManager) runWatcher() { if !ok { return } - go r.handleWatcherEvent(event) + r.handleWatcherEvent(event) } } }