diff --git a/config/config.go b/config/config.go index 31a16ae2..c62f314e 100644 --- a/config/config.go +++ b/config/config.go @@ -568,7 +568,7 @@ func (s *SQLite) Validate() error { } func (s *SQLite) ConnectionString() (string, error) { - connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON", s.DBFile) + connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate", s.DBFile) if s.BusyTimeoutSeconds > 0 { timeout := s.BusyTimeoutSeconds * 1000 connectionString = fmt.Sprintf("%s&_busy_timeout=%d", connectionString, timeout) diff --git a/config/config_test.go b/config/config_test.go index bbf9e299..d6b7e88a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -387,13 +387,13 @@ func TestGormParams(t *testing.T) { dbType, uri, err := cfg.GormParams() require.Nil(t, err) require.Equal(t, SQLiteBackend, dbType) - require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON"), uri) + require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate"), uri) cfg.SQLite.BusyTimeoutSeconds = 5 dbType, uri, err = cfg.GormParams() require.Nil(t, err) require.Equal(t, SQLiteBackend, dbType) - require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000"), uri) + require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_busy_timeout=5000"), uri) cfg.DbBackend = MySQLBackend cfg.MySQL = getMySQLDefaultConfig() diff --git a/database/sql/instances.go b/database/sql/instances.go index fa36726a..6a3d18c0 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "log/slog" + "math" "github.com/google/uuid" "gorm.io/datatypes" @@ -31,54 +32,74 @@ import ( "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) { - pool, err := s.getPoolByID(s.conn, poolID) - if err != nil { - return params.Instance{}, fmt.Errorf("error fetching pool: %w", err) - } - +func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) { defer func() { if err == nil { s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance) } }() - var labels datatypes.JSON - if len(param.AditionalLabels) > 0 { - labels, err = json.Marshal(param.AditionalLabels) + err = s.conn.Transaction(func(tx *gorm.DB) error { + pool, err := s.getPoolByID(tx, poolID) if err != nil { - return params.Instance{}, fmt.Errorf("error marshalling labels: %w", err) + return fmt.Errorf("error fetching pool: %w", err) } - } - - var secret []byte - if len(param.JitConfiguration) > 0 { - secret, err = s.marshalAndSeal(param.JitConfiguration) - if err != nil { - return params.Instance{}, fmt.Errorf("error marshalling jit config: %w", err) + var cnt int64 + q := s.conn.Model(&Instance{}).Where("pool_id = ?", pool.ID).Count(&cnt) + if q.Error != nil { + return fmt.Errorf("error fetching instance count: %w", q.Error) } + var maxRunners int64 + if pool.MaxRunners > math.MaxInt64 { + maxRunners = math.MaxInt64 + } else { + maxRunners = int64(pool.MaxRunners) + } + if cnt >= maxRunners { + return runnerErrors.NewConflictError("max runners reached for pool %s", pool.ID) + } + + var labels datatypes.JSON + if len(param.AditionalLabels) > 0 { + labels, err = json.Marshal(param.AditionalLabels) + if err != nil { + return fmt.Errorf("error marshalling labels: %w", err) + } + } + + var secret []byte + if len(param.JitConfiguration) > 0 { + secret, err = s.marshalAndSeal(param.JitConfiguration) + if err != nil { + return fmt.Errorf("error marshalling jit config: %w", err) + } + } + + newInstance := Instance{ + Pool: pool, + Name: param.Name, + Status: param.Status, + RunnerStatus: param.RunnerStatus, + OSType: param.OSType, + OSArch: param.OSArch, + CallbackURL: param.CallbackURL, + MetadataURL: param.MetadataURL, + GitHubRunnerGroup: param.GitHubRunnerGroup, + JitConfiguration: secret, + AditionalLabels: labels, + AgentID: param.AgentID, + } + q = tx.Create(&newInstance) + if q.Error != nil { + return fmt.Errorf("error creating instance: %w", q.Error) + } + return nil + }) + if err != nil { + return params.Instance{}, fmt.Errorf("error creating instance: %w", err) } - newInstance := Instance{ - Pool: pool, - Name: param.Name, - Status: param.Status, - RunnerStatus: param.RunnerStatus, - OSType: param.OSType, - OSArch: param.OSArch, - CallbackURL: param.CallbackURL, - MetadataURL: param.MetadataURL, - GitHubRunnerGroup: param.GitHubRunnerGroup, - JitConfiguration: secret, - AditionalLabels: labels, - AgentID: param.AgentID, - } - q := s.conn.Create(&newInstance) - if q.Error != nil { - return params.Instance{}, fmt.Errorf("error creating instance: %w", q.Error) - } - - return s.sqlToParamsInstance(newInstance) + return s.GetInstance(ctx, param.Name) } func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) { diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 92a18720..8891da72 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -20,6 +20,7 @@ import ( "fmt" "regexp" "sort" + "sync" "testing" "github.com/stretchr/testify/suite" @@ -210,17 +211,182 @@ func (s *InstancesTestSuite) TestCreateInstance() { func (s *InstancesTestSuite) TestCreateInstanceInvalidPoolID() { _, err := s.Store.CreateInstance(s.adminCtx, "dummy-pool-id", params.CreateInstanceParams{}) - s.Require().Equal("error fetching pool: error parsing id: invalid request", err.Error()) + s.Require().Equal("error creating instance: error fetching pool: error parsing id: invalid request", err.Error()) +} + +func (s *InstancesTestSuite) TestCreateInstanceMaxRunnersReached() { + // Create a fourth instance (pool has max 4 runners, already has 3) + fourthInstanceParams := params.CreateInstanceParams{ + Name: "test-instance-4", + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err := s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, fourthInstanceParams) + s.Require().Nil(err) + + // Try to create a fifth instance, which should fail due to max runners limit + fifthInstanceParams := params.CreateInstanceParams{ + Name: "test-instance-5", + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err = s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, fifthInstanceParams) + s.Require().NotNil(err) + s.Require().Contains(err.Error(), "max runners reached for pool") +} + +func (s *InstancesTestSuite) TestCreateInstanceMaxRunnersReachedSpecificPool() { + // Create a new pool with max runners set to 3 + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + MaxRunners: 3, + MinIdleRunners: 1, + Image: "test-image", + Flavor: "test-flavor", + OSType: "linux", + Tags: []string{"amd64", "linux"}, + } + entity, err := s.Fixtures.Org.GetEntity() + s.Require().Nil(err) + testPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams) + s.Require().Nil(err) + + // Create exactly 3 instances (max limit) + for i := 1; i <= 3; i++ { + instanceParams := params.CreateInstanceParams{ + Name: fmt.Sprintf("max-test-instance-%d", i), + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err := s.Store.CreateInstance(s.adminCtx, testPool.ID, instanceParams) + s.Require().Nil(err) + } + + // Try to create a fourth instance, which should fail + fourthInstanceParams := params.CreateInstanceParams{ + Name: "max-test-instance-4", + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err = s.Store.CreateInstance(s.adminCtx, testPool.ID, fourthInstanceParams) + s.Require().NotNil(err) + s.Require().Contains(err.Error(), "max runners reached for pool") + + // Verify instance count is still 3 + count, err := s.Store.PoolInstanceCount(s.adminCtx, testPool.ID) + s.Require().Nil(err) + s.Require().Equal(int64(3), count) +} + +func (s *InstancesTestSuite) TestCreateInstanceConcurrentMaxRunnersRaceCondition() { + // Create a new pool with max runners set to 15, starting from 0 + createPoolParams := params.CreatePoolParams{ + ProviderName: "test-provider", + MaxRunners: 15, + MinIdleRunners: 0, + Image: "test-image", + Flavor: "test-flavor", + OSType: "linux", + Tags: []string{"amd64", "linux"}, + } + entity, err := s.Fixtures.Org.GetEntity() + s.Require().Nil(err) + raceTestPool, err := s.Store.CreateEntityPool(s.adminCtx, entity, createPoolParams) + s.Require().Nil(err) + + // Verify pool starts with 0 instances + initialCount, err := s.Store.PoolInstanceCount(s.adminCtx, raceTestPool.ID) + s.Require().Nil(err) + s.Require().Equal(int64(0), initialCount) + + // Concurrently try to create 150 instances (should only allow 15) + var wg sync.WaitGroup + results := make([]error, 150) + + for i := 0; i < 150; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + instanceParams := params.CreateInstanceParams{ + Name: fmt.Sprintf("race-test-instance-%d", index), + OSType: "linux", + OSArch: "amd64", + CallbackURL: "https://garm.example.com/", + } + _, err := s.Store.CreateInstance(s.adminCtx, raceTestPool.ID, instanceParams) + results[index] = err + }(i) + } + + wg.Wait() + + // Count successful and failed creations + successCount := 0 + conflictErrorCount := 0 + databaseLockedCount := 0 + otherErrorCount := 0 + + for i, err := range results { + if err == nil { + successCount++ + continue + } + + errStr := fmt.Sprintf("%v", err) + expectedConflictErr1 := "error creating instance: max runners reached for pool " + raceTestPool.ID + expectedConflictErr2 := "max runners reached for pool " + raceTestPool.ID + databaseLockedErr := "error creating instance: error creating instance: database is locked" + + switch errStr { + case expectedConflictErr1, expectedConflictErr2: + conflictErrorCount++ + case databaseLockedErr: + databaseLockedCount++ + s.T().Logf("Got database locked error for goroutine %d: %v", i, err) + default: + otherErrorCount++ + s.T().Logf("Got unexpected error for goroutine %d: %v", i, err) + } + } + + s.T().Logf("Results: success=%d, conflict=%d, databaseLocked=%d, other=%d", + successCount, conflictErrorCount, databaseLockedCount, otherErrorCount) + + // Verify final instance count is <= 15 (the main test - no more than max runners) + finalCount, err := s.Store.PoolInstanceCount(s.adminCtx, raceTestPool.ID) + s.Require().Nil(err) + s.Require().LessOrEqual(int64(successCount), int64(15), "Should not create more than max runners") + s.Require().Equal(int64(successCount), finalCount, "Final count should match successful creations") + + // The key test: verify we never exceeded max runners despite concurrent attempts + s.Require().True(finalCount <= 15, "Pool should never exceed max runners limit of 15, got %d", finalCount) + + // If there were database lock errors, that's a concurrency issue but not a max runners violation + if databaseLockedCount > 0 { + s.T().Logf("WARNING: Got %d database lock errors during concurrent testing - this indicates SQLite concurrency limitations", databaseLockedCount) + } + + // The critical assertion: total successful attempts + database locked + conflicts should equal 150 + s.Require().Equal(150, successCount+conflictErrorCount+databaseLockedCount+otherErrorCount, + "All 150 goroutines should have completed with some result") } func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() { pool := s.Fixtures.Pool + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT ?")). WithArgs(pool.ID, 1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pool.ID)) - s.Fixtures.SQLMock.ExpectBegin() + WillReturnRows(sqlmock.NewRows([]string{"id", "max_runners"}).AddRow(pool.ID, 4)) + s.Fixtures.SQLMock. + ExpectQuery(regexp.QuoteMeta("SELECT count(*) FROM `instances` WHERE pool_id = ? AND `instances`.`deleted_at` IS NULL")). + WithArgs(pool.ID). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) s.Fixtures.SQLMock. ExpectExec("INSERT INTO `pools`"). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -233,7 +399,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() { s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("error creating instance: mocked insert instance error", err.Error()) + s.Require().Equal("error creating instance: error creating instance: mocked insert instance error", err.Error()) } func (s *InstancesTestSuite) TestGetInstanceByName() { diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index fa5d178e..e3b83e8d 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -221,6 +221,7 @@ func (s *PoolsTestSuite) TestEntityPoolOperations() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, @@ -301,6 +302,7 @@ func (s *PoolsTestSuite) TestListEntityInstances() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, diff --git a/database/sql/sql.go b/database/sql/sql.go index 9f4c37ab..eb97a3c8 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -71,6 +71,7 @@ func newDBConn(dbCfg config.Database) (conn *gorm.DB, err error) { if dbCfg.Debug { conn = conn.Debug() } + return conn, nil } diff --git a/database/watcher/filters.go b/database/watcher/filters.go index acf79ba8..6b920ec3 100644 --- a/database/watcher/filters.go +++ b/database/watcher/filters.go @@ -183,25 +183,22 @@ func WithEntityJobFilter(ghEntity params.ForgeEntity) dbCommon.PayloadFilterFunc switch ghEntity.EntityType { case params.ForgeEntityTypeRepository: - if job.RepoID != nil && job.RepoID.String() != ghEntity.ID { - return false + if job.RepoID != nil && job.RepoID.String() == ghEntity.ID { + return true } case params.ForgeEntityTypeOrganization: - if job.OrgID != nil && job.OrgID.String() != ghEntity.ID { - return false + if job.OrgID != nil && job.OrgID.String() == ghEntity.ID { + return true } case params.ForgeEntityTypeEnterprise: - if job.EnterpriseID != nil && job.EnterpriseID.String() != ghEntity.ID { - return false + if job.EnterpriseID != nil && job.EnterpriseID.String() == ghEntity.ID { + return true } - default: - return false } - - return true default: return false } + return false } } diff --git a/database/watcher/watcher_store_test.go b/database/watcher/watcher_store_test.go index 97fc8a9d..2315078d 100644 --- a/database/watcher/watcher_store_test.go +++ b/database/watcher/watcher_store_test.go @@ -179,6 +179,7 @@ func (s *WatcherStoreTestSuite) TestInstanceWatcher() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, @@ -393,6 +394,7 @@ func (s *WatcherStoreTestSuite) TestPoolWatcher() { createPoolParams := params.CreatePoolParams{ ProviderName: "test-provider", + MaxRunners: 5, Image: "test-image", Flavor: "test-flavor", OSType: commonParams.Linux, diff --git a/internal/testing/testing.go b/internal/testing/testing.go index 38725882..8d8b941a 100644 --- a/internal/testing/testing.go +++ b/internal/testing/testing.go @@ -205,7 +205,8 @@ func GetTestSqliteDBConfig(t *testing.T) config.Database { DbBackend: config.SQLiteBackend, Passphrase: encryptionPassphrase, SQLite: config.SQLite{ - DBFile: filepath.Join(dir, "garm.db"), + DBFile: filepath.Join(dir, "garm.db"), + BusyTimeoutSeconds: 30, // 30 second timeout for concurrent transactions }, } } diff --git a/params/params.go b/params/params.go index 8e51fa23..3f7b24b5 100644 --- a/params/params.go +++ b/params/params.go @@ -1121,6 +1121,26 @@ type Job struct { UpdatedAt time.Time `json:"updated_at,omitempty"` } +func (j Job) BelongsTo(entity ForgeEntity) bool { + switch entity.EntityType { + case ForgeEntityTypeRepository: + if j.RepoID != nil { + return entity.ID == j.RepoID.String() + } + case ForgeEntityTypeEnterprise: + if j.EnterpriseID != nil { + return entity.ID == j.EnterpriseID.String() + } + case ForgeEntityTypeOrganization: + if j.OrgID != nil { + return entity.ID == j.OrgID.String() + } + default: + return false + } + return false +} + // swagger:model Jobs // used by swagger client generated code type Jobs []Job @@ -1144,13 +1164,13 @@ type CertificateBundle struct { RootCertificates map[string][]byte `json:"root_certificates,omitempty"` } -// swagger:model ForgeEntity type UpdateSystemInfoParams struct { OSName string `json:"os_name,omitempty"` OSVersion string `json:"os_version,omitempty"` AgentID *int64 `json:"agent_id,omitempty"` } +// swagger:model ForgeEntity type ForgeEntity struct { Owner string `json:"owner,omitempty"` Name string `json:"name,omitempty"` diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 36a7de03..446f2afb 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -124,6 +124,7 @@ func NewEntityPoolManager(ctx context.Context, entity params.ForgeEntity, instan store: store, providers: providers, quit: make(chan struct{}), + jobs: make(map[int64]params.Job), wg: wg, backoff: backoff, consumer: consumer, @@ -142,6 +143,7 @@ type basePoolManager struct { consumer dbCommon.Consumer store dbCommon.Store + jobs map[int64]params.Job providers map[string]common.Provider tools []commonParams.RunnerApplicationDownload @@ -1059,7 +1061,7 @@ func (r *basePoolManager) scaleDownOnePool(ctx context.Context, pool params.Pool // consideration for scale-down. The 5 minute grace period prevents a situation where a // "queued" workflow triggers the creation of a new idle runner, and this routine reaps // an idle runner before they have a chance to pick up a job. - if inst.RunnerStatus == params.RunnerIdle && inst.Status == commonParams.InstanceRunning && time.Since(inst.UpdatedAt).Minutes() > 2 { + if inst.RunnerStatus == params.RunnerIdle && inst.Status == commonParams.InstanceRunning { idleWorkers = append(idleWorkers, inst) } } @@ -1068,7 +1070,7 @@ func (r *basePoolManager) scaleDownOnePool(ctx context.Context, pool params.Pool return nil } - surplus := float64(len(idleWorkers) - pool.MinIdleRunnersAsInt()) + surplus := float64(len(idleWorkers) - (pool.MinIdleRunnersAsInt() + len(r.getQueuedJobs()))) if surplus <= 0 { return nil @@ -1143,17 +1145,8 @@ func (r *basePoolManager) addRunnerToPool(pool params.Pool, aditionalLabels []st return fmt.Errorf("pool %s is disabled", pool.ID) } - poolInstanceCount, err := r.store.PoolInstanceCount(r.ctx, pool.ID) - if err != nil { - return fmt.Errorf("failed to list pool instances: %w", err) - } - - if poolInstanceCount >= int64(pool.MaxRunnersAsInt()) { - return fmt.Errorf("max workers (%d) reached for pool %s", pool.MaxRunners, pool.ID) - } - if err := r.AddRunner(r.ctx, pool.ID, aditionalLabels); err != nil { - return fmt.Errorf("failed to add new instance for pool %s: %s", pool.ID, err) + return fmt.Errorf("failed to add new instance for pool %s: %w", pool.ID, err) } return nil } @@ -1760,10 +1753,7 @@ func (r *basePoolManager) DeleteRunner(runner params.Instance, forceRemove, bypa // so those will trigger the creation of a runner. The jobs we don't know about will be dealt with by the idle runners. // Once jobs are consumed, you can set min-idle-runners to 0 again. func (r *basePoolManager) consumeQueuedJobs() error { - queued, err := r.store.ListEntityJobsByStatus(r.ctx, r.entity.EntityType, r.entity.ID, params.JobStatusQueued) - if err != nil { - return fmt.Errorf("error listing queued jobs: %w", err) - } + queued := r.getQueuedJobs() poolsCache := poolsForTags{ poolCacheType: r.entity.GetPoolBalancerType(), diff --git a/runner/pool/util.go b/runner/pool/util.go index d58f90a3..3fed1478 100644 --- a/runner/pool/util.go +++ b/runner/pool/util.go @@ -84,9 +84,32 @@ func composeWatcherFilters(entity params.ForgeEntity) dbCommon.PayloadFilterFunc watcher.WithEntityFilter(entity), // Watch for changes to the github credentials watcher.WithForgeCredentialsFilter(entity.Credentials), + watcher.WithAll( + watcher.WithEntityJobFilter(entity), + watcher.WithAny( + watcher.WithOperationTypeFilter(dbCommon.UpdateOperation), + watcher.WithOperationTypeFilter(dbCommon.CreateOperation), + watcher.WithOperationTypeFilter(dbCommon.DeleteOperation), + ), + ), ) } +func (r *basePoolManager) getQueuedJobs() []params.Job { + r.mux.Lock() + defer r.mux.Unlock() + + ret := []params.Job{} + + for _, job := range r.jobs { + slog.DebugContext(r.ctx, "considering job for processing", "job_id", job.ID, "job_status", job.Status) + if params.JobStatus(job.Status) == params.JobStatusQueued { + ret = append(ret, job) + } + } + return ret +} + func (r *basePoolManager) waitForToolsOrCancel() (hasTools, stopped bool) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() diff --git a/runner/pool/watcher.go b/runner/pool/watcher.go index 999b52c6..16764a2c 100644 --- a/runner/pool/watcher.go +++ b/runner/pool/watcher.go @@ -162,11 +162,46 @@ func (r *basePoolManager) handleWatcherEvent(event common.ChangePayload) { return } r.handleEntityUpdate(entityInfo, event.Operation) + case common.JobEntityType: + slog.DebugContext(r.ctx, "new job via watcher") + job, ok := event.Payload.(params.Job) + if !ok { + slog.ErrorContext(r.ctx, "failed to cast payload to job") + return + } + if !job.BelongsTo(r.entity) { + slog.InfoContext(r.ctx, "job does not belong to entity", "worklof_job_id", job.WorkflowJobID, "scaleset_job_id", job.ScaleSetJobID, "job_id", job.ID) + return + } + slog.DebugContext(r.ctx, "recording job", "job_id", job.ID, "job_status", job.Status) + r.mux.Lock() + switch event.Operation { + case common.CreateOperation, common.UpdateOperation: + if params.JobStatus(job.Status) != params.JobStatusCompleted { + slog.DebugContext(r.ctx, "adding job to map", "job_id", job.ID, "job_status", job.Status) + r.jobs[job.ID] = job + break + } + fallthrough + case common.DeleteOperation: + delete(r.jobs, job.ID) + } + r.mux.Unlock() } } func (r *basePoolManager) runWatcher() { defer r.consumer.Close() + queued, err := r.store.ListEntityJobsByStatus(r.ctx, r.entity.EntityType, r.entity.ID, params.JobStatusQueued) + if err != nil { + slog.ErrorContext(r.ctx, "failed to list jobs", "error", err) + } + + r.mux.Lock() + for _, job := range queued { + r.jobs[job.ID] = job + } + r.mux.Unlock() for { select { case <-r.quit: @@ -177,7 +212,7 @@ func (r *basePoolManager) runWatcher() { if !ok { return } - go r.handleWatcherEvent(event) + r.handleWatcherEvent(event) } } }