Add agent mode

This change adds a new "agent mode" to GARM. The agent enables GARM to
set up a persistent websocket connection between the garm server and the
runners it spawns. The goal is to be able to easier keep track of state,
even without subsequent webhooks from the forge.

The Agent will report via websockets when the runner is actually online,
when it started a job and when it finished a job.

Additionally, the agent allows us to enable optional remote shell between
the user and any runner that is spun up using agent mode. The remote shell
is multiplexed over the same persistent websocket connection the agent
sets up with the server (the agent never listens on a port).

Enablement has also been done in the web UI for this functionality.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2025-09-16 07:42:59 +00:00 committed by Gabriel
parent 3b132e4233
commit 42cfd1b3c6
246 changed files with 11042 additions and 672 deletions

View file

@ -34,15 +34,23 @@ func dbControllerToCommonController(dbInfo ControllerInfo) (params.ControllerInf
return params.ControllerInfo{}, fmt.Errorf("error joining webhook URL: %w", err)
}
return params.ControllerInfo{
if dbInfo.GARMAgentReleasesURL == "" {
dbInfo.GARMAgentReleasesURL = appdefaults.GARMAgentDefaultReleasesURL
}
ret := params.ControllerInfo{
ControllerID: dbInfo.ControllerID,
MetadataURL: dbInfo.MetadataURL,
WebhookURL: dbInfo.WebhookBaseURL,
ControllerWebhookURL: url,
CallbackURL: dbInfo.CallbackURL,
AgentURL: dbInfo.AgentURL,
MinimumJobAgeBackoff: dbInfo.MinimumJobAgeBackoff,
Version: appdefaults.GetVersion(),
}, nil
GARMAgentReleasesURL: dbInfo.GARMAgentReleasesURL,
SyncGARMAgentTools: dbInfo.SyncGARMAgentTools,
}
return ret, nil
}
func (s *sqlDatabase) ControllerInfo() (params.ControllerInfo, error) {
@ -63,6 +71,24 @@ func (s *sqlDatabase) ControllerInfo() (params.ControllerInfo, error) {
return paramInfo, nil
}
func (s *sqlDatabase) HasEntitiesWithAgentModeEnabled() (bool, error) {
var reposCnt int64
if err := s.conn.Model(&Repository{}).Where("agent_mode = ?", true).Count(&reposCnt).Error; err != nil {
return false, fmt.Errorf("error fetching repo count: %w", err)
}
var orgCount int64
if err := s.conn.Model(&Organization{}).Where("agent_mode = ?", true).Count(&orgCount).Error; err != nil {
return false, fmt.Errorf("error fetching repo count: %w", err)
}
var enterpriseCount int64
if err := s.conn.Model(&Enterprise{}).Where("agent_mode = ?", true).Count(&enterpriseCount).Error; err != nil {
return false, fmt.Errorf("error fetching repo count: %w", err)
}
return reposCnt+orgCount+enterpriseCount > 0, nil
}
func (s *sqlDatabase) InitController() (params.ControllerInfo, error) {
if _, err := s.ControllerInfo(); err == nil {
return params.ControllerInfo{}, runnerErrors.NewConflictError("controller already initialized")
@ -76,6 +102,7 @@ func (s *sqlDatabase) InitController() (params.ControllerInfo, error) {
newInfo := ControllerInfo{
ControllerID: newID,
MinimumJobAgeBackoff: 30,
GARMAgentReleasesURL: appdefaults.GARMAgentDefaultReleasesURL,
}
q := s.conn.Save(&newInfo)
@ -120,6 +147,22 @@ func (s *sqlDatabase) UpdateController(info params.UpdateControllerParams) (para
dbInfo.WebhookBaseURL = *info.WebhookURL
}
if info.AgentURL != nil {
dbInfo.AgentURL = *info.AgentURL
}
if info.GARMAgentReleasesURL != nil {
agentToolsURL := *info.GARMAgentReleasesURL
if agentToolsURL == "" {
agentToolsURL = appdefaults.GARMAgentDefaultReleasesURL
}
dbInfo.GARMAgentReleasesURL = agentToolsURL
}
if info.SyncGARMAgentTools != nil {
dbInfo.SyncGARMAgentTools = *info.SyncGARMAgentTools
}
if info.MinimumJobAgeBackoff != nil {
dbInfo.MinimumJobAgeBackoff = *info.MinimumJobAgeBackoff
}

View file

@ -29,7 +29,7 @@ import (
"github.com/cloudbase/garm/params"
)
func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name string, credentials params.ForgeCredentials, webhookSecret string, poolBalancerType params.PoolBalancerType) (paramEnt params.Enterprise, err error) {
func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name string, credentials params.ForgeCredentials, webhookSecret string, poolBalancerType params.PoolBalancerType, agentMode bool) (paramEnt params.Enterprise, err error) {
if webhookSecret == "" {
return params.Enterprise{}, errors.New("creating enterprise: missing secret")
}
@ -51,6 +51,7 @@ func (s *sqlDatabase) CreateEnterprise(ctx context.Context, name string, credent
Name: name,
WebhookSecret: secret,
PoolBalancerType: poolBalancerType,
AgentMode: agentMode,
}
err = s.conn.Transaction(func(tx *gorm.DB) error {
newEnterprise.CredentialsID = &credentials.ID
@ -211,6 +212,10 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string,
enterprise.PoolBalancerType = param.PoolBalancerType
}
if param.AgentMode != nil {
enterprise.AgentMode = *param.AgentMode
}
q := tx.Save(&enterprise)
if q.Error != nil {
return fmt.Errorf("error saving enterprise: %w", q.Error)

View file

@ -113,6 +113,7 @@ func (s *EnterpriseTestSuite) SetupTest() {
s.testCreds,
fmt.Sprintf("test-webhook-secret-%d", i),
params.PoolBalancerTypeRoundRobin,
false,
)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create database object (test-enterprise-%d): %q", i, err))
@ -191,7 +192,9 @@ func (s *EnterpriseTestSuite) TestCreateEnterprise() {
s.Fixtures.CreateEnterpriseParams.Name,
s.testCreds,
s.Fixtures.CreateEnterpriseParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
// assertions
s.Require().Nil(err)
@ -222,7 +225,9 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseInvalidDBPassphrase() {
s.Fixtures.CreateEnterpriseParams.Name,
s.testCreds,
s.Fixtures.CreateEnterpriseParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NotNil(err)
s.Require().Equal("error encoding secret: invalid passphrase length (expected length 32 characters)", err.Error())
@ -240,7 +245,9 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBCreateErr() {
s.Fixtures.CreateEnterpriseParams.Name,
s.testCreds,
s.Fixtures.CreateEnterpriseParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NotNil(err)
s.Require().Equal("error creating enterprise: error creating enterprise: creating enterprise mock error", err.Error())
@ -296,6 +303,7 @@ func (s *EnterpriseTestSuite) TestListEnterprisesWithFilter() {
s.ghesCreds,
"test-secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -305,6 +313,7 @@ func (s *EnterpriseTestSuite) TestListEnterprisesWithFilter() {
s.testCreds,
"test-secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -314,6 +323,7 @@ func (s *EnterpriseTestSuite) TestListEnterprisesWithFilter() {
s.testCreds,
"test-secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
enterprises, err := s.Store.ListEnterprises(s.adminCtx, params.EnterpriseFilter{
@ -844,7 +854,9 @@ func (s *EnterpriseTestSuite) TestAddRepoEntityEvent() {
s.Fixtures.CreateEnterpriseParams.Name,
s.testCreds,
s.Fixtures.CreateEnterpriseParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().Nil(err)
entity, err := enterprise.GetEntity()

View file

@ -253,6 +253,63 @@ func (s *sqlDatabase) DeleteFileObject(_ context.Context, objID uint) (err error
return nil
}
func (s *sqlDatabase) DeleteFileObjectsByTags(_ context.Context, tags []string) (int64, error) {
if len(tags) == 0 {
return 0, fmt.Errorf("no tags provided")
}
var deletedCount int64
err := s.objectsConn.Transaction(func(tx *gorm.DB) error {
// Build query to find all file objects matching ALL tags
query := tx.Model(&FileObject{}).Preload("TagsList").Omit("content")
for _, tag := range tags {
query = query.Where("EXISTS (SELECT 1 FROM file_object_tags WHERE file_object_tags.file_object_id = file_objects.id AND file_object_tags.tag = ?)", tag)
}
// Get matching objects with their full details (except content blob)
var fileObjects []FileObject
if err := query.Find(&fileObjects).Error; err != nil {
return fmt.Errorf("failed to find matching objects: %w", err)
}
if len(fileObjects) == 0 {
// No objects match - not an error, just nothing to delete
return nil
}
// Extract IDs for deletion
fileObjIDs := make([]uint, len(fileObjects))
for i, obj := range fileObjects {
fileObjIDs[i] = obj.ID
}
// Delete all matching objects (hard delete with Unscoped)
result := tx.Unscoped().Where("id IN ?", fileObjIDs).Delete(&FileObject{})
if result.Error != nil {
return fmt.Errorf("failed to delete objects: %w", result.Error)
}
deletedCount = result.RowsAffected
// Send notifications with full object details for each deleted object
for _, obj := range fileObjects {
s.sendNotify(common.FileObjectEntityType, common.DeleteOperation, s.sqlFileObjectToCommonParams(obj))
}
return nil
})
if err != nil {
return 0, err
}
// NOTE: Same as DeleteFileObject - deleted file objects leave empty space
// in the database. Users should run VACUUM manually to reclaim space.
// See DeleteFileObject for performance details.
return deletedCount, nil
}
func (s *sqlDatabase) GetFileObject(_ context.Context, objID uint) (params.FileObject, error) {
var fileObj FileObject
if err := s.objectsConn.Preload("TagsList").Where("id = ?", objID).Omit("content").First(&fileObj).Error; err != nil {
@ -304,7 +361,7 @@ func (s *sqlDatabase) SearchFileObjectByTags(_ context.Context, tags []string, p
if err := query.
Limit(queryPageSize).
Offset(queryOffset).
Order("created_at DESC").
Order("id DESC").
Omit("content").
Find(&fileObjectRes).Error; err != nil {
return params.FileObjectPaginatedResponse{}, fmt.Errorf("failed to query database: %w", err)
@ -426,7 +483,7 @@ func (s *sqlDatabase) ListFileObjects(_ context.Context, page, pageSize uint64)
if err := s.objectsConn.Preload("TagsList").Omit("content").
Limit(queryPageSize).
Offset(queryOffset).
Order("created_at DESC").
Order("id DESC").
Find(&fileObjs).Error; err != nil {
return params.FileObjectPaginatedResponse{}, fmt.Errorf("failed to list file objects: %w", err)
}

View file

@ -529,7 +529,7 @@ func (s *GiteaTestSuite) TestDeleteCredentialsFailsIfReposOrgsOrEntitiesUseIt()
s.Require().NoError(err)
s.Require().NotNil(creds)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(repo)
@ -540,7 +540,7 @@ func (s *GiteaTestSuite) TestDeleteCredentialsFailsIfReposOrgsOrEntitiesUseIt()
err = s.db.DeleteRepository(ctx, repo.ID)
s.Require().NoError(err)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(org)
@ -551,7 +551,7 @@ func (s *GiteaTestSuite) TestDeleteCredentialsFailsIfReposOrgsOrEntitiesUseIt()
err = s.db.DeleteOrganization(ctx, org.ID)
s.Require().NoError(err)
enterprise, err := s.db.CreateEnterprise(ctx, "test-enterprise", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
enterprise, err := s.db.CreateEnterprise(ctx, "test-enterprise", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().ErrorIs(err, runnerErrors.ErrBadRequest)
s.Require().Equal(params.Enterprise{}, enterprise)
@ -685,7 +685,7 @@ func (s *GiteaTestSuite) TestDeleteCredentialsWithOrgsOrReposFails() {
s.Require().NoError(err)
s.Require().NotNil(creds)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(repo)
@ -696,7 +696,7 @@ func (s *GiteaTestSuite) TestDeleteCredentialsWithOrgsOrReposFails() {
err = s.db.DeleteRepository(ctx, repo.ID)
s.Require().NoError(err)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(org)
@ -743,7 +743,7 @@ func (s *GiteaTestSuite) TestDeleteGiteaEndpointFailsWithOrgsReposOrCredentials(
s.Require().NoError(err)
s.Require().NotNil(creds)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(repo)
@ -755,7 +755,7 @@ func (s *GiteaTestSuite) TestDeleteGiteaEndpointFailsWithOrgsReposOrCredentials(
err = s.db.DeleteRepository(ctx, repo.ID)
s.Require().NoError(err)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(org)

View file

@ -640,7 +640,7 @@ func (s *GithubTestSuite) TestDeleteCredentialsFailsIfReposOrgsOrEntitiesUseIt()
s.Require().NoError(err)
s.Require().NotNil(creds)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(repo)
@ -651,7 +651,7 @@ func (s *GithubTestSuite) TestDeleteCredentialsFailsIfReposOrgsOrEntitiesUseIt()
err = s.db.DeleteRepository(ctx, repo.ID)
s.Require().NoError(err)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(org)
@ -662,7 +662,7 @@ func (s *GithubTestSuite) TestDeleteCredentialsFailsIfReposOrgsOrEntitiesUseIt()
err = s.db.DeleteOrganization(ctx, org.ID)
s.Require().NoError(err)
enterprise, err := s.db.CreateEnterprise(ctx, "test-enterprise", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
enterprise, err := s.db.CreateEnterprise(ctx, "test-enterprise", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(enterprise)
@ -872,7 +872,7 @@ func (s *GithubTestSuite) TestDeleteGithubEndpointFailsWithOrgsReposOrCredential
s.Require().NoError(err)
s.Require().NotNil(creds)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
repo, err := s.db.CreateRepository(ctx, "test-owner", "test-repo", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(repo)
@ -884,7 +884,7 @@ func (s *GithubTestSuite) TestDeleteGithubEndpointFailsWithOrgsReposOrCredential
err = s.db.DeleteRepository(ctx, repo.ID)
s.Require().NoError(err)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin)
org, err := s.db.CreateOrganization(ctx, "test-org", creds, "superSecret@123BlaBla", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotNil(org)

View file

@ -21,6 +21,7 @@ import (
"fmt"
"log/slog"
"math"
"slices"
"github.com/google/uuid"
"gorm.io/datatypes"
@ -28,6 +29,7 @@ import (
"gorm.io/gorm/clause"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
commonParams "github.com/cloudbase/garm-provider-common/params"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/params"
)
@ -124,7 +126,7 @@ func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string)
return instance, nil
}
func (s *sqlDatabase) getInstance(_ context.Context, instanceNameOrID string, preload ...string) (Instance, error) {
func (s *sqlDatabase) getInstance(_ context.Context, tx *gorm.DB, instanceNameOrID string, preload ...string) (Instance, error) {
var instance Instance
var whereArg any = instanceNameOrID
@ -134,7 +136,7 @@ func (s *sqlDatabase) getInstance(_ context.Context, instanceNameOrID string, pr
whereArg = id
whereClause = "id = ?"
}
q := s.conn
q := tx
if len(preload) > 0 {
for _, item := range preload {
@ -156,7 +158,7 @@ func (s *sqlDatabase) getInstance(_ context.Context, instanceNameOrID string, pr
}
func (s *sqlDatabase) GetInstance(ctx context.Context, instanceName string) (params.Instance, error) {
instance, err := s.getInstance(ctx, instanceName, "StatusMessages", "Pool", "ScaleSet")
instance, err := s.getInstance(ctx, s.conn, instanceName, "StatusMessages", "Pool", "ScaleSet")
if err != nil {
return params.Instance{}, fmt.Errorf("error fetching instance: %w", err)
}
@ -208,7 +210,7 @@ func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceN
}
func (s *sqlDatabase) DeleteInstanceByName(ctx context.Context, instanceName string) error {
instance, err := s.getInstance(ctx, instanceName, "Pool", "ScaleSet")
instance, err := s.getInstance(ctx, s.conn, instanceName, "Pool", "ScaleSet")
if err != nil {
if errors.Is(err, runnerErrors.ErrNotFound) {
return nil
@ -250,7 +252,7 @@ func (s *sqlDatabase) DeleteInstanceByName(ctx context.Context, instanceName str
}
func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error {
instance, err := s.getInstance(ctx, instanceName)
instance, err := s.getInstance(ctx, s.conn, instanceName)
if err != nil {
return fmt.Errorf("error updating instance: %w", err)
}
@ -261,81 +263,199 @@ func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string,
EventLevel: eventLevel,
}
if err := s.conn.Model(&instance).Association("StatusMessages").Append(&msg); err != nil {
// Use Create instead of Association.Append to avoid loading all existing messages
msg.InstanceID = instance.ID
if err := s.conn.Create(&msg).Error; err != nil {
return fmt.Errorf("error adding status message: %w", err)
}
// Keep only the latest 30 status messages to prevent database bloat
const maxStatusMessages = 30
var count int64
if err := s.conn.Model(&InstanceStatusUpdate{}).Where("instance_id = ?", instance.ID).Count(&count).Error; err != nil {
return fmt.Errorf("error counting status messages: %w", err)
}
if count > maxStatusMessages {
// Get the ID of the 30th most recent message
var cutoffMsg InstanceStatusUpdate
if err := s.conn.Model(&InstanceStatusUpdate{}).
Select("id").
Where("instance_id = ?", instance.ID).
Order("id desc").
Offset(maxStatusMessages - 1).
Limit(1).
First(&cutoffMsg).Error; err != nil {
return fmt.Errorf("error finding cutoff message: %w", err)
}
// Delete all messages older than the cutoff
if err := s.conn.Where("instance_id = ? and id < ?", instance.ID, cutoffMsg.ID).Unscoped().Delete(&InstanceStatusUpdate{}).Error; err != nil {
return fmt.Errorf("error deleting old status messages: %w", err)
}
}
return nil
}
// validateAgentID checks agent ID consistency
func (s *sqlDatabase) validateAgentID(currentAgentID, newAgentID int64) error {
if currentAgentID != 0 && newAgentID != 0 && currentAgentID != newAgentID {
return runnerErrors.NewBadRequestError("agent ID mismatch")
}
return nil
}
func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) {
instance, err := s.getInstance(ctx, instanceName, "Pool", "ScaleSet")
if err != nil {
return params.Instance{}, fmt.Errorf("error updating instance: %w", err)
// validateRunnerStatusTransition validates runner status state transition
func (s *sqlDatabase) validateRunnerStatusTransition(current, newStatus params.RunnerStatus) error {
if newStatus == "" || newStatus == current {
return nil
}
allowedTransitions, ok := params.RunnerStatusTransitions[current]
if !ok {
return fmt.Errorf("Instance is in invalid state: %s", current)
}
if !slices.Contains(allowedTransitions, newStatus) {
return runnerErrors.NewBadRequestError("invalid runner status transition from %s to %s", current, newStatus)
}
return nil
}
// validateInstanceStatusTransition validates instance status state transition
func (s *sqlDatabase) validateInstanceStatusTransition(current, newStatus commonParams.InstanceStatus) error {
if newStatus == "" || newStatus == current {
return nil
}
allowedTransitions, ok := params.InstanceStatusTransitions[current]
if !ok {
// we need a better way to handle this. Because if we err out here, we cannot recover
// unless the user manually updates the instance.
return fmt.Errorf("Instance is in invalid state: %s", current)
}
if !slices.Contains(allowedTransitions, newStatus) {
return runnerErrors.NewBadRequestError("invalid instance status transition from %s to %s", current, newStatus)
}
return nil
}
// applyInstanceUpdates applies parameter updates to the instance
func (s *sqlDatabase) applyInstanceUpdates(instance *Instance, param params.UpdateInstanceParams) error {
// Simple field updates
if param.AgentID != 0 {
instance.AgentID = param.AgentID
}
if param.ProviderID != "" {
instance.ProviderID = &param.ProviderID
}
if param.OSName != "" {
instance.OSName = param.OSName
}
if param.OSVersion != "" {
instance.OSVersion = param.OSVersion
}
if string(param.RunnerStatus) != "" {
instance.RunnerStatus = param.RunnerStatus
}
if string(param.Status) != "" {
if param.Heartbeat != nil {
instance.Heartbeat = *param.Heartbeat
}
if param.Status != "" {
instance.Status = param.Status
}
if param.CreateAttempt != 0 {
instance.CreateAttempt = param.CreateAttempt
}
if param.TokenFetched != nil {
instance.TokenFetched = *param.TokenFetched
}
// Complex field updates
if param.Capabilities != nil {
asJs, err := json.Marshal(*param.Capabilities)
if err != nil {
return runnerErrors.NewBadRequestError("invalid capabilities: %s", err)
}
instance.Capabilities = asJs
}
if param.JitConfiguration != nil {
secret, err := s.marshalAndSeal(param.JitConfiguration)
if err != nil {
return params.Instance{}, fmt.Errorf("error marshalling jit config: %w", err)
return fmt.Errorf("error marshalling jit config: %w", err)
}
instance.JitConfiguration = secret
}
instance.ProviderFault = param.ProviderFault
return nil
}
q := s.conn.Save(&instance)
if q.Error != nil {
return params.Instance{}, fmt.Errorf("error updating instance: %w", q.Error)
func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) {
var rowsAffected int64
err := s.conn.Transaction(func(tx *gorm.DB) error {
instance, err := s.getInstance(ctx, tx, instanceName, "Pool", "ScaleSet")
if err != nil {
return fmt.Errorf("error updating instance: %w", err)
}
// Validate transitions
if err := s.validateAgentID(instance.AgentID, param.AgentID); err != nil {
return err
}
if err := s.validateRunnerStatusTransition(instance.RunnerStatus, param.RunnerStatus); err != nil {
return err
}
if err := s.validateInstanceStatusTransition(instance.Status, param.Status); err != nil {
return err
}
// Apply updates
if err := s.applyInstanceUpdates(&instance, param); err != nil {
return err
}
// Save instance
result := tx.Save(&instance)
if result.Error != nil {
return fmt.Errorf("error updating instance: %w", result.Error)
}
rowsAffected = result.RowsAffected
// Update addresses if provided
if len(param.Addresses) > 0 {
addrs := make([]Address, 0, len(param.Addresses))
for _, addr := range param.Addresses {
addrs = append(addrs, Address{
Address: addr.Address,
Type: string(addr.Type),
})
}
if err := tx.Model(&instance).Association("Addresses").Replace(addrs); err != nil {
return fmt.Errorf("error updating addresses: %w", err)
}
}
return nil
})
if err != nil {
return params.Instance{}, fmt.Errorf("error updating instance: %w", err)
}
if len(param.Addresses) > 0 {
addrs := []Address{}
for _, addr := range param.Addresses {
addrs = append(addrs, Address{
Address: addr.Address,
Type: string(addr.Type),
})
}
if err := s.conn.Model(&instance).Association("Addresses").Replace(addrs); err != nil {
return params.Instance{}, fmt.Errorf("error updating addresses: %w", err)
}
instance, err := s.getInstance(ctx, s.conn, instanceName, "Pool", "ScaleSet")
if err != nil {
return params.Instance{}, fmt.Errorf("error updating instance: %w", err)
}
inst, err := s.sqlToParamsInstance(instance)
if err != nil {
return params.Instance{}, fmt.Errorf("error converting instance: %w", err)
}
s.sendNotify(common.InstanceEntityType, common.UpdateOperation, inst)
if rowsAffected > 0 {
s.sendNotify(common.InstanceEntityType, common.UpdateOperation, inst)
}
return inst, nil
}

View file

@ -92,7 +92,7 @@ func (s *InstancesTestSuite) SetupTest() {
creds := garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), githubEndpoint)
// create an organization for testing purposes
org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin, false)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create org: %s", err))
}
@ -573,10 +573,6 @@ func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() {
WithArgs(instance.ID).
WillReturnRows(sqlmock.NewRows([]string{"message", "instance_id"}).AddRow("instance sample message", instance.ID))
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("UPDATE `instances` SET `updated_at`=? WHERE `instances`.`deleted_at` IS NULL AND `id` = ?")).
WithArgs(sqlmock.AnyArg(), instance.ID).
WillReturnResult(sqlmock.NewResult(1, 1))
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `instance_status_updates`")).
WillReturnError(fmt.Errorf("mocked add status message error"))
@ -605,10 +601,12 @@ func (s *InstancesTestSuite) TestUpdateInstance() {
func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateInstanceErr() {
instance := s.Fixtures.Instances[0]
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE name = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT ?")).
WithArgs(instance.Name, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(instance.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at", "deleted_at", "provider_id", "name", "agent_id", "os_type", "os_arch", "os_name", "os_version", "status", "runner_status", "heartbeat", "callback_url", "metadata_url", "provider_fault", "create_attempt", "token_fetched", "jit_configuration", "git_hub_runner_group", "aditional_labels", "capabilities", "pool_id", "scale_set_fk_id"}).
AddRow(instance.ID, instance.CreatedAt, instance.UpdatedAt, nil, nil, instance.Name, 0, "linux", "amd64", "", "", "running", "idle", instance.Heartbeat, "", "", nil, 0, false, nil, "", nil, nil, nil, nil))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `addresses` WHERE `addresses`.`instance_id` = ? AND `addresses`.`deleted_at` IS NULL")).
WithArgs(instance.ID).
@ -621,7 +619,6 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateInstanceErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instance_status_updates` WHERE `instance_status_updates`.`instance_id` = ? AND `instance_status_updates`.`deleted_at` IS NULL")).
WithArgs(instance.ID).
WillReturnRows(sqlmock.NewRows([]string{"message", "instance_id"}).AddRow("instance sample message", instance.ID))
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectExec(("UPDATE `instances`")).
WillReturnError(fmt.Errorf("mocked update instance error"))
@ -630,17 +627,19 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateInstanceErr() {
_, err := s.StoreSQLMocked.UpdateInstance(s.adminCtx, instance.Name, s.Fixtures.UpdateInstanceParams)
s.Require().NotNil(err)
s.Require().Equal("error updating instance: mocked update instance error", err.Error())
s.Require().Equal("error updating instance: error updating instance: mocked update instance error", err.Error())
s.assertSQLMockExpectations()
}
func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() {
instance := s.Fixtures.Instances[0]
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE name = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT ?")).
WithArgs(instance.Name, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(instance.ID))
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at", "deleted_at", "provider_id", "name", "agent_id", "os_type", "os_arch", "os_name", "os_version", "status", "runner_status", "heartbeat", "callback_url", "metadata_url", "provider_fault", "create_attempt", "token_fetched", "jit_configuration", "git_hub_runner_group", "aditional_labels", "capabilities", "pool_id", "scale_set_fk_id"}).
AddRow(instance.ID, instance.CreatedAt, instance.UpdatedAt, nil, nil, instance.Name, 0, "linux", "amd64", "", "", "running", "idle", instance.Heartbeat, "", "", nil, 0, false, nil, "", nil, nil, nil, nil))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `addresses` WHERE `addresses`.`instance_id` = ? AND `addresses`.`deleted_at` IS NULL")).
WithArgs(instance.ID).
@ -653,18 +652,6 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instance_status_updates` WHERE `instance_status_updates`.`instance_id` = ? AND `instance_status_updates`.`deleted_at` IS NULL")).
WithArgs(instance.ID).
WillReturnRows(sqlmock.NewRows([]string{"message", "instance_id"}).AddRow("instance sample message", instance.ID))
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("UPDATE `instances` SET")).
WillReturnResult(sqlmock.NewResult(1, 1))
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `addresses`")).
WillReturnResult(sqlmock.NewResult(1, 1))
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("INSERT INTO `instance_status_updates`")).
WillReturnResult(sqlmock.NewResult(1, 1))
s.Fixtures.SQLMock.ExpectCommit()
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock.
ExpectExec(regexp.QuoteMeta("UPDATE `instances` SET")).
WillReturnResult(sqlmock.NewResult(1, 1))
@ -676,7 +663,7 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() {
_, err := s.StoreSQLMocked.UpdateInstance(s.adminCtx, instance.Name, s.Fixtures.UpdateInstanceParams)
s.Require().NotNil(err)
s.Require().Equal("error updating addresses: update addresses mock error", err.Error())
s.Require().Equal("error updating instance: error updating instance: update addresses mock error; update addresses mock error", err.Error())
s.assertSQLMockExpectations()
}

View file

@ -100,7 +100,7 @@ func (s *sqlDatabase) paramsJobToWorkflowJob(ctx context.Context, job params.Job
}
if job.RunnerName != "" {
instance, err := s.getInstance(s.ctx, job.RunnerName)
instance, err := s.getInstance(s.ctx, s.conn, job.RunnerName)
if err != nil {
// This usually is very normal as not all jobs run on our runners.
slog.DebugContext(ctx, "failed to get instance by name", "instance_name", job.RunnerName)
@ -282,7 +282,7 @@ func (s *sqlDatabase) CreateOrUpdateJob(ctx context.Context, job params.Job) (pa
}
if job.RunnerName != "" {
instance, err := s.getInstance(ctx, job.RunnerName)
instance, err := s.getInstance(ctx, s.conn, job.RunnerName)
if err == nil {
workflowJob.InstanceID = &instance.ID
} else {

View file

@ -51,9 +51,20 @@ type ControllerInfo struct {
ControllerID uuid.UUID
CallbackURL string
MetadataURL string
// CallbackURL is the URL where userdata scripts call back into, to send status updates
// and installation progress.
CallbackURL string
// MetadataURL is the base URL from which runners can get their installation metadata.
MetadataURL string
// WebhookBaseURL is the base URL used to construct the controller webhook URL.
WebhookBaseURL string
// AgentURL is the websocket enabled URL whenre garm agents connect to.
AgentURL string
// GARMAgentReleasesURL is the URL from which GARM can sync garm-agent binaries. Alternatively
// the user can manually upload binaries.
GARMAgentReleasesURL string
// SyncGARMAgentTools enables or disables automatic sync of garm-agent tools.
SyncGARMAgentTools bool
// MinimumJobAgeBackoff is the minimum time that a job must be in the queue
// before GARM will attempt to allocate a runner to service it. This backoff
// is useful if you have idle runners in various pools that could potentially
@ -104,6 +115,7 @@ type Pool struct {
// any kind of data needed by providers.
ExtraSpecs datatypes.JSON
GitHubRunnerGroup string
EnableShell bool
RepoID *uuid.UUID `gorm:"index"`
Repository Repository `gorm:"foreignKey:RepoID;"`
@ -159,7 +171,8 @@ type ScaleSet struct {
// ExtraSpecs is an opaque json that gets sent to the provider
// as part of the bootstrap params for instances. It can contain
// any kind of data needed by providers.
ExtraSpecs datatypes.JSON
ExtraSpecs datatypes.JSON
EnableShell bool
RepoID *uuid.UUID `gorm:"index"`
Repository Repository `gorm:"foreignKey:RepoID;"`
@ -203,6 +216,7 @@ type Repository struct {
ScaleSets []ScaleSet `gorm:"foreignKey:RepoID"`
Jobs []WorkflowJob `gorm:"foreignKey:RepoID;constraint:OnDelete:SET NULL"`
PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"`
AgentMode bool `gorm:"index:repo_agent_idx"`
EndpointName *string `gorm:"index:idx_owner_nocase,unique,collate:nocase"`
Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"`
@ -235,6 +249,7 @@ type Organization struct {
ScaleSet []ScaleSet `gorm:"foreignKey:OrgID"`
Jobs []WorkflowJob `gorm:"foreignKey:OrgID;constraint:OnDelete:SET NULL"`
PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"`
AgentMode bool `gorm:"index:org_agent_idx"`
EndpointName *string `gorm:"index:idx_org_name_nocase,collate:nocase"`
Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"`
@ -265,6 +280,7 @@ type Enterprise struct {
ScaleSet []ScaleSet `gorm:"foreignKey:EnterpriseID"`
Jobs []WorkflowJob `gorm:"foreignKey:EnterpriseID;constraint:OnDelete:SET NULL"`
PoolBalancerType params.PoolBalancerType `gorm:"type:varchar(64)"`
AgentMode bool `gorm:"index:enterprise_agent_idx"`
EndpointName *string `gorm:"index:idx_ent_name_nocase,collate:nocase"`
Endpoint GithubEndpoint `gorm:"foreignKey:EndpointName;constraint:OnDelete:SET NULL"`
@ -306,6 +322,7 @@ type Instance struct {
Addresses []Address `gorm:"foreignKey:InstanceID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE;"`
Status commonParams.InstanceStatus
RunnerStatus params.RunnerStatus
Heartbeat time.Time
CallbackURL string
MetadataURL string
ProviderFault []byte `gorm:"type:longblob"`
@ -314,6 +331,7 @@ type Instance struct {
JitConfiguration []byte `gorm:"type:longblob"`
GitHubRunnerGroup string
AditionalLabels datatypes.JSON
Capabilities datatypes.JSON
PoolID *uuid.UUID
Pool Pool `gorm:"foreignKey:PoolID"`

View file

@ -29,7 +29,7 @@ import (
"github.com/cloudbase/garm/params"
)
func (s *sqlDatabase) CreateOrganization(ctx context.Context, name string, credentials params.ForgeCredentials, webhookSecret string, poolBalancerType params.PoolBalancerType) (param params.Organization, err error) {
func (s *sqlDatabase) CreateOrganization(ctx context.Context, name string, credentials params.ForgeCredentials, webhookSecret string, poolBalancerType params.PoolBalancerType, agentMode bool) (param params.Organization, err error) {
if webhookSecret == "" {
return params.Organization{}, errors.New("creating org: missing secret")
}
@ -47,6 +47,7 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name string, crede
Name: name,
WebhookSecret: secret,
PoolBalancerType: poolBalancerType,
AgentMode: agentMode,
}
err = s.conn.Transaction(func(tx *gorm.DB) error {
@ -195,6 +196,10 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
org.PoolBalancerType = param.PoolBalancerType
}
if param.AgentMode != nil {
org.AgentMode = *param.AgentMode
}
q := tx.Save(&org)
if q.Error != nil {
return fmt.Errorf("error saving org: %w", q.Error)

View file

@ -114,6 +114,7 @@ func (s *OrgTestSuite) SetupTest() {
s.testCreds,
fmt.Sprintf("test-webhook-secret-%d", i),
params.PoolBalancerTypeRoundRobin,
false,
)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create database object (test-org-%d): %q", i, err))
@ -192,7 +193,9 @@ func (s *OrgTestSuite) TestCreateOrganization() {
s.Fixtures.CreateOrgParams.Name,
s.testCreds,
s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
// assertions
s.Require().Nil(err)
@ -221,7 +224,9 @@ func (s *OrgTestSuite) TestCreateOrgForGitea() {
s.Fixtures.CreateOrgParams.Name,
s.testCredsGitea,
s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
// assertions
s.Require().Nil(err)
@ -256,7 +261,9 @@ func (s *OrgTestSuite) TestCreateOrganizationInvalidForgeType() {
s.Fixtures.CreateOrgParams.Name,
credentials,
s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NotNil(err)
s.Require().Equal("error creating org: unsupported credentials type: invalid request", err.Error())
}
@ -279,7 +286,9 @@ func (s *OrgTestSuite) TestCreateOrganizationInvalidDBPassphrase() {
s.Fixtures.CreateOrgParams.Name,
s.testCreds,
s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NotNil(err)
s.Require().Equal("error encoding secret: invalid passphrase length (expected length 32 characters)", err.Error())
@ -297,7 +306,8 @@ func (s *OrgTestSuite) TestCreateOrganizationDBCreateErr() {
s.Fixtures.CreateOrgParams.Name,
s.testCreds,
s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false)
s.Require().NotNil(err)
s.Require().Equal("error creating org: error creating org: creating org mock error", err.Error())
@ -353,6 +363,7 @@ func (s *OrgTestSuite) TestListOrganizationsWithFilters() {
s.testCreds,
"super secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -362,6 +373,7 @@ func (s *OrgTestSuite) TestListOrganizationsWithFilters() {
s.testCredsGitea,
"super secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -371,6 +383,7 @@ func (s *OrgTestSuite) TestListOrganizationsWithFilters() {
s.testCreds,
"super secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
orgs, err := s.Store.ListOrganizations(
@ -899,7 +912,8 @@ func (s *OrgTestSuite) TestAddOrgEntityEvent() {
s.Fixtures.CreateOrgParams.Name,
s.testCreds,
s.Fixtures.CreateOrgParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false)
s.Require().Nil(err)
entity, err := org.GetEntity()

View file

@ -293,6 +293,7 @@ func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.ForgeE
GitHubRunnerGroup: param.GitHubRunnerGroup,
Priority: param.Priority,
TemplateID: param.TemplateID,
EnableShell: param.EnableShell,
}
if len(param.ExtraSpecs) > 0 {
newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs)
@ -316,13 +317,13 @@ func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.ForgeE
return fmt.Errorf("error checking entity existence: %w", err)
}
tags := []Tag{}
var tags []*Tag
for _, val := range param.Tags {
t, err := s.getOrCreateTag(tx, val)
if err != nil {
return fmt.Errorf("error creating tag: %w", err)
}
tags = append(tags, t)
tags = append(tags, &t)
}
q := tx.Create(&newPool)
@ -330,8 +331,9 @@ func (s *sqlDatabase) CreateEntityPool(ctx context.Context, entity params.ForgeE
return fmt.Errorf("error creating pool: %w", q.Error)
}
for i := range tags {
if err := tx.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil {
// Append all tags at once instead of one by one for better performance
if len(tags) > 0 {
if err := tx.Model(&newPool).Association("Tags").Append(tags); err != nil {
return fmt.Errorf("error associating tags: %w", err)
}
}

View file

@ -81,7 +81,7 @@ func (s *PoolsTestSuite) SetupTest() {
creds := garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), githubEndpoint)
// create an organization for testing purposes
org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
org, err := s.Store.CreateOrganization(s.adminCtx, "test-org", creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin, false)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create org: %s", err))
}
@ -211,7 +211,7 @@ func (s *PoolsTestSuite) TestEntityPoolOperations() {
ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.Store, s.T())
creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.Store, s.T(), ep)
s.T().Cleanup(func() { s.Store.DeleteGithubCredentials(s.ctx, creds.ID) })
repo, err := s.Store.CreateRepository(s.ctx, "test-owner", "test-repo", creds, "test-secret", params.PoolBalancerTypeRoundRobin)
repo, err := s.Store.CreateRepository(s.ctx, "test-owner", "test-repo", creds, "test-secret", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotEmpty(repo.ID)
s.T().Cleanup(func() { s.Store.DeleteRepository(s.ctx, repo.ID) })
@ -261,7 +261,7 @@ func (s *PoolsTestSuite) TestEntityPoolOperations() {
s.Require().Equal(*updatePoolParams.Enabled, pool.Enabled)
s.Require().Equal(updatePoolParams.Flavor, pool.Flavor)
s.Require().Equal(updatePoolParams.Image, pool.Image)
s.Require().Equal(updatePoolParams.RunnerPrefix.Prefix, pool.RunnerPrefix.Prefix)
s.Require().Equal(updatePoolParams.Prefix, pool.Prefix)
s.Require().Equal(*updatePoolParams.MaxRunners, pool.MaxRunners)
s.Require().Equal(*updatePoolParams.MinIdleRunners, pool.MinIdleRunners)
s.Require().Equal(updatePoolParams.OSType, pool.OSType)
@ -292,7 +292,7 @@ func (s *PoolsTestSuite) TestListEntityInstances() {
ep := garmTesting.CreateDefaultGithubEndpoint(s.ctx, s.Store, s.T())
creds := garmTesting.CreateTestGithubCredentials(s.ctx, "test-creds", s.Store, s.T(), ep)
s.T().Cleanup(func() { s.Store.DeleteGithubCredentials(s.ctx, creds.ID) })
repo, err := s.Store.CreateRepository(s.ctx, "test-owner", "test-repo", creds, "test-secret", params.PoolBalancerTypeRoundRobin)
repo, err := s.Store.CreateRepository(s.ctx, "test-owner", "test-repo", creds, "test-secret", params.PoolBalancerTypeRoundRobin, false)
s.Require().NoError(err)
s.Require().NotEmpty(repo.ID)
s.T().Cleanup(func() { s.Store.DeleteRepository(s.ctx, repo.ID) })

View file

@ -29,7 +29,7 @@ import (
"github.com/cloudbase/garm/params"
)
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name string, credentials params.ForgeCredentials, webhookSecret string, poolBalancerType params.PoolBalancerType) (param params.Repository, err error) {
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name string, credentials params.ForgeCredentials, webhookSecret string, poolBalancerType params.PoolBalancerType, agentMode bool) (param params.Repository, err error) {
defer func() {
if err == nil {
s.sendNotify(common.RepositoryEntityType, common.CreateOperation, param)
@ -49,6 +49,7 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name string,
Owner: owner,
WebhookSecret: secret,
PoolBalancerType: poolBalancerType,
AgentMode: agentMode,
}
err = s.conn.Transaction(func(tx *gorm.DB) error {
switch credentials.ForgeType {
@ -196,6 +197,9 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
if param.PoolBalancerType != "" {
repo.PoolBalancerType = param.PoolBalancerType
}
if param.AgentMode != nil {
repo.AgentMode = *param.AgentMode
}
q := tx.Save(&repo)
if q.Error != nil {

View file

@ -126,6 +126,7 @@ func (s *RepoTestSuite) SetupTest() {
s.testCreds,
fmt.Sprintf("test-webhook-secret-%d", i),
params.PoolBalancerTypeRoundRobin,
false,
)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create database object (test-repo-%d): %v", i, err))
@ -211,6 +212,7 @@ func (s *RepoTestSuite) TestCreateRepository() {
s.testCreds,
s.Fixtures.CreateRepoParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin,
false,
)
// assertions
@ -243,6 +245,7 @@ func (s *RepoTestSuite) TestCreateRepositoryGitea() {
s.testCredsGitea,
s.Fixtures.CreateRepoParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin,
false,
)
// assertions
@ -281,6 +284,7 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidForgeType() {
},
s.Fixtures.CreateRepoParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NotNil(err)
@ -307,6 +311,7 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBPassphrase() {
s.testCreds,
s.Fixtures.CreateRepoParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NotNil(err)
@ -327,6 +332,7 @@ func (s *RepoTestSuite) TestCreateRepositoryInvalidDBCreateErr() {
s.testCreds,
s.Fixtures.CreateRepoParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NotNil(err)
@ -390,6 +396,7 @@ func (s *RepoTestSuite) TestListRepositoriesWithFilters() {
s.testCreds,
"super secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -400,6 +407,7 @@ func (s *RepoTestSuite) TestListRepositoriesWithFilters() {
s.testCredsGitea,
"super secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -410,6 +418,7 @@ func (s *RepoTestSuite) TestListRepositoriesWithFilters() {
s.testCreds,
"super secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -420,6 +429,7 @@ func (s *RepoTestSuite) TestListRepositoriesWithFilters() {
s.testCreds,
"super secret",
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().NoError(err)
@ -986,7 +996,9 @@ func (s *RepoTestSuite) TestAddRepoEntityEvent() {
s.Fixtures.CreateRepoParams.Name,
s.testCreds,
s.Fixtures.CreateRepoParams.WebhookSecret,
params.PoolBalancerTypeRoundRobin)
params.PoolBalancerTypeRoundRobin,
false,
)
s.Require().Nil(err)
entity, err := repo.GetEntity()

View file

@ -84,6 +84,7 @@ func (s *sqlDatabase) CreateEntityScaleSet(ctx context.Context, entity params.Fo
GitHubRunnerGroup: param.GitHubRunnerGroup,
State: params.ScaleSetPendingCreate,
TemplateID: param.TemplateID,
EnableShell: param.EnableShell,
}
if len(param.ExtraSpecs) > 0 {
@ -303,6 +304,10 @@ func (s *sqlDatabase) updateScaleSet(tx *gorm.DB, scaleSet ScaleSet, param param
scaleSet.TemplateID = param.TemplateID
}
if param.EnableShell != nil {
scaleSet.EnableShell = *param.EnableShell
}
if param.Name != "" {
scaleSet.Name = param.Name
}

View file

@ -62,17 +62,17 @@ func (s *ScaleSetsTestSuite) SetupTest() {
s.creds = garmTesting.CreateTestGithubCredentials(adminCtx, "new-creds", db, s.T(), githubEndpoint)
// create an organization for testing purposes
s.org, err = s.Store.CreateOrganization(s.adminCtx, "test-org", s.creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
s.org, err = s.Store.CreateOrganization(s.adminCtx, "test-org", s.creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin, false)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create org: %s", err))
}
s.repo, err = s.Store.CreateRepository(s.adminCtx, "test-org", "test-repo", s.creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
s.repo, err = s.Store.CreateRepository(s.adminCtx, "test-org", "test-repo", s.creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin, false)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create repo: %s", err))
}
s.enterprise, err = s.Store.CreateEnterprise(s.adminCtx, "test-enterprise", s.creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin)
s.enterprise, err = s.Store.CreateEnterprise(s.adminCtx, "test-enterprise", s.creds, "test-webhookSecret", params.PoolBalancerTypeRoundRobin, false)
if err != nil {
s.FailNow(fmt.Sprintf("failed to create enterprise: %s", err))
}
@ -131,8 +131,8 @@ func (s *ScaleSetsTestSuite) callback(old, newSet params.ScaleSet) error {
s.Require().Equal(newSet.Flavor, "new-test-flavor")
s.Require().Equal(old.GitHubRunnerGroup, "test-group")
s.Require().Equal(newSet.GitHubRunnerGroup, "new-test-group")
s.Require().Equal(old.RunnerPrefix.Prefix, "garm")
s.Require().Equal(newSet.RunnerPrefix.Prefix, "test-prefix2")
s.Require().Equal(old.Prefix, "garm")
s.Require().Equal(newSet.Prefix, "test-prefix2")
s.Require().Equal(old.Enabled, false)
s.Require().Equal(newSet.Enabled, true)
return nil

View file

@ -21,6 +21,7 @@ import (
"fmt"
"log/slog"
"net/url"
"regexp"
"strings"
"gorm.io/driver/mysql"
@ -470,6 +471,11 @@ func (s *sqlDatabase) ensureTemplates(migrateTemplates bool) error {
return fmt.Errorf("failed to get linux template for gitea: %w", err)
}
giteaWindowsData, err := templates.GetTemplateContent(commonParams.Windows, params.GiteaEndpointType)
if err != nil {
return fmt.Errorf("failed to get windows template for gitea: %w", err)
}
adminCtx := auth.GetAdminContext(s.ctx)
githubWindowsParams := params.CreateTemplateParams{
@ -478,8 +484,9 @@ func (s *sqlDatabase) ensureTemplates(migrateTemplates bool) error {
OSType: commonParams.Windows,
ForgeType: params.GithubEndpointType,
Data: githubWindowsData,
IsSystem: true,
}
githubWindowsSystemTemplate, err := s.createSystemTemplate(adminCtx, githubWindowsParams)
githubWindowsSystemTemplate, err := s.CreateTemplate(adminCtx, githubWindowsParams)
if err != nil {
return fmt.Errorf("failed to create github windows template: %w", err)
}
@ -490,8 +497,9 @@ func (s *sqlDatabase) ensureTemplates(migrateTemplates bool) error {
OSType: commonParams.Linux,
ForgeType: params.GithubEndpointType,
Data: githubLinuxData,
IsSystem: true,
}
githubLinuxSystemTemplate, err := s.createSystemTemplate(adminCtx, githubLinuxParams)
githubLinuxSystemTemplate, err := s.CreateTemplate(adminCtx, githubLinuxParams)
if err != nil {
return fmt.Errorf("failed to create github linux template: %w", err)
}
@ -502,12 +510,26 @@ func (s *sqlDatabase) ensureTemplates(migrateTemplates bool) error {
OSType: commonParams.Linux,
ForgeType: params.GiteaEndpointType,
Data: giteaLinuxData,
IsSystem: true,
}
giteaLinuxSystemTemplate, err := s.createSystemTemplate(adminCtx, giteaLinuxParams)
giteaLinuxSystemTemplate, err := s.CreateTemplate(adminCtx, giteaLinuxParams)
if err != nil {
return fmt.Errorf("failed to create gitea linux template: %w", err)
}
giteaWindowsParams := params.CreateTemplateParams{
Name: "gitea_windows",
Description: "Default Windows runner install template for Gitea",
OSType: commonParams.Windows,
ForgeType: params.GiteaEndpointType,
Data: giteaWindowsData,
IsSystem: true,
}
giteaWindowsSystemTemplate, err := s.CreateTemplate(adminCtx, giteaWindowsParams)
if err != nil {
return fmt.Errorf("failed to create gitea windows template: %w", err)
}
getTplID := func(forgeType params.EndpointType, osType commonParams.OSType) uint {
var templateID uint
switch forgeType {
@ -515,6 +537,8 @@ func (s *sqlDatabase) ensureTemplates(migrateTemplates bool) error {
switch osType {
case commonParams.Linux:
templateID = giteaLinuxSystemTemplate.ID
case commonParams.Windows:
templateID = giteaWindowsSystemTemplate.ID
default:
return 0
}
@ -582,64 +606,142 @@ func (s *sqlDatabase) ensureTemplates(migrateTemplates bool) error {
return nil
}
// dropIndexIfExists drops an index if it exists
func (s *sqlDatabase) dropIndexIfExists(model interface{}, indexName string) {
if s.conn.Migrator().HasIndex(model, indexName) {
if err := s.conn.Migrator().DropIndex(model, indexName); err != nil {
slog.With(slog.Any("error", err)).
Error(fmt.Sprintf("failed to drop index %s", indexName))
}
}
}
// migratePoolNullIDs updates pools to set null IDs instead of zero UUIDs
func (s *sqlDatabase) migratePoolNullIDs() error {
if !s.conn.Migrator().HasTable(&Pool{}) {
return nil
}
zeroUUID := "00000000-0000-0000-0000-000000000000"
updates := []struct {
column string
query string
}{
{"repo_id", fmt.Sprintf("update pools set repo_id=NULL where repo_id='%s'", zeroUUID)},
{"org_id", fmt.Sprintf("update pools set org_id=NULL where org_id='%s'", zeroUUID)},
{"enterprise_id", fmt.Sprintf("update pools set enterprise_id=NULL where enterprise_id='%s'", zeroUUID)},
}
for _, update := range updates {
if err := s.conn.Exec(update.query).Error; err != nil {
return fmt.Errorf("error updating pools %s: %w", update.column, err)
}
}
return nil
}
// migrateGithubEndpointType adds and initializes endpoint_type column
func (s *sqlDatabase) migrateGithubEndpointType() error {
if !s.conn.Migrator().HasTable(&GithubEndpoint{}) {
return nil
}
if s.conn.Migrator().HasColumn(&GithubEndpoint{}, "endpoint_type") {
return nil
}
if err := s.conn.Migrator().AutoMigrate(&GithubEndpoint{}); err != nil {
return fmt.Errorf("error migrating github endpoints: %w", err)
}
if err := s.conn.Exec("update github_endpoints set endpoint_type = 'github' where endpoint_type is null").Error; err != nil {
return fmt.Errorf("error updating github endpoints: %w", err)
}
return nil
}
// migrateControllerInfo updates controller info with new fields
func (s *sqlDatabase) migrateControllerInfo(hasMinAgeField, hasAgentURL bool) error {
if hasMinAgeField && hasAgentURL {
return nil
}
var controller ControllerInfo
if err := s.conn.First(&controller).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return fmt.Errorf("error fetching controller info: %w", err)
}
if !hasMinAgeField {
controller.MinimumJobAgeBackoff = 30
}
if controller.GARMAgentReleasesURL == "" {
controller.GARMAgentReleasesURL = appdefaults.GARMAgentDefaultReleasesURL
}
if !hasAgentURL && controller.WebhookBaseURL != "" {
matchWebhooksPath := regexp.MustCompile(`/webhooks(/)?$`)
controller.AgentURL = matchWebhooksPath.ReplaceAllLiteralString(controller.WebhookBaseURL, `/agent`)
}
if err := s.conn.Save(&controller).Error; err != nil {
return fmt.Errorf("error updating controller info: %w", err)
}
return nil
}
// preMigrationChecks performs checks before running migrations
func (s *sqlDatabase) preMigrationChecks() (needsCredentialMigration, migrateTemplates, hasMinAgeField, hasAgentURL bool) {
// Check if credentials need migration
needsCredentialMigration = !s.conn.Migrator().HasTable(&GithubCredentials{}) ||
!s.conn.Migrator().HasTable(&GithubEndpoint{})
// Check if templates need migration
migrateTemplates = !s.conn.Migrator().HasTable(&Template{})
// Check for controller info fields
if s.conn.Migrator().HasTable(&ControllerInfo{}) {
hasMinAgeField = s.conn.Migrator().HasColumn(&ControllerInfo{}, "minimum_job_age_backoff")
hasAgentURL = s.conn.Migrator().HasColumn(&ControllerInfo{}, "agent_url")
}
return
}
func (s *sqlDatabase) migrateDB() error {
if s.conn.Migrator().HasIndex(&Organization{}, "idx_organizations_name") {
if err := s.conn.Migrator().DropIndex(&Organization{}, "idx_organizations_name"); err != nil {
slog.With(slog.Any("error", err)).Error("failed to drop index idx_organizations_name")
}
}
if s.conn.Migrator().HasIndex(&Repository{}, "idx_owner") {
if err := s.conn.Migrator().DropIndex(&Repository{}, "idx_owner"); err != nil {
slog.With(slog.Any("error", err)).Error("failed to drop index idx_owner")
}
}
// Drop obsolete indexes
s.dropIndexIfExists(&Organization{}, "idx_organizations_name")
s.dropIndexIfExists(&Repository{}, "idx_owner")
// Run cascade migration
if err := s.cascadeMigration(); err != nil {
return fmt.Errorf("error running cascade migration: %w", err)
}
if s.conn.Migrator().HasTable(&Pool{}) {
if err := s.conn.Exec("update pools set repo_id=NULL where repo_id='00000000-0000-0000-0000-000000000000'").Error; err != nil {
return fmt.Errorf("error updating pools %w", err)
}
if err := s.conn.Exec("update pools set org_id=NULL where org_id='00000000-0000-0000-0000-000000000000'").Error; err != nil {
return fmt.Errorf("error updating pools: %w", err)
}
if err := s.conn.Exec("update pools set enterprise_id=NULL where enterprise_id='00000000-0000-0000-0000-000000000000'").Error; err != nil {
return fmt.Errorf("error updating pools: %w", err)
}
// Migrate pool null IDs
if err := s.migratePoolNullIDs(); err != nil {
return err
}
// Migrate workflows
if err := s.migrateWorkflow(); err != nil {
return fmt.Errorf("error migrating workflows: %w", err)
}
if s.conn.Migrator().HasTable(&GithubEndpoint{}) {
if !s.conn.Migrator().HasColumn(&GithubEndpoint{}, "endpoint_type") {
if err := s.conn.Migrator().AutoMigrate(&GithubEndpoint{}); err != nil {
return fmt.Errorf("error migrating github endpoints: %w", err)
}
if err := s.conn.Exec("update github_endpoints set endpoint_type = 'github' where endpoint_type is null").Error; err != nil {
return fmt.Errorf("error updating github endpoints: %w", err)
}
}
// Migrate GitHub endpoint type
if err := s.migrateGithubEndpointType(); err != nil {
return err
}
var needsCredentialMigration bool
if !s.conn.Migrator().HasTable(&GithubCredentials{}) || !s.conn.Migrator().HasTable(&GithubEndpoint{}) {
needsCredentialMigration = true
}
var hasMinAgeField bool
if s.conn.Migrator().HasTable(&ControllerInfo{}) && s.conn.Migrator().HasColumn(&ControllerInfo{}, "minimum_job_age_backoff") {
hasMinAgeField = true
}
migrateTemplates := !s.conn.Migrator().HasTable(&Template{})
// Check if we need to migrate credentials and templates
needsCredentialMigration, migrateTemplates, hasMinAgeField, hasAgentURL := s.preMigrationChecks()
// Run main schema migration
s.conn.Exec("PRAGMA foreign_keys = OFF")
if err := s.conn.AutoMigrate(
&User{},
@ -672,30 +774,24 @@ func (s *sqlDatabase) migrateDB() error {
s.conn.Exec("PRAGMA foreign_keys = ON")
if !hasMinAgeField {
var controller ControllerInfo
if err := s.conn.First(&controller).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("error updating controller info: %w", err)
}
} else {
controller.MinimumJobAgeBackoff = 30
if err := s.conn.Save(&controller).Error; err != nil {
return fmt.Errorf("error updating controller info: %w", err)
}
}
// Migrate controller info if needed
if err := s.migrateControllerInfo(hasMinAgeField, hasAgentURL); err != nil {
return err
}
// Ensure github endpoint exists
if err := s.ensureGithubEndpoint(); err != nil {
return fmt.Errorf("error ensuring github endpoint: %w", err)
}
// Migrate credentials if needed
if needsCredentialMigration {
if err := s.migrateCredentialsToDB(); err != nil {
return fmt.Errorf("error migrating credentials: %w", err)
}
}
// Ensure templates exist
if err := s.ensureTemplates(migrateTemplates); err != nil {
return fmt.Errorf("failed to create default templates: %w", err)
}

View file

@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"github.com/google/uuid"
"gorm.io/gorm"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
@ -142,70 +143,40 @@ func (s *sqlDatabase) GetTemplateByName(ctx context.Context, name string) (param
return ret, nil
}
func (s *sqlDatabase) createSystemTemplate(ctx context.Context, param params.CreateTemplateParams) (template params.Template, err error) {
if !auth.IsAdmin(ctx) {
func (s *sqlDatabase) CreateTemplate(ctx context.Context, param params.CreateTemplateParams) (template params.Template, err error) {
if param.IsSystem && !auth.IsAdmin(ctx) {
return params.Template{}, runnerErrors.ErrUnauthorized
}
defer func() {
if err == nil {
s.sendNotify(common.TemplateEntityType, common.CreateOperation, template)
var userID *uuid.UUID
if !param.IsSystem {
parsedID, err := getUIDFromContext(ctx)
if err != nil {
return params.Template{}, fmt.Errorf("error creating template: %w", err)
}
}()
sealed, err := s.marshalAndSeal(param.Data)
if err != nil {
return params.Template{}, fmt.Errorf("failed to seal data: %w", err)
}
tpl := Template{
UserID: nil,
Name: param.Name,
Description: param.Description,
OSType: param.OSType,
Data: sealed,
ForgeType: param.ForgeType,
}
if err := s.conn.Create(&tpl).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return params.Template{}, runnerErrors.NewConflictError("a template name already exists with the specified name")
}
return params.Template{}, fmt.Errorf("error creating template: %w", err)
}
template, err = s.sqlToParamTemplate(tpl)
if err != nil {
return params.Template{}, fmt.Errorf("failed to convert template: %w", err)
}
return template, nil
}
func (s *sqlDatabase) CreateTemplate(ctx context.Context, param params.CreateTemplateParams) (template params.Template, err error) {
userID, err := getUIDFromContext(ctx)
if err != nil {
return params.Template{}, fmt.Errorf("error creating template: %w", err)
userID = &parsedID
}
defer func() {
if err == nil {
s.sendNotify(common.TemplateEntityType, common.CreateOperation, template)
}
}()
sealed, err := s.marshalAndSeal(param.Data)
if err != nil {
return params.Template{}, fmt.Errorf("failed to seal data: %w", err)
}
tpl := Template{
UserID: &userID,
Name: param.Name,
Description: param.Description,
OSType: param.OSType,
Data: sealed,
ForgeType: param.ForgeType,
}
if err := param.Validate(); err != nil {
return params.Template{}, fmt.Errorf("failed to validate create params: %w", err)
}
sealed, err := s.marshalAndSeal(param.Data)
if err != nil {
return params.Template{}, fmt.Errorf("failed to seal data: %w", err)
}
tpl := Template{
UserID: userID,
Name: param.Name,
Description: param.Description,
OSType: param.OSType,
Data: sealed,
ForgeType: param.ForgeType,
}
if err := s.conn.Create(&tpl).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return params.Template{}, runnerErrors.NewConflictError("a template name already exists with the specified name")

View file

@ -192,7 +192,7 @@ func (s *TemplatesTestSuite) TestListTemplatesWithForgeTypeFilter() {
}
func (s *TemplatesTestSuite) TestListTemplatesWithNameFilter() {
partialName := "system"
partialName := params.SystemUser
templates, err := s.Store.ListTemplates(s.adminCtx, nil, nil, &partialName)
s.Require().Nil(err)
s.Require().Len(templates, 1)
@ -201,7 +201,7 @@ func (s *TemplatesTestSuite) TestListTemplatesWithNameFilter() {
func (s *TemplatesTestSuite) TestListTemplatesDBFetchErr() {
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT `templates`.`id`,`templates`.`created_at`,`templates`.`updated_at`,`templates`.`deleted_at`,`templates`.`name`,`templates`.`user_id`,`templates`.`description`,`templates`.`os_type`,`templates`.`forge_type` FROM `templates` WHERE `templates`.`deleted_at` IS NULL")).
ExpectQuery(regexp.QuoteMeta("SELECT `templates`.`id`,`templates`.`created_at`,`templates`.`updated_at`,`templates`.`deleted_at`,`templates`.`name`,`templates`.`user_id`,`templates`.`description`,`templates`.`os_type`,`templates`.`forge_type`,`templates`.`agent_mode` FROM `templates` WHERE `templates`.`deleted_at` IS NULL")).
WillReturnError(fmt.Errorf("mocked fetching templates error"))
_, err := s.StoreSQLMocked.ListTemplates(s.adminCtx, nil, nil, nil)
@ -320,12 +320,13 @@ func (s *TemplatesTestSuite) TestCreateTemplateSystemAndUserConflict() {
// Now try to create a system template with the same name using direct access to createSystemTemplate
// This should succeed since the unique constraint is on (name, user_id) and system templates have user_id = NULL
sqlDB := s.Store.(*sqlDatabase)
_, err = sqlDB.createSystemTemplate(s.adminCtx, params.CreateTemplateParams{
_, err = sqlDB.CreateTemplate(s.adminCtx, params.CreateTemplateParams{
Name: templateName,
Description: "System template with same name",
OSType: commonParams.Windows,
ForgeType: params.GithubEndpointType,
Data: []byte(`{"provider": "azure", "image": "windows-2022"}`),
IsSystem: true,
})
// This should succeed because system templates (user_id = NULL) and user templates

View file

@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"github.com/google/uuid"
"gorm.io/datatypes"
@ -73,6 +74,16 @@ func (s *sqlDatabase) sqlToParamsInstance(instance Instance) (params.Instance, e
JitConfiguration: jitConfig,
GitHubRunnerGroup: instance.GitHubRunnerGroup,
AditionalLabels: labels,
Heartbeat: instance.Heartbeat,
}
if len(instance.Capabilities) > 0 {
var caps params.AgentCapabilities
if err := json.Unmarshal(instance.Capabilities, &caps); err == nil {
ret.Capabilities = caps
} else {
slog.ErrorContext(s.ctx, "failed to unmarshal capabilities", "instance_name", instance.Name, "error", err)
}
}
if instance.ScaleSetFkID != nil {
@ -150,6 +161,7 @@ func (s *sqlDatabase) sqlToCommonOrganization(org Organization, detailed bool) (
Endpoint: endpoint,
CreatedAt: org.CreatedAt,
UpdatedAt: org.UpdatedAt,
AgentMode: org.AgentMode,
}
var forgeCreds params.ForgeCredentials
@ -222,6 +234,7 @@ func (s *sqlDatabase) sqlToCommonEnterprise(enterprise Enterprise, detailed bool
CreatedAt: enterprise.CreatedAt,
UpdatedAt: enterprise.UpdatedAt,
Endpoint: endpoint,
AgentMode: enterprise.AgentMode,
}
if enterprise.CredentialsID != nil {
@ -285,6 +298,7 @@ func (s *sqlDatabase) sqlToCommonPool(pool Pool) (params.Pool, error) {
Priority: pool.Priority,
CreatedAt: pool.CreatedAt,
UpdatedAt: pool.UpdatedAt,
EnableShell: pool.EnableShell,
}
if pool.TemplateID != nil && *pool.TemplateID != 0 {
@ -361,6 +375,7 @@ func (s *sqlDatabase) sqlToCommonScaleSet(scaleSet ScaleSet) (params.ScaleSet, e
ExtendedState: scaleSet.ExtendedState,
LastMessageID: scaleSet.LastMessageID,
DesiredRunnerCount: scaleSet.DesiredRunnerCount,
EnableShell: scaleSet.EnableShell,
}
if scaleSet.TemplateID != nil && *scaleSet.TemplateID != 0 {
@ -435,6 +450,7 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository, detailed bool) (par
CreatedAt: repo.CreatedAt,
UpdatedAt: repo.UpdatedAt,
Endpoint: endpoint,
AgentMode: repo.AgentMode,
}
if repo.CredentialsID != nil && repo.GiteaCredentialsID != nil {
@ -531,6 +547,10 @@ func (s *sqlDatabase) updatePool(tx *gorm.DB, pool Pool, param params.UpdatePool
pool.Flavor = param.Flavor
}
if param.EnableShell != nil {
pool.EnableShell = *param.EnableShell
}
if param.Image != "" {
pool.Image = param.Image
}
@ -737,23 +757,35 @@ func (s *sqlDatabase) addRepositoryEvent(ctx context.Context, repoID string, eve
Message: statusMessage,
EventType: event,
EventLevel: eventLevel,
RepoID: repo.ID,
}
if err := s.conn.Model(&repo).Association("Events").Append(&msg); err != nil {
// Use Create instead of Association.Append to avoid loading all existing events
if err := s.conn.Create(&msg).Error; err != nil {
return fmt.Errorf("error adding status message: %w", err)
}
if maxEvents > 0 {
var latestEvents []RepositoryEvent
q := s.conn.Model(&RepositoryEvent{}).
Limit(maxEvents).Order("id desc").
Where("repo_id = ?", repo.ID).Find(&latestEvents)
if q.Error != nil {
return fmt.Errorf("error fetching latest events: %w", q.Error)
var count int64
if err := s.conn.Model(&RepositoryEvent{}).Where("repo_id = ?", repo.ID).Count(&count).Error; err != nil {
return fmt.Errorf("error counting events: %w", err)
}
if len(latestEvents) == maxEvents {
lastInList := latestEvents[len(latestEvents)-1]
if err := s.conn.Where("repo_id = ? and id < ?", repo.ID, lastInList.ID).Unscoped().Delete(&RepositoryEvent{}).Error; err != nil {
if count > int64(maxEvents) {
// Get the ID of the Nth most recent event
var cutoffEvent RepositoryEvent
if err := s.conn.Model(&RepositoryEvent{}).
Select("id").
Where("repo_id = ?", repo.ID).
Order("id desc").
Offset(maxEvents - 1).
Limit(1).
First(&cutoffEvent).Error; err != nil {
return fmt.Errorf("error finding cutoff event: %w", err)
}
// Delete all events older than the cutoff
if err := s.conn.Where("repo_id = ? and id < ?", repo.ID, cutoffEvent.ID).Unscoped().Delete(&RepositoryEvent{}).Error; err != nil {
return fmt.Errorf("error deleting old events: %w", err)
}
}
@ -771,23 +803,35 @@ func (s *sqlDatabase) addOrgEvent(ctx context.Context, orgID string, event param
Message: statusMessage,
EventType: event,
EventLevel: eventLevel,
OrgID: org.ID,
}
if err := s.conn.Model(&org).Association("Events").Append(&msg); err != nil {
// Use Create instead of Association.Append to avoid loading all existing events
if err := s.conn.Create(&msg).Error; err != nil {
return fmt.Errorf("error adding status message: %w", err)
}
if maxEvents > 0 {
var latestEvents []OrganizationEvent
q := s.conn.Model(&OrganizationEvent{}).
Limit(maxEvents).Order("id desc").
Where("org_id = ?", org.ID).Find(&latestEvents)
if q.Error != nil {
return fmt.Errorf("error fetching latest events: %w", q.Error)
var count int64
if err := s.conn.Model(&OrganizationEvent{}).Where("org_id = ?", org.ID).Count(&count).Error; err != nil {
return fmt.Errorf("error counting events: %w", err)
}
if len(latestEvents) == maxEvents {
lastInList := latestEvents[len(latestEvents)-1]
if err := s.conn.Where("org_id = ? and id < ?", org.ID, lastInList.ID).Unscoped().Delete(&OrganizationEvent{}).Error; err != nil {
if count > int64(maxEvents) {
// Get the ID of the Nth most recent event
var cutoffEvent OrganizationEvent
if err := s.conn.Model(&OrganizationEvent{}).
Select("id").
Where("org_id = ?", org.ID).
Order("id desc").
Offset(maxEvents - 1).
Limit(1).
First(&cutoffEvent).Error; err != nil {
return fmt.Errorf("error finding cutoff event: %w", err)
}
// Delete all events older than the cutoff
if err := s.conn.Where("org_id = ? and id < ?", org.ID, cutoffEvent.ID).Unscoped().Delete(&OrganizationEvent{}).Error; err != nil {
return fmt.Errorf("error deleting old events: %w", err)
}
}
@ -802,26 +846,38 @@ func (s *sqlDatabase) addEnterpriseEvent(ctx context.Context, entID string, even
}
msg := EnterpriseEvent{
Message: statusMessage,
EventType: event,
EventLevel: eventLevel,
Message: statusMessage,
EventType: event,
EventLevel: eventLevel,
EnterpriseID: ent.ID,
}
if err := s.conn.Model(&ent).Association("Events").Append(&msg); err != nil {
// Use Create instead of Association.Append to avoid loading all existing events
if err := s.conn.Create(&msg).Error; err != nil {
return fmt.Errorf("error adding status message: %w", err)
}
if maxEvents > 0 {
var latestEvents []EnterpriseEvent
q := s.conn.Model(&EnterpriseEvent{}).
Limit(maxEvents).Order("id desc").
Where("enterprise_id = ?", ent.ID).Find(&latestEvents)
if q.Error != nil {
return fmt.Errorf("error fetching latest events: %w", q.Error)
var count int64
if err := s.conn.Model(&EnterpriseEvent{}).Where("enterprise_id = ?", ent.ID).Count(&count).Error; err != nil {
return fmt.Errorf("error counting events: %w", err)
}
if len(latestEvents) == maxEvents {
lastInList := latestEvents[len(latestEvents)-1]
if err := s.conn.Where("enterprise_id = ? and id < ?", ent.ID, lastInList.ID).Unscoped().Delete(&EnterpriseEvent{}).Error; err != nil {
if count > int64(maxEvents) {
// Get the ID of the Nth most recent event
var cutoffEvent EnterpriseEvent
if err := s.conn.Model(&EnterpriseEvent{}).
Select("id").
Where("enterprise_id = ?", ent.ID).
Order("id desc").
Offset(maxEvents - 1).
Limit(1).
First(&cutoffEvent).Error; err != nil {
return fmt.Errorf("error finding cutoff event: %w", err)
}
// Delete all events older than the cutoff
if err := s.conn.Where("enterprise_id = ? and id < ?", ent.ID, cutoffEvent.ID).Unscoped().Delete(&EnterpriseEvent{}).Error; err != nil {
return fmt.Errorf("error deleting old events: %w", err)
}
}
@ -996,7 +1052,7 @@ func (s *sqlDatabase) sqlToParamTemplate(template Template) (params.Template, er
}
}
owner := "system"
owner := params.SystemUser
if template.UserID != nil {
owner = template.User.Username
}