diff --git a/go.mod b/go.mod index 6806da50..9be81550 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( gorm.io/datatypes v1.2.6 gorm.io/driver/mysql v1.6.0 gorm.io/driver/sqlite v1.6.0 - gorm.io/gorm v1.30.5 + gorm.io/gorm v1.31.0 ) require ( diff --git a/go.sum b/go.sum index 8a17c245..08a45b03 100644 --- a/go.sum +++ b/go.sum @@ -234,5 +234,5 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= gorm.io/driver/sqlserver v1.6.0 h1:VZOBQVsVhkHU/NzNhRJKoANt5pZGQAS1Bwc6m6dgfnc= gorm.io/driver/sqlserver v1.6.0/go.mod h1:WQzt4IJo/WHKnckU9jXBLMJIVNMVeTu25dnOzehntWw= -gorm.io/gorm v1.30.5 h1:dvEfYwxL+i+xgCNSGGBT1lDjCzfELK8fHZxL3Ee9X0s= -gorm.io/gorm v1.30.5/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= +gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/vendor/gorm.io/gorm/association.go b/vendor/gorm.io/gorm/association.go index e3f51d17..f210ca0a 100644 --- a/vendor/gorm.io/gorm/association.go +++ b/vendor/gorm.io/gorm/association.go @@ -19,10 +19,10 @@ type Association struct { } func (db *DB) Association(column string) *Association { - association := &Association{DB: db} + association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table - if err := db.Statement.Parse(db.Statement.Model); err == nil { + if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] @@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } - } else { - association.Error = err } return association @@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro } func (association *Association) Append(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: @@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error { } func (association *Association) Replace(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship @@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue @@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } + processMap := func(mapv reflect.Value) { + child := reflect.New(association.Relationship.FieldSchema.ModelType) + + switch association.Relationship.Type { + case schema.HasMany: + for _, ref := range association.Relationship.References { + key := reflect.ValueOf(ref.ForeignKey.DBName) + if ref.OwnPrimaryKey { + v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) + mapv.SetMapIndex(key, v) + } else if ref.PrimaryValue != "" { + mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) + } + } + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + case schema.Many2Many: + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + + for _, key := range mapv.MapKeys() { + k := strings.ToLower(key.String()) + if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { + _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) + } + } + appendToFieldValues(child) + } + } + switch rv.Kind() { + case reflect.Map: + processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + elem := reflect.Indirect(rv.Index(i)) + if elem.Kind() == reflect.Map { + processMap(elem) + continue + } + appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { @@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB { return tx } + +func expandValues(values ...any) (results []any) { + appendToResult := func(rv reflect.Value) { + // unwrap interface + if rv.IsValid() && rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.IsValid() && rv.Kind() == reflect.Struct { + p := reflect.New(rv.Type()) + p.Elem().Set(rv) + results = append(results, p.Interface()) + } else if rv.IsValid() { + results = append(results, rv.Interface()) + } + } + + // Process each argument; if an argument is a slice/array, expand its elements + for _, value := range values { + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + appendToResult(rv.Index(i)) + } + } else { + appendToResult(rv) + } + } + return +} diff --git a/vendor/gorm.io/gorm/clause/association.go b/vendor/gorm.io/gorm/clause/association.go new file mode 100644 index 00000000..a9bf7eb0 --- /dev/null +++ b/vendor/gorm.io/gorm/clause/association.go @@ -0,0 +1,35 @@ +package clause + +// AssociationOpType represents association operation types +type AssociationOpType int + +const ( + OpUnlink AssociationOpType = iota // Unlink association + OpDelete // Delete association records + OpUpdate // Update association records + OpCreate // Create association records with assignments +) + +// Association represents an association operation +type Association struct { + Association string // Association name + Type AssociationOpType // Operation type + Conditions []Expression // Filter conditions + Set []Assignment // Assignment operations (for Update and Create) + Values []interface{} // Values for Create operation +} + +// AssociationAssigner is an interface for association operation providers +type AssociationAssigner interface { + AssociationAssignments() []Association +} + +// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter +func (ao Association) Assignments() []Assignment { + return []Assignment{} +} + +// AssociationAssignments implements the AssociationAssigner interface +func (ao Association) AssociationAssignments() []Association { + return []Association{ao} +} diff --git a/vendor/gorm.io/gorm/clause/set.go b/vendor/gorm.io/gorm/clause/set.go index 2ffadb38..cb5f36a0 100644 --- a/vendor/gorm.io/gorm/clause/set.go +++ b/vendor/gorm.io/gorm/clause/set.go @@ -11,7 +11,7 @@ type Assignment struct { // Assigner assignments provider interface type Assigner interface { - Assignments() []Assignment + Assignments() []Assignment } func (set Set) Name() string { diff --git a/vendor/gorm.io/gorm/generics.go b/vendor/gorm.io/gorm/generics.go index 8c79342b..79238d5f 100644 --- a/vendor/gorm.io/gorm/generics.go +++ b/vendor/gorm.io/gorm/generics.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" + "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type result struct { @@ -57,6 +60,7 @@ type CreateInterface[T any] interface { Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error @@ -200,7 +204,7 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { } func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { - return setCreateOrUpdateG[T]{c: c.chainG, assigns: toAssignments(assignments...)} + return c.processSet(assignments...) } func (c createG[T]) Create(ctx context.Context, r *T) error { @@ -431,7 +435,7 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { } func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { - return setCreateOrUpdateG[T]{c: c, assigns: toAssignments(assignments...)} + return c.processSet(assignments...) } func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { @@ -601,36 +605,6 @@ func (c chainG[T]) Build(builder clause.Builder) { } } -type setCreateOrUpdateG[T any] struct { - c chainG[T] - assigns []clause.Assignment -} - -// toAssignments converts various supported types into []clause.Assignment. -// Supported inputs implement clause.Assigner. -func toAssignments(items ...clause.Assigner) []clause.Assignment { - out := make([]clause.Assignment, 0, len(items)) - for _, it := range items { - out = append(out, it.Assignments()...) - } - return out -} - -func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { - var r T - res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) - return int(res.RowsAffected), res.Error -} - -func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { - var r T - data := make(map[string]interface{}, len(s.assigns)) - for _, a := range s.assigns { - data[a.Column.Name] = a.Value - } - return s.c.g.apply(ctx).Model(r).Create(data).Error -} - type execG[T any] struct { g *g[T] } @@ -679,3 +653,242 @@ func (g execG[T]) Row(ctx context.Context) *sql.Row { func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { return g.g.apply(ctx).Rows() } + +func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { + var ( + assigns []clause.Assignment + assocOps []clause.Association + ) + + for _, item := range items { + // Check if it's an AssociationAssigner + if assocAssigner, ok := item.(clause.AssociationAssigner); ok { + assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) + } else { + assigns = append(assigns, item.Assignments()...) + } + } + + return setCreateOrUpdateG[T]{ + c: c, + assigns: assigns, + assocOps: assocOps, + } +} + +// setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. +// It supports regular assignments and association operations. +type setCreateOrUpdateG[T any] struct { + c chainG[T] + assigns []clause.Assignment + assocOps []clause.Association +} + +func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return 0, err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + var r T + res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) + return int(res.RowsAffected), res.Error + } + + return 0, nil +} + +func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + data := make(map[string]interface{}, len(s.assigns)) + for _, a := range s.assigns { + data[a.Column.Name] = a.Value + } + var r T + return s.c.g.apply(ctx).Model(r).Create(data).Error + } + + return nil +} + +// executeAssociationOperation executes an association operation +func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { + var r T + base := s.c.g.apply(ctx).Model(r) + + switch op.Type { + case clause.OpCreate: + return s.handleAssociationCreate(ctx, base, op) + case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: + return s.handleAssociation(ctx, base, op) + default: + return fmt.Errorf("unknown association operation type: %v", op.Type) + } +} + +func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { + if len(op.Set) > 0 { + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + data := make(map[string]interface{}, len(op.Set)) + for _, a := range op.Set { + data[a.Column.Name] = a.Value + } + return assoc.Append(data) + }, op.Association) + } + + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + return assoc.Append(op.Values...) + }, op.Association) +} + +// handleAssociationForOwners is a helper function that handles associations for all owners +func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { + var owners []T + if err := base.Find(&owners).Error; err != nil { + return err + } + + for _, owner := range owners { + assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) + if assoc.Error != nil { + return assoc.Error + } + + if err := handler(owner, assoc); err != nil { + return err + } + } + return nil +} + +func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { + assoc := base.Association(op.Association) + if assoc.Error != nil { + return assoc.Error + } + + var ( + rel = assoc.Relationship + assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() + fkNil = map[string]any{} + setMap = make(map[string]any, len(op.Set)) + ownerPKNames []string + ownerFKNames []string + primaryColumns []any + foreignColumns []any + ) + + for _, a := range op.Set { + setMap[a.Column.Name] = a.Value + } + + for _, ref := range rel.References { + fkNil[ref.ForeignKey.DBName] = nil + + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) + } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + } + } + + assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) + + switch rel.Type { + case schema.HasOne, schema.HasMany: + assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) + switch op.Type { + case clause.OpUnlink: + return assocDB.Updates(fkNil).Error + case clause.OpDelete: + return assocDB.Delete(assocModel).Error + case clause.OpUpdate: + return assocDB.Updates(setMap).Error + } + case schema.BelongsTo: + switch op.Type { + case clause.OpDelete: + return base.Transaction(func(tx *DB) error { + assocDB.Statement.ConnPool = tx.Statement.ConnPool + base.Statement.ConnPool = tx.Statement.ConnPool + + if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { + return err + } + return base.Updates(fkNil).Error + }) + case clause.OpUnlink: + return base.Updates(fkNil).Error + case clause.OpUpdate: + return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error + } + case schema.Many2Many: + joinModel := reflect.New(rel.JoinTable.ModelType).Interface() + joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) + + // EXISTS owners: owners.pk = join.owner_fk for all owner refs + ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") + for _, ref := range rel.References { + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownersExists = ownersExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + + // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions + relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + relatedExists = relatedExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + relatedExists = relatedExists.Where(op.Conditions) + + switch op.Type { + case clause.OpUnlink, clause.OpDelete: + joinDB = joinDB.Where("EXISTS (?)", ownersExists) + if len(op.Conditions) > 0 { + joinDB = joinDB.Where("EXISTS (?)", relatedExists) + } + return joinDB.Delete(nil).Error + case clause.OpUpdate: + // Update related table rows that have join rows matching owners + relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) + + // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners + joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + joinSub = joinSub.Where(clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + joinSub = joinSub.Where("EXISTS (?)", ownersExists) + return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error + } + } + return errors.New("unsupported relationship") +} diff --git a/vendor/gorm.io/gorm/schema/utils.go b/vendor/gorm.io/gorm/schema/utils.go index fa1c65d4..d4fe252e 100644 --- a/vendor/gorm.io/gorm/schema/utils.go +++ b/vendor/gorm.io/gorm/schema/utils.go @@ -121,6 +121,17 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, } switch reflectValue.Kind() { + case reflect.Map: + results = [][]interface{}{make([]interface{}, len(fields))} + for idx, field := range fields { + mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) + if mapValue.IsZero() { + mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) + } + results[0][idx] = mapValue.Interface() + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/vendor/gorm.io/gorm/statement.go b/vendor/gorm.io/gorm/statement.go index cd7369e3..736087d7 100644 --- a/vendor/gorm.io/gorm/statement.go +++ b/vendor/gorm.io/gorm/statement.go @@ -336,6 +336,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case []clause.Expression: + conds = append(conds, v...) case *DB: v.executeScopes() diff --git a/vendor/modules.txt b/vendor/modules.txt index dc3ea70b..9d379013 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -385,7 +385,7 @@ gorm.io/driver/mysql # gorm.io/driver/sqlite v1.6.0 ## explicit; go 1.20 gorm.io/driver/sqlite -# gorm.io/gorm v1.30.5 +# gorm.io/gorm v1.31.0 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks