diff --git a/go.mod b/go.mod index 1ef71c9d..426d0fb9 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( gorm.io/datatypes v1.2.6 gorm.io/driver/mysql v1.6.0 gorm.io/driver/sqlite v1.6.0 - gorm.io/gorm v1.30.1 + gorm.io/gorm v1.30.2 ) require ( diff --git a/go.sum b/go.sum index ef3ada85..14a8097b 100644 --- a/go.sum +++ b/go.sum @@ -213,5 +213,5 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= gorm.io/driver/sqlserver v1.6.0 h1:VZOBQVsVhkHU/NzNhRJKoANt5pZGQAS1Bwc6m6dgfnc= gorm.io/driver/sqlserver v1.6.0/go.mod h1:WQzt4IJo/WHKnckU9jXBLMJIVNMVeTu25dnOzehntWw= -gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= -gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.30.2 h1:f7bevlVoVe4Byu3pmbWPVHnPsLoWaMjEb7/clyr9Ivs= +gorm.io/gorm v1.30.2/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/vendor/gorm.io/gorm/callbacks.go b/vendor/gorm.io/gorm/callbacks.go index 50b5b0e9..bd97f040 100644 --- a/vendor/gorm.io/gorm/callbacks.go +++ b/vendor/gorm.io/gorm/callbacks.go @@ -89,10 +89,16 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } + if db.DefaultContextTimeout > 0 { + if _, ok := stmt.Context.Deadline(); !ok { + stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) + } + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/vendor/gorm.io/gorm/callbacks/create.go b/vendor/gorm.io/gorm/callbacks/create.go index cb8429b3..e5929adb 100644 --- a/vendor/gorm.io/gorm/callbacks/create.go +++ b/vendor/gorm.io/gorm/callbacks/create.go @@ -80,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + onConflict, _ := c.Expression.(clause.OnConflict) + if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing + } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { + mode |= gorm.ScanUpdate } } diff --git a/vendor/gorm.io/gorm/finisher_api.go b/vendor/gorm.io/gorm/finisher_api.go index 57809d17..e601fe66 100644 --- a/vendor/gorm.io/gorm/finisher_api.go +++ b/vendor/gorm.io/gorm/finisher_api.go @@ -675,9 +675,9 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { } ctx := tx.Statement.Context - if _, ok := ctx.Deadline(); !ok { - if db.Config.DefaultTransactionTimeout > 0 { - ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + if db.DefaultTransactionTimeout > 0 { + if _, ok := ctx.Deadline(); !ok { + ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } diff --git a/vendor/gorm.io/gorm/generics.go b/vendor/gorm.io/gorm/generics.go index f3c3e553..5f1fce8b 100644 --- a/vendor/gorm.io/gorm/generics.go +++ b/vendor/gorm.io/gorm/generics.go @@ -425,12 +425,12 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { - relationships := db.Statement.Schema.Relationships + relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { - relationships = relation.FieldSchema.Relationships + relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil diff --git a/vendor/gorm.io/gorm/gorm.go b/vendor/gorm.io/gorm/gorm.go index 6619f071..a209bb09 100644 --- a/vendor/gorm.io/gorm/gorm.go +++ b/vendor/gorm.io/gorm/gorm.go @@ -23,6 +23,7 @@ type Config struct { // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration + DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer diff --git a/vendor/gorm.io/gorm/logger/slog.go b/vendor/gorm.io/gorm/logger/slog.go new file mode 100644 index 00000000..44f289e6 --- /dev/null +++ b/vendor/gorm.io/gorm/logger/slog.go @@ -0,0 +1,90 @@ +package logger + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" +) + +type slogLogger struct { + Logger *slog.Logger + LogLevel LogLevel + SlowThreshold time.Duration + Parameterized bool + Colorful bool // Ignored in slog + IgnoreRecordNotFoundError bool +} + +func NewSlogLogger(logger *slog.Logger, config Config) Interface { + return &slogLogger{ + Logger: logger, + LogLevel: config.LogLevel, + SlowThreshold: config.SlowThreshold, + Parameterized: config.ParameterizedQueries, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, + } +} + +func (l *slogLogger) LogMode(level LogLevel) Interface { + newLogger := *l + newLogger.LogLevel = level + return &newLogger +} + +func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.Logger.InfoContext(ctx, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.Logger.WarnContext(ctx, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.Logger.ErrorContext(ctx, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + sql, rows := fc() + fields := []slog.Attr{ + slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), + slog.String("sql", sql), + } + + if rows != -1 { + fields = append(fields, slog.Int64("rows", rows)) + } + + switch { + case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): + fields = append(fields, slog.String("error", err.Error())) + l.Logger.ErrorContext(ctx, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: + l.Logger.WarnContext(ctx, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.LogLevel >= Info: + l.Logger.InfoContext(ctx, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + } +} diff --git a/vendor/gorm.io/gorm/schema/field.go b/vendor/gorm.io/gorm/schema/field.go index 67e60f70..de797402 100644 --- a/vendor/gorm.io/gorm/schema/field.go +++ b/vendor/gorm.io/gorm/schema/field.go @@ -458,20 +458,12 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) { case len(field.StructField.Index) == 1 && fieldIndex >= 0: field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { v = reflect.Indirect(v) - if v.Type() != modelType { - fieldValue := v.FieldByName(field.Name) - return fieldValue.Interface(), fieldValue.IsZero() - } fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { v = reflect.Indirect(v) - if v.Type() != modelType { - fieldValue := v.FieldByName(field.Name) - return fieldValue.Interface(), fieldValue.IsZero() - } for _, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { v = v.Field(fieldIdx) @@ -516,17 +508,11 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) { case len(field.StructField.Index) == 1 && fieldIndex >= 0: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { v = reflect.Indirect(v) - if v.Type() != modelType { - return v.FieldByName(field.Name) - } return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { v = reflect.Indirect(v) - if v.Type() != modelType { - return v.FieldByName(field.Name) - } for idx, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { v = v.Field(fieldIdx) diff --git a/vendor/gorm.io/gorm/schema/relationship.go b/vendor/gorm.io/gorm/schema/relationship.go index f1ace924..0535bba4 100644 --- a/vendor/gorm.io/gorm/schema/relationship.go +++ b/vendor/gorm.io/gorm/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/vendor/gorm.io/gorm/schema/schema.go b/vendor/gorm.io/gorm/schema/schema.go index 2a5c28e2..9419846b 100644 --- a/vendor/gorm.io/gorm/schema/schema.go +++ b/vendor/gorm.io/gorm/schema/schema.go @@ -60,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -75,7 +75,7 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } @@ -93,10 +93,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { - if len(bindNames) == 0 { - return nil - } +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { @@ -114,6 +111,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -125,34 +130,33 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - value := reflect.ValueOf(dest) - if value.Kind() == reflect.Ptr && value.IsNil() { - value = reflect.New(value.Type().Elem()) - } - modelType := reflect.Indirect(value).Type() - - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -163,28 +167,29 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { - tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { - tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } - if specialTableName != "" && specialTableName != tableName { + if specialTableName != "" { tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + DBNames: make([]string, 0, 10), + Fields: make([]*Field, 0, 10), + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -228,8 +233,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { + // remove the existing primary key field for idx, f := range schema.PrimaryFields { - if f == v { + if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } @@ -284,10 +290,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -304,30 +337,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": - expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) - if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - } else { - logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) - // PASS - } - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -343,84 +352,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/vendor/gorm.io/gorm/statement.go b/vendor/gorm.io/gorm/statement.go index ba5d3f18..74feaedd 100644 --- a/vendor/gorm.io/gorm/statement.go +++ b/vendor/gorm.io/gorm/statement.go @@ -658,12 +658,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/vendor/modules.txt b/vendor/modules.txt index 5dd8f751..6acb85e6 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -351,7 +351,7 @@ gorm.io/driver/mysql # gorm.io/driver/sqlite v1.6.0 ## explicit; go 1.20 gorm.io/driver/sqlite -# gorm.io/gorm v1.30.1 +# gorm.io/gorm v1.30.2 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks