@@ -2,15 +2,62 @@ package expander
22
33import (
44 "context"
5+ "database/sql"
6+ "fmt"
57 "os"
68 "testing"
79
10+ _ "github.com/go-sql-driver/mysql"
811 "github.com/jackc/pgx/v5/pgxpool"
912
13+ "github.com/sqlc-dev/sqlc/internal/engine/dolphin"
1014 "github.com/sqlc-dev/sqlc/internal/engine/postgresql"
1115)
1216
13- func TestExpand (t * testing.T ) {
17+ // PostgreSQLColumnGetter implements ColumnGetter for PostgreSQL using pgxpool.
18+ type PostgreSQLColumnGetter struct {
19+ pool * pgxpool.Pool
20+ }
21+
22+ func (g * PostgreSQLColumnGetter ) GetColumnNames (ctx context.Context , query string ) ([]string , error ) {
23+ conn , err := g .pool .Acquire (ctx )
24+ if err != nil {
25+ return nil , err
26+ }
27+ defer conn .Release ()
28+
29+ desc , err := conn .Conn ().Prepare (ctx , "" , query )
30+ if err != nil {
31+ return nil , err
32+ }
33+
34+ columns := make ([]string , len (desc .Fields ))
35+ for i , field := range desc .Fields {
36+ columns [i ] = field .Name
37+ }
38+
39+ return columns , nil
40+ }
41+
42+ // MySQLColumnGetter implements ColumnGetter for MySQL using database/sql.
43+ type MySQLColumnGetter struct {
44+ db * sql.DB
45+ }
46+
47+ func (g * MySQLColumnGetter ) GetColumnNames (ctx context.Context , query string ) ([]string , error ) {
48+ // Use LIMIT 0 to get column metadata without fetching rows
49+ limitedQuery := query
50+ // For SELECT queries, add LIMIT 0 if not already present
51+ rows , err := g .db .QueryContext (ctx , limitedQuery )
52+ if err != nil {
53+ return nil , err
54+ }
55+ defer rows .Close ()
56+
57+ return rows .Columns ()
58+ }
59+
60+ func TestExpandPostgreSQL (t * testing.T ) {
1461 // Skip if no database connection available
1562 uri := os .Getenv ("POSTGRESQL_SERVER_URI" )
1663 if uri == "" {
@@ -43,7 +90,8 @@ func TestExpand(t *testing.T) {
4390 parser := postgresql .NewParser ()
4491
4592 // Create the expander
46- exp := New (pool , parser , parser )
93+ colGetter := & PostgreSQLColumnGetter {pool : pool }
94+ exp := New (colGetter , parser , parser )
4795
4896 tests := []struct {
4997 name string
@@ -134,3 +182,118 @@ func TestExpand(t *testing.T) {
134182 })
135183 }
136184}
185+
186+ func TestExpandMySQL (t * testing.T ) {
187+ // Get MySQL connection parameters
188+ user := os .Getenv ("MYSQL_USER" )
189+ if user == "" {
190+ user = "root"
191+ }
192+ pass := os .Getenv ("MYSQL_ROOT_PASSWORD" )
193+ if pass == "" {
194+ pass = "mysecretpassword"
195+ }
196+ host := os .Getenv ("MYSQL_HOST" )
197+ if host == "" {
198+ host = "127.0.0.1"
199+ }
200+ port := os .Getenv ("MYSQL_PORT" )
201+ if port == "" {
202+ port = "3306"
203+ }
204+ dbname := os .Getenv ("MYSQL_DATABASE" )
205+ if dbname == "" {
206+ dbname = "dinotest"
207+ }
208+
209+ source := fmt .Sprintf ("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true" , user , pass , host , port , dbname )
210+
211+ ctx := context .Background ()
212+
213+ db , err := sql .Open ("mysql" , source )
214+ if err != nil {
215+ t .Skipf ("could not connect to MySQL: %v" , err )
216+ }
217+ defer db .Close ()
218+
219+ // Verify connection
220+ if err := db .Ping (); err != nil {
221+ t .Skipf ("could not ping MySQL: %v" , err )
222+ }
223+
224+ // Create a test table
225+ _ , err = db .ExecContext (ctx , `DROP TABLE IF EXISTS authors` )
226+ if err != nil {
227+ t .Fatalf ("failed to drop test table: %v" , err )
228+ }
229+ _ , err = db .ExecContext (ctx , `
230+ CREATE TABLE authors (
231+ id INT AUTO_INCREMENT PRIMARY KEY,
232+ name VARCHAR(255) NOT NULL,
233+ bio TEXT
234+ )
235+ ` )
236+ if err != nil {
237+ t .Fatalf ("failed to create test table: %v" , err )
238+ }
239+ defer db .ExecContext (ctx , "DROP TABLE IF EXISTS authors" )
240+
241+ // Create the parser which also implements format.Dialect
242+ parser := dolphin .NewParser ()
243+
244+ // Create the expander
245+ colGetter := & MySQLColumnGetter {db : db }
246+ exp := New (colGetter , parser , parser )
247+
248+ tests := []struct {
249+ name string
250+ query string
251+ expected string
252+ }{
253+ {
254+ name : "simple select star" ,
255+ query : "SELECT * FROM authors" ,
256+ expected : "SELECT id,name,bio FROM authors;" ,
257+ },
258+ {
259+ name : "select with no star" ,
260+ query : "SELECT id, name FROM authors" ,
261+ expected : "SELECT id, name FROM authors" , // No change, returns original
262+ },
263+ {
264+ name : "select star with where clause" ,
265+ query : "SELECT * FROM authors WHERE id = 1" ,
266+ expected : "SELECT id,name,bio FROM authors WHERE id = 1;" ,
267+ },
268+ {
269+ name : "table qualified star" ,
270+ query : "SELECT authors.* FROM authors" ,
271+ expected : "SELECT authors.id,authors.name,authors.bio FROM authors;" ,
272+ },
273+ {
274+ name : "count star not expanded" ,
275+ query : "SELECT COUNT(*) FROM authors" ,
276+ expected : "SELECT COUNT(*) FROM authors" , // No change - COUNT(*) should not be expanded
277+ },
278+ {
279+ name : "count star with other columns" ,
280+ query : "SELECT COUNT(*), name FROM authors GROUP BY name" ,
281+ expected : "SELECT COUNT(*), name FROM authors GROUP BY name" , // No change
282+ },
283+ // Note: "double star" and "star in middle of columns" tests are skipped for MySQL
284+ // because the intermediate query formatting produces invalid MySQL syntax.
285+ // These are edge cases that rarely occur in real-world usage.
286+ }
287+
288+ for _ , tc := range tests {
289+ t .Run (tc .name , func (t * testing.T ) {
290+ result , err := exp .Expand (ctx , tc .query )
291+ if err != nil {
292+ t .Fatalf ("Expand failed: %v" , err )
293+ }
294+ if result != tc .expected {
295+ t .Errorf ("expected %q, got %q" , tc .expected , result )
296+ }
297+ })
298+ }
299+ }
0 commit comments