Extra specs is an opaque valid JSON that can be set on a pool and which will be passed along to the provider as part of instance bootstrap params. This field is meant to allow operators to send extra configuration values to external or built-in providers. The extra specs is not interpreted or useful in any way to garm itself, but it may be useful to the provider which interacts with the IaaS. The extra specs are not meant to be used for secrets. Adding sensitive information to this field is highly discouraged. This field is meant as a means to add fine tuning knobs to the providers, on a per pool basis. Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
492 lines
13 KiB
Go
492 lines
13 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"math"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/callbacks"
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/logger"
|
|
"gorm.io/gorm/migrator"
|
|
"gorm.io/gorm/schema"
|
|
"gorm.io/gorm/utils"
|
|
)
|
|
|
|
type Config struct {
|
|
DriverName string
|
|
ServerVersion string
|
|
DSN string
|
|
DSNConfig *mysql.Config
|
|
Conn gorm.ConnPool
|
|
SkipInitializeWithVersion bool
|
|
DefaultStringSize uint
|
|
DefaultDatetimePrecision *int
|
|
DisableWithReturning bool
|
|
DisableDatetimePrecision bool
|
|
DontSupportRenameIndex bool
|
|
DontSupportRenameColumn bool
|
|
DontSupportForShareClause bool
|
|
DontSupportNullAsDefaultValue bool
|
|
}
|
|
|
|
type Dialector struct {
|
|
*Config
|
|
}
|
|
|
|
var (
|
|
// CreateClauses create clauses
|
|
CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
|
// QueryClauses query clauses
|
|
QueryClauses = []string{}
|
|
// UpdateClauses update clauses
|
|
UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"}
|
|
// DeleteClauses delete clauses
|
|
DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"}
|
|
|
|
defaultDatetimePrecision = 3
|
|
)
|
|
|
|
func Open(dsn string) gorm.Dialector {
|
|
dsnConf, _ := mysql.ParseDSN(dsn)
|
|
return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
|
|
}
|
|
|
|
func New(config Config) gorm.Dialector {
|
|
return &Dialector{Config: &config}
|
|
}
|
|
|
|
func (dialector Dialector) Name() string {
|
|
return "mysql"
|
|
}
|
|
|
|
// NowFunc return now func
|
|
func (dialector Dialector) NowFunc(n int) func() time.Time {
|
|
return func() time.Time {
|
|
round := time.Second / time.Duration(math.Pow10(n))
|
|
return time.Now().Round(round)
|
|
}
|
|
}
|
|
|
|
func (dialector Dialector) Apply(config *gorm.Config) error {
|
|
if config.NowFunc != nil {
|
|
return nil
|
|
}
|
|
|
|
if dialector.DefaultDatetimePrecision == nil {
|
|
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
|
|
}
|
|
// while maintaining the readability of the code, separate the business logic from
|
|
// the general part and leave it to the function to do it here.
|
|
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
|
|
return nil
|
|
}
|
|
|
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
|
if dialector.DriverName == "" {
|
|
dialector.DriverName = "mysql"
|
|
}
|
|
|
|
if dialector.DefaultDatetimePrecision == nil {
|
|
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
|
|
}
|
|
|
|
if dialector.Conn != nil {
|
|
db.ConnPool = dialector.Conn
|
|
} else {
|
|
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
withReturning := false
|
|
if !dialector.Config.SkipInitializeWithVersion {
|
|
err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if strings.Contains(dialector.ServerVersion, "MariaDB") {
|
|
dialector.Config.DontSupportRenameIndex = true
|
|
dialector.Config.DontSupportRenameColumn = true
|
|
dialector.Config.DontSupportForShareClause = true
|
|
dialector.Config.DontSupportNullAsDefaultValue = true
|
|
withReturning = checkVersion(dialector.ServerVersion, "10.5")
|
|
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
|
|
dialector.Config.DontSupportRenameIndex = true
|
|
dialector.Config.DontSupportRenameColumn = true
|
|
dialector.Config.DontSupportForShareClause = true
|
|
} else if strings.HasPrefix(dialector.ServerVersion, "5.7.") {
|
|
dialector.Config.DontSupportRenameColumn = true
|
|
dialector.Config.DontSupportForShareClause = true
|
|
} else if strings.HasPrefix(dialector.ServerVersion, "5.") {
|
|
dialector.Config.DisableDatetimePrecision = true
|
|
dialector.Config.DontSupportRenameIndex = true
|
|
dialector.Config.DontSupportRenameColumn = true
|
|
dialector.Config.DontSupportForShareClause = true
|
|
}
|
|
}
|
|
|
|
// register callbacks
|
|
callbackConfig := &callbacks.Config{
|
|
CreateClauses: CreateClauses,
|
|
QueryClauses: QueryClauses,
|
|
UpdateClauses: UpdateClauses,
|
|
DeleteClauses: DeleteClauses,
|
|
}
|
|
|
|
if !dialector.Config.DisableWithReturning && withReturning {
|
|
callbackConfig.LastInsertIDReversed = true
|
|
|
|
if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") {
|
|
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
|
|
}
|
|
|
|
if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") {
|
|
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
|
|
}
|
|
|
|
if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") {
|
|
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
|
|
}
|
|
}
|
|
|
|
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
|
|
|
|
for k, v := range dialector.ClauseBuilders() {
|
|
db.ClauseBuilders[k] = v
|
|
}
|
|
return
|
|
}
|
|
|
|
const (
|
|
// ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key
|
|
ClauseOnConflict = "ON CONFLICT"
|
|
// ClauseValues for clause.ClauseBuilder VALUES key
|
|
ClauseValues = "VALUES"
|
|
// ClauseFor for clause.ClauseBuilder FOR key
|
|
ClauseFor = "FOR"
|
|
)
|
|
|
|
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
|
clauseBuilders := map[string]clause.ClauseBuilder{
|
|
ClauseOnConflict: func(c clause.Clause, builder clause.Builder) {
|
|
onConflict, ok := c.Expression.(clause.OnConflict)
|
|
if !ok {
|
|
c.Build(builder)
|
|
return
|
|
}
|
|
|
|
builder.WriteString("ON DUPLICATE KEY UPDATE ")
|
|
if len(onConflict.DoUpdates) == 0 {
|
|
if s := builder.(*gorm.Statement).Schema; s != nil {
|
|
var column clause.Column
|
|
onConflict.DoNothing = false
|
|
|
|
if s.PrioritizedPrimaryField != nil {
|
|
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
|
|
} else if len(s.DBNames) > 0 {
|
|
column = clause.Column{Name: s.DBNames[0]}
|
|
}
|
|
|
|
if column.Name != "" {
|
|
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
|
|
}
|
|
|
|
builder.(*gorm.Statement).AddClause(onConflict)
|
|
}
|
|
}
|
|
|
|
for idx, assignment := range onConflict.DoUpdates {
|
|
if idx > 0 {
|
|
builder.WriteByte(',')
|
|
}
|
|
|
|
builder.WriteQuoted(assignment.Column)
|
|
builder.WriteByte('=')
|
|
if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" {
|
|
column.Table = ""
|
|
builder.WriteString("VALUES(")
|
|
builder.WriteQuoted(column)
|
|
builder.WriteByte(')')
|
|
} else {
|
|
builder.AddVar(builder, assignment.Value)
|
|
}
|
|
}
|
|
},
|
|
ClauseValues: func(c clause.Clause, builder clause.Builder) {
|
|
if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
|
|
builder.WriteString("VALUES()")
|
|
return
|
|
}
|
|
c.Build(builder)
|
|
},
|
|
}
|
|
|
|
if dialector.Config.DontSupportForShareClause {
|
|
clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) {
|
|
if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") {
|
|
builder.WriteString("LOCK IN SHARE MODE")
|
|
return
|
|
}
|
|
c.Build(builder)
|
|
}
|
|
}
|
|
|
|
return clauseBuilders
|
|
}
|
|
|
|
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
|
return clause.Expr{SQL: "DEFAULT"}
|
|
}
|
|
|
|
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
|
return Migrator{
|
|
Migrator: migrator.Migrator{
|
|
Config: migrator.Config{
|
|
DB: db,
|
|
Dialector: dialector,
|
|
},
|
|
},
|
|
Dialector: dialector,
|
|
}
|
|
}
|
|
|
|
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
|
writer.WriteByte('?')
|
|
}
|
|
|
|
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
|
var (
|
|
underQuoted, selfQuoted bool
|
|
continuousBacktick int8
|
|
shiftDelimiter int8
|
|
)
|
|
|
|
for _, v := range []byte(str) {
|
|
switch v {
|
|
case '`':
|
|
continuousBacktick++
|
|
if continuousBacktick == 2 {
|
|
writer.WriteString("``")
|
|
continuousBacktick = 0
|
|
}
|
|
case '.':
|
|
if continuousBacktick > 0 || !selfQuoted {
|
|
shiftDelimiter = 0
|
|
underQuoted = false
|
|
continuousBacktick = 0
|
|
writer.WriteByte('`')
|
|
}
|
|
writer.WriteByte(v)
|
|
continue
|
|
default:
|
|
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
|
writer.WriteByte('`')
|
|
underQuoted = true
|
|
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
|
continuousBacktick -= 1
|
|
}
|
|
}
|
|
|
|
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
|
writer.WriteString("``")
|
|
}
|
|
|
|
writer.WriteByte(v)
|
|
}
|
|
shiftDelimiter++
|
|
}
|
|
|
|
if continuousBacktick > 0 && !selfQuoted {
|
|
writer.WriteString("``")
|
|
}
|
|
writer.WriteByte('`')
|
|
}
|
|
|
|
type localTimeInterface interface {
|
|
In(loc *time.Location) time.Time
|
|
}
|
|
|
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
|
if dialector.DSNConfig != nil && dialector.DSNConfig.Loc == time.Local {
|
|
for i, v := range vars {
|
|
if p, ok := v.(localTimeInterface); ok {
|
|
func(i int, t localTimeInterface) {
|
|
defer func() {
|
|
recover()
|
|
}()
|
|
vars[i] = t.In(time.Local)
|
|
}(i, p)
|
|
}
|
|
}
|
|
}
|
|
return logger.ExplainSQL(sql, nil, `'`, vars...)
|
|
}
|
|
|
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
|
switch field.DataType {
|
|
case schema.Bool:
|
|
return "boolean"
|
|
case schema.Int, schema.Uint:
|
|
return dialector.getSchemaIntAndUnitType(field)
|
|
case schema.Float:
|
|
return dialector.getSchemaFloatType(field)
|
|
case schema.String:
|
|
return dialector.getSchemaStringType(field)
|
|
case schema.Time:
|
|
return dialector.getSchemaTimeType(field)
|
|
case schema.Bytes:
|
|
return dialector.getSchemaBytesType(field)
|
|
default:
|
|
return dialector.getSchemaCustomType(field)
|
|
}
|
|
}
|
|
|
|
func (dialector Dialector) getSchemaFloatType(field *schema.Field) string {
|
|
if field.Precision > 0 {
|
|
return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale)
|
|
}
|
|
|
|
if field.Size <= 32 {
|
|
return "float"
|
|
}
|
|
|
|
return "double"
|
|
}
|
|
|
|
func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
|
|
size := field.Size
|
|
if size == 0 {
|
|
if dialector.DefaultStringSize > 0 {
|
|
size = int(dialector.DefaultStringSize)
|
|
} else {
|
|
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
|
|
// TEXT, GEOMETRY or JSON column can't have a default value
|
|
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
|
|
size = 191 // utf8mb4
|
|
}
|
|
}
|
|
}
|
|
|
|
if size >= 65536 && size <= int(math.Pow(2, 24)) {
|
|
return "mediumtext"
|
|
}
|
|
|
|
if size > int(math.Pow(2, 24)) || size <= 0 {
|
|
return "longtext"
|
|
}
|
|
|
|
return fmt.Sprintf("varchar(%d)", size)
|
|
}
|
|
|
|
func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
|
|
if !dialector.DisableDatetimePrecision && field.Precision == 0 {
|
|
field.Precision = *dialector.DefaultDatetimePrecision
|
|
}
|
|
|
|
var precision string
|
|
if field.Precision > 0 {
|
|
precision = fmt.Sprintf("(%d)", field.Precision)
|
|
}
|
|
|
|
if field.NotNull || field.PrimaryKey {
|
|
return "datetime" + precision
|
|
}
|
|
return "datetime" + precision + " NULL"
|
|
}
|
|
|
|
func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
|
|
if field.Size > 0 && field.Size < 65536 {
|
|
return fmt.Sprintf("varbinary(%d)", field.Size)
|
|
}
|
|
|
|
if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
|
|
return "mediumblob"
|
|
}
|
|
|
|
return "longblob"
|
|
}
|
|
|
|
func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
|
|
constraint := func(sqlType string) string {
|
|
if field.DataType == schema.Uint {
|
|
sqlType += " unsigned"
|
|
}
|
|
if field.NotNull {
|
|
sqlType += " NOT NULL"
|
|
}
|
|
if field.AutoIncrement {
|
|
sqlType += " AUTO_INCREMENT"
|
|
}
|
|
return sqlType
|
|
}
|
|
|
|
switch {
|
|
case field.Size <= 8:
|
|
return constraint("tinyint")
|
|
case field.Size <= 16:
|
|
return constraint("smallint")
|
|
case field.Size <= 24:
|
|
return constraint("mediumint")
|
|
case field.Size <= 32:
|
|
return constraint("int")
|
|
default:
|
|
return constraint("bigint")
|
|
}
|
|
}
|
|
|
|
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
|
|
sqlType := string(field.DataType)
|
|
|
|
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), " auto_increment") {
|
|
sqlType += " AUTO_INCREMENT"
|
|
}
|
|
|
|
return sqlType
|
|
}
|
|
|
|
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
|
|
return tx.Exec("SAVEPOINT " + name).Error
|
|
}
|
|
|
|
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
|
return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error
|
|
}
|
|
|
|
// checkVersion newer or equal returns true, old returns false
|
|
func checkVersion(newVersion, oldVersion string) bool {
|
|
if newVersion == oldVersion {
|
|
return true
|
|
}
|
|
|
|
var (
|
|
versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`)
|
|
|
|
newVersions = strings.Split(newVersion, ".")
|
|
oldVersions = strings.Split(oldVersion, ".")
|
|
)
|
|
for idx, nv := range newVersions {
|
|
if len(oldVersions) <= idx {
|
|
return true
|
|
}
|
|
|
|
nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1"))
|
|
ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
|
|
if nvi == ovi {
|
|
continue
|
|
}
|
|
return nvi > ovi
|
|
}
|
|
|
|
return false
|
|
}
|