Skip to content

Commit 0d97c17

Browse files
committed
feat: add queriestest package
1 parent 5e57ce4 commit 0d97c17

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed

queriestest/driver.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Package queriestest implements utilities for testing SQL queries.
2+
package queriestest
3+
4+
import (
5+
"context"
6+
"database/sql"
7+
"database/sql/driver"
8+
"io"
9+
"slices"
10+
"testing"
11+
)
12+
13+
type Driver struct {
14+
ExecContext func(t *testing.T, query string, args []any) (driver.Result, error)
15+
QueryContext func(t *testing.T, query string, args []any) (driver.Rows, error)
16+
}
17+
18+
func NewDB(t *testing.T, d Driver) *sql.DB {
19+
name := t.Name()
20+
sql.Register(name, testDriver{t, d})
21+
db, _ := sql.Open(name, "")
22+
return db
23+
}
24+
25+
var (
26+
_ driver.Driver = testDriver{}
27+
_ driver.Conn = testDriver{}
28+
_ driver.ConnBeginTx = testDriver{}
29+
_ driver.Tx = testDriver{}
30+
_ driver.ExecerContext = testDriver{}
31+
_ driver.QueryerContext = testDriver{}
32+
)
33+
34+
type testDriver struct {
35+
t *testing.T
36+
driver Driver
37+
}
38+
39+
// Open implements [driver.Driver].
40+
func (d testDriver) Open(string) (driver.Conn, error) { return d, nil }
41+
42+
// Prepare implements [driver.Conn].
43+
func (testDriver) Prepare(string) (driver.Stmt, error) { panic("unimplemented") }
44+
45+
// Close implements [driver.Conn].
46+
func (testDriver) Close() error { return nil }
47+
48+
// Begin implements [driver.Conn].
49+
func (testDriver) Begin() (driver.Tx, error) {
50+
panic("unreachable") // BeginTx always takes precedence over Begin.
51+
}
52+
53+
// BeginTx implements [driver.ConnBeginTx].
54+
func (d testDriver) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) { return d, nil }
55+
56+
// Commit implements [driver.Tx].
57+
func (testDriver) Commit() error { return nil }
58+
59+
// Rollback implements [driver.Tx].
60+
func (testDriver) Rollback() error { return nil }
61+
62+
// ExecContext implements [driver.ExecerContext].
63+
func (d testDriver) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
64+
if d.driver.ExecContext == nil {
65+
panic("queriestest: Driver.ExecContext is called but not set")
66+
}
67+
return d.driver.ExecContext(d.t, query, namedToAny(args))
68+
}
69+
70+
// QueryContext implements [driver.QueryerContext].
71+
func (d testDriver) QueryContext(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
72+
if d.driver.QueryContext == nil {
73+
panic("queriestest: Driver.QueryContext is called but not set")
74+
}
75+
return d.driver.QueryContext(d.t, query, namedToAny(args))
76+
}
77+
78+
func namedToAny(values []driver.NamedValue) []any {
79+
args := make([]any, len(values))
80+
for i, value := range values {
81+
args[i] = value.Value
82+
}
83+
return args
84+
}
85+
86+
var _ driver.Result = testResult{}
87+
88+
type testResult struct {
89+
lastInsertId int64
90+
rowsAffected int64
91+
}
92+
93+
func NewResult(lastInsertId, rowsAffected int64) driver.Result {
94+
return testResult{lastInsertId, rowsAffected}
95+
}
96+
97+
// LastInsertId implements [driver.Result].
98+
func (r testResult) LastInsertId() (int64, error) { return r.lastInsertId, nil }
99+
100+
// RowsAffected implements [driver.Result].
101+
func (r testResult) RowsAffected() (int64, error) { return r.rowsAffected, nil }
102+
103+
var _ driver.Rows = new(Rows)
104+
105+
type Rows struct {
106+
columns []string
107+
values [][]any
108+
}
109+
110+
func NewRows(columns ...string) *Rows {
111+
return &Rows{columns: columns}
112+
}
113+
114+
func (r *Rows) Add(values ...any) *Rows {
115+
r.values = append(r.values, values)
116+
return r
117+
}
118+
119+
// Columns implements [driver.Rows].
120+
func (r *Rows) Columns() []string { return r.columns }
121+
122+
// Close implements [driver.Rows].
123+
func (r *Rows) Close() error { return nil }
124+
125+
// Next implements [driver.Rows].
126+
func (r *Rows) Next(values []driver.Value) error {
127+
if len(r.values) == 0 {
128+
return io.EOF
129+
}
130+
for i := range values {
131+
values[i] = r.values[0][i]
132+
}
133+
r.values = slices.Delete(r.values, 0, 1)
134+
return nil
135+
}

0 commit comments

Comments
 (0)