diff --git a/database/sql/models.go b/database/sql/models.go index bf790941..051b5a3b 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -50,7 +50,7 @@ type Tag struct { Base Name string `gorm:"type:varchar(64);uniqueIndex"` - Pools []*Pool `gorm:"many2many:pool_tags;constraint:OnDelete:CASCADE"` + Pools []*Pool `gorm:"many2many:pool_tags;constraint:OnDelete:CASCADE,OnUpdate:CASCADE;"` } type Pool struct { @@ -65,7 +65,7 @@ type Pool struct { Flavor string `gorm:"index:idx_pool_type"` OSType params.OSType OSArch params.OSArch - Tags []*Tag `gorm:"many2many:pool_tags;constraint:OnDelete:CASCADE"` + Tags []*Tag `gorm:"many2many:pool_tags;constraint:OnDelete:CASCADE,OnUpdate:CASCADE;"` Enabled bool // ExtraSpecs is an opaque json that gets sent to the provider // as part of the bootstrap params for instances. It can contain @@ -130,7 +130,7 @@ type InstanceStatusUpdate struct { EventLevel params.EventLevel Message string `gorm:"type:text"` - InstanceID uuid.UUID `gorm:"index:instance_id"` + InstanceID uuid.UUID `gorm:"index:idx_instance_status_updates_instance_id"` Instance Instance `gorm:"foreignKey:InstanceID"` } @@ -144,7 +144,7 @@ type Instance struct { OSArch params.OSArch OSName string OSVersion string - Addresses []Address `gorm:"foreignKey:InstanceID;constraint:OnDelete:CASCADE"` + Addresses []Address `gorm:"foreignKey:InstanceID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE;"` Status common.InstanceStatus RunnerStatus common.RunnerStatus CallbackURL string @@ -158,7 +158,7 @@ type Instance struct { PoolID uuid.UUID Pool Pool `gorm:"foreignKey:PoolID"` - StatusMessages []InstanceStatusUpdate `gorm:"foreignKey:InstanceID;constraint:OnDelete:CASCADE"` + StatusMessages []InstanceStatusUpdate `gorm:"foreignKey:InstanceID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE;"` } type User struct { diff --git a/database/sql/sql.go b/database/sql/sql.go index 4cdb2619..dc850e6b 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -16,7 +16,9 @@ package sql import ( "context" + "fmt" "log" + "strings" "github.com/pkg/errors" "gorm.io/driver/mysql" @@ -79,6 +81,111 @@ type sqlDatabase struct { cfg config.Database } +var renameTemplate = ` +PRAGMA foreign_keys = OFF; +BEGIN TRANSACTION; + +ALTER TABLE %s RENAME TO %s_old; +COMMIT; +` + +var restoreNameTemplate = ` +PRAGMA foreign_keys = OFF; +BEGIN TRANSACTION; +DROP TABLE IF EXISTS %s; +ALTER TABLE %s_old RENAME TO %s; +COMMIT; +` + +var copyContentsTemplate = ` +PRAGMA foreign_keys = OFF; +BEGIN TRANSACTION; +INSERT INTO %s SELECT * FROM %s_old; +DROP TABLE %s_old; + +COMMIT; +` + +func (s *sqlDatabase) cascadeMigrationSQLite(model interface{}, name string, justDrop bool) error { + if !s.conn.Migrator().HasTable(name) { + return nil + } + defer s.conn.Exec("PRAGMA foreign_keys = ON;") + + var data string + var indexes []string + if err := s.conn.Raw(fmt.Sprintf("select sql from sqlite_master where tbl_name='%s' and name='%s'", name, name)).Scan(&data).Error; err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to get table %s: %w", name, err) + } + } + + if err := s.conn.Raw(fmt.Sprintf("SELECT name FROM sqlite_master WHERE type == 'index' AND tbl_name == '%s' and name not like 'sqlite_%%'", name)).Scan(&indexes).Error; err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to get table indexes %s: %w", name, err) + } + } + + if strings.Contains(data, "ON DELETE CASCADE") { + return nil + } + + if justDrop { + if err := s.conn.Migrator().DropTable(model); err != nil { + return fmt.Errorf("failed to drop table %s: %w", name, err) + } + return nil + } + + for _, index := range indexes { + if err := s.conn.Migrator().DropIndex(model, index); err != nil { + return fmt.Errorf("failed to drop index %s: %w", index, err) + } + } + + err := s.conn.Exec(fmt.Sprintf(renameTemplate, name, name)).Error + if err != nil { + return fmt.Errorf("failed to rename table %s: %w", name, err) + } + + if model != nil { + if err := s.conn.Migrator().AutoMigrate(model); err != nil { + if err := s.conn.Exec(fmt.Sprintf(restoreNameTemplate, name, name, name)).Error; err != nil { + log.Printf("failed to restore table %s: %s", name, err) + } + return fmt.Errorf("failed to create table %s: %w", name, err) + } + } + err = s.conn.Exec(fmt.Sprintf(copyContentsTemplate, name, name, name)).Error + if err != nil { + return fmt.Errorf("failed to copy contents to table %s: %w", name, err) + } + + return nil +} + +func (s *sqlDatabase) cascadeMigration() error { + switch s.cfg.DbBackend { + case config.SQLiteBackend: + if err := s.cascadeMigrationSQLite(&Address{}, "addresses", true); err != nil { + return fmt.Errorf("failed to drop table addresses: %w", err) + } + + if err := s.cascadeMigrationSQLite(&InstanceStatusUpdate{}, "instance_status_updates", true); err != nil { + return fmt.Errorf("failed to drop table instance_status_updates: %w", err) + } + + if err := s.cascadeMigrationSQLite(&Tag{}, "pool_tags", false); err != nil { + return fmt.Errorf("failed to migrate addresses: %w", err) + } + case config.MySQLBackend: + return nil + default: + return fmt.Errorf("invalid db backend: %s", s.cfg.DbBackend) + } + return nil +} + func (s *sqlDatabase) migrateDB() error { if s.conn.Migrator().HasIndex(&Organization{}, "idx_organizations_name") { if err := s.conn.Migrator().DropIndex(&Organization{}, "idx_organizations_name"); err != nil { @@ -91,6 +198,11 @@ func (s *sqlDatabase) migrateDB() error { log.Printf("failed to drop index idx_owner: %s", err) } } + + if err := s.cascadeMigration(); err != nil { + return errors.Wrap(err, "running cascade migration") + } + if err := s.conn.AutoMigrate( &Tag{}, &Pool{},