diff --git a/gen/integration/sqlx/main.go b/gen/integration/sqlx/main.go index 55a6c69..9e6afd3 100644 --- a/gen/integration/sqlx/main.go +++ b/gen/integration/sqlx/main.go @@ -15,8 +15,9 @@ import ( "strings" "time" - _ "github.com/mattn/go-sqlite3" defc "github.com/x5iu/defc/runtime" + + _ "github.com/mattn/go-sqlite3" ) var executor Executor @@ -158,6 +159,40 @@ func main() { } log.Println("All constbind tests passed!") + + // Test that panic in Scan propagates correctly + // When a struct field's Scan method panics, the panic should propagate + // to the caller without causing deadlock. + func() { + // Create a separate database connection for this test + panicDB := defc.MustOpen("sqlite3", ":memory:") + panicDB.Exec("CREATE TABLE user (id INTEGER PRIMARY KEY, name TEXT)") + panicDB.Exec("INSERT INTO user (id, name) VALUES (1, 'test')") + panicCore := &sqlc{panicDB} + panicExecutor := NewExecutorFromCore(panicCore) + + done := make(chan struct{}) + go func() { + defer close(done) + defer func() { + if r := recover(); r == nil { + log.Fatalln("Expected panic but got none") + } else { + log.Printf("Expected panic occurred: %v\n", r) + } + }() + // This should panic, not return an error + _, _ = panicExecutor.GetPanicUser(ctx, 1) + }() + select { + case <-done: + // Test completed successfully (panic was caught in goroutine) + case <-time.After(30 * time.Second): + log.Fatalln("Test timeout: panic test exceeded 30 seconds") + } + }() + + log.Println("All tests passed!") } type sqlc struct { @@ -272,6 +307,11 @@ type Executor interface { // /* {"name": "defc", "action": "test"} */ // update user set name = ${newName} where id = ${id}; UpdateUserName(ctx context.Context, id int64, newName string) (sql.Result, error) + + // GetPanicUser query constbind + // /* {"name": "defc", "action": "test"} */ + // SELECT id, name from user where id = ${id}; + GetPanicUser(ctx context.Context, id int64) (*PanicUser, error) } type UserID struct { @@ -350,6 +390,17 @@ func (project *Project) FromRow(row defc.Row) error { ) } +type PanicUser struct { + ID PanicID `db:"id"` + Name string `db:"name"` +} + +type PanicID int64 + +func (pid *PanicID) Scan(src any) error { + panic("PanicID.Scan: should panic") +} + func sqlComment(context.Context) string { return `/* {"name": "defc", "action": "test"} */` } diff --git a/gen/template/sqlx.tmpl b/gen/template/sqlx.tmpl index 6bbffca..8176148 100644 --- a/gen/template/sqlx.tmpl +++ b/gen/template/sqlx.tmpl @@ -249,9 +249,6 @@ var ( if {{ $tx }} == nil { panic("tx is nil") } - if !__imp.__withTx{ - defer {{ $tx }}.Rollback() - } {{ $offset := printf "offset%s" $method.Ident -}} {{ $args := printf "args%s" $method.Ident -}} @@ -268,6 +265,7 @@ var ( {{- if $.HasFeature "sqlx/in" }} {{ $query }}, {{ $args }}, {{ $err }} := {{ if $.HasFeature "sqlx/nort" }}{{ $inFunc }}{{ else }}__rt.In{{ end }}({{ $query }}, {{ $argList }}) if {{ $err }} != nil { + if !__imp.__withTx { {{ $tx }}.Rollback() } return {{ range $index, $type := $method.Out -}} {{- if lt $index (sub (len $method.Out) 1) -}} v{{- $index -}}{{- $method.Ident }}, @@ -296,6 +294,7 @@ var ( {{ $splitSql }}, {{ $argList }}, {{ $err }} = sqlx.Named({{ $splitSql }}, {{ $args }}) if {{ $err }} != nil { + if !__imp.__withTx { {{ $tx }}.Rollback() } return {{ range $index, $type := $method.Out -}} {{- if lt $index (sub (len $method.Out) 1) -}} v{{- $index -}}{{- $method.Ident }}, @@ -309,6 +308,7 @@ var ( {{ $splitSql }}, {{ $argList }}, {{ $err }} = sqlx.In({{ $splitSql }}, {{ $argList }}...) {{ end -}} if {{ $err }} != nil { + if !__imp.__withTx { {{ $tx }}.Rollback() } return {{ range $index, $type := $method.Out -}} {{- if lt $index (sub (len $method.Out) 1) -}} v{{- $index -}}{{- $method.Ident }}, @@ -368,6 +368,7 @@ var ( {{ end }} if {{ $err }} != nil { + if !__imp.__withTx { {{ $tx }}.Rollback() } return {{ range $index, $type := $method.Out -}} {{- if lt $index (sub (len $method.Out) 1) -}} v{{- $index -}}{{- $method.Ident }}, diff --git a/runtime/version.go b/runtime/version.go index 5ad8eaa..0eda52c 100644 --- a/runtime/version.go +++ b/runtime/version.go @@ -1,3 +1,3 @@ package defc -const Version = "v1.44.2" +const Version = "v1.44.3" diff --git a/sqlx/sqlx.refined.go b/sqlx/sqlx.refined.go index cda7087..5228221 100644 --- a/sqlx/sqlx.refined.go +++ b/sqlx/sqlx.refined.go @@ -572,27 +572,36 @@ func (r *Row) Scan(dest ...any) error { if r.err != nil { return r.err } - defer r.rows.Close() + // Note: We intentionally do NOT use defer r.rows.Close() here. + // If rows.Scan panics (e.g., due to a custom Scanner implementation), + // calling rows.Close() in a defer will cause a deadlock with certain + // database drivers (e.g., go-sqlite3). By not using defer, panic will + // propagate naturally without attempting to close the rows. for _, dp := range dest { if _, ok := dp.(*sql.RawBytes); ok { + _ = r.rows.Close() return errors.New("sql: RawBytes isn't allowed on Row.Scan") } } if !r.rows.Next() { if err := r.rows.Err(); err != nil { + _ = r.rows.Close() return err } + _ = r.rows.Close() return sql.ErrNoRows } err := r.rows.Scan(dest...) if err != nil { + _ = r.rows.Close() return err } // Make sure the query can be processed to completion with no errors. - if err := r.rows.Close(); err != nil { + if err = r.rows.Err(); err != nil { + _ = r.rows.Close() return err } - return nil + return r.rows.Close() } // Columns returns the underlying sql.Rows.Columns(), or the deferred error usually @@ -957,13 +966,16 @@ func (r *Row) scanAny(dest any, structOnly bool) error { r.err = sql.ErrNoRows return r.err } - defer r.rows.Close() + // Note: r.Scan() will close r.rows, so we don't defer close here. + // For early returns before Scan, we must close rows explicitly. v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { + _ = r.rows.Close() return errors.New("must pass a pointer, not a value, to StructScan destination") } if v.IsNil() { + _ = r.rows.Close() return errors.New("nil pointer passed to StructScan destination") } @@ -971,15 +983,18 @@ func (r *Row) scanAny(dest any, structOnly bool) error { scannable := isScannable(base) if structOnly && scannable { + _ = r.rows.Close() return structOnlyError(base) } columns, err := r.Columns() if err != nil { + _ = r.rows.Close() return err } if scannable && len(columns) > 1 { + _ = r.rows.Close() return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns)) } @@ -988,7 +1003,9 @@ func (r *Row) scanAny(dest any, structOnly bool) error { } if fr, ok := dest.(FromRow); ok { - return fr.FromRow(r) + err := fr.FromRow(r) + _ = r.rows.Close() + return err } m := r.Mapper @@ -996,12 +1013,14 @@ func (r *Row) scanAny(dest any, structOnly bool) error { fields := m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !r.unsafe { + _ = r.rows.Close() return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values := make([]any, len(columns)) err = fieldsByTraversal(v, fields, values, true) if err != nil { + _ = r.rows.Close() return err } // scan into the struct field pointers and append to our results