diff --git a/go.mod b/go.mod index 097353a1..94071dd7 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( gorm.io/datatypes v1.2.5 gorm.io/driver/mysql v1.5.7 gorm.io/driver/sqlite v1.5.7 - gorm.io/gorm v1.26.1 + gorm.io/gorm v1.30.0 ) require ( diff --git a/go.sum b/go.sum index 28e4606e..467ebbcf 100644 --- a/go.sum +++ b/go.sum @@ -231,5 +231,5 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa gorm.io/driver/sqlserver v1.5.4 h1:xA+Y1KDNspv79q43bPyjDMUgHoYHLhXYmdFcYPobg8g= gorm.io/driver/sqlserver v1.5.4/go.mod h1:+frZ/qYmuna11zHPlh5oc2O6ZA/lS88Keb0XSH1Zh/g= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.26.1 h1:ghB2gUI9FkS46luZtn6DLZ0f6ooBJ5IbVej2ENFDjRw= -gorm.io/gorm v1.26.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= +gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/vendor/gorm.io/gorm/callbacks/create.go b/vendor/gorm.io/gorm/callbacks/create.go index 8b7846b6..d8701f51 100644 --- a/vendor/gorm.io/gorm/callbacks/create.go +++ b/vendor/gorm.io/gorm/callbacks/create.go @@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) { db.AddError(rows.Close()) }() gorm.Scan(rows, db, mode) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } return @@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } + if db.RowsAffected == 0 { return } diff --git a/vendor/gorm.io/gorm/callbacks/delete.go b/vendor/gorm.io/gorm/callbacks/delete.go index 84f446a3..07ed6fee 100644 --- a/vendor/gorm.io/gorm/callbacks/delete.go +++ b/vendor/gorm.io/gorm/callbacks/delete.go @@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if !ok { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } return @@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { gorm.Scan(rows, db, mode) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } db.AddError(rows.Close()) } } diff --git a/vendor/gorm.io/gorm/callbacks/preload.go b/vendor/gorm.io/gorm/callbacks/preload.go index fd8214bb..225cda28 100644 --- a/vendor/gorm.io/gorm/callbacks/preload.go +++ b/vendor/gorm.io/gorm/callbacks/preload.go @@ -103,11 +103,11 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati joined = true continue } - joinNames := strings.SplitN(join, ".", 2) - if len(joinNames) == 2 { - if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { + join0, join1, cut := strings.Cut(join, ".") + if cut { + if _, ok := relationships.Relations[join0]; ok && name == join0 { joined = true - nestedJoins = append(nestedJoins, joinNames[1]) + nestedJoins = append(nestedJoins, join1) } } } @@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) if len(values) != 0 { + tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values}) + for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { tx = fc(tx) @@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + if len(inlineConds) > 0 { + tx = tx.Where(inlineConds[0], inlineConds[1:]...) + } + + if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil { return err } } diff --git a/vendor/gorm.io/gorm/callbacks/query.go b/vendor/gorm.io/gorm/callbacks/query.go index bbf238a9..548bf709 100644 --- a/vendor/gorm.io/gorm/callbacks/query.go +++ b/vendor/gorm.io/gorm/callbacks/query.go @@ -25,6 +25,10 @@ func Query(db *gorm.DB) { db.AddError(rows.Close()) }() gorm.Scan(rows, db, 0) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } } } @@ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) { } } - specifiedRelationsName := make(map[string]interface{}) + specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable} for _, join := range db.Statement.Joins { if db.Statement.Schema != nil { var isRelations bool // is relations or raw sql @@ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) { nestedJoinNames := strings.Split(join.Name, ".") if len(nestedJoinNames) > 1 { isNestedJoin := true - gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) currentRelations := db.Statement.Schema.Relationships.Relations for _, relname := range nestedJoinNames { // incomplete match, only treated as raw sql if relation, ok = currentRelations[relname]; ok { - gussNestedRelations = append(gussNestedRelations, relation) + guessNestedRelations = append(guessNestedRelations, relation) currentRelations = relation.FieldSchema.Relationships.Relations } else { isNestedJoin = false @@ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) { if isNestedJoin { isRelations = true - relations = gussNestedRelations + relations = guessNestedRelations } } } if isRelations { - genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { - tableAliasName := relation.Name - if parentTableName != clause.CurrentTable { - tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) - } - + genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join { columnStmt := gorm.Statement{ Table: tableAliasName, DB: db, Schema: relation.FieldSchema, Selects: join.Selects, Omits: join.Omits, @@ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) { } } + if join.Expression != nil { + return clause.Join{ + Type: join.JoinType, + Expression: join.Expression, + } + } + exprs := make([]clause.Expression, len(relation.References)) for idx, ref := range relation.References { if ref.OwnPrimaryKey { @@ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) { } parentTableName := clause.CurrentTable - for _, rel := range relations { + for idx, rel := range relations { // joins table alias like "Manager, Company, Manager__Company" - nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) - if _, ok := specifiedRelationsName[nestedAlias]; !ok { - fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) - specifiedRelationsName[nestedAlias] = nil + curAliasName := rel.Name + if parentTableName != clause.CurrentTable { + curAliasName = utils.NestedRelationName(parentTableName, curAliasName) } - if parentTableName != clause.CurrentTable { - parentTableName = utils.NestedRelationName(parentTableName, rel.Name) - } else { - parentTableName = rel.Name + if _, ok := specifiedRelationsName[curAliasName]; !ok { + aliasName := curAliasName + if idx == len(relations)-1 && join.Alias != "" { + aliasName = join.Alias + } + + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel)) + specifiedRelationsName[curAliasName] = aliasName } + + parentTableName = curAliasName } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ diff --git a/vendor/gorm.io/gorm/callbacks/raw.go b/vendor/gorm.io/gorm/callbacks/raw.go index 013e638c..3bb647c4 100644 --- a/vendor/gorm.io/gorm/callbacks/raw.go +++ b/vendor/gorm.io/gorm/callbacks/raw.go @@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } } diff --git a/vendor/gorm.io/gorm/callbacks/update.go b/vendor/gorm.io/gorm/callbacks/update.go index 7cde7f61..8e2782e1 100644 --- a/vendor/gorm.io/gorm/callbacks/update.go +++ b/vendor/gorm.io/gorm/callbacks/update.go @@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) { gorm.Scan(rows, db, mode) db.Statement.Dest = dest db.AddError(rows.Close()) + + if db.Statement.Result != nil { + db.Statement.Result.RowsAffected = db.RowsAffected + } } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) { if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() } + + if db.Statement.Result != nil { + db.Statement.Result.Result = result + db.Statement.Result.RowsAffected = db.RowsAffected + } } } } diff --git a/vendor/gorm.io/gorm/chainable_api.go b/vendor/gorm.io/gorm/chainable_api.go index 8953413d..8a6aea34 100644 --- a/vendor/gorm.io/gorm/chainable_api.go +++ b/vendor/gorm.io/gorm/chainable_api.go @@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { // Unscoped allows queries to include records marked as deleted, // overriding the soft deletion behavior. // Example: -// var users []User -// db.Unscoped().Find(&users) -// // Retrieves all users, including deleted ones. +// +// var users []User +// db.Unscoped().Find(&users) +// // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true diff --git a/vendor/gorm.io/gorm/clause/joins.go b/vendor/gorm.io/gorm/clause/joins.go index 879892be..a6f13e55 100644 --- a/vendor/gorm.io/gorm/clause/joins.go +++ b/vendor/gorm.io/gorm/clause/joins.go @@ -1,5 +1,7 @@ package clause +import "gorm.io/gorm/utils" + type JoinType string const ( @@ -9,6 +11,30 @@ const ( RightJoin JoinType = "RIGHT" ) +type JoinTarget struct { + Type JoinType + Association string + Subquery Expression + Table string +} + +func Has(name string) JoinTarget { + return JoinTarget{Type: InnerJoin, Association: name} +} + +func (jt JoinType) Association(name string) JoinTarget { + return JoinTarget{Type: jt, Association: name} +} + +func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { + return JoinTarget{Type: jt, Association: name, Subquery: subquery} +} + +func (jt JoinTarget) As(name string) JoinTarget { + jt.Table = name + return jt +} + // Join clause for from type Join struct { Type JoinType @@ -18,6 +44,12 @@ type Join struct { Expression Expression } +func JoinTable(names ...string) Table { + return Table{ + Name: utils.JoinNestedRelationNames(names), + } +} + func (join Join) Build(builder Builder) { if join.Expression != nil { join.Expression.Build(builder) diff --git a/vendor/gorm.io/gorm/finisher_api.go b/vendor/gorm.io/gorm/finisher_api.go index 6802945c..57809d17 100644 --- a/vendor/gorm.io/gorm/finisher_api.go +++ b/vendor/gorm.io/gorm/finisher_api.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } + ctx := tx.Statement.Context + if _, ok := ctx.Deadline(); !ok { + if db.Config.DefaultTransactionTimeout > 0 { + ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + } + } + switch beginner := tx.Statement.ConnPool.(type) { case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) default: err = ErrInvalidTransaction } diff --git a/vendor/gorm.io/gorm/generics.go b/vendor/gorm.io/gorm/generics.go new file mode 100644 index 00000000..ad2d063f --- /dev/null +++ b/vendor/gorm.io/gorm/generics.go @@ -0,0 +1,605 @@ +package gorm + +import ( + "context" + "database/sql" + "fmt" + "sort" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +type result struct { + Result sql.Result + RowsAffected int64 +} + +func (info *result) ModifyStatement(stmt *Statement) { + stmt.Result = info +} + +// Build implements clause.Expression interface +func (result) Build(clause.Builder) { +} + +func WithResult() *result { + return &result{} +} + +type Interface[T any] interface { + Raw(sql string, values ...interface{}) ExecInterface[T] + Exec(ctx context.Context, sql string, values ...interface{}) error + CreateInterface[T] +} + +type CreateInterface[T any] interface { + ChainInterface[T] + Table(name string, args ...interface{}) CreateInterface[T] + Create(ctx context.Context, r *T) error + CreateInBatches(ctx context.Context, r *[]T, batchSize int) error +} + +type ChainInterface[T any] interface { + ExecInterface[T] + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] + Select(query string, args ...interface{}) ChainInterface[T] + Omit(columns ...string) ChainInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + + Build(builder clause.Builder) + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) +} + +type ExecInterface[T any] interface { + Scan(ctx context.Context, r interface{}) error + First(context.Context) (T, error) + Last(ctx context.Context) (T, error) + Take(context.Context) (T, error) + Find(ctx context.Context) ([]T, error) + FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error + Row(ctx context.Context) *sql.Row + Rows(ctx context.Context) (*sql.Rows, error) +} + +type JoinBuilder interface { + Select(...string) JoinBuilder + Omit(...string) JoinBuilder + Where(query interface{}, args ...interface{}) JoinBuilder + Not(query interface{}, args ...interface{}) JoinBuilder + Or(query interface{}, args ...interface{}) JoinBuilder +} + +type PreloadBuilder interface { + Select(...string) PreloadBuilder + Omit(...string) PreloadBuilder + Where(query interface{}, args ...interface{}) PreloadBuilder + Not(query interface{}, args ...interface{}) PreloadBuilder + Or(query interface{}, args ...interface{}) PreloadBuilder + Limit(offset int) PreloadBuilder + Offset(offset int) PreloadBuilder + Order(value interface{}) PreloadBuilder + LimitPerRecord(num int) PreloadBuilder +} + +type op func(*DB) *DB + +func G[T any](db *DB, opts ...clause.Expression) Interface[T] { + v := &g[T]{ + db: db, + ops: make([]op, 0, 5), + } + + if len(opts) > 0 { + v.ops = append(v.ops, func(db *DB) *DB { + return db.Clauses(opts...) + }) + } + + v.createG = &createG[T]{ + chainG: chainG[T]{ + execG: execG[T]{g: v}, + }, + } + return v +} + +type g[T any] struct { + *createG[T] + db *DB + ops []op +} + +func (g *g[T]) apply(ctx context.Context) *DB { + db := g.db + if !db.DryRun { + db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance() + } + + for _, op := range g.ops { + db = op(db) + } + return db +} + +func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { + return execG[T]{g: &g[T]{ + db: c.db, + ops: append(c.ops, func(db *DB) *DB { + return db.Raw(sql, values...) + }), + }} +} + +func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { + return c.apply(ctx).Exec(sql, values...).Error +} + +type createG[T any] struct { + chainG[T] +} + +func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Table(name, args...) + })} +} + +func (c createG[T]) Create(ctx context.Context, r *T) error { + return c.g.apply(ctx).Create(r).Error +} + +func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { + return c.g.apply(ctx).CreateInBatches(r, batchSize).Error +} + +type chainG[T any] struct { + execG[T] +} + +func (c chainG[T]) getInstance() *DB { + var r T + return c.g.apply(context.Background()).Model(r).getInstance() +} + +func (c chainG[T]) with(v op) chainG[T] { + return chainG[T]{ + execG: execG[T]{g: &g[T]{ + db: c.g.db, + ops: append(append([]op(nil), c.g.ops...), v), + }}, + } +} + +func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { + return c.with(func(db *DB) *DB { + for _, fc := range scopes { + fc(db.Statement) + } + return db + }) +} + +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + +func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Where(query, args...) + }) +} + +func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Not(query, args...) + }) +} + +func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Or(query, args...) + }) +} + +func (c chainG[T]) Limit(offset int) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Limit(offset) + }) +} + +func (c chainG[T]) Offset(offset int) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Offset(offset) + }) +} + +type joinBuilder struct { + db *DB +} + +func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { + q.db.Where(query, args...) + return q +} + +func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { + q.db.Where(query, args...) + return q +} + +func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { + q.db.Where(query, args...) + return q +} + +func (q *joinBuilder) Select(columns ...string) JoinBuilder { + q.db.Select(columns) + return q +} + +func (q *joinBuilder) Omit(columns ...string) JoinBuilder { + q.db.Omit(columns...) + return q +} + +type preloadBuilder struct { + limitPerRecord int + db *DB +} + +func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q *preloadBuilder) Select(columns ...string) PreloadBuilder { + q.db.Select(columns) + return q +} + +func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder { + q.db.Omit(columns...) + return q +} + +func (q *preloadBuilder) Limit(limit int) PreloadBuilder { + q.db.Limit(limit) + return q +} + +func (q *preloadBuilder) Offset(offset int) PreloadBuilder { + q.db.Offset(offset) + return q +} + +func (q *preloadBuilder) Order(value interface{}) PreloadBuilder { + q.db.Order(value) + return q +} + +func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder { + q.limitPerRecord = num + return q +} + +func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { + return c.with(func(db *DB) *DB { + if jt.Table == "" { + jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name + } + + q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} + if on != nil { + if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { + db.AddError(err) + } + } + + j := join{ + Name: jt.Association, + Alias: jt.Table, + Selects: q.db.Statement.Selects, + Omits: q.db.Statement.Omits, + JoinType: jt.Type, + } + + if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + j.On = &where + } + + if jt.Subquery != nil { + joinType := j.JoinType + if joinType == "" { + joinType = clause.LeftJoin + } + + if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok { + stmt := db.getInstance().Statement + if len(j.Selects) == 0 { + j.Selects = stmt.Selects + } + if len(j.Omits) == 0 { + j.Omits = stmt.Omits + } + } + + expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} + + if j.On != nil { + expr.SQL += " ON ?" + expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs}) + } + + j.Expression = expr + } + + db.Statement.Joins = append(db.Statement.Joins, j) + sort.Slice(db.Statement.Joins, func(i, j int) bool { + return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name + }) + return db + }) +} + +func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Select(query, args...) + }) +} + +func (c chainG[T]) Omit(columns ...string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Omit(columns...) + }) +} + +func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.MapColumns(m) + }) +} + +func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Distinct(args...) + }) +} + +func (c chainG[T]) Group(name string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Group(name) + }) +} + +func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Having(query, args...) + }) +} + +func (c chainG[T]) Order(value interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Order(value) + }) +} + +func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Preload(association, func(tx *DB) *DB { + q := preloadBuilder{db: tx.getInstance()} + if query != nil { + if err := query(&q); err != nil { + db.AddError(err) + } + } + + relation, ok := db.Statement.Schema.Relationships.Relations[association] + if !ok { + if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { + relationships := db.Statement.Schema.Relationships + for _, field := range preloadFields { + var ok bool + relation, ok = relationships.Relations[field] + if ok { + relationships = relation.FieldSchema.Relationships + } else { + db.AddError(fmt.Errorf("relation %s not found", association)) + return nil + } + } + } else { + db.AddError(fmt.Errorf("relation %s not found", association)) + return nil + } + } + + if q.limitPerRecord > 0 { + if relation.JoinTable != nil { + tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)) + return tx + } + + refColumns := []clause.Column{} + for _, rel := range relation.References { + if rel.OwnPrimaryKey { + refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName}) + } + } + + if len(refColumns) != 0 { + selectExpr := clause.CommaExpression{} + for _, column := range q.db.Statement.Selects { + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) + } + + if len(selectExpr.Exprs) == 0 { + selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} + } + + partitionBy := clause.CommaExpression{} + for _, column := range refColumns { + partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}}) + } + + rnnColumn := clause.Column{Name: "gorm_preload_rnn"} + sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)" + vars := []interface{}{partitionBy} + if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok { + vars = append(vars, orderBy) + } else { + vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }}) + } + vars = append(vars, rnnColumn) + + selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) + + q.db.Clauses(clause.Select{Expression: selectExpr}) + + return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) + } + } + + return q.db + }) + }) +} + +func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { + r := new(T) + res := c.g.apply(ctx).Delete(r) + return int(res.RowsAffected), res.Error +} + +func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { + var r T + res := c.g.apply(ctx).Model(r).Update(name, value) + return int(res.RowsAffected), res.Error +} + +func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { + res := c.g.apply(ctx).Updates(t) + return int(res.RowsAffected), res.Error +} + +func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { + var r T + err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error + return +} + +func (c chainG[T]) Build(builder clause.Builder) { + subdb := c.getInstance() + subdb.Logger = logger.Discard + subdb.DryRun = true + + if stmt, ok := builder.(*Statement); ok { + if subdb.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = subdb.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + subdb.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + + builder.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars + } +} + +type execG[T any] struct { + g *g[T] +} + +func (g execG[T]) First(ctx context.Context) (T, error) { + var r T + err := g.g.apply(ctx).First(&r).Error + return r, err +} + +func (g execG[T]) Scan(ctx context.Context, result interface{}) error { + var r T + err := g.g.apply(ctx).Model(r).Find(&result).Error + return err +} + +func (g execG[T]) Last(ctx context.Context) (T, error) { + var r T + err := g.g.apply(ctx).Last(&r).Error + return r, err +} + +func (g execG[T]) Take(ctx context.Context) (T, error) { + var r T + err := g.g.apply(ctx).Take(&r).Error + return r, err +} + +func (g execG[T]) Find(ctx context.Context) ([]T, error) { + var r []T + err := g.g.apply(ctx).Find(&r).Error + return r, err +} + +func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { + var data []T + return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { + return fc(data, batch) + }).Error +} + +func (g execG[T]) Row(ctx context.Context) *sql.Row { + return g.g.apply(ctx).Row() +} + +func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { + return g.g.apply(ctx).Rows() +} diff --git a/vendor/gorm.io/gorm/gorm.go b/vendor/gorm.io/gorm/gorm.go index 63a28b37..67889262 100644 --- a/vendor/gorm.io/gorm/gorm.go +++ b/vendor/gorm.io/gorm/gorm.go @@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt" type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // You can disable it by setting `SkipDefaultTransaction` to true - SkipDefaultTransaction bool + SkipDefaultTransaction bool + DefaultTransactionTimeout time.Duration + // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer // FullSaveAssociations full save associations @@ -135,12 +137,16 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return isConfig && !isConfig2 }) + var skipAfterInitialize bool for _, opt := range opts { if opt != nil { if applyErr := opt.Apply(config); applyErr != nil { return nil, applyErr } defer func(opt Option) { + if skipAfterInitialize { + return + } if errr := opt.AfterInitialize(db); errr != nil { err = errr } @@ -192,6 +198,10 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { if db, _ := db.DB(); db != nil { _ = db.Close() } + + // DB is not initialized, so we skip AfterInitialize + skipAfterInitialize = true + return } if config.TranslateError { @@ -519,7 +529,7 @@ func (db *DB) Use(plugin Plugin) error { // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { - tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance()) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) diff --git a/vendor/gorm.io/gorm/scan.go b/vendor/gorm.io/gorm/scan.go index 6dc55f62..9a99d024 100644 --- a/vendor/gorm.io/gorm/scan.go +++ b/vendor/gorm.io/gorm/scan.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "reflect" + "strings" "time" "gorm.io/gorm/schema" @@ -244,6 +245,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) { matchedFieldCount[column] = 1 } } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation + aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1]) + for _, join := range db.Statement.Joins { + if join.Alias == aliasName { + names = append(strings.Split(join.Name, "."), names[len(names)-1]) + break + } + } + if rel, ok := sch.Relationships.Relations[names[0]]; ok { subNameCount := len(names) // nested relation fields diff --git a/vendor/gorm.io/gorm/schema/field.go b/vendor/gorm.io/gorm/schema/field.go index d1a633ce..a6ff1a72 100644 --- a/vendor/gorm.io/gorm/schema/field.go +++ b/vendor/gorm.io/gorm/schema/field.go @@ -318,9 +318,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if val, ok := field.TagSettings["TYPE"]; ok { - switch DataType(strings.ToLower(val)) { + lowerVal := DataType(strings.ToLower(val)) + switch lowerVal { case Bool, Int, Uint, Float, String, Time, Bytes: - field.DataType = DataType(strings.ToLower(val)) + field.DataType = lowerVal default: field.DataType = DataType(val) } diff --git a/vendor/gorm.io/gorm/schema/index.go b/vendor/gorm.io/gorm/schema/index.go index a1cdc639..2690a0cb 100644 --- a/vendor/gorm.io/gorm/schema/index.go +++ b/vendor/gorm.io/gorm/schema/index.go @@ -105,7 +105,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) { var ( name string tag = strings.Join(v[1:], ":") - idx = strings.Index(tag, ",") + idx = strings.IndexByte(tag, ',') tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") settings = ParseTagSetting(tagSetting, ",") length, _ = strconv.Atoi(settings["LENGTH"]) diff --git a/vendor/gorm.io/gorm/schema/relationship.go b/vendor/gorm.io/gorm/schema/relationship.go index def4a595..f1ace924 100644 --- a/vendor/gorm.io/gorm/schema/relationship.go +++ b/vendor/gorm.io/gorm/schema/relationship.go @@ -78,7 +78,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { cacheStore := schema.cacheStore if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { - schema.err = err + schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -663,6 +663,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && rel.References[idx].PrimaryValue == ref.PrimaryValue) { matched = false + break } } @@ -675,7 +676,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { var ( name string - idx = strings.Index(str, ",") + idx = strings.IndexByte(str, ',') settings = ParseTagSetting(str, ",") ) @@ -762,8 +763,9 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref } func copyableDataType(str DataType) bool { + lowerStr := strings.ToLower(string(str)) for _, s := range []string{"auto_increment", "primary key"} { - if strings.Contains(strings.ToLower(string(str)), s) { + if strings.Contains(lowerStr, s) { return false } } diff --git a/vendor/gorm.io/gorm/statement.go b/vendor/gorm.io/gorm/statement.go index 39e05d09..c6183724 100644 --- a/vendor/gorm.io/gorm/statement.go +++ b/vendor/gorm.io/gorm/statement.go @@ -47,15 +47,18 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB + Result *result } type join struct { - Name string - Conds []interface{} - On *clause.Where - Selects []string - Omits []string - JoinType clause.JoinType + Name string + Alias string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + Expression clause.Expression + JoinType clause.JoinType } // StatementModifier statement modifier interface @@ -205,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { writer.WriteString("(NULL)") } - case *DB: - subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() - if v.Statement.SQL.Len() > 0 { + case interface{ getInstance() *DB }: + cv := v.getInstance() + + subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + if cv.Statement.SQL.Len() > 0 { var ( vars = subdb.Statement.Vars - sql = v.Statement.SQL.String() + sql = cv.Statement.SQL.String() ) subdb.Statement.Vars = make([]interface{}, 0, len(vars)) for _, vv := range vars { subdb.Statement.Vars = append(subdb.Statement.Vars, vv) bindvar := strings.Builder{} - v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + cv.BindVarTo(&bindvar, subdb.Statement, vv) sql = strings.Replace(sql, bindvar.String(), "?", 1) } @@ -321,6 +326,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] arg, _ = valuer.Value() } + curTable := stmt.Table + if curTable == "" { + curTable = clause.CurrentTable + } + switch v := arg.(type) { case clause.Expression: conds = append(conds, v) @@ -351,7 +361,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] sort.Strings(keys) for _, key := range keys { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + column := clause.Column{Name: key, Table: curTable} + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: keys := make([]string, 0, len(v)) @@ -362,12 +373,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) + column := clause.Column{Name: key, Table: curTable} switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else if _, ok := v[key].(Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else { // optimize reflect value length valueLen := reflectValue.Len() @@ -376,10 +388,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] values[i] = reflectValue.Index(i).Interface() } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{Column: column, Values: values}) } default: - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } } default: @@ -406,9 +418,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } @@ -420,9 +432,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } @@ -447,14 +459,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } if len(values) > 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) return []clause.Expression{clause.And(conds...)} } return nil } } - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) } } } @@ -521,6 +533,7 @@ func (stmt *Statement) clone() *Statement { Context: stmt.Context, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, SkipHooks: stmt.SkipHooks, + Result: stmt.Result, } if stmt.SQL.Len() > 0 { diff --git a/vendor/modules.txt b/vendor/modules.txt index 46b3a5e6..b18b786a 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -365,7 +365,7 @@ gorm.io/driver/mysql # gorm.io/driver/sqlite v1.5.7 ## explicit; go 1.20 gorm.io/driver/sqlite -# gorm.io/gorm v1.26.1 +# gorm.io/gorm v1.30.0 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks