Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent interface type array from causing runtime errors #7361

Closed
wants to merge 10 commits into from
14 changes: 7 additions & 7 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}

tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
return db.create(value)
}

// CreateInBatches inserts value in batches of batchSize
Expand Down Expand Up @@ -63,12 +60,15 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {

tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx = tx.callbacks.Create().Execute(tx)
db.create(value)
}
return
}
func (db *DB) create(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
}

// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) {
Expand Down
7 changes: 6 additions & 1 deletion scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
if reflectValueType.Kind() == reflect.Interface && reflectValue.Len() > 0 {
reflectValueType = reflect.Indirect(reflectValue.Index(0)).Elem().Type()
}
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
Expand Down Expand Up @@ -318,7 +321,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
} else {
elem = reflect.New(reflectValueType)
}

if elem.Type().Kind() == reflect.Interface {
elem = elem.Elem()
}
db.scanIntoStruct(rows, elem, values, fields, joinFields)

if !update {
Expand Down
3 changes: 3 additions & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ func (field *Field) setupValuerAndSetter() {
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Kind() == reflect.Interface {
v = reflect.Indirect(v)
}
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
Expand Down
4 changes: 3 additions & 1 deletion schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam

for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
if modelType.Kind() == reflect.Interface && value.Len() > 0 {
modelType = reflect.Indirect(value.Index(0)).Elem().Type()
}
}

if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
Expand Down
62 changes: 62 additions & 0 deletions tests/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,3 +791,65 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Errorf("failed to create data from map with table, @id != id")
}
}

func TestCreateWithInterfaceType(t *testing.T) {
user := *GetUser("create", Config{})
type UserInterface interface{}
var userInterface UserInterface = &user

if results := DB.Create(userInterface); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}

if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}

if user.CreatedAt.IsZero() {
t.Errorf("user's created at should be not zero")
}

if user.UpdatedAt.IsZero() {
t.Errorf("user's updated at should be not zero")
}

var newUser User
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
CheckUser(t, newUser, user)
}
}

func TestCreateWithInterfaceArrayTypeWithTable(t *testing.T) {
user := *GetUser("create", Config{})
type UserInterface interface{}
var userInterface UserInterface = user

if results := DB.Table("users").Create([]UserInterface{userInterface}); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}

if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}

if user.CreatedAt.IsZero() {
t.Errorf("user's created at should be not zero")
}

if user.UpdatedAt.IsZero() {
t.Errorf("user's updated at should be not zero")
}

var newUser User
if err := DB.Table("users").Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
CheckUser(t, newUser, user)
}
}
Loading