From ef3402bf17e0b092f54ca0b0d99ebed1d1146cf9 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 26 May 2025 19:12:59 +0000 Subject: [PATCH] Add write lock for sqlite3 Signed-off-by: Gabriel Adrian Samfira --- database/sql/controller.go | 6 ++++++ database/sql/enterprise.go | 9 +++++++++ database/sql/github.go | 18 ++++++++++++++++++ database/sql/instances.go | 12 ++++++++++++ database/sql/jobs.go | 18 ++++++++++++++++++ database/sql/organizations.go | 9 +++++++++ database/sql/pools.go | 12 ++++++++++++ database/sql/repositories.go | 9 +++++++++ database/sql/sql.go | 6 ++++++ database/sql/users.go | 6 ++++++ 10 files changed, 105 insertions(+) diff --git a/database/sql/controller.go b/database/sql/controller.go index fb360e00..71890c88 100644 --- a/database/sql/controller.go +++ b/database/sql/controller.go @@ -63,6 +63,9 @@ func (s *sqlDatabase) ControllerInfo() (params.ControllerInfo, error) { } func (s *sqlDatabase) InitController() (params.ControllerInfo, error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + if _, err := s.ControllerInfo(); err == nil { return params.ControllerInfo{}, runnerErrors.NewConflictError("controller already initialized") } @@ -88,6 +91,9 @@ func (s *sqlDatabase) InitController() (params.ControllerInfo, error) { } func (s *sqlDatabase) UpdateController(info params.UpdateControllerParams) (paramInfo params.ControllerInfo, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.ControllerEntityType, common.UpdateOperation, paramInfo) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index dfcb10a2..9b927bed 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -29,6 +29,9 @@ import ( ) func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (paramEnt params.Enterprise, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + if webhookSecret == "" { return params.Enterprise{}, errors.New("creating enterprise: missing secret") } @@ -132,6 +135,9 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e } func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error { + s.writeMux.Lock() + defer s.writeMux.Unlock() + enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return errors.Wrap(err, "fetching enterprise") @@ -157,6 +163,9 @@ func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) } func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (newParams params.Enterprise, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.EnterpriseEntityType, common.UpdateOperation, newParams) diff --git a/database/sql/github.go b/database/sql/github.go index a66c7331..d787653d 100644 --- a/database/sql/github.go +++ b/database/sql/github.go @@ -111,6 +111,9 @@ func getUIDFromContext(ctx context.Context) (uuid.UUID, error) { } func (s *sqlDatabase) CreateGithubEndpoint(_ context.Context, param params.CreateGithubEndpointParams) (ghEndpoint params.GithubEndpoint, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.GithubEndpointEntityType, common.CreateOperation, ghEndpoint) @@ -164,6 +167,9 @@ func (s *sqlDatabase) ListGithubEndpoints(_ context.Context) ([]params.GithubEnd } func (s *sqlDatabase) UpdateGithubEndpoint(_ context.Context, name string, param params.UpdateGithubEndpointParams) (ghEndpoint params.GithubEndpoint, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.GithubEndpointEntityType, common.UpdateOperation, ghEndpoint) @@ -229,6 +235,9 @@ func (s *sqlDatabase) GetGithubEndpoint(_ context.Context, name string) (params. } func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.GithubEndpointEntityType, common.DeleteOperation, params.GithubEndpoint{Name: name}) @@ -287,6 +296,9 @@ func (s *sqlDatabase) DeleteGithubEndpoint(_ context.Context, name string) (err } func (s *sqlDatabase) CreateGithubCredentials(ctx context.Context, param params.CreateGithubCredentialsParams) (ghCreds params.GithubCredentials, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + userID, err := getUIDFromContext(ctx) if err != nil { return params.GithubCredentials{}, errors.Wrap(err, "creating github credentials") @@ -450,6 +462,9 @@ func (s *sqlDatabase) ListGithubCredentials(ctx context.Context) ([]params.Githu } func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, param params.UpdateGithubCredentialsParams) (ghCreds params.GithubCredentials, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.GithubCredentialsEntityType, common.UpdateOperation, ghCreds) @@ -529,6 +544,9 @@ func (s *sqlDatabase) UpdateGithubCredentials(ctx context.Context, id uint, para } func (s *sqlDatabase) DeleteGithubCredentials(ctx context.Context, id uint) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + var name string defer func() { if err == nil { diff --git a/database/sql/instances.go b/database/sql/instances.go index 864e7ba2..c7fb02f6 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -31,6 +31,9 @@ import ( ) func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") @@ -143,6 +146,9 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string } func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return errors.Wrap(err, "deleting instance") @@ -176,6 +182,9 @@ func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceN } func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error { + s.writeMux.Lock() + defer s.writeMux.Unlock() + instance, err := s.getInstanceByName(ctx, instanceName) if err != nil { return errors.Wrap(err, "updating instance") @@ -194,6 +203,9 @@ func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, } func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + instance, err := s.getInstanceByName(ctx, instanceName) if err != nil { return params.Instance{}, errors.Wrap(err, "updating instance") diff --git a/database/sql/jobs.go b/database/sql/jobs.go index b7dda926..9cbf2ffe 100644 --- a/database/sql/jobs.go +++ b/database/sql/jobs.go @@ -95,6 +95,9 @@ func (s *sqlDatabase) paramsJobToWorkflowJob(ctx context.Context, job params.Job } func (s *sqlDatabase) DeleteJob(_ context.Context, jobID int64) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { if notifyErr := s.sendNotify(common.JobEntityType, common.DeleteOperation, params.Job{ID: jobID}); notifyErr != nil { @@ -113,6 +116,9 @@ func (s *sqlDatabase) DeleteJob(_ context.Context, jobID int64) (err error) { } func (s *sqlDatabase) LockJob(_ context.Context, jobID int64, entityID string) error { + s.writeMux.Lock() + defer s.writeMux.Unlock() + entityUUID, err := uuid.Parse(entityID) if err != nil { return errors.Wrap(err, "parsing entity id") @@ -152,6 +158,9 @@ func (s *sqlDatabase) LockJob(_ context.Context, jobID int64, entityID string) e } func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + var workflowJob WorkflowJob q := s.conn.Clauses(clause.Locking{Strength: "UPDATE"}).Preload("Instance").Where("id = ? and status = ?", jobID, params.JobStatusQueued).First(&workflowJob) @@ -180,6 +189,9 @@ func (s *sqlDatabase) BreakLockJobIsQueued(_ context.Context, jobID int64) (err } func (s *sqlDatabase) UnlockJob(_ context.Context, jobID int64, entityID string) error { + s.writeMux.Lock() + defer s.writeMux.Unlock() + var workflowJob WorkflowJob q := s.conn.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ?", jobID).First(&workflowJob) @@ -213,6 +225,9 @@ func (s *sqlDatabase) UnlockJob(_ context.Context, jobID int64, entityID string) } func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (params.Job, error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + var workflowJob WorkflowJob var err error q := s.conn.Clauses(clause.Locking{Strength: "UPDATE"}).Preload("Instance").Where("id = ?", job.ID).First(&workflowJob) @@ -381,6 +396,9 @@ func (s *sqlDatabase) GetJobByID(_ context.Context, jobID int64) (params.Job, er // DeleteCompletedJobs deletes all completed jobs. func (s *sqlDatabase) DeleteCompletedJobs(_ context.Context) error { + s.writeMux.Lock() + defer s.writeMux.Unlock() + query := s.conn.Model(&WorkflowJob{}).Where("status = ?", params.JobStatusCompleted) if err := query.Unscoped().Delete(&WorkflowJob{}); err.Error != nil { diff --git a/database/sql/organizations.go b/database/sql/organizations.go index c41b9269..3c2cdbbf 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -30,6 +30,9 @@ import ( ) func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (org params.Organization, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + if webhookSecret == "" { return params.Organization{}, errors.New("creating org: missing secret") } @@ -123,6 +126,9 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio } func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return errors.Wrap(err, "fetching org") @@ -148,6 +154,9 @@ func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err } func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (paramOrg params.Organization, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.OrganizationEntityType, common.UpdateOperation, paramOrg) diff --git a/database/sql/pools.go b/database/sql/pools.go index fdcf3f5a..cd888505 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -68,6 +68,9 @@ func (s *sqlDatabase) GetPoolByID(_ context.Context, poolID string) (params.Pool } func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return errors.Wrap(err, "fetching pool by ID") @@ -255,6 +258,9 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par } func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (pool params.Pool, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + if len(param.Tags) == 0 { return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") } @@ -343,6 +349,9 @@ func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntit } func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + entityID, err := uuid.Parse(entity.ID) if err != nil { return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") @@ -380,6 +389,9 @@ func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEn } func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (updatedPool params.Pool, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.PoolEntityType, common.UpdateOperation, updatedPool) diff --git a/database/sql/repositories.go b/database/sql/repositories.go index c1eaef3b..d6cefc64 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -30,6 +30,9 @@ import ( ) func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string, poolBalancerType params.PoolBalancerType) (param params.Repository, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.RepositoryEntityType, common.CreateOperation, param) @@ -122,6 +125,9 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, } func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return errors.Wrap(err, "fetching repo") @@ -147,6 +153,9 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err } func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (newParams params.Repository, err error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + defer func() { if err == nil { s.sendNotify(common.RepositoryEntityType, common.UpdateOperation, newParams) diff --git a/database/sql/sql.go b/database/sql/sql.go index d4e6895a..290cce3f 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -20,6 +20,7 @@ import ( "log/slog" "net/url" "strings" + "sync" "github.com/pkg/errors" "gorm.io/driver/mysql" @@ -91,6 +92,11 @@ type sqlDatabase struct { ctx context.Context cfg config.Database producer common.Producer + + // while busy_timeout helps, in situations of high contention, we can still + // end up with multiple threads trying to write to the database. SQLite does now + // support row level locking. + writeMux sync.Mutex } var renameTemplate = ` diff --git a/database/sql/users.go b/database/sql/users.go index 7d604a83..6bc0973f 100644 --- a/database/sql/users.go +++ b/database/sql/users.go @@ -57,6 +57,9 @@ func (s *sqlDatabase) getUserByID(tx *gorm.DB, userID string) (User, error) { } func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) (params.User, error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + if user.Username == "" || user.Email == "" || user.Password == "" { return params.User{}, runnerErrors.NewBadRequestError("missing username, password or email") } @@ -119,6 +122,9 @@ func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User } func (s *sqlDatabase) UpdateUser(_ context.Context, user string, param params.UpdateUserParams) (params.User, error) { + s.writeMux.Lock() + defer s.writeMux.Unlock() + var err error var dbUser User err = s.conn.Transaction(func(tx *gorm.DB) error {