diff --git a/README.md b/README.md index 0a6d52177..f7cd3adac 100644 --- a/README.md +++ b/README.md @@ -80,23 +80,28 @@ Table of Contents ### Features - Full model generation -- High performance through generation - Extremely fast code generation +- High performance through generation & intelligent caching - Uses boil.Executor (simple interface, sql.DB, sqlx.DB etc. compatible) - Easy workflow (models can always be regenerated, full auto-complete) - Strongly typed querying (usually no converting or binding to pointers) - Hooks (Before/After Create/Select/Update/Delete/Upsert) - Automatic CreatedAt/UpdatedAt +- Table whitelist/blacklist - Relationships/Associations - Eager loading (recursive) +- Custom struct tags +- Schema support - Transactions - Raw SQL fallback - Compatibility tests (Run against your own DB schema) - Debug logging +- Postgres 1d arrays, json, hstore & more ### Supported Databases - PostgreSQL +- MySQL *Note: Seeking contributors for other database engines.* @@ -203,30 +208,32 @@ order: - `$XDG_CONFIG_HOME/sqlboiler/` - `$HOME/.config/sqlboiler/` -We require you pass in the `postgres` configuration via the configuration file rather than env vars. -There is no command line argument support for database configuration. Values given under the `postgres` -block are passed directly to the [pq](github.com/lib/pq) driver. Here is a rundown of all the different +We require you pass in your `postgres` and `mysql` database configuration via the configuration file rather than env vars. +There is no command line argument support for database configuration. Values given under the `postgres` and `mysql` +block are passed directly to the postgres and mysql drivers. Here is a rundown of all the different values that can go in that section: -| Name | Required | Default | -| --- | --- | --- | -| dbname | yes | none | -| host | yes | none | -| port | no | 5432 | -| user | yes | none | -| pass | no | none | -| sslmode | no | "require" | +| Name | Required | Postgres Default | MySQL Default | +| --- | --- | --- | --- | +| dbname | yes | none | none | +| host | yes | none | none | +| port | no | 5432 | 3306 | +| user | yes | none | none | +| pass | no | none | none | +| sslmode | no | "require" | "true" | You can also pass in these top level configuration values if you would prefer not to pass them through the command line or environment variables: -| Name | Default | -| --- | --- | +| Name | Defaults | +| ------------------ | --------- | | basedir | none | +| schema | "public" *(or dbname for mysql)* | | pkgname | "models" | | output | "models" | -| exclude | [ ] | -| tag | [ ] | +| whitelist | [] | +| blacklist | [] | +| tag | [] | | debug | false | | no-hooks | false | | no-tests | false | @@ -256,23 +263,26 @@ Usage: sqlboiler [flags] Examples: -sqlboiler postgres + sqlboiler postgres + sqlboiler mysql Flags: - -b, --basedir string The base directory has the templates and templates_test folders + -b, --blacklist stringSlice Do not include these tables in your generated package + -w, --whitelist stringSlice Only include these tables in your generated package + -s, --schema string The name of your database schema, for databases that support real schemas (default "public") + -p, --pkgname string The name you wish to assign to your generated package (default "models") + -o, --output string The name of the folder to output to (default "models") + -t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml -d, --debug Debug mode prints stack traces on error - -x, --exclude stringSlice Tables to be excluded from the generated package + --basedir string The base directory has the templates and templates_test folders --no-auto-timestamps Disable automatic timestamps for created_at/updated_at --no-hooks Disable hooks feature for your models --no-tests Disable generated go test files - -o, --output string The name of the folder to output to (default "models") - -p, --pkgname string The name you wish to assign to your generated package (default "models") - -t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml ``` -Follow the steps below to do some basic model generation. Once we've generated -our models, we can run the compatibility tests which will exercise the entirety -of the generated code. This way we can ensure that our database is compatible +Follow the steps below to do some basic model generation. Once you've generated +your models, you can run the compatibility tests which will exercise the entirety +of the generated code. This way you can ensure that your database is compatible with SQLBoiler. If you find there are some failing tests, please check the [Diagnosing Problems](#diagnosing-problems) section. @@ -281,8 +291,7 @@ with SQLBoiler. If you find there are some failing tests, please check the sqlboiler -x goose_migrations postgres # Run the generated tests -go test ./models # This requires an administrator postgres user because of some - # voodoo we do to disable triggers for the generated test db +go test ./models ``` ## Diagnosing Problems @@ -292,7 +301,7 @@ The most common causes of problems and panics are: - Forgetting to exclude tables you do not want included in your generation, like migration tables. - Tables without a primary key. All tables require one. - Forgetting to put foreign key constraints on your columns that reference other tables. -- The compatibility tests that run against your own DB schema require a superuser, ensure the user +- The compatibility tests require privileges to create a database for testing purposes, ensure the user supplied in your `sqlboiler.toml` config has adequate privileges. - A nil or closed database handle. Ensure your passed in `boil.Executor` is not nil. - If you decide to use the `G` variant of functions instead, make sure you've initialized your @@ -345,9 +354,8 @@ ALTER TABLE pilot_languages ADD CONSTRAINT pilots_fkey FOREIGN KEY (pilot_id) RE ALTER TABLE pilot_languages ADD CONSTRAINT languages_fkey FOREIGN KEY (language_id) REFERENCES languages(id); ``` -The generated model structs for this schema look like the following. Note that I've included the relationship -structs as well so you can see how it all pieces together, but these are unexported and not something you should -ever need to touch directly: +The generated model structs for this schema look like the following. Note that we've included the relationship +structs as well so you can see how it all pieces together: ```go type Pilot struct { @@ -355,6 +363,7 @@ type Pilot struct { Name string `boil:"name" json:"name" toml:"name" yaml:"name"` R *pilotR `boil:"-" json:"-" toml:"-" yaml:"-"` + L pilotR `boil:"-" json:"-" toml:"-" yaml:"-"` } type pilotR struct { @@ -371,6 +380,7 @@ type Jet struct { Color string `boil:"color" json:"color" toml:"color" yaml:"color"` R *jetR `boil:"-" json:"-" toml:"-" yaml:"-"` + L jetR `boil:"-" json:"-" toml:"-" yaml:"-"` } type jetR struct { @@ -382,6 +392,7 @@ type Language struct { Language string `boil:"language" json:"language" toml:"language" yaml:"language"` R *languageR `boil:"-" json:"-" toml:"-" yaml:"-"` + L languageR `boil:"-" json:"-" toml:"-" yaml:"-"` } type languageR struct { @@ -414,7 +425,7 @@ Note: You can set the timezone for this feature by calling `boil.SetLocation()` This is somewhat of a work around until we can devise a better solution in a later version. * **Update** * The `updated_at` column will always be set to `time.Now()`. If you need to override - this value you will need to fall back to another method in the meantime: `boil.SQL()`, + this value you will need to fall back to another method in the meantime: `queries.Raw()`, overriding `updated_at` in all of your objects using a hook, or create your own wrapper. * **Upsert** * `created_at` will be set automatically if it is a zero value, otherwise your supplied value @@ -452,36 +463,8 @@ err := models.NewQuery(db, From("pilots")).All() As you can see, [Query Mods](#query-mods) allow you to modify your queries, and [Finishers](#finishers) allow you to execute the final action. -If you plan on executing the same query with the same values using the query builder, -you should do so like the following to utilize caching: - -```go -// Instead of this: -for i := 0; i < 10; i++ { - pilots := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5)).All() -} - -// You should do this -query := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5)) -for i := 0; i < 10; i++ { - pilots := query.All() -} - -// Every execution of All() after the first will use a cached version of -// the built query that short circuits the query builder all together. -// This allows you to save on performance. - -// Just something to be aware of: query mods don't store pointers, so if -// your passed in variable's value changes, your generated query will not change. -``` - -Note: You will see exported `boil.SetX` methods in the boil package. These should not be used on query -objects because they will break caching. Unfortunately these had to be exported due to some circular -dependency issues, but they're not functionality we want exposed. If you want a different -query object, generate a new one. - -Take a look at our [Relationships Query Building](#relationships) section for some additional query -building information. +We also generate query building helper methods for your relationships as well. Take a look at our +[Relationships Query Building](#relationships) section for some additional query building information. ### Query Mod System @@ -575,26 +558,31 @@ UpdateAll(models.M{"name": "John", "age": 23}) // Update all rows matching the b DeleteAll() // Delete all rows matching the built query. Exists() // Returns a bool indicating whether the row(s) for the built query exists. Bind(&myObj) // Bind the results of a query to your own struct object. +Exec() // Execute an SQL query that does not require any rows returned. +QueryRow() // Execute an SQL query expected to return only a single row. +Query() // Execute an SQL query expected to return multiple rows. ``` ### Raw Query -We provide `boil.SQL()` for executing raw queries. Generally you will want to use `Bind()` with +We provide `queries.Raw()` for executing raw queries. Generally you will want to use `Bind()` with this, like the following: ```go -err := boil.SQL(db, "select * from pilots where id=$1", 5).Bind(&obj) +err := queries.Raw(db, "select * from pilots where id=$1", 5).Bind(&obj) ``` You can use your own structs or a generated struct as a parameter to Bind. Bind supports both a single object for single row queries and a slice of objects for multiple row queries. -You also have `models.NewQuery()` at your disposal if you would still like to use [Query Build](#query-building) -but would like to build against a non-generated model. +`queries.Raw()` also has a method that can execute a query without binding to an object, if required. + +You also have `models.NewQuery()` at your disposal if you would still like to use [Query Building](#query-building) +in combination with your own custom, non-generated model. ### Binding -For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/boil#Bind). +For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/queries#Bind). The `Bind()` [Finisher](#finisher) allows the results of a query built with the [Raw SQL](#raw-query) method or the [Query Builder](#query-building) methods to be bound @@ -613,7 +601,7 @@ type PilotAndJet struct { var paj PilotAndJet // Use a raw query -err := boil.SQL(` +err := queries.Raw(` select pilots.id as "pilots.id", pilots.name as "pilots.name", jets.id as "jets.id", jets.pilot_id as "jets.pilot_id", jets.age as "jets.age", jets.name as "jets.name", jets.color as "jets.color" @@ -641,7 +629,7 @@ var info JetInfo err := models.NewQuery(db, Select("sum(age) as age_sum", "count(*) as juicy_count", From("jets"))).Bind(&info) // Use a raw query -err := boil.SQL(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info) +err := queries.Raw(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info) ``` We support the following struct tag modes for `Bind()` control: @@ -905,8 +893,8 @@ err := p1.Insert(db) // Insert the first pilot with name "Larry" // p1 now has an ID field set to 1 var p2 models.Pilot -p2.Name "Borris" -err := p2.Insert(db) // Insert the second pilot with name "Borris" +p2.Name "Boris" +err := p2.Insert(db) // Insert the second pilot with name "Boris" // p2 now has an ID field set to 2 var p3 models.Pilot @@ -999,8 +987,13 @@ p1.Name = "Hogan" err := p1.Upsert(db, true, []string{"id"}, []string{"name"}, "id", "name") ``` +The `updateOnConflict` argument allows you to specify whether you would like Postgres +to perform a `DO NOTHING` on conflict, opposed to a `DO UPDATE`. For MySQL, this param will not be generated. + +The `conflictColumns` argument allows you to specify the `ON CONFLICT` columns for Postgres. +For MySQL, this param will not be generated. + Note: Passing a different set of column values to the update component is not currently supported. -If this feature is important to you let us know and we can consider adding something for this. ### Reload In the event that your objects get out of sync with the database for whatever reason, @@ -1010,7 +1003,7 @@ attached to the objects. ```go pilot, _ := models.FindPilot(db, 1) -// > Object becomes out of sync for some reason +// > Object becomes out of sync for some reason, perhaps async processing // Refresh the object with the latest data from the db err := pilot.Reload(db) @@ -1051,10 +1044,26 @@ The generated models might import a couple of packages that are not on your syst `cd` into your generated models directory and type `go get -u -t` to fetch them. You will only need to run this command once, not per generation. +#### How should I handle multiple schemas? + +If your database uses multiple schemas you should generate a new package for each of your schemas. +Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not +fake schemas (like MySQL). + +#### How do I use types.BytesArray for Postgres bytea arrays? + +Only "escaped format" is supported for types.BytesArray. This means that your byte slice needs to have +a format of "\\x00" (4 bytes per byte) opposed to "\x00" (1 byte per byte). This is to maintain compatibility +with all Postgres drivers. Example: + +`x := types.BytesArray{0: []byte("\\x68\\x69")}` + +Please note that multi-dimensional Postgres ARRAY types are not supported at this time. + #### Where is the homepage? -The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) -[Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler +The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler) +generator is located at: https://github.com/vattle/sqlboiler ## Benchmarks diff --git a/bdb/column.go b/bdb/column.go index 688bd95bc..b6cd3359d 100644 --- a/bdb/column.go +++ b/bdb/column.go @@ -5,6 +5,11 @@ import "github.com/vattle/sqlboiler/strmangle" // Column holds information about a database column. // Types are Go types, converted by TranslateColumnType. type Column struct { + // ArrType is the underlying data type of the Postgres + // ARRAY type. See here: + // https://www.postgresql.org/docs/9.1/static/infoschema-element-types.html + ArrType *string + UDTName string Name string Type string DBType string diff --git a/bdb/drivers/mock.go b/bdb/drivers/mock.go index 692062797..abe85a482 100644 --- a/bdb/drivers/mock.go +++ b/bdb/drivers/mock.go @@ -9,13 +9,16 @@ import ( type MockDriver struct{} // TableNames returns a list of mock table names -func (m *MockDriver) TableNames(exclude []string) ([]string, error) { +func (m *MockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { + if len(whitelist) > 0 { + return whitelist, nil + } tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} - return strmangle.SetComplement(tables, exclude), nil + return strmangle.SetComplement(tables, blacklist), nil } // Columns returns a list of mock columns -func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) { +func (m *MockDriver) Columns(schema, tableName string) ([]bdb.Column, error) { return map[string][]bdb.Column{ "pilots": { {Name: "id", Type: "int", DBType: "integer"}, @@ -56,7 +59,7 @@ func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) { } // ForeignKeyInfo returns a list of mock foreignkeys -func (m *MockDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) { +func (m *MockDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { return map[string][]bdb.ForeignKey{ "jets": { {Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true}, @@ -79,7 +82,7 @@ func (m *MockDriver) TranslateColumnType(c bdb.Column) bdb.Column { } // PrimaryKeyInfo returns mock primary key info for the passed in table name -func (m *MockDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) { +func (m *MockDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) { return map[string]*bdb.PrimaryKey{ "pilots": { Name: "pilot_id_pkey", @@ -120,3 +123,18 @@ func (m *MockDriver) Open() error { return nil } // Close mimics a database close call func (m *MockDriver) Close() {} + +// RightQuote is the quoting character for the right side of the identifier +func (m *MockDriver) RightQuote() byte { + return '"' +} + +// LeftQuote is the quoting character for the left side of the identifier +func (m *MockDriver) LeftQuote() byte { + return '"' +} + +// IndexPlaceholders returns true to indicate fake support of indexed placeholders +func (m *MockDriver) IndexPlaceholders() bool { + return false +} diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go new file mode 100644 index 000000000..af94db599 --- /dev/null +++ b/bdb/drivers/mysql.go @@ -0,0 +1,321 @@ +package drivers + +import ( + "database/sql" + "fmt" + "strconv" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/pkg/errors" + "github.com/vattle/sqlboiler/bdb" +) + +// MySQLDriver holds the database connection string and a handle +// to the database connection. +type MySQLDriver struct { + connStr string + dbConn *sql.DB +} + +// NewMySQLDriver takes the database connection details as parameters and +// returns a pointer to a MySQLDriver object. Note that it is required to +// call MySQLDriver.Open() and MySQLDriver.Close() to open and close +// the database connection once an object has been obtained. +func NewMySQLDriver(user, pass, dbname, host string, port int, sslmode string) *MySQLDriver { + driver := MySQLDriver{ + connStr: MySQLBuildQueryString(user, pass, dbname, host, port, sslmode), + } + + return &driver +} + +// MySQLBuildQueryString builds a query string for MySQL. +func MySQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { + var config mysql.Config + + config.User = user + if len(pass) != 0 { + config.Passwd = pass + } + config.DBName = dbname + config.Net = "tcp" + config.Addr = host + if port == 0 { + port = 3306 + } + config.Addr += ":" + strconv.Itoa(port) + config.TLSConfig = sslmode + + return config.FormatDSN() +} + +// Open opens the database connection using the connection string +func (m *MySQLDriver) Open() error { + var err error + m.dbConn, err = sql.Open("mysql", m.connStr) + if err != nil { + return err + } + + return nil +} + +// Close closes the database connection +func (m *MySQLDriver) Close() { + m.dbConn.Close() +} + +// UseLastInsertID returns false for postgres +func (m *MySQLDriver) UseLastInsertID() bool { + return true +} + +// TableNames connects to the postgres database and +// retrieves all table names from the information_schema where the +// table schema is public. +func (m *MySQLDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { + var names []string + + query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`) + args := []interface{}{schema} + if len(whitelist) > 0 { + query += fmt.Sprintf(" and table_name in (%s);", strings.Repeat(",?", len(whitelist))[1:]) + for _, w := range whitelist { + args = append(args, w) + } + } else if len(blacklist) > 0 { + query += fmt.Sprintf(" and table_name not in (%s);", strings.Repeat(",?", len(blacklist))[1:]) + for _, b := range blacklist { + args = append(args, b) + } + } + + rows, err := m.dbConn.Query(query, args...) + + if err != nil { + return nil, err + } + + defer rows.Close() + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + names = append(names, name) + } + + return names, nil +} + +// Columns takes a table name and attempts to retrieve the table information +// from the database information_schema.columns. It retrieves the column names +// and column types and returns those as a []Column after TranslateColumnType() +// converts the SQL types to Go types, for example: "varchar" to "string" +func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) { + var columns []bdb.Column + + rows, err := m.dbConn.Query(` + select column_name, data_type, if(extra = 'auto_increment','auto_increment', column_default), is_nullable, + exists ( + select c.column_name + from information_schema.table_constraints tc + inner join information_schema.key_column_usage kcu + on tc.constraint_name = kcu.constraint_name and tc.table_name = kcu.table_name and tc.table_schema = kcu.table_schema + where c.column_name = kcu.column_name and tc.table_name = c.table_name and + (tc.constraint_type = 'PRIMARY KEY' or tc.constraint_type = 'UNIQUE') + ) as is_unique + from information_schema.columns as c + where table_name = ? and table_schema = ?; + `, tableName, schema) + + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var colName, colType, colDefault, nullable string + var unique bool + var defaultPtr *string + if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil { + return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) + } + + if defaultPtr != nil && *defaultPtr != "NULL" { + colDefault = *defaultPtr + } + + column := bdb.Column{ + Name: colName, + DBType: colType, + Default: colDefault, + Nullable: nullable == "YES", + Unique: unique, + } + columns = append(columns, column) + } + + return columns, nil +} + +// PrimaryKeyInfo looks up the primary key for a table. +func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) { + pkey := &bdb.PrimaryKey{} + var err error + + query := ` + select tc.constraint_name + from information_schema.table_constraints as tc + where tc.table_name = ? and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = ?;` + + row := m.dbConn.QueryRow(query, tableName, schema) + if err = row.Scan(&pkey.Name); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + + queryColumns := ` + select kcu.column_name + from information_schema.key_column_usage as kcu + where table_name = ? and constraint_name = ? and table_schema = ?;` + + var rows *sql.Rows + if rows, err = m.dbConn.Query(queryColumns, tableName, pkey.Name, schema); err != nil { + return nil, err + } + defer rows.Close() + + var columns []string + for rows.Next() { + var column string + + err = rows.Scan(&column) + if err != nil { + return nil, err + } + + columns = append(columns, column) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + pkey.Columns = columns + + return pkey, nil +} + +// ForeignKeyInfo retrieves the foreign keys for a given table name. +func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { + var fkeys []bdb.ForeignKey + + query := ` + select constraint_name, table_name, column_name, referenced_table_name, referenced_column_name + from information_schema.key_column_usage + where table_schema = ? and referenced_table_schema = ? and table_name = ? + ` + + var rows *sql.Rows + var err error + if rows, err = m.dbConn.Query(query, schema, schema, tableName); err != nil { + return nil, err + } + + for rows.Next() { + var fkey bdb.ForeignKey + var sourceTable string + + fkey.Table = tableName + err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn) + if err != nil { + return nil, err + } + + fkeys = append(fkeys, fkey) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return fkeys, nil +} + +// TranslateColumnType converts postgres database types to Go types, for example +// "varchar" to "string" and "bigint" to "int64". It returns this parsed data +// as a Column object. +func (m *MySQLDriver) TranslateColumnType(c bdb.Column) bdb.Column { + if c.Nullable { + switch c.DBType { + case "tinyint": + c.Type = "null.Int8" + case "smallint": + c.Type = "null.Int16" + case "mediumint", "int", "integer": + c.Type = "null.Int" + case "bigint": + c.Type = "null.Int64" + case "float": + c.Type = "null.Float32" + case "double", "double precision", "real": + c.Type = "null.Float64" + case "boolean", "bool": + c.Type = "null.Bool" + case "date", "datetime", "timestamp", "time": + c.Type = "null.Time" + case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": + c.Type = "null.Bytes" + case "json": + c.Type = "types.JSON" + default: + c.Type = "null.String" + } + } else { + switch c.DBType { + case "tinyint": + c.Type = "int8" + case "smallint": + c.Type = "int16" + case "mediumint", "int", "integer": + c.Type = "int" + case "bigint": + c.Type = "null.Int64" + case "float": + c.Type = "float32" + case "double", "double precision", "real": + c.Type = "float64" + case "boolean", "bool": + c.Type = "bool" + case "date", "datetime", "timestamp", "time": + c.Type = "time.Time" + case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": + c.Type = "[]byte" + case "json": + c.Type = "types.JSON" + default: + c.Type = "string" + } + } + + return c +} + +// RightQuote is the quoting character for the right side of the identifier +func (m *MySQLDriver) RightQuote() byte { + return '`' +} + +// LeftQuote is the quoting character for the left side of the identifier +func (m *MySQLDriver) LeftQuote() byte { + return '`' +} + +// IndexPlaceholders returns false to indicate MySQL doesnt support indexed placeholders +func (m *MySQLDriver) IndexPlaceholders() bool { + return false +} diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index c4b9984ac..0422939bb 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -19,23 +19,20 @@ type PostgresDriver struct { dbConn *sql.DB } -// validatedTypes are types that cannot be zero values in the database. -var validatedTypes = []string{"uuid"} - // NewPostgresDriver takes the database connection details as parameters and // returns a pointer to a PostgresDriver object. Note that it is required to // call PostgresDriver.Open() and PostgresDriver.Close() to open and close // the database connection once an object has been obtained. func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver { driver := PostgresDriver{ - connStr: BuildQueryString(user, pass, dbname, host, port, sslmode), + connStr: PostgresBuildQueryString(user, pass, dbname, host, port, sslmode), } return &driver } -// BuildQueryString for Postgres -func BuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { +// PostgresBuildQueryString builds a query string. +func PostgresBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { parts := []string{} if len(user) != 0 { parts = append(parts, fmt.Sprintf("user=%s", user)) @@ -82,21 +79,25 @@ func (p *PostgresDriver) UseLastInsertID() bool { // TableNames connects to the postgres database and // retrieves all table names from the information_schema where the -// table schema is public. It excludes common migration tool tables -// such as gorp_migrations -func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) { +// table schema is schema. It uses a whitelist and blacklist. +func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { var names []string - query := `select table_name from information_schema.tables where table_schema = 'public'` - if len(exclude) > 0 { - quoteStr := func(x string) string { - return `'` + x + `'` + query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = $1`) + args := []interface{}{schema} + if len(whitelist) > 0 { + query += fmt.Sprintf(" and table_name in (%s);", strmangle.Placeholders(true, len(whitelist), 2, 1)) + for _, w := range whitelist { + args = append(args, w) + } + } else if len(blacklist) > 0 { + query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(true, len(blacklist), 2, 1)) + for _, b := range blacklist { + args = append(args, b) } - exclude = strmangle.StringMap(quoteStr, exclude) - query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ",")) } - rows, err := p.dbConn.Query(query) + rows, err := p.dbConn.Query(query, args...) if err != nil { return nil, err @@ -118,11 +119,11 @@ func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) { // from the database information_schema.columns. It retrieves the column names // and column types and returns those as a []Column after TranslateColumnType() // converts the SQL types to Go types, for example: "varchar" to "string" -func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { +func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) { var columns []bdb.Column rows, err := p.dbConn.Query(` - select column_name, data_type, column_default, is_nullable, + select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable, (select exists( select 1 from information_schema.constraint_column_usage as ccu @@ -136,11 +137,13 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { inner join pg_index pgi on pgi.indexrelid = pgc.oid inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey) where - pgix.schemaname = 'public' and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true + pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true )) as is_unique - from information_schema.columns as c - where table_name=$1 and table_schema = 'public'; - `, tableName) + from information_schema.columns as c LEFT JOIN information_schema.element_types e + ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier) + = (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier)) + where c.table_name=$2 and c.table_schema = $1; + `, schema, tableName) if err != nil { return nil, err @@ -148,10 +151,11 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { defer rows.Close() for rows.Next() { - var colName, colType, colDefault, nullable string + var colName, udtName, colType, colDefault, nullable string + var elementType *string var unique bool var defaultPtr *string - if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil { + if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil { return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) } @@ -162,12 +166,13 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { } column := bdb.Column{ - Name: colName, - DBType: colType, - Default: colDefault, - Nullable: nullable == "YES", - Unique: unique, - Validated: isValidated(colType), + Name: colName, + DBType: colType, + ArrType: elementType, + UDTName: udtName, + Default: colDefault, + Nullable: nullable == "YES", + Unique: unique, } columns = append(columns, column) } @@ -176,16 +181,16 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { } // PrimaryKeyInfo looks up the primary key for a table. -func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) { +func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) { pkey := &bdb.PrimaryKey{} var err error query := ` select tc.constraint_name from information_schema.table_constraints as tc - where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = 'public';` + where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = $2;` - row := p.dbConn.QueryRow(query, tableName) + row := p.dbConn.QueryRow(query, tableName, schema) if err = row.Scan(&pkey.Name); err != nil { if err == sql.ErrNoRows { return nil, nil @@ -196,10 +201,10 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro queryColumns := ` select kcu.column_name from information_schema.key_column_usage as kcu - where constraint_name = $1 and table_schema = 'public';` + where constraint_name = $1 and table_schema = $2;` var rows *sql.Rows - if rows, err = p.dbConn.Query(queryColumns, pkey.Name); err != nil { + if rows, err = p.dbConn.Query(queryColumns, pkey.Name, schema); err != nil { return nil, err } defer rows.Close() @@ -226,7 +231,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro } // ForeignKeyInfo retrieves the foreign keys for a given table name. -func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) { +func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { var fkeys []bdb.ForeignKey query := ` @@ -239,11 +244,11 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, err from information_schema.table_constraints as tc inner join information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name inner join information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name - where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = 'public';` + where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = $2;` var rows *sql.Rows var err error - if rows, err = p.dbConn.Query(query, tableName); err != nil { + if rows, err = p.dbConn.Query(query, tableName, schema); err != nil { return nil, err } @@ -279,18 +284,35 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "null.Int" case "smallint", "smallserial": c.Type = "null.Int16" - case "decimal", "numeric", "double precision", "money": + case "decimal", "numeric", "double precision": c.Type = "null.Float64" case "real": c.Type = "null.Float32" - case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": + case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "null.String" case "bytea": - c.Type = "[]byte" + c.Type = "null.Bytes" + case "json", "jsonb": + c.Type = "null.JSON" case "boolean": c.Type = "null.Bool" case "date", "time", "timestamp without time zone", "timestamp with time zone": c.Type = "null.Time" + case "ARRAY": + if c.ArrType == nil { + panic("unable to get postgres ARRAY underlying type") + } + c.Type = getArrayType(c) + // Make DBType something like ARRAYinteger for parsing with randomize.Struct + c.DBType = c.DBType + *c.ArrType + case "USER-DEFINED": + if c.UDTName == "hstore" { + c.Type = "types.HStore" + c.DBType = "hstore" + } else { + c.Type = "string" + fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName) + } default: c.Type = "null.String" } @@ -302,18 +324,32 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "int" case "smallint", "smallserial": c.Type = "int16" - case "decimal", "numeric", "double precision", "money": + case "decimal", "numeric", "double precision": c.Type = "float64" case "real": c.Type = "float32" - case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": + case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "string" + case "json", "jsonb": + c.Type = "types.JSON" case "bytea": c.Type = "[]byte" case "boolean": c.Type = "bool" case "date", "time", "timestamp without time zone", "timestamp with time zone": c.Type = "time.Time" + case "ARRAY": + c.Type = getArrayType(c) + // Make DBType something like ARRAYinteger for parsing with randomize.Struct + c.DBType = c.DBType + *c.ArrType + case "USER-DEFINED": + if c.UDTName == "hstore" { + c.Type = "types.HStore" + c.DBType = "hstore" + } else { + c.Type = "string" + fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName) + } default: c.Type = "string" } @@ -322,13 +358,35 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { return c } -// isValidated checks if the database type is in the validatedTypes list. -func isValidated(typ string) bool { - for _, v := range validatedTypes { - if v == typ { - return true - } +// getArrayType returns the correct boil.Array type for each database type +func getArrayType(c bdb.Column) string { + switch *c.ArrType { + case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial": + return "types.Int64Array" + case "bytea": + return "types.BytesArray" + case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": + return "types.StringArray" + case "boolean": + return "types.BoolArray" + case "decimal", "numeric", "double precision", "real": + return "types.Float64Array" + default: + return "types.StringArray" } +} - return false +// RightQuote is the quoting character for the right side of the identifier +func (p *PostgresDriver) RightQuote() byte { + return '"' +} + +// LeftQuote is the quoting character for the left side of the identifier +func (p *PostgresDriver) LeftQuote() byte { + return '"' +} + +// IndexPlaceholders returns true to indicate PSQL supports indexed placeholders +func (p *PostgresDriver) IndexPlaceholders() bool { + return true } diff --git a/bdb/interface.go b/bdb/interface.go index 9e4bfc70f..0cfee7369 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -6,10 +6,10 @@ import "github.com/pkg/errors" // Interface for a database driver. Functionality required to support a specific // database type (eg, MySQL, Postgres etc.) type Interface interface { - TableNames(exclude []string) ([]string, error) - Columns(tableName string) ([]Column, error) - PrimaryKeyInfo(tableName string) (*PrimaryKey, error) - ForeignKeyInfo(tableName string) ([]ForeignKey, error) + TableNames(schema string, whitelist, blacklist []string) ([]string, error) + Columns(schema, tableName string) ([]Column, error) + PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) + ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) // TranslateColumnType takes a Database column type and returns a go column type. TranslateColumnType(Column) Column @@ -22,23 +22,32 @@ type Interface interface { Open() error // Close the database connection Close() + + // Dialect helpers, these provide the values that will go into + // a queries.Dialect, so the query builder knows how to support + // your database driver properly. + LeftQuote() byte + RightQuote() byte + IndexPlaceholders() bool } // Tables returns the metadata for all tables, minus the tables -// specified in the exclude slice. -func Tables(db Interface, exclude ...string) ([]Table, error) { +// specified in the blacklist. +func Tables(db Interface, schema string, whitelist, blacklist []string) ([]Table, error) { var err error - names, err := db.TableNames(exclude) + names, err := db.TableNames(schema, whitelist, blacklist) if err != nil { return nil, errors.Wrap(err, "unable to get table names") } var tables []Table for _, name := range names { - t := Table{Name: name} + t := Table{ + Name: name, + } - if t.Columns, err = db.Columns(name); err != nil { + if t.Columns, err = db.Columns(schema, name); err != nil { return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name) } @@ -46,11 +55,11 @@ func Tables(db Interface, exclude ...string) ([]Table, error) { t.Columns[i] = db.TranslateColumnType(c) } - if t.PKey, err = db.PrimaryKeyInfo(name); err != nil { + if t.PKey, err = db.PrimaryKeyInfo(schema, name); err != nil { return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name) } - if t.FKeys, err = db.ForeignKeyInfo(name); err != nil { + if t.FKeys, err = db.ForeignKeyInfo(schema, name); err != nil { return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name) } diff --git a/bdb/interface_test.go b/bdb/interface_test.go index fd3e59f60..48be0886b 100644 --- a/bdb/interface_test.go +++ b/bdb/interface_test.go @@ -6,20 +6,23 @@ import ( "github.com/vattle/sqlboiler/strmangle" ) -type mockDriver struct{} +type testMockDriver struct{} -func (m mockDriver) TranslateColumnType(c Column) Column { return c } -func (m mockDriver) UseLastInsertID() bool { return false } -func (m mockDriver) Open() error { return nil } -func (m mockDriver) Close() {} +func (m testMockDriver) TranslateColumnType(c Column) Column { return c } +func (m testMockDriver) UseLastInsertID() bool { return false } +func (m testMockDriver) Open() error { return nil } +func (m testMockDriver) Close() {} -func (m mockDriver) TableNames(exclude []string) ([]string, error) { +func (m testMockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { + if len(whitelist) > 0 { + return whitelist, nil + } tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} - return strmangle.SetComplement(tables, exclude), nil + return strmangle.SetComplement(tables, blacklist), nil } // Columns returns a list of mock columns -func (m mockDriver) Columns(tableName string) ([]Column, error) { +func (m testMockDriver) Columns(schema, tableName string) ([]Column, error) { return map[string][]Column{ "pilots": { {Name: "id", Type: "int", DBType: "integer"}, @@ -61,7 +64,7 @@ func (m mockDriver) Columns(tableName string) ([]Column, error) { } // ForeignKeyInfo returns a list of mock foreignkeys -func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) { +func (m testMockDriver) ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) { return map[string][]ForeignKey{ "jets": { {Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true}, @@ -81,7 +84,7 @@ func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) { } // PrimaryKeyInfo returns mock primary key info for the passed in table name -func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { +func (m testMockDriver) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) { return map[string]*PrimaryKey{ "pilots": {Name: "pilot_id_pkey", Columns: []string{"id"}}, "airports": {Name: "airport_id_pkey", Columns: []string{"id"}}, @@ -93,10 +96,25 @@ func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { }[tableName], nil } +// RightQuote is the quoting character for the right side of the identifier +func (m testMockDriver) RightQuote() byte { + return '"' +} + +// LeftQuote is the quoting character for the left side of the identifier +func (m testMockDriver) LeftQuote() byte { + return '"' +} + +// IndexPlaceholders returns true to indicate fake support of indexed placeholders +func (m testMockDriver) IndexPlaceholders() bool { + return false +} + func TestTables(t *testing.T) { t.Parallel() - tables, err := Tables(mockDriver{}) + tables, err := Tables(testMockDriver{}, "public", nil, nil) if err != nil { t.Error(err) } diff --git a/bdb/keys.go b/bdb/keys.go index 35a8a84df..1ba1a2655 100644 --- a/bdb/keys.go +++ b/bdb/keys.go @@ -3,6 +3,7 @@ package bdb import ( "fmt" "regexp" + "strings" ) var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`) @@ -79,3 +80,30 @@ func SQLColDefinitions(cols []Column, names []string) SQLColumnDefs { return ret } + +// AutoIncPrimaryKey returns the auto-increment primary key column name or an +// empty string. Primary key columns with default values are presumed +// to be auto-increment, because pkeys need to be unique and a static +// default value would cause collisions. +func AutoIncPrimaryKey(cols []Column, pkey *PrimaryKey) *Column { + if pkey == nil { + return nil + } + + for _, pkeyColumn := range pkey.Columns { + for _, c := range cols { + if c.Name != pkeyColumn { + continue + } + + if c.Default != "auto_increment" || c.Nullable || + !(strings.HasPrefix(c.Type, "int") || strings.HasPrefix(c.Type, "uint")) { + continue + } + + return &c + } + } + + return nil +} diff --git a/bdb/table.go b/bdb/table.go index 82e872d7e..28fa6248c 100644 --- a/bdb/table.go +++ b/bdb/table.go @@ -4,8 +4,11 @@ import "fmt" // Table metadata from the database schema. type Table struct { - Name string - Columns []Column + Name string + // For dbs with real schemas, like Postgres. + // Example value: "schema_name"."table_name" + SchemaName string + Columns []Column PKey *PrimaryKey FKeys []ForeignKey diff --git a/config.go b/config.go index 4af2518c6..da2813448 100644 --- a/config.go +++ b/config.go @@ -3,10 +3,12 @@ package main // Config for the running of the commands type Config struct { DriverName string + Schema string PkgName string OutFolder string BaseDir string - ExcludeTables []string + WhitelistTables []string + BlacklistTables []string Tags []string Debug bool NoTests bool @@ -14,6 +16,7 @@ type Config struct { NoAutoTimestamps bool Postgres PostgresConfig + MySQL MySQLConfig } // PostgresConfig configures a postgres database @@ -25,3 +28,13 @@ type PostgresConfig struct { DBName string SSLMode string } + +// MySQLConfig configures a mysql database +type MySQLConfig struct { + User string + Pass string + Host string + Port int + DBName string + SSLMode string +} diff --git a/imports.go b/imports.go index b753edaad..bf9693055 100644 --- a/imports.go +++ b/imports.go @@ -153,7 +153,8 @@ var defaultTemplateImports = imports{ thirdParty: importList{ `"github.com/pkg/errors"`, `"github.com/vattle/sqlboiler/boil"`, - `"github.com/vattle/sqlboiler/boil/qm"`, + `"github.com/vattle/sqlboiler/queries"`, + `"github.com/vattle/sqlboiler/queries/qm"`, `"github.com/vattle/sqlboiler/strmangle"`, }, } @@ -162,7 +163,8 @@ var defaultSingletonTemplateImports = map[string]imports{ "boil_queries": { thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"github.com/vattle/sqlboiler/boil/qm"`, + `"github.com/vattle/sqlboiler/queries"`, + `"github.com/vattle/sqlboiler/queries/qm"`, }, }, "boil_types": { @@ -180,29 +182,38 @@ var defaultTestTemplateImports = imports{ }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"github.com/vattle/sqlboiler/boil/randomize"`, + `"github.com/vattle/sqlboiler/randomize"`, `"github.com/vattle/sqlboiler/strmangle"`, }, } var defaultSingletonTestTemplateImports = map[string]imports{ - "boil_viper_test": { + "boil_main_test": { standard: importList{ `"database/sql"`, + `"flag"`, + `"fmt"`, + `"math/rand"`, `"os"`, `"path/filepath"`, + `"testing"`, + `"time"`, }, thirdParty: importList{ + `"github.com/kat-co/vala"`, + `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, + `"github.com/vattle/sqlboiler/boil"`, }, }, "boil_queries_test": { standard: importList{ - `"crypto/md5"`, + `"bytes"`, `"fmt"`, - `"os"`, - `"strconv"`, + `"io"`, + `"io/ioutil"`, `"math/rand"`, + `"regexp"`, }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, @@ -218,25 +229,40 @@ var defaultSingletonTestTemplateImports = map[string]imports{ var defaultTestMainImports = map[string]imports{ "postgres": { standard: importList{ - `"testing"`, - `"os"`, - `"os/exec"`, - `"flag"`, + `"bytes"`, + `"database/sql"`, `"fmt"`, + `"io"`, `"io/ioutil"`, + `"os"`, + `"os/exec"`, + `"strings"`, + }, + thirdParty: importList{ + `"github.com/pkg/errors"`, + `"github.com/spf13/viper"`, + `"github.com/vattle/sqlboiler/bdb/drivers"`, + `"github.com/vattle/sqlboiler/randomize"`, + `_ "github.com/lib/pq"`, + }, + }, + "mysql": { + standard: importList{ `"bytes"`, `"database/sql"`, - `"path/filepath"`, - `"time"`, - `"math/rand"`, + `"fmt"`, + `"io"`, + `"io/ioutil"`, + `"os"`, + `"os/exec"`, + `"strings"`, }, thirdParty: importList{ - `"github.com/kat-co/vala"`, `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, - `"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/bdb/drivers"`, - `_ "github.com/lib/pq"`, + `"github.com/vattle/sqlboiler/randomize"`, + `_ "github.com/go-sql-driver/mysql"`, }, }, } @@ -246,51 +272,75 @@ var defaultTestMainImports = map[string]imports{ // TranslateColumnType to see the type assignments. var importsBasedOnType = map[string]imports{ "null.Float32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Float64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int8": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int16": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint8": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint16": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.String": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Bool": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Time": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + }, + "null.JSON": { + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + }, + "null.Bytes": { + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "time.Time": { standard: importList{`"time"`}, }, + "types.JSON": { + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, + }, + "types.BytesArray": { + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, + }, + "types.Int64Array": { + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, + }, + "types.Float64Array": { + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, + }, + "types.BoolArray": { + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, + }, + "types.Hstore": { + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, + }, } diff --git a/imports_test.go b/imports_test.go index 79863fa2c..b07fca79b 100644 --- a/imports_test.go +++ b/imports_test.go @@ -75,7 +75,7 @@ func TestCombineTypeImports(t *testing.T) { }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"gopkg.in/nullbio/null.v4"`, + `"gopkg.in/nullbio/null.v5"`, }, } @@ -108,7 +108,7 @@ func TestCombineTypeImports(t *testing.T) { }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"gopkg.in/nullbio/null.v4"`, + `"gopkg.in/nullbio/null.v5"`, }, } @@ -124,7 +124,7 @@ func TestCombineImports(t *testing.T) { a := imports{ standard: importList{"fmt"}, - thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v4"}, + thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v5"}, } b := imports{ standard: importList{"os"}, @@ -136,8 +136,8 @@ func TestCombineImports(t *testing.T) { if c.standard[0] != "fmt" && c.standard[1] != "os" { t.Errorf("Wanted: fmt, os got: %#v", c.standard) } - if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" { - t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty) + if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v5" { + t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v5 got: %#v", c.thirdParty) } } diff --git a/main.go b/main.go index db3043b6e..2fd63ffba 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( - "errors" "fmt" "os" "path/filepath" @@ -61,9 +60,11 @@ func main() { // Set up the cobra root command flags rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") + rootCmd.PersistentFlags().StringP("schema", "s", "public", "The name of your database schema, for databases that support real schemas") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") - rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders") - rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package") + rootCmd.PersistentFlags().StringP("basedir", "", "", "The base directory has the templates and templates_test folders") + rootCmd.PersistentFlags().StringSliceP("blacklist", "b", nil, "Do not include these tables in your generated package") + rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package") rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml") rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error") rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files") @@ -72,6 +73,9 @@ func main() { viper.SetDefault("postgres.sslmode", "require") viper.SetDefault("postgres.port", "5432") + viper.SetDefault("mysql.sslmode", "true") + viper.SetDefault("mysql.port", "3306") + viper.BindPFlags(rootCmd.PersistentFlags()) viper.AutomaticEnv() @@ -79,7 +83,7 @@ func main() { if e, ok := err.(commandFailure); ok { fmt.Printf("Error: %v\n\n", string(e)) rootCmd.Help() - } else if !cmdConfig.Debug { + } else if !viper.GetBool("debug") { fmt.Printf("Error: %v\n", err) } else { fmt.Printf("Error: %+v\n", err) @@ -107,6 +111,7 @@ func preRun(cmd *cobra.Command, args []string) error { cmdConfig = &Config{ DriverName: driverName, OutFolder: viper.GetString("output"), + Schema: viper.GetString("schema"), PkgName: viper.GetString("pkgname"), Debug: viper.GetBool("debug"), NoTests: viper.GetBool("no-tests"), @@ -115,12 +120,20 @@ func preRun(cmd *cobra.Command, args []string) error { } // BUG: https://github.com/spf13/viper/issues/200 - // Look up the value of ExcludeTables & Tags directly from PFlags in Cobra if we + // Look up the value of blacklist, whitelist & tags directly from PFlags in Cobra if we // detect a malformed value coming out of viper. // Once the bug is fixed we'll be able to move this into the init above - cmdConfig.ExcludeTables = viper.GetStringSlice("exclude") - if len(cmdConfig.ExcludeTables) == 1 && strings.HasPrefix(cmdConfig.ExcludeTables[0], "[") { - cmdConfig.ExcludeTables, err = cmd.PersistentFlags().GetStringSlice("exclude") + cmdConfig.BlacklistTables = viper.GetStringSlice("blacklist") + if len(cmdConfig.BlacklistTables) == 1 && strings.HasPrefix(cmdConfig.BlacklistTables[0], "[") { + cmdConfig.BlacklistTables, err = cmd.PersistentFlags().GetStringSlice("blacklist") + if err != nil { + return err + } + } + + cmdConfig.WhitelistTables = viper.GetStringSlice("whitelist") + if len(cmdConfig.WhitelistTables) == 1 && strings.HasPrefix(cmdConfig.WhitelistTables[0], "[") { + cmdConfig.WhitelistTables, err = cmd.PersistentFlags().GetStringSlice("whitelist") if err != nil { return err } @@ -134,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error { } } - if viper.IsSet("postgres.dbname") { + if driverName == "postgres" { cmdConfig.Postgres = PostgresConfig{ User: viper.GetString("postgres.user"), Pass: viper.GetString("postgres.pass"), @@ -144,10 +157,17 @@ func preRun(cmd *cobra.Command, args []string) error { SSLMode: viper.GetString("postgres.sslmode"), } - // Set the default SSLMode value + // BUG: https://github.com/spf13/viper/issues/71 + // Despite setting defaults, nested values don't get defaults + // Set them manually if cmdConfig.Postgres.SSLMode == "" { - viper.Set("postgres.sslmode", "require") - cmdConfig.Postgres.SSLMode = viper.GetString("postgres.sslmode") + cmdConfig.Postgres.SSLMode = "require" + viper.Set("postgres.sslmode", cmdConfig.Postgres.SSLMode) + } + + if cmdConfig.Postgres.Port == 0 { + cmdConfig.Postgres.Port = 5432 + viper.Set("postgres.port", cmdConfig.Postgres.Port) } err = vala.BeginValidation().Validate( @@ -161,8 +181,45 @@ func preRun(cmd *cobra.Command, args []string) error { if err != nil { return commandFailure(err.Error()) } - } else if driverName == "postgres" { - return errors.New("postgres driver requires a postgres section in your config file") + } + + if driverName == "mysql" { + cmdConfig.MySQL = MySQLConfig{ + User: viper.GetString("mysql.user"), + Pass: viper.GetString("mysql.pass"), + Host: viper.GetString("mysql.host"), + Port: viper.GetInt("mysql.port"), + DBName: viper.GetString("mysql.dbname"), + SSLMode: viper.GetString("mysql.sslmode"), + } + + // MySQL doesn't have schemas, just databases + cmdConfig.Schema = cmdConfig.MySQL.DBName + + // BUG: https://github.com/spf13/viper/issues/71 + // Despite setting defaults, nested values don't get defaults + // Set them manually + if cmdConfig.MySQL.SSLMode == "" { + cmdConfig.MySQL.SSLMode = "true" + viper.Set("mysql.sslmode", cmdConfig.MySQL.SSLMode) + } + + if cmdConfig.MySQL.Port == 0 { + cmdConfig.MySQL.Port = 3306 + viper.Set("mysql.port", cmdConfig.MySQL.Port) + } + + err = vala.BeginValidation().Validate( + vala.StringNotEmpty(cmdConfig.MySQL.User, "mysql.user"), + vala.StringNotEmpty(cmdConfig.MySQL.Host, "mysql.host"), + vala.Not(vala.Equals(cmdConfig.MySQL.Port, 0, "mysql.port")), + vala.StringNotEmpty(cmdConfig.MySQL.DBName, "mysql.dbname"), + vala.StringNotEmpty(cmdConfig.MySQL.SSLMode, "mysql.sslmode"), + ).Check() + + if err != nil { + return commandFailure(err.Error()) + } } cmdState, err = New(cmdConfig) diff --git a/boil/_fixtures/00.sql b/queries/_fixtures/00.sql similarity index 100% rename from boil/_fixtures/00.sql rename to queries/_fixtures/00.sql diff --git a/boil/_fixtures/01.sql b/queries/_fixtures/01.sql similarity index 100% rename from boil/_fixtures/01.sql rename to queries/_fixtures/01.sql diff --git a/boil/_fixtures/02.sql b/queries/_fixtures/02.sql similarity index 100% rename from boil/_fixtures/02.sql rename to queries/_fixtures/02.sql diff --git a/boil/_fixtures/03.sql b/queries/_fixtures/03.sql similarity index 100% rename from boil/_fixtures/03.sql rename to queries/_fixtures/03.sql diff --git a/boil/_fixtures/04.sql b/queries/_fixtures/04.sql similarity index 100% rename from boil/_fixtures/04.sql rename to queries/_fixtures/04.sql diff --git a/boil/_fixtures/05.sql b/queries/_fixtures/05.sql similarity index 100% rename from boil/_fixtures/05.sql rename to queries/_fixtures/05.sql diff --git a/boil/_fixtures/06.sql b/queries/_fixtures/06.sql similarity index 100% rename from boil/_fixtures/06.sql rename to queries/_fixtures/06.sql diff --git a/boil/_fixtures/07.sql b/queries/_fixtures/07.sql similarity index 100% rename from boil/_fixtures/07.sql rename to queries/_fixtures/07.sql diff --git a/boil/_fixtures/08.sql b/queries/_fixtures/08.sql similarity index 100% rename from boil/_fixtures/08.sql rename to queries/_fixtures/08.sql diff --git a/boil/_fixtures/09.sql b/queries/_fixtures/09.sql similarity index 100% rename from boil/_fixtures/09.sql rename to queries/_fixtures/09.sql diff --git a/boil/_fixtures/10.sql b/queries/_fixtures/10.sql similarity index 100% rename from boil/_fixtures/10.sql rename to queries/_fixtures/10.sql diff --git a/boil/_fixtures/11.sql b/queries/_fixtures/11.sql similarity index 100% rename from boil/_fixtures/11.sql rename to queries/_fixtures/11.sql diff --git a/boil/_fixtures/12.sql b/queries/_fixtures/12.sql similarity index 100% rename from boil/_fixtures/12.sql rename to queries/_fixtures/12.sql diff --git a/boil/_fixtures/13.sql b/queries/_fixtures/13.sql similarity index 100% rename from boil/_fixtures/13.sql rename to queries/_fixtures/13.sql diff --git a/boil/_fixtures/14.sql b/queries/_fixtures/14.sql similarity index 100% rename from boil/_fixtures/14.sql rename to queries/_fixtures/14.sql diff --git a/boil/_fixtures/15.sql b/queries/_fixtures/15.sql similarity index 100% rename from boil/_fixtures/15.sql rename to queries/_fixtures/15.sql diff --git a/boil/eager_load.go b/queries/eager_load.go similarity index 98% rename from boil/eager_load.go rename to queries/eager_load.go index e6eafd80a..5d4c9b1d4 100644 --- a/boil/eager_load.go +++ b/queries/eager_load.go @@ -1,15 +1,16 @@ -package boil +package queries import ( "database/sql" "reflect" "github.com/pkg/errors" + "github.com/vattle/sqlboiler/boil" "github.com/vattle/sqlboiler/strmangle" ) type loadRelationshipState struct { - exec Executor + exec boil.Executor loaded map[string]struct{} toLoad []string } diff --git a/boil/eager_load_test.go b/queries/eager_load_test.go similarity index 92% rename from boil/eager_load_test.go rename to queries/eager_load_test.go index 86ada1aa2..282bff019 100644 --- a/boil/eager_load_test.go +++ b/queries/eager_load_test.go @@ -1,6 +1,10 @@ -package boil +package queries -import "testing" +import ( + "testing" + + "github.com/vattle/sqlboiler/boil" +) var loadFunctionCalled bool var loadFunctionNestedCalled int @@ -32,12 +36,12 @@ type testNestedRSlice struct { type testNestedLSlice struct { } -func (testLStruct) LoadTestOne(exec Executor, singular bool, obj interface{}) error { +func (testLStruct) LoadTestOne(exec boil.Executor, singular bool, obj interface{}) error { loadFunctionCalled = true return nil } -func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error { +func (testNestedLStruct) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error { switch x := obj.(type) { case *testNestedStruct: x.R = &testNestedRStruct{ @@ -54,7 +58,7 @@ func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj inter return nil } -func (testNestedLSlice) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error { +func (testNestedLSlice) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error { switch x := obj.(type) { case *testNestedSlice: diff --git a/boil/helpers.go b/queries/helpers.go similarity index 97% rename from boil/helpers.go rename to queries/helpers.go index 43e3ff2ce..59ad8a3ff 100644 --- a/boil/helpers.go +++ b/queries/helpers.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "fmt" diff --git a/boil/helpers_test.go b/queries/helpers_test.go similarity index 96% rename from boil/helpers_test.go rename to queries/helpers_test.go index 756834547..c87bc7602 100644 --- a/boil/helpers_test.go +++ b/queries/helpers_test.go @@ -1,11 +1,11 @@ -package boil +package queries import ( "reflect" "testing" "time" - "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v5" ) type testObj struct { diff --git a/boil/qm/query_mods.go b/queries/qm/query_mods.go similarity index 66% rename from boil/qm/query_mods.go rename to queries/qm/query_mods.go index 5cbf3c63d..b2e7e14f6 100644 --- a/boil/qm/query_mods.go +++ b/queries/qm/query_mods.go @@ -1,12 +1,12 @@ package qm -import "github.com/vattle/sqlboiler/boil" +import "github.com/vattle/sqlboiler/queries" // QueryMod to modify the query object -type QueryMod func(q *boil.Query) +type QueryMod func(q *queries.Query) // Apply the query mods to the Query object -func Apply(q *boil.Query, mods ...QueryMod) { +func Apply(q *queries.Query, mods ...QueryMod) { for _, mod := range mods { mod(q) } @@ -14,8 +14,8 @@ func Apply(q *boil.Query, mods ...QueryMod) { // SQL allows you to execute a plain SQL statement func SQL(sql string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.SetSQL(q, sql, args...) + return func(q *queries.Query) { + queries.SetSQL(q, sql, args...) } } @@ -25,29 +25,29 @@ func SQL(sql string, args ...interface{}) QueryMod { // Relationship name plurality is important, if your relationship is // singular, you need to specify the singular form and vice versa. func Load(relationships ...string) QueryMod { - return func(q *boil.Query) { - boil.SetLoad(q, relationships...) + return func(q *queries.Query) { + queries.SetLoad(q, relationships...) } } // InnerJoin on another table func InnerJoin(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendInnerJoin(q, clause, args...) + return func(q *queries.Query) { + queries.AppendInnerJoin(q, clause, args...) } } // Select specific columns opposed to all columns func Select(columns ...string) QueryMod { - return func(q *boil.Query) { - boil.AppendSelect(q, columns...) + return func(q *queries.Query) { + queries.AppendSelect(q, columns...) } } // Where allows you to specify a where clause for your statement func Where(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendWhere(q, clause, args...) + return func(q *queries.Query) { + queries.AppendWhere(q, clause, args...) } } @@ -55,24 +55,24 @@ func Where(clause string, args ...interface{}) QueryMod { // And is a duplicate of the Where function, but allows for more natural looking // query mod chains, for example: (Where("a=?"), And("b=?"), Or("c=?"))) func And(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendWhere(q, clause, args...) + return func(q *queries.Query) { + queries.AppendWhere(q, clause, args...) } } // Or allows you to specify a where clause separated by an OR for your statement func Or(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendWhere(q, clause, args...) - boil.SetLastWhereAsOr(q) + return func(q *queries.Query) { + queries.AppendWhere(q, clause, args...) + queries.SetLastWhereAsOr(q) } } // WhereIn allows you to specify a "x IN (set)" clause for your where statement // Example clauses: "column in ?", "(column1,column2) in ?" func WhereIn(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendIn(q, clause, args...) + return func(q *queries.Query) { + queries.AppendIn(q, clause, args...) } } @@ -81,65 +81,65 @@ func WhereIn(clause string, args ...interface{}) QueryMod { // allows for more natural looking query mod chains, for example: // (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?")) func AndIn(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendIn(q, clause, args...) + return func(q *queries.Query) { + queries.AppendIn(q, clause, args...) } } // OrIn allows you to specify an IN clause separated by // an OR for your where statement func OrIn(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendIn(q, clause, args...) - boil.SetLastInAsOr(q) + return func(q *queries.Query) { + queries.AppendIn(q, clause, args...) + queries.SetLastInAsOr(q) } } // GroupBy allows you to specify a group by clause for your statement func GroupBy(clause string) QueryMod { - return func(q *boil.Query) { - boil.AppendGroupBy(q, clause) + return func(q *queries.Query) { + queries.AppendGroupBy(q, clause) } } // OrderBy allows you to specify a order by clause for your statement func OrderBy(clause string) QueryMod { - return func(q *boil.Query) { - boil.AppendOrderBy(q, clause) + return func(q *queries.Query) { + queries.AppendOrderBy(q, clause) } } // Having allows you to specify a having clause for your statement func Having(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendHaving(q, clause, args...) + return func(q *queries.Query) { + queries.AppendHaving(q, clause, args...) } } // From allows to specify the table for your statement func From(from string) QueryMod { - return func(q *boil.Query) { - boil.AppendFrom(q, from) + return func(q *queries.Query) { + queries.AppendFrom(q, from) } } // Limit the number of returned rows func Limit(limit int) QueryMod { - return func(q *boil.Query) { - boil.SetLimit(q, limit) + return func(q *queries.Query) { + queries.SetLimit(q, limit) } } // Offset into the results func Offset(offset int) QueryMod { - return func(q *boil.Query) { - boil.SetOffset(q, offset) + return func(q *queries.Query) { + queries.SetOffset(q, offset) } } // For inserts a concurrency locking clause at the end of your statement func For(clause string) QueryMod { - return func(q *boil.Query) { - boil.SetFor(q, clause) + return func(q *queries.Query) { + queries.SetFor(q, clause) } } diff --git a/boil/query.go b/queries/query.go similarity index 61% rename from boil/query.go rename to queries/query.go index 2ea37e938..81237a85a 100644 --- a/boil/query.go +++ b/queries/query.go @@ -1,8 +1,10 @@ -package boil +package queries import ( "database/sql" "fmt" + + "github.com/vattle/sqlboiler/boil" ) // joinKind is the type of join @@ -18,8 +20,9 @@ const ( // Query holds the state for the built up query type Query struct { - executor Executor - plainSQL plainSQL + executor boil.Executor + dialect *Dialect + rawSQL rawSQL load []string delete bool update map[string]interface{} @@ -37,6 +40,20 @@ type Query struct { forlock string } +// Dialect holds values that direct the query builder +// how to build compatible queries for each database. +// Each database driver needs to implement functions +// that provide these values. +type Dialect struct { + // The left quote character for SQL identifiers + LQ byte + // The right quote character for SQL identifiers + RQ byte + // Bool flag indicating whether indexed + // placeholders ($1) are used, or ? placeholders. + IndexPlaceholders bool +} + type where struct { clause string orSeparator bool @@ -54,7 +71,7 @@ type having struct { args []interface{} } -type plainSQL struct { +type rawSQL struct { sql string args []interface{} } @@ -65,65 +82,92 @@ type join struct { args []interface{} } -// SQL makes a plainSQL query, usually for use with bind -func SQL(exec Executor, query string, args ...interface{}) *Query { +// Raw makes a raw query, usually for use with bind +func Raw(exec boil.Executor, query string, args ...interface{}) *Query { return &Query{ executor: exec, - plainSQL: plainSQL{ + rawSQL: rawSQL{ sql: query, args: args, }, } } -// SQLG makes a plainSQL query using the global Executor, usually for use with bind -func SQLG(query string, args ...interface{}) *Query { - return SQL(GetDB(), query, args...) +// RawG makes a raw query using the global boil.Executor, usually for use with bind +func RawG(query string, args ...interface{}) *Query { + return Raw(boil.GetDB(), query, args...) } -// ExecQuery executes a query that does not need a row returned -func ExecQuery(q *Query) (sql.Result, error) { +// Exec executes a query that does not need a row returned +func (q *Query) Exec() (sql.Result, error) { qs, args := buildQuery(q) - if DebugMode { - fmt.Fprintln(DebugWriter, qs) - fmt.Fprintln(DebugWriter, args) + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, qs) + fmt.Fprintln(boil.DebugWriter, args) } return q.executor.Exec(qs, args...) } -// ExecQueryOne executes the query for the One finisher and returns a row -func ExecQueryOne(q *Query) *sql.Row { +// QueryRow executes the query for the One finisher and returns a row +func (q *Query) QueryRow() *sql.Row { qs, args := buildQuery(q) - if DebugMode { - fmt.Fprintln(DebugWriter, qs) - fmt.Fprintln(DebugWriter, args) + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, qs) + fmt.Fprintln(boil.DebugWriter, args) } return q.executor.QueryRow(qs, args...) } -// ExecQueryAll executes the query for the All finisher and returns multiple rows -func ExecQueryAll(q *Query) (*sql.Rows, error) { +// Query executes the query for the All finisher and returns multiple rows +func (q *Query) Query() (*sql.Rows, error) { qs, args := buildQuery(q) - if DebugMode { - fmt.Fprintln(DebugWriter, qs) - fmt.Fprintln(DebugWriter, args) + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, qs) + fmt.Fprintln(boil.DebugWriter, args) } return q.executor.Query(qs, args...) } +// ExecP executes a query that does not need a row returned +// It will panic on error +func (q *Query) ExecP() sql.Result { + res, err := q.Exec() + if err != nil { + panic(boil.WrapErr(err)) + } + + return res +} + +// QueryP executes the query for the All finisher and returns multiple rows +// It will panic on error +func (q *Query) QueryP() *sql.Rows { + rows, err := q.Query() + if err != nil { + panic(boil.WrapErr(err)) + } + + return rows +} + // SetExecutor on the query. -func SetExecutor(q *Query, exec Executor) { +func SetExecutor(q *Query, exec boil.Executor) { q.executor = exec } // GetExecutor on the query. -func GetExecutor(q *Query) Executor { +func GetExecutor(q *Query) boil.Executor { return q.executor } +// SetDialect on the query. +func SetDialect(q *Query, dialect *Dialect) { + q.dialect = dialect +} + // SetSQL on the query. func SetSQL(q *Query, sql string, args ...interface{}) { - q.plainSQL = plainSQL{sql: sql, args: args} + q.rawSQL = rawSQL{sql: sql, args: args} } // SetLoad on the query. @@ -131,6 +175,11 @@ func SetLoad(q *Query, relationships ...string) { q.load = append([]string(nil), relationships...) } +// SetSelect on the query. +func SetSelect(q *Query, sel []string) { + q.selectCols = sel +} + // SetCount on the query. func SetCount(q *Query) { q.count = true diff --git a/boil/query_builders.go b/queries/query_builders.go similarity index 74% rename from boil/query_builders.go rename to queries/query_builders.go index 7ea79952f..8997eaa92 100644 --- a/boil/query_builders.go +++ b/queries/query_builders.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "bytes" @@ -20,8 +20,8 @@ func buildQuery(q *Query) (string, []interface{}) { var args []interface{} switch { - case len(q.plainSQL.sql) != 0: - return q.plainSQL.sql, q.plainSQL.args + case len(q.rawSQL.sql) != 0: + return q.rawSQL.sql, q.rawSQL.args case q.delete: buf, args = buildDeleteQuery(q) case len(q.update) > 0: @@ -34,8 +34,8 @@ func buildQuery(q *Query) (string, []interface{}) { // Cache the generated query for query object re-use bufStr := buf.String() - q.plainSQL.sql = bufStr - q.plainSQL.args = args + q.rawSQL.sql = bufStr + q.rawSQL.args = args return bufStr, args } @@ -57,8 +57,8 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { // Don't identQuoteSlice - writeAsStatements does this buf.WriteString(strings.Join(selectColsWithAs, ", ")) } else if hasSelectCols { - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), ", ")) - } else if hasJoins { + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", ")) + } else if hasJoins && !q.count { selectColsWithStars := writeStars(q) buf.WriteString(strings.Join(selectColsWithStars, ", ")) } else { @@ -70,7 +70,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { buf.WriteByte(')') } - fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) if len(q.joins) > 0 { argsLen := len(args) @@ -82,7 +82,12 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause) args = append(args, j.args...) } - resp, _ := convertQuestionMarks(joinBuf.String(), argsLen+1) + var resp string + if q.dialect.IndexPlaceholders { + resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1) + } else { + resp = joinBuf.String() + } fmt.Fprintf(buf, resp) strmangle.PutBuffer(joinBuf) } @@ -110,7 +115,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { buf := strmangle.GetBuffer() buf.WriteString("DELETE FROM ") - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) where, whereArgs := whereClause(q, 1) if len(whereArgs) != 0 { @@ -135,7 +140,7 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { buf := strmangle.GetBuffer() buf.WriteString("UPDATE ") - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) cols := make(sort.StringSlice, len(q.update)) var args []interface{} @@ -150,13 +155,13 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { for i := 0; i < len(cols); i++ { args = append(args, q.update[cols[i]]) - cols[i] = strmangle.IdentQuote(cols[i]) + cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i]) } buf.WriteString(fmt.Sprintf( " SET (%s) = (%s)", strings.Join(cols, ", "), - strmangle.Placeholders(len(cols), 1, 1)), + strmangle.Placeholders(q.dialect.IndexPlaceholders, len(cols), 1, 1)), ) where, whereArgs := whereClause(q, len(args)+1) @@ -178,11 +183,40 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { return buf, args } -// BuildUpsertQuery builds a SQL statement string using the upsertData provided. -func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { - conflict = strmangle.IdentQuoteSlice(conflict) - whitelist = strmangle.IdentQuoteSlice(whitelist) - ret = strmangle.IdentQuoteSlice(ret) +// BuildUpsertQueryMySQL builds a SQL statement string using the upsertData provided. +func BuildUpsertQueryMySQL(dia Dialect, tableName string, update, whitelist []string) string { + whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) + + buf := strmangle.GetBuffer() + defer strmangle.PutBuffer(buf) + + fmt.Fprintf( + buf, + "INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ", + tableName, + strings.Join(whitelist, ", "), + strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1), + ) + + for i, v := range update { + if i != 0 { + buf.WriteByte(',') + } + quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v) + buf.WriteString(quoted) + buf.WriteString(" = VALUES(") + buf.WriteString(quoted) + buf.WriteByte(')') + } + + return buf.String() +} + +// BuildUpsertQueryPostgres builds a SQL statement string using the upsertData provided. +func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { + conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict) + whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) + ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret) buf := strmangle.GetBuffer() defer strmangle.PutBuffer(buf) @@ -192,7 +226,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ", tableName, strings.Join(whitelist, ", "), - strmangle.Placeholders(len(whitelist), 1, 1), + strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1), ) if !updateOnConflict || len(update) == 0 { @@ -206,7 +240,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf if i != 0 { buf.WriteByte(',') } - quoted := strmangle.IdentQuote(v) + quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v) buf.WriteString(quoted) buf.WriteString(" = EXCLUDED.") buf.WriteString(quoted) @@ -237,7 +271,12 @@ func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) { fmt.Fprintf(havingBuf, j.clause) *args = append(*args, j.args...) } - resp, _ := convertQuestionMarks(havingBuf.String(), argsLen+1) + var resp string + if q.dialect.IndexPlaceholders { + resp, _ = convertQuestionMarks(havingBuf.String(), argsLen+1) + } else { + resp = havingBuf.String() + } fmt.Fprintf(buf, resp) strmangle.PutBuffer(havingBuf) } @@ -264,7 +303,7 @@ func writeStars(q *Query) []string { for i, f := range q.from { toks := strings.Split(f, " ") if len(toks) == 1 { - cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(toks[0])) + cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0])) continue } @@ -276,7 +315,7 @@ func writeStars(q *Query) []string { if len(alias) != 0 { name = alias } - cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(name)) + cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name)) } return cols @@ -292,7 +331,7 @@ func writeAsStatements(q *Query) []string { toks := strings.Split(col, ".") if len(toks) == 1 { - cols[i] = strmangle.IdentQuote(col) + cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col) continue } @@ -301,7 +340,7 @@ func writeAsStatements(q *Query) []string { asParts[j] = strings.Trim(tok, `"`) } - cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(col), strings.Join(asParts, ".")) + cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, ".")) } return cols @@ -335,7 +374,13 @@ func whereClause(q *Query, startAt int) (string, []interface{}) { args = append(args, where.args...) } - resp, _ := convertQuestionMarks(buf.String(), startAt) + var resp string + if q.dialect.IndexPlaceholders { + resp, _ = convertQuestionMarks(buf.String(), startAt) + } else { + resp = buf.String() + } + return resp, args } @@ -374,7 +419,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) { // column name side, however if this case is being hit then the regexp // probably needs adjustment, or the user is passing in invalid clauses. if matches == nil { - clause, count := convertInQuestionMarks(in.clause, startAt, 1, ln) + clause, count := convertInQuestionMarks(q.dialect.IndexPlaceholders, in.clause, startAt, 1, ln) buf.WriteString(clause) startAt = startAt + count } else { @@ -384,11 +429,24 @@ func inClause(q *Query, startAt int) (string, []interface{}) { // of the clause to determine how many columns they are using. // This number determines the groupAt for the convert function. cols := strings.Split(leftSide, ",") - cols = strmangle.IdentQuoteSlice(cols) + cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols) groupAt := len(cols) - leftClause, leftCount := convertQuestionMarks(strings.Join(cols, ","), startAt) - rightClause, rightCount := convertInQuestionMarks(rightSide, startAt+leftCount, groupAt, ln-leftCount) + var leftClause string + var leftCount int + if q.dialect.IndexPlaceholders { + leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt) + } else { + // Count the number of cols that are question marks, so we know + // how much to offset convertInQuestionMarks by + for _, v := range cols { + if v == "?" { + leftCount++ + } + } + leftClause = strings.Join(cols, ",") + } + rightClause, rightCount := convertInQuestionMarks(q.dialect.IndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount) buf.WriteString(leftClause) buf.WriteString(" IN ") buf.WriteString(rightClause) @@ -406,7 +464,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) { // It uses groupAt to determine how many placeholders should be in each group, // for example, groupAt 2 would result in: (($1,$2),($3,$4)) // and groupAt 1 would result in ($1,$2,$3,$4) -func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string, int) { +func convertInQuestionMarks(indexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) { if startAt == 0 || len(clause) == 0 { panic("Not a valid start number.") } @@ -428,7 +486,7 @@ func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string, paramBuf.WriteString(clause[:foundAt]) paramBuf.WriteByte('(') - paramBuf.WriteString(strmangle.Placeholders(total, startAt, groupAt)) + paramBuf.WriteString(strmangle.Placeholders(indexPlaceholders, total, startAt, groupAt)) paramBuf.WriteByte(')') paramBuf.WriteString(clause[foundAt+1:]) diff --git a/boil/query_builders_test.go b/queries/query_builders_test.go similarity index 95% rename from boil/query_builders_test.go rename to queries/query_builders_test.go index b5035d87c..9af45da09 100644 --- a/boil/query_builders_test.go +++ b/queries/query_builders_test.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "bytes" @@ -97,6 +97,7 @@ func TestBuildQuery(t *testing.T) { for i, test := range tests { filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i)) + test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} out, args := buildQuery(test.q) if *writeGoldenFiles { @@ -149,6 +150,7 @@ func TestWriteStars(t *testing.T) { } for i, test := range tests { + test.In.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} selects := writeStars(&test.In) if !reflect.DeepEqual(selects, test.Out) { t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects) @@ -275,6 +277,7 @@ func TestWhereClause(t *testing.T) { } for i, test := range tests { + test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} result, _ := whereClause(&test.q, 1) if result != test.expect { t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) @@ -407,6 +410,7 @@ func TestInClause(t *testing.T) { } for i, test := range tests { + test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} result, args := inClause(&test.q, 1) if result != test.expect { t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) @@ -489,7 +493,7 @@ func TestConvertInQuestionMarks(t *testing.T) { } for i, test := range tests { - res, count := convertInQuestionMarks(test.clause, test.start, test.group, test.total) + res, count := convertInQuestionMarks(true, test.clause, test.start, test.group, test.total) if res != test.expect { t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res) } @@ -497,6 +501,14 @@ func TestConvertInQuestionMarks(t *testing.T) { t.Errorf("%d) Expected %d, got %d", i, test.total, count) } } + + res, count := convertInQuestionMarks(false, "?", 1, 3, 9) + if res != "((?,?,?),(?,?,?),(?,?,?))" { + t.Errorf("Mismatch between expected and result: %s", res) + } + if count != 9 { + t.Errorf("Expected 9 results, got %d", count) + } } func TestWriteAsStatements(t *testing.T) { @@ -512,6 +524,7 @@ func TestWriteAsStatements(t *testing.T) { `a.clown.run`, `COUNT(a)`, }, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } expect := []string{ diff --git a/boil/query_test.go b/queries/query_test.go similarity index 91% rename from boil/query_test.go rename to queries/query_test.go index f61818929..06bccf45e 100644 --- a/boil/query_test.go +++ b/queries/query_test.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "database/sql" @@ -36,12 +36,12 @@ func TestSetSQL(t *testing.T) { q := &Query{} SetSQL(q, "select * from thing", 5, 3) - if len(q.plainSQL.args) != 2 { - t.Errorf("Expected len 2, got %d", len(q.plainSQL.args)) + if len(q.rawSQL.args) != 2 { + t.Errorf("Expected len 2, got %d", len(q.rawSQL.args)) } - if q.plainSQL.sql != "select * from thing" { - t.Errorf("Was not expected string, got %s", q.plainSQL.sql) + if q.rawSQL.sql != "select * from thing" { + t.Errorf("Was not expected string, got %s", q.rawSQL.sql) } } @@ -290,6 +290,17 @@ func TestFrom(t *testing.T) { } } +func TestSetSelect(t *testing.T) { + t.Parallel() + + q := &Query{selectCols: []string{"hello"}} + SetSelect(q, nil) + + if q.selectCols != nil { + t.Errorf("want nil") + } +} + func TestSetCount(t *testing.T) { t.Parallel() @@ -362,24 +373,24 @@ func TestAppendSelect(t *testing.T) { func TestSQL(t *testing.T) { t.Parallel() - q := SQL(&sql.DB{}, "thing", 5) - if q.plainSQL.sql != "thing" { - t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql) + q := Raw(&sql.DB{}, "thing", 5) + if q.rawSQL.sql != "thing" { + t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql) } - if q.plainSQL.args[0].(int) != 5 { - t.Errorf("Expected 5, got %v", q.plainSQL.args[0]) + if q.rawSQL.args[0].(int) != 5 { + t.Errorf("Expected 5, got %v", q.rawSQL.args[0]) } } func TestSQLG(t *testing.T) { t.Parallel() - q := SQLG("thing", 5) - if q.plainSQL.sql != "thing" { - t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql) + q := RawG("thing", 5) + if q.rawSQL.sql != "thing" { + t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql) } - if q.plainSQL.args[0].(int) != 5 { - t.Errorf("Expected 5, got %v", q.plainSQL.args[0]) + if q.rawSQL.args[0].(int) != 5 { + t.Errorf("Expected 5, got %v", q.rawSQL.args[0]) } } diff --git a/boil/reflect.go b/queries/reflect.go similarity index 82% rename from boil/reflect.go rename to queries/reflect.go index 27376edc4..1a1e6cbc5 100644 --- a/boil/reflect.go +++ b/queries/reflect.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "database/sql" @@ -8,6 +8,7 @@ import ( "sync" "github.com/pkg/errors" + "github.com/vattle/sqlboiler/boil" "github.com/vattle/sqlboiler/strmangle" ) @@ -40,7 +41,7 @@ const ( // It panics on error. See boil.Bind() documentation. func (q *Query) BindP(obj interface{}) { if err := q.Bind(obj); err != nil { - panic(WrapErr(err)) + panic(boil.WrapErr(err)) } } @@ -100,7 +101,7 @@ func (q *Query) Bind(obj interface{}) error { return err } - rows, err := ExecQueryAll(q) + rows, err := q.Query() if err != nil { return errors.Wrap(err, "bind failed to execute query") } @@ -322,8 +323,10 @@ func ptrFromMapping(val reflect.Value, mapping uint64, addressOf bool) reflect.V v := (mapping >> uint(i*8)) & sentinel if v == sentinel { - if val.Kind() != reflect.Ptr { + if addressOf && val.Kind() != reflect.Ptr { return val.Addr() + } else if !addressOf && val.Kind() == reflect.Ptr { + return reflect.Indirect(val) } return val } @@ -404,74 +407,3 @@ func makeCacheKey(typ string, cols []string) string { return mapKey } - -// GetStructValues returns the values (as interface) of the matching columns in obj -func GetStructValues(obj interface{}, columns ...string) []interface{} { - ret := make([]interface{}, len(columns)) - val := reflect.Indirect(reflect.ValueOf(obj)) - - for i, c := range columns { - fieldName := strmangle.TitleCase(c) - field := val.FieldByName(fieldName) - if !field.IsValid() { - panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj)) - } - ret[i] = field.Interface() - } - - return ret -} - -// GetSliceValues returns the values (as interface) of the matching columns in obj. -func GetSliceValues(slice []interface{}, columns ...string) []interface{} { - ret := make([]interface{}, len(slice)*len(columns)) - - for i, obj := range slice { - val := reflect.Indirect(reflect.ValueOf(obj)) - for j, c := range columns { - fieldName := strmangle.TitleCase(c) - field := val.FieldByName(fieldName) - if !field.IsValid() { - panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj)) - } - ret[i*len(columns)+j] = field.Interface() - } - } - - return ret -} - -// GetStructPointers returns a slice of pointers to the matching columns in obj -func GetStructPointers(obj interface{}, columns ...string) []interface{} { - val := reflect.ValueOf(obj).Elem() - - var ln int - var getField func(reflect.Value, int) reflect.Value - - if len(columns) == 0 { - ln = val.NumField() - getField = func(v reflect.Value, i int) reflect.Value { - return v.Field(i) - } - } else { - ln = len(columns) - getField = func(v reflect.Value, i int) reflect.Value { - return v.FieldByName(strmangle.TitleCase(columns[i])) - } - } - - ret := make([]interface{}, ln) - for i := 0; i < ln; i++ { - field := getField(val, i) - - if !field.IsValid() { - // Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually - // produce an Invalid value, so we make a hopefully safe assumption here. - panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i]))) - } - - ret[i] = field.Addr().Interface() - } - - return ret -} diff --git a/boil/reflect_test.go b/queries/reflect_test.go similarity index 64% rename from boil/reflect_test.go rename to queries/reflect_test.go index 0a05fd680..8d83ed35b 100644 --- a/boil/reflect_test.go +++ b/queries/reflect_test.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "database/sql/driver" @@ -6,10 +6,8 @@ import ( "strconv" "strings" "testing" - "time" "gopkg.in/DATA-DOG/go-sqlmock.v1" - "gopkg.in/nullbio/null.v4" ) func bin64(i uint64) string { @@ -44,7 +42,8 @@ func TestBindStruct(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -83,7 +82,8 @@ func TestBindSlice(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -133,7 +133,8 @@ func TestBindPtrSlice(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -265,6 +266,76 @@ func TestPtrFromMapping(t *testing.T) { } } +func TestValuesFromMapping(t *testing.T) { + t.Parallel() + + type NestedPtrs struct { + Int int + IntP *int + NestedPtrsP *NestedPtrs + } + + val := &NestedPtrs{ + Int: 5, + IntP: new(int), + NestedPtrsP: &NestedPtrs{ + Int: 6, + IntP: new(int), + }, + } + + mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)} + v := ValuesFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping) + + if got := v[0].(int); got != 5 { + t.Error("flat int was wrong:", got) + } + if got := v[1].(int); got != 0 { + t.Error("flat pointer was wrong:", got) + } + if got := v[2].(int); got != 6 { + t.Error("nested int was wrong:", got) + } + if got := v[3].(int); got != 0 { + t.Error("nested pointer was wrong:", got) + } +} + +func TestPtrsFromMapping(t *testing.T) { + t.Parallel() + + type NestedPtrs struct { + Int int + IntP *int + NestedPtrsP *NestedPtrs + } + + val := &NestedPtrs{ + Int: 5, + IntP: new(int), + NestedPtrsP: &NestedPtrs{ + Int: 6, + IntP: new(int), + }, + } + + mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)} + v := PtrsFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping) + + if got := *v[0].(*int); got != 5 { + t.Error("flat int was wrong:", got) + } + if got := *v[1].(*int); got != 0 { + t.Error("flat pointer was wrong:", got) + } + if got := *v[2].(*int); got != 6 { + t.Error("nested int was wrong:", got) + } + if got := *v[3].(*int); got != 0 { + t.Error("nested pointer was wrong:", got) + } +} + func TestGetBoilTag(t *testing.T) { t.Parallel() @@ -369,7 +440,8 @@ func TestBindSingular(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -412,8 +484,9 @@ func TestBind_InnerJoin(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, - joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}}, + from: []string{"fun"}, + joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -454,249 +527,59 @@ func TestBind_InnerJoin(t *testing.T) { } } -// func TestBind_InnerJoinSelect(t *testing.T) { -// t.Parallel() -// -// testResults := []*struct { -// Happy struct { -// ID int -// } `boil:"h,bind"` -// Fun struct { -// ID int -// } `boil:",bind"` -// }{} -// -// query := &Query{ -// selectCols: []string{"fun.id", "h.id"}, -// from: []string{"fun"}, -// joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}}, -// } -// -// db, mock, err := sqlmock.New() -// if err != nil { -// t.Error(err) -// } -// -// ret := sqlmock.NewRows([]string{"fun.id", "h.id"}) -// ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11))) -// ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13))) -// mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret) -// -// SetExecutor(query, db) -// err = query.Bind(&testResults) -// if err != nil { -// t.Error(err) -// } -// -// if len(testResults) != 2 { -// t.Fatal("wrong number of results:", len(testResults)) -// } -// if id := testResults[0].Happy.ID; id != 11 { -// t.Error("wrong ID:", id) -// } -// if id := testResults[0].Fun.ID; id != 10 { -// t.Error("wrong ID:", id) -// } -// -// if id := testResults[1].Happy.ID; id != 13 { -// t.Error("wrong ID:", id) -// } -// if id := testResults[1].Fun.ID; id != 12 { -// t.Error("wrong ID:", id) -// } -// -// if err := mock.ExpectationsWereMet(); err != nil { -// t.Error(err) -// } -// } - -// func TestBindPtrs_Easy(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// ID int `boil:"identifier"` -// Date time.Time -// }{} -// -// cols := []string{"identifier", "date"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.ID { -// t.Error("id is the wrong pointer") -// } -// if ptrs[1].(*time.Time) != &testStruct.Date { -// t.Error("id is the wrong pointer") -// } -// } -// -// func TestBindPtrs_Recursive(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// Happy struct { -// ID int `boil:"identifier"` -// } -// Fun struct { -// ID int -// } `boil:",bind"` -// }{} -// -// cols := []string{"id", "fun.id"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.Fun.ID { -// t.Error("id is the wrong pointer") -// } -// if ptrs[1].(*int) != &testStruct.Fun.ID { -// t.Error("id is the wrong pointer") -// } -// } -// -// func TestBindPtrs_RecursiveTags(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// Happy struct { -// ID int `boil:"identifier"` -// } `boil:",bind"` -// Fun struct { -// ID int `boil:"identification"` -// } `boil:",bind"` -// }{} -// -// cols := []string{"happy.identifier", "fun.identification"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.Happy.ID { -// t.Error("id is the wrong pointer") -// } -// if ptrs[1].(*int) != &testStruct.Fun.ID { -// t.Error("id is the wrong pointer") -// } -// } -// -// func TestBindPtrs_Ignore(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// ID int `boil:"-"` -// Happy struct { -// ID int -// } `boil:",bind"` -// }{} -// -// cols := []string{"id"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.Happy.ID { -// t.Error("id is the wrong pointer") -// } -// } - -func TestGetStructValues(t *testing.T) { +func TestBind_InnerJoinSelect(t *testing.T) { t.Parallel() - timeThing := time.Now() - o := struct { - TitleThing string - Name string - ID int - Stuff int - Things int - Time time.Time - NullBool null.Bool - }{ - TitleThing: "patrick", - Stuff: 10, - Things: 0, - Time: timeThing, - NullBool: null.NewBool(true, false), - } + testResults := []*struct { + Happy struct { + ID int + } `boil:"h,bind"` + Fun struct { + ID int + } `boil:",bind"` + }{} - vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool") - if vals[0].(string) != "patrick" { - t.Errorf("Want test, got %s", vals[0]) - } - if vals[1].(string) != "" { - t.Errorf("Want empty string, got %s", vals[1]) - } - if vals[2].(int) != 0 { - t.Errorf("Want 0, got %d", vals[2]) - } - if vals[3].(int) != 10 { - t.Errorf("Want 10, got %d", vals[3]) - } - if vals[4].(int) != 0 { - t.Errorf("Want 0, got %d", vals[4]) - } - if !vals[5].(time.Time).Equal(timeThing) { - t.Errorf("Want %s, got %s", o.Time, vals[5]) - } - if !vals[6].(null.Bool).IsZero() { - t.Errorf("Want %v, got %v", o.NullBool, vals[6]) + query := &Query{ + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, + selectCols: []string{"fun.id", "h.id"}, + from: []string{"fun"}, + joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}}, } -} -func TestGetSliceValues(t *testing.T) { - t.Parallel() - - o := []struct { - ID int - Name string - }{ - {5, "a"}, - {6, "b"}, + db, mock, err := sqlmock.New() + if err != nil { + t.Error(err) } - in := make([]interface{}, len(o)) - in[0] = o[0] - in[1] = o[1] + ret := sqlmock.NewRows([]string{"fun.id", "h.id"}) + ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11))) + ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13))) + mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret) - vals := GetSliceValues(in, "id", "name") - if got := vals[0].(int); got != 5 { - t.Error(got) + SetExecutor(query, db) + err = query.Bind(&testResults) + if err != nil { + t.Error(err) } - if got := vals[1].(string); got != "a" { - t.Error(got) + + if len(testResults) != 2 { + t.Fatal("wrong number of results:", len(testResults)) } - if got := vals[2].(int); got != 6 { - t.Error(got) + if id := testResults[0].Happy.ID; id != 11 { + t.Error("wrong ID:", id) } - if got := vals[3].(string); got != "b" { - t.Error(got) + if id := testResults[0].Fun.ID; id != 10 { + t.Error("wrong ID:", id) } -} -func TestGetStructPointers(t *testing.T) { - t.Parallel() - - o := struct { - Title string - ID *int - }{ - Title: "patrick", + if id := testResults[1].Happy.ID; id != 13 { + t.Error("wrong ID:", id) } - - ptrs := GetStructPointers(&o, "title", "id") - *ptrs[0].(*string) = "test" - if o.Title != "test" { - t.Errorf("Expected test, got %s", o.Title) + if id := testResults[1].Fun.ID; id != 12 { + t.Error("wrong ID:", id) } - x := 5 - *ptrs[1].(**int) = &x - if *o.ID != 5 { - t.Errorf("Expected 5, got %d", *o.ID) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) } } diff --git a/randomize/random.go b/randomize/random.go new file mode 100644 index 000000000..27395c6e1 --- /dev/null +++ b/randomize/random.go @@ -0,0 +1,121 @@ +package randomize + +import ( + "crypto/md5" + "fmt" + "math/rand" +) + +const alphabetAll = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const alphabetLowerAlpha = "abcdefghijklmnopqrstuvwxyz" + +func randStr(s *Seed, ln int) string { + str := make([]byte, ln) + for i := 0; i < ln; i++ { + str[i] = byte(alphabetAll[s.nextInt()%len(alphabetAll)]) + } + + return string(str) +} + +func randByteSlice(s *Seed, ln int) []byte { + str := make([]byte, ln) + for i := 0; i < ln; i++ { + str[i] = byte(s.nextInt() % 256) + } + + return str +} + +func randPoint() string { + a := rand.Intn(100) + b := a + 1 + return fmt.Sprintf("(%d,%d)", a, b) +} + +func randBox() string { + a := rand.Intn(100) + b := a + 1 + c := a + 2 + d := a + 3 + return fmt.Sprintf("(%d,%d),(%d,%d)", a, b, c, d) +} + +func randCircle() string { + a, b, c := rand.Intn(100), rand.Intn(100), rand.Intn(100) + return fmt.Sprintf("((%d,%d),%d)", a, b, c) +} + +func randNetAddr() string { + return fmt.Sprintf( + "%d.%d.%d.%d", + rand.Intn(254)+1, + rand.Intn(254)+1, + rand.Intn(254)+1, + rand.Intn(254)+1, + ) +} + +func randMacAddr() string { + buf := make([]byte, 6) + _, err := rand.Read(buf) + if err != nil { + panic(err) + } + + // Set the local bit + buf[0] |= 2 + return fmt.Sprintf( + "%02x:%02x:%02x:%02x:%02x:%02x", + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], + ) +} + +func randLsn() string { + a := rand.Int63n(9000000) + b := rand.Int63n(9000000) + return fmt.Sprintf("%d/%d", a, b) +} + +func randTxID() string { + // Order of integers is relevant + a := rand.Intn(200) + 100 + b := a + 100 + c := a + d := a + 50 + return fmt.Sprintf("%d:%d:%d,%d", a, b, c, d) +} + +func randMoney(s *Seed) string { + return fmt.Sprintf("%d.00", s.nextInt()) +} + +// StableDBName takes a database name in, and generates +// a random string using the database name as the rand Seed. +// getDBNameHash is used to generate unique test database names. +func StableDBName(input string) string { + return randStrFromSource(stableSource(input), 40) +} + +// stableSource takes an input value, and produces a random +// seed from it that will produce very few collisions in +// a 40 character random string made from a different alphabet. +func stableSource(input string) *rand.Rand { + sum := md5.Sum([]byte(input)) + var seed int64 + for i, byt := range sum { + seed ^= int64(byt) << uint((i*4)%64) + } + return rand.New(rand.NewSource(seed)) +} + +func randStrFromSource(r *rand.Rand, length int) string { + ln := len(alphabetLowerAlpha) + + output := make([]rune, length) + for i := 0; i < length; i++ { + output[i] = rune(alphabetLowerAlpha[r.Intn(ln)]) + } + + return string(output) +} diff --git a/randomize/random_test.go b/randomize/random_test.go new file mode 100644 index 000000000..ee910ce61 --- /dev/null +++ b/randomize/random_test.go @@ -0,0 +1,19 @@ +package randomize + +import "testing" + +func TestStableDBName(t *testing.T) { + t.Parallel() + + db := "awesomedb" + + one, two := StableDBName(db), StableDBName(db) + + if len(one) != 40 { + t.Error("want 40 characters:", len(one), one) + } + + if one != two { + t.Error("it should always produce the same value") + } +} diff --git a/boil/randomize/randomize.go b/randomize/randomize.go similarity index 55% rename from boil/randomize/randomize.go rename to randomize/randomize.go index 89e516acb..739198067 100644 --- a/boil/randomize/randomize.go +++ b/randomize/randomize.go @@ -2,41 +2,58 @@ package randomize import ( + "database/sql" + "fmt" "reflect" "regexp" "sort" "strconv" + "strings" "sync/atomic" "time" - "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v5" "github.com/pkg/errors" "github.com/satori/go.uuid" "github.com/vattle/sqlboiler/strmangle" + "github.com/vattle/sqlboiler/types" ) var ( - typeNullFloat32 = reflect.TypeOf(null.Float32{}) - typeNullFloat64 = reflect.TypeOf(null.Float64{}) - typeNullInt = reflect.TypeOf(null.Int{}) - typeNullInt8 = reflect.TypeOf(null.Int8{}) - typeNullInt16 = reflect.TypeOf(null.Int16{}) - typeNullInt32 = reflect.TypeOf(null.Int32{}) - typeNullInt64 = reflect.TypeOf(null.Int64{}) - typeNullUint = reflect.TypeOf(null.Uint{}) - typeNullUint8 = reflect.TypeOf(null.Uint8{}) - typeNullUint16 = reflect.TypeOf(null.Uint16{}) - typeNullUint32 = reflect.TypeOf(null.Uint32{}) - typeNullUint64 = reflect.TypeOf(null.Uint64{}) - typeNullString = reflect.TypeOf(null.String{}) - typeNullBool = reflect.TypeOf(null.Bool{}) - typeNullTime = reflect.TypeOf(null.Time{}) - typeTime = reflect.TypeOf(time.Time{}) - - rgxValidTime = regexp.MustCompile(`[2-9]+`) - - validatedTypes = []string{"uuid", "interval"} + typeNullFloat32 = reflect.TypeOf(null.Float32{}) + typeNullFloat64 = reflect.TypeOf(null.Float64{}) + typeNullInt = reflect.TypeOf(null.Int{}) + typeNullInt8 = reflect.TypeOf(null.Int8{}) + typeNullInt16 = reflect.TypeOf(null.Int16{}) + typeNullInt32 = reflect.TypeOf(null.Int32{}) + typeNullInt64 = reflect.TypeOf(null.Int64{}) + typeNullUint = reflect.TypeOf(null.Uint{}) + typeNullUint8 = reflect.TypeOf(null.Uint8{}) + typeNullUint16 = reflect.TypeOf(null.Uint16{}) + typeNullUint32 = reflect.TypeOf(null.Uint32{}) + typeNullUint64 = reflect.TypeOf(null.Uint64{}) + typeNullString = reflect.TypeOf(null.String{}) + typeNullBool = reflect.TypeOf(null.Bool{}) + typeNullTime = reflect.TypeOf(null.Time{}) + typeNullBytes = reflect.TypeOf(null.Bytes{}) + typeNullJSON = reflect.TypeOf(null.JSON{}) + typeTime = reflect.TypeOf(time.Time{}) + typeJSON = reflect.TypeOf(types.JSON{}) + typeInt64Array = reflect.TypeOf(types.Int64Array{}) + typeBytesArray = reflect.TypeOf(types.BytesArray{}) + typeBoolArray = reflect.TypeOf(types.BoolArray{}) + typeFloat64Array = reflect.TypeOf(types.Float64Array{}) + typeStringArray = reflect.TypeOf(types.StringArray{}) + typeHStore = reflect.TypeOf(types.HStore{}) + rgxValidTime = regexp.MustCompile(`[2-9]+`) + + validatedTypes = []string{ + "inet", "line", "uuid", "interval", + "json", "jsonb", "box", "cidr", "circle", + "lseg", "macaddr", "path", "pg_lsn", "point", + "polygon", "txid_snapshot", "money", "hstore", + } ) // Seed is an atomic counter for pseudo-randomization structs. Using full @@ -163,7 +180,59 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo field.Set(reflect.ValueOf(value)) return nil } + if fieldType == "box" || fieldType == "line" || fieldType == "lseg" || + fieldType == "path" || fieldType == "polygon" { + value = null.NewString(randBox(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "cidr" || fieldType == "inet" { + value = null.NewString(randNetAddr(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "macaddr" { + value = null.NewString(randMacAddr(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "circle" { + value = null.NewString(randCircle(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "pg_lsn" { + value = null.NewString(randLsn(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "point" { + value = null.NewString(randPoint(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "txid_snapshot" { + value = null.NewString(randTxID(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "money" { + value = null.NewString(randMoney(s), true) + field.Set(reflect.ValueOf(value)) + return nil + } + case typeNullJSON: + value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true) + field.Set(reflect.ValueOf(value)) + return nil + case typeHStore: + value := types.HStore{} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + field.Set(reflect.ValueOf(value)) + return nil } + } else { switch kind { case reflect.String: @@ -177,6 +246,59 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo field.Set(reflect.ValueOf(value)) return nil } + if fieldType == "box" || fieldType == "line" || fieldType == "lseg" || + fieldType == "path" || fieldType == "polygon" { + value = randBox() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "cidr" || fieldType == "inet" { + value = randNetAddr() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "macaddr" { + value = randMacAddr() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "circle" { + value = randCircle() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "pg_lsn" { + value = randLsn() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "point" { + value = randPoint() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "txid_snapshot" { + value = randTxID() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "money" { + value = randMoney(s) + field.Set(reflect.ValueOf(value)) + return nil + } + } + switch typ { + case typeJSON: + value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1))) + field.Set(reflect.ValueOf(value)) + return nil + case typeHStore: + value := types.HStore{} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + field.Set(reflect.ValueOf(value)) + return nil } } } @@ -191,8 +313,11 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo isNull = false } - // Retrieve the value to be returned - if kind == reflect.Struct { + // If it's a Postgres array, treat it like one + if strings.HasPrefix(fieldType, "ARRAY") { + value = getArrayRandValue(s, typ, fieldType) + // Retrieve the value to be returned + } else if kind == reflect.Struct { if isNull { value = getStructNullValue(typ) } else { @@ -215,6 +340,69 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo return nil } +func getArrayRandValue(s *Seed, typ reflect.Type, fieldType string) interface{} { + fieldType = strings.TrimLeft(fieldType, "ARRAY") + switch typ { + case typeInt64Array: + return types.Int64Array{int64(s.nextInt()), int64(s.nextInt())} + case typeFloat64Array: + return types.Float64Array{float64(s.nextInt()), float64(s.nextInt())} + case typeBoolArray: + return types.BoolArray{s.nextInt()%2 == 0, s.nextInt()%2 == 0, s.nextInt()%2 == 0} + case typeStringArray: + if fieldType == "interval" { + value := strconv.Itoa((s.nextInt()%26)+2) + " days" + return types.StringArray{value, value} + } + if fieldType == "uuid" { + value := uuid.NewV4().String() + return types.StringArray{value, value} + } + if fieldType == "box" || fieldType == "line" || fieldType == "lseg" || + fieldType == "path" || fieldType == "polygon" { + value := randBox() + return types.StringArray{value, value} + } + if fieldType == "cidr" || fieldType == "inet" { + value := randNetAddr() + return types.StringArray{value, value} + } + if fieldType == "macaddr" { + value := randMacAddr() + return types.StringArray{value, value} + } + if fieldType == "circle" { + value := randCircle() + return types.StringArray{value, value} + } + if fieldType == "pg_lsn" { + value := randLsn() + return types.StringArray{value, value} + } + if fieldType == "point" { + value := randPoint() + return types.StringArray{value, value} + } + if fieldType == "txid_snapshot" { + value := randTxID() + return types.StringArray{value, value} + } + if fieldType == "money" { + value := randMoney(s) + return types.StringArray{value, value} + } + if fieldType == "json" || fieldType == "jsonb" { + value := []byte(fmt.Sprintf(`"%s"`, randStr(s, 1))) + return types.StringArray{string(value)} + } + return types.StringArray{randStr(s, 4), randStr(s, 4), randStr(s, 4)} + case typeBytesArray: + return types.BytesArray{randByteSlice(s, 4), randByteSlice(s, 4), randByteSlice(s, 4)} + } + + return nil +} + // getStructNullValue for the matching type. func getStructNullValue(typ reflect.Type) interface{} { switch typ { @@ -250,6 +438,8 @@ func getStructNullValue(typ reflect.Type) interface{} { return null.NewUint32(0, false) case typeNullUint64: return null.NewUint64(0, false) + case typeNullBytes: + return null.NewBytes(nil, false) } return nil @@ -292,6 +482,8 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} { return null.NewUint32(uint32(s.nextInt()), true) case typeNullUint64: return null.NewUint64(uint64(s.nextInt()), true) + case typeNullBytes: + return null.NewBytes(randByteSlice(s, 16), true) } return nil @@ -378,23 +570,3 @@ func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interfac return nil } - -const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func randStr(s *Seed, ln int) string { - str := make([]byte, ln) - for i := 0; i < ln; i++ { - str[i] = byte(alphabet[s.nextInt()%len(alphabet)]) - } - - return string(str) -} - -func randByteSlice(s *Seed, ln int) []byte { - str := make([]byte, ln) - for i := 0; i < ln; i++ { - str[i] = byte(s.nextInt() % 256) - } - - return str -} diff --git a/boil/randomize/randomize_test.go b/randomize/randomize_test.go similarity index 99% rename from boil/randomize/randomize_test.go rename to randomize/randomize_test.go index ee028b9ab..a3ba08304 100644 --- a/boil/randomize/randomize_test.go +++ b/randomize/randomize_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v5" ) func TestRandomizeStruct(t *testing.T) { diff --git a/sqlboiler.go b/sqlboiler.go index aa4d92402..0c52dcea1 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -15,7 +15,8 @@ import ( "github.com/pkg/errors" "github.com/vattle/sqlboiler/bdb" "github.com/vattle/sqlboiler/bdb/drivers" - "github.com/vattle/sqlboiler/boil" + "github.com/vattle/sqlboiler/queries" + "github.com/vattle/sqlboiler/strmangle" ) const ( @@ -32,8 +33,9 @@ const ( type State struct { Config *Config - Driver bdb.Interface - Tables []bdb.Table + Driver bdb.Interface + Tables []bdb.Table + Dialect queries.Dialect Templates *templateList TestTemplates *templateList @@ -59,7 +61,7 @@ func New(config *Config) (*State, error) { return nil, errors.Wrap(err, "unable to connect to the database") } - err = s.initTables(config.ExcludeTables) + err = s.initTables(config.Schema, config.WhitelistTables, config.BlacklistTables) if err != nil { return nil, errors.Wrap(err, "unable to initialize tables") } @@ -69,8 +71,7 @@ func New(config *Config) (*State, error) { if err != nil { return nil, errors.Wrap(err, "unable to json marshal tables") } - boil.DebugWriter.Write(b) - fmt.Fprintln(boil.DebugWriter) + fmt.Printf("%s\n", b) } err = s.initOutFolder() @@ -96,11 +97,15 @@ func New(config *Config) (*State, error) { func (s *State) Run(includeTests bool) error { singletonData := &templateData{ Tables: s.Tables, + Schema: s.Config.Schema, DriverName: s.Config.DriverName, UseLastInsertID: s.Driver.UseLastInsertID(), PkgName: s.Config.PkgName, NoHooks: s.Config.NoHooks, NoAutoTimestamps: s.Config.NoAutoTimestamps, + Dialect: s.Dialect, + LQ: strmangle.QuoteCharacter(s.Dialect.LQ), + RQ: strmangle.QuoteCharacter(s.Dialect.RQ), StringFuncs: templateStringMappers, } @@ -127,12 +132,16 @@ func (s *State) Run(includeTests bool) error { data := &templateData{ Tables: s.Tables, Table: table, + Schema: s.Config.Schema, DriverName: s.Config.DriverName, UseLastInsertID: s.Driver.UseLastInsertID(), PkgName: s.Config.PkgName, NoHooks: s.Config.NoHooks, NoAutoTimestamps: s.Config.NoAutoTimestamps, Tags: s.Config.Tags, + Dialect: s.Dialect, + LQ: strmangle.QuoteCharacter(s.Dialect.LQ), + RQ: strmangle.QuoteCharacter(s.Dialect.RQ), StringFuncs: templateStringMappers, } @@ -227,6 +236,15 @@ func (s *State) initDriver(driverName string) error { s.Config.Postgres.Port, s.Config.Postgres.SSLMode, ) + case "mysql": + s.Driver = drivers.NewMySQLDriver( + s.Config.MySQL.User, + s.Config.MySQL.Pass, + s.Config.MySQL.DBName, + s.Config.MySQL.Host, + s.Config.MySQL.Port, + s.Config.MySQL.SSLMode, + ) case "mock": s.Driver = &drivers.MockDriver{} } @@ -235,13 +253,17 @@ func (s *State) initDriver(driverName string) error { return errors.New("An invalid driver name was provided") } + s.Dialect.LQ = s.Driver.LeftQuote() + s.Dialect.RQ = s.Driver.RightQuote() + s.Dialect.IndexPlaceholders = s.Driver.IndexPlaceholders() + return nil } // initTables retrieves all "public" schema table names from the database. -func (s *State) initTables(exclude []string) error { +func (s *State) initTables(schema string, whitelist, blacklist []string) error { var err error - s.Tables, err = bdb.Tables(s.Driver, exclude...) + s.Tables, err = bdb.Tables(s.Driver, schema, whitelist, blacklist) if err != nil { return errors.Wrap(err, "unable to fetch table data") } diff --git a/sqlboiler_test.go b/sqlboiler_test.go index a170361fc..367e429b5 100644 --- a/sqlboiler_test.go +++ b/sqlboiler_test.go @@ -37,10 +37,10 @@ func TestNew(t *testing.T) { }() config := &Config{ - DriverName: "mock", - PkgName: "models", - OutFolder: out, - ExcludeTables: []string{"hangars"}, + DriverName: "mock", + PkgName: "models", + OutFolder: out, + BlacklistTables: []string{"hangars"}, } state, err = New(config) diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 822891963..5f62aca47 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -22,6 +22,7 @@ var uppercaseWords = map[string]struct{}{ "id": {}, "uid": {}, "uuid": {}, + "json": {}, } func init() { @@ -33,9 +34,21 @@ func init() { boilRuleset = newBoilRuleset() } +// SchemaTable returns a table name with a schema prefixed if +// using a database that supports real schemas, for example, +// for Postgres: "schema_name"."table_name", versus +// simply "table_name" for MySQL (because it does not support real schemas) +func SchemaTable(lq, rq string, driver string, schema string, table string) string { + if driver == "postgres" && schema != "public" { + return fmt.Sprintf(`%s%s%s.%s%s%s`, lq, schema, rq, lq, table, rq) + } + + return fmt.Sprintf(`%s%s%s`, lq, table, rq) +} + // IdentQuote attempts to quote simple identifiers in SQL tatements -func IdentQuote(s string) string { - if strings.ToLower(s) == "null" { +func IdentQuote(lq byte, rq byte, s string) string { + if strings.ToLower(s) == "null" || s == "?" { return s } @@ -52,28 +65,28 @@ func IdentQuote(s string) string { buf.WriteByte('.') } - if strings.HasPrefix(split, `"`) || strings.HasSuffix(split, `"`) || split == "*" { + if split[0] == lq || split[len(split)-1] == rq || split == "*" { buf.WriteString(split) continue } - buf.WriteByte('"') + buf.WriteByte(lq) buf.WriteString(split) - buf.WriteByte('"') + buf.WriteByte(rq) } return buf.String() } // IdentQuoteSlice applies IdentQuote to a slice. -func IdentQuoteSlice(s []string) []string { +func IdentQuoteSlice(lq byte, rq byte, s []string) []string { if len(s) == 0 { return s } strs := make([]string, len(s)) for i, str := range s { - strs[i] = IdentQuote(str) + strs[i] = IdentQuote(lq, rq, str) } return strs @@ -105,6 +118,16 @@ func Identifier(in int) string { return cols.String() } +// QuoteCharacter returns a string that allows the quote character +// to be embedded into a Go string that uses double quotes: +func QuoteCharacter(q byte) string { + if q == '"' { + return `\"` + } + + return string(q) +} + // Plural converts singular words to plural words (eg: person to people) func Plural(name string) string { buf := GetBuffer() @@ -368,7 +391,8 @@ func PrefixStringSlice(str string, strs []string) []string { // Placeholders generates the SQL statement placeholders for in queries. // For example, ($1,$2,$3),($4,$5,$6) etc. // It will start counting placeholders at "start". -func Placeholders(count int, start int, group int) string { +// If indexPlaceholders is false, it will convert to ? instead of $1 etc. +func Placeholders(indexPlaceholders bool, count int, start int, group int) string { buf := GetBuffer() defer PutBuffer(buf) @@ -387,7 +411,11 @@ func Placeholders(count int, start int, group int) string { buf.WriteByte(',') } } - buf.WriteString(fmt.Sprintf("$%d", start+i)) + if indexPlaceholders { + buf.WriteString(fmt.Sprintf("$%d", start+i)) + } else { + buf.WriteByte('?') + } } if group > 1 { buf.WriteByte(')') @@ -399,14 +427,19 @@ func Placeholders(count int, start int, group int) string { // SetParamNames takes a slice of columns and returns a comma separated // list of parameter names for a template statement SET clause. // eg: "col1"=$1, "col2"=$2, "col3"=$3 -func SetParamNames(columns []string) string { +func SetParamNames(lq, rq string, start int, columns []string) string { buf := GetBuffer() defer PutBuffer(buf) for i, c := range columns { - buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, i+1)) + if start != 0 { + buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, i+start)) + } else { + buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq)) + } + if i < len(columns)-1 { - buf.WriteString(", ") + buf.WriteByte(',') } } @@ -415,16 +448,17 @@ func SetParamNames(columns []string) string { // WhereClause returns the where clause using start as the $ flag index // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" -func WhereClause(start int, cols []string) string { - if start == 0 { - panic("0 is not a valid start number for whereClause") - } - +func WhereClause(lq, rq string, start int, cols []string) string { buf := GetBuffer() defer PutBuffer(buf) for i, c := range cols { - buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, start+i)) + if start != 0 { + buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i)) + } else { + buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq)) + } + if i < len(cols)-1 { buf.WriteString(" AND ") } diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index c3626e2ad..6d802d4c0 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -29,7 +29,7 @@ func TestIdentQuote(t *testing.T) { } for _, test := range tests { - if got := IdentQuote(test.In); got != test.Out { + if got := IdentQuote('"', '"', test.In); got != test.Out { t.Errorf("want: %s, got: %s", test.Out, got) } } @@ -38,7 +38,7 @@ func TestIdentQuote(t *testing.T) { func TestIdentQuoteSlice(t *testing.T) { t.Parallel() - ret := IdentQuoteSlice([]string{`thing`, `null`}) + ret := IdentQuoteSlice('"', '"', []string{`thing`, `null`}) if ret[0] != `"thing"` { t.Error(ret[0]) } @@ -69,34 +69,60 @@ func TestIdentifier(t *testing.T) { } } +func TestQuoteCharacter(t *testing.T) { + t.Parallel() + + if QuoteCharacter('[') != "[" { + t.Error("want just the normal quote character") + } + if QuoteCharacter('`') != "`" { + t.Error("want just the normal quote character") + } + if QuoteCharacter('"') != `\"` { + t.Error("want an escaped character") + } +} + func TestPlaceholders(t *testing.T) { t.Parallel() - x := Placeholders(1, 2, 1) + x := Placeholders(true, 1, 2, 1) want := "$2" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(5, 1, 1) + x = Placeholders(true, 5, 1, 1) want = "$1,$2,$3,$4,$5" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(6, 1, 2) + x = Placeholders(false, 5, 1, 1) + want = "?,?,?,?,?" + if want != x { + t.Errorf("want %s, got %s", want, x) + } + + x = Placeholders(true, 6, 1, 2) + want = "($1,$2),($3,$4),($5,$6)" + if want != x { + t.Errorf("want %s, got %s", want, x) + } + + x = Placeholders(true, 6, 1, 2) want = "($1,$2),($3,$4),($5,$6)" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(9, 1, 3) - want = "($1,$2,$3),($4,$5,$6),($7,$8,$9)" + x = Placeholders(false, 9, 1, 3) + want = "(?,?,?),(?,?,?),(?,?,?)" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(7, 1, 3) + x = Placeholders(true, 7, 1, 3) want = "($1,$2,$3),($4,$5,$6),($7)" if want != x { t.Errorf("want %s, got %s", want, x) @@ -291,6 +317,28 @@ func TestPrefixStringSlice(t *testing.T) { } } +func TestSetParamNames(t *testing.T) { + t.Parallel() + + tests := []struct { + Cols []string + Start int + Should string + }{ + {Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=?,"col2"=?`}, + {Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`}, + {Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4,"col2"=$5`}, + {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4,"col2"=$5,"col3"=$6`}, + } + + for i, test := range tests { + r := SetParamNames(`"`, `"`, test.Start, test.Cols) + if r != test.Should { + t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) + } + } +} + func TestWhereClause(t *testing.T) { t.Parallel() @@ -299,13 +347,14 @@ func TestWhereClause(t *testing.T) { Start int Should string }{ + {Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=? AND "col2"=?`}, {Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`}, {Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5`}, {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5 AND "col3"=$6`}, } for i, test := range tests { - r := WhereClause(test.Start, test.Cols) + r := WhereClause(`"`, `"`, test.Start, test.Cols) if r != test.Should { t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) } diff --git a/templates.go b/templates.go index e760711d7..12108bbd2 100644 --- a/templates.go +++ b/templates.go @@ -8,21 +8,45 @@ import ( "text/template" "github.com/vattle/sqlboiler/bdb" + "github.com/vattle/sqlboiler/queries" "github.com/vattle/sqlboiler/strmangle" ) // templateData for sqlboiler templates type templateData struct { - Tables []bdb.Table - Table bdb.Table - DriverName string - UseLastInsertID bool - PkgName string + Tables []bdb.Table + Table bdb.Table + + // Controls what names are output + PkgName string + Schema string + + // Controls which code is output (mysql vs postgres ...) + DriverName string + UseLastInsertID bool + + // Turn off auto timestamps or hook generation NoHooks bool NoAutoTimestamps bool - Tags []string + // Tags control which + Tags []string + + // StringFuncs are usable in templates with stringMap StringFuncs map[string]func(string) string + + // Dialect controls quoting + Dialect queries.Dialect + LQ string + RQ string +} + +func (t templateData) Quotes(s string) string { + return fmt.Sprintf("%s%s%s", t.LQ, s, t.RQ) +} + +func (t templateData) SchemaTable(table string) string { + return strmangle.SchemaTable(t.LQ, t.RQ, t.DriverName, t.Schema, table) } type templateList struct { @@ -113,7 +137,7 @@ var templateStringMappers = map[string]func(string) string{ // add a function pointer here. var templateFunctions = template.FuncMap{ // String ops - "quoteWrap": func(a string) string { return fmt.Sprintf(`"%s"`, a) }, + "quoteWrap": func(s string) string { return fmt.Sprintf(`"%s"`, s) }, "id": strmangle.Identifier, // Pluralization @@ -150,6 +174,7 @@ var templateFunctions = template.FuncMap{ // dbdrivers ops "filterColumnsByDefault": bdb.FilterColumnsByDefault, + "autoIncPrimaryKey": bdb.AutoIncPrimaryKey, "sqlColDefinitions": bdb.SQLColDefinitions, "columnNames": bdb.ColumnNames, "columnDBTypes": bdb.ColumnDBTypes, diff --git a/templates/00_struct.tpl b/templates/00_struct.tpl index 70d1e2961..60b64b2e6 100644 --- a/templates/00_struct.tpl +++ b/templates/00_struct.tpl @@ -1,5 +1,5 @@ {{- define "relationship_to_one_struct_helper" -}} - {{.Function.Name}} *{{.ForeignTable.NameGo}} + {{.Function.Name}} *{{.ForeignTable.NameGo}} {{- end -}} {{- $dot := . -}} @@ -8,30 +8,30 @@ {{- $modelNameCamel := $tableNameSingular | camelCase -}} // {{$modelName}} is an object representing the database table. type {{$modelName}} struct { - {{range $column := .Table.Columns -}} - {{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"` - {{end -}} - {{- if .Table.IsJoinTable -}} - {{- else}} - R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` - L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` - {{end -}} + {{range $column := .Table.Columns -}} + {{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"` + {{end -}} + {{- if .Table.IsJoinTable -}} + {{- else}} + R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` + L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` + {{end -}} } {{- if .Table.IsJoinTable -}} {{- else}} // {{$modelNameCamel}}R is where relationships are stored. type {{$modelNameCamel}}R struct { - {{range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} - {{- template "relationship_to_one_struct_helper" $rel}} - {{end -}} - {{- range .Table.ToManyRelationships -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}} - {{else -}} - {{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}} - {{$rel.Function.Name}} {{$rel.ForeignTable.Slice}} + {{range .Table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_struct_helper" $rel}} + {{end -}} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}} + {{else -}} + {{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}} + {{$rel.Function.Name}} {{$rel.ForeignTable.Slice}} {{end -}}{{/* if ForeignColumnUnique */}} {{- end -}}{{/* range tomany */}} } diff --git a/templates/01_types.tpl b/templates/01_types.tpl index 790b56026..b3459d8bd 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -3,31 +3,36 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}} var ( - {{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} - {{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} - {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} - {{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} + {{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} + {{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} + {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} + {{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} ) type ( - {{$tableNameSingular}}Slice []*{{$tableNameSingular}} - {{if eq .NoHooks false -}} - {{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error - {{- end}} + // {{$tableNameSingular}}Slice is an alias for a slice of pointers to {{$tableNameSingular}}. + // This should generally be used opposed to []{{$tableNameSingular}}. + {{$tableNameSingular}}Slice []*{{$tableNameSingular}} + {{if eq .NoHooks false -}} + // {{$tableNameSingular}}Hook is the signature for custom {{$tableNameSingular}} hook methods + {{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error + {{- end}} - {{$varNameSingular}}Query struct { - *boil.Query - } + {{$varNameSingular}}Query struct { + *queries.Query + } ) -// Cache for insert and update +// Cache for insert, update and upsert var ( - {{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{}) - {{$varNameSingular}}Mapping = boil.MakeStructMapping({{$varNameSingular}}Type) - {{$varNameSingular}}InsertCacheMut sync.RWMutex - {{$varNameSingular}}InsertCache = make(map[string]insertCache) - {{$varNameSingular}}UpdateCacheMut sync.RWMutex - {{$varNameSingular}}UpdateCache = make(map[string]updateCache) + {{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{}) + {{$varNameSingular}}Mapping = queries.MakeStructMapping({{$varNameSingular}}Type) + {{$varNameSingular}}InsertCacheMut sync.RWMutex + {{$varNameSingular}}InsertCache = make(map[string]insertCache) + {{$varNameSingular}}UpdateCacheMut sync.RWMutex + {{$varNameSingular}}UpdateCache = make(map[string]updateCache) + {{$varNameSingular}}UpsertCacheMut sync.RWMutex + {{$varNameSingular}}UpsertCache = make(map[string]insertCache) ) // Force time package dependency for automated UpdatedAt/CreatedAt. diff --git a/templates/02_hooks.tpl b/templates/02_hooks.tpl index 9e152d1cc..e87283f4d 100644 --- a/templates/02_hooks.tpl +++ b/templates/02_hooks.tpl @@ -14,123 +14,124 @@ var {{$varNameSingular}}AfterUpsertHooks []{{$tableNameSingular}}Hook // doBeforeInsertHooks executes all "before insert" hooks. func (o *{{$tableNameSingular}}) doBeforeInsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeInsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeInsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doBeforeUpdateHooks executes all "before Update" hooks. func (o *{{$tableNameSingular}}) doBeforeUpdateHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeUpdateHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeUpdateHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doBeforeDeleteHooks executes all "before Delete" hooks. func (o *{{$tableNameSingular}}) doBeforeDeleteHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeDeleteHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeDeleteHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doBeforeUpsertHooks executes all "before Upsert" hooks. func (o *{{$tableNameSingular}}) doBeforeUpsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeUpsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeUpsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterInsertHooks executes all "after Insert" hooks. func (o *{{$tableNameSingular}}) doAfterInsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterInsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterInsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterSelectHooks executes all "after Select" hooks. func (o *{{$tableNameSingular}}) doAfterSelectHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterSelectHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterSelectHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterUpdateHooks executes all "after Update" hooks. func (o *{{$tableNameSingular}}) doAfterUpdateHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterUpdateHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterUpdateHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterDeleteHooks executes all "after Delete" hooks. func (o *{{$tableNameSingular}}) doAfterDeleteHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterDeleteHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterDeleteHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterUpsertHooks executes all "after Upsert" hooks. func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterUpsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterUpsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } +// Add{{$tableNameSingular}}Hook registers your hook function for all future operations. func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) { - switch hookPoint { - case boil.BeforeInsertHook: - {{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook) - case boil.BeforeUpdateHook: - {{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook) - case boil.BeforeDeleteHook: - {{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook) - case boil.BeforeUpsertHook: - {{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook) - case boil.AfterInsertHook: - {{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook) - case boil.AfterSelectHook: - {{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook) - case boil.AfterUpdateHook: - {{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook) - case boil.AfterDeleteHook: - {{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook) - case boil.AfterUpsertHook: - {{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook) - } + switch hookPoint { + case boil.BeforeInsertHook: + {{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook) + case boil.BeforeUpdateHook: + {{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook) + case boil.BeforeDeleteHook: + {{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook) + case boil.BeforeUpsertHook: + {{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook) + case boil.AfterInsertHook: + {{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook) + case boil.AfterSelectHook: + {{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook) + case boil.AfterUpdateHook: + {{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook) + case boil.AfterDeleteHook: + {{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook) + case boil.AfterUpsertHook: + {{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook) + } } {{- end}} diff --git a/templates/03_finishers.tpl b/templates/03_finishers.tpl index 6d259550e..429a27625 100644 --- a/templates/03_finishers.tpl +++ b/templates/03_finishers.tpl @@ -2,114 +2,115 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} // OneP returns a single {{$varNameSingular}} record from the query, and panics on error. func (q {{$varNameSingular}}Query) OneP() (*{{$tableNameSingular}}) { - o, err := q.One() - if err != nil { - panic(boil.WrapErr(err)) - } + o, err := q.One() + if err != nil { + panic(boil.WrapErr(err)) + } - return o + return o } // One returns a single {{$varNameSingular}} record from the query. func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) { - o := &{{$tableNameSingular}}{} + o := &{{$tableNameSingular}}{} - boil.SetLimit(q.Query, 1) + queries.SetLimit(q.Query, 1) - err := q.Bind(o) - if err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, sql.ErrNoRows - } - return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}") - } + err := q.Bind(o) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, sql.ErrNoRows + } + return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}") + } - {{if not .NoHooks -}} - if err := o.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { - return o, err - } - {{- end}} + {{if not .NoHooks -}} + if err := o.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil { + return o, err + } + {{- end}} - return o, nil + return o, nil } // AllP returns all {{$tableNameSingular}} records from the query, and panics on error. func (q {{$varNameSingular}}Query) AllP() {{$tableNameSingular}}Slice { - o, err := q.All() - if err != nil { - panic(boil.WrapErr(err)) - } + o, err := q.All() + if err != nil { + panic(boil.WrapErr(err)) + } - return o + return o } // All returns all {{$tableNameSingular}} records from the query. func (q {{$varNameSingular}}Query) All() ({{$tableNameSingular}}Slice, error) { - var o {{$tableNameSingular}}Slice - - err := q.Bind(&o) - if err != nil { - return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice") - } - - {{if not .NoHooks -}} - if len({{$varNameSingular}}AfterSelectHooks) != 0 { - for _, obj := range o { - if err := obj.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { - return o, err - } - } - } - {{- end}} - - return o, nil + var o {{$tableNameSingular}}Slice + + err := q.Bind(&o) + if err != nil { + return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice") + } + + {{if not .NoHooks -}} + if len({{$varNameSingular}}AfterSelectHooks) != 0 { + for _, obj := range o { + if err := obj.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil { + return o, err + } + } + } + {{- end}} + + return o, nil } // CountP returns the count of all {{$tableNameSingular}} records in the query, and panics on error. func (q {{$varNameSingular}}Query) CountP() int64 { - c, err := q.Count() - if err != nil { - panic(boil.WrapErr(err)) - } + c, err := q.Count() + if err != nil { + panic(boil.WrapErr(err)) + } - return c + return c } // Count returns the count of all {{$tableNameSingular}} records in the query. func (q {{$varNameSingular}}Query) Count() (int64, error) { - var count int64 + var count int64 - boil.SetCount(q.Query) + queries.SetSelect(q.Query, nil) + queries.SetCount(q.Query) - err := boil.ExecQueryOne(q.Query).Scan(&count) - if err != nil { - return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows") - } + err := q.Query.QueryRow().Scan(&count) + if err != nil { + return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows") + } - return count, nil + return count, nil } // Exists checks if the row exists in the table, and panics on error. func (q {{$varNameSingular}}Query) ExistsP() bool { - e, err := q.Exists() - if err != nil { - panic(boil.WrapErr(err)) - } + e, err := q.Exists() + if err != nil { + panic(boil.WrapErr(err)) + } - return e + return e } // Exists checks if the row exists in the table. func (q {{$varNameSingular}}Query) Exists() (bool, error) { - var count int64 + var count int64 - boil.SetCount(q.Query) - boil.SetLimit(q.Query, 1) + queries.SetCount(q.Query) + queries.SetLimit(q.Query, 1) - err := boil.ExecQueryOne(q.Query).Scan(&count) - if err != nil { - return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists") - } + err := q.Query.QueryRow().Scan(&count) + if err != nil { + return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists") + } - return count > 0, nil + return count > 0, nil } diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 5081f397d..b95662651 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -1,30 +1,34 @@ {{- define "relationship_to_one_helper" -}} -{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} + {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- with .Rel -}}{{/* Rel holds the text helper data, passed in through preserveDot */}} + {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} // {{.Function.Name}}G pointed to by the foreign key. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { - return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...) + return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...) } // {{.Function.Name}} pointed to by the foreign key. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) ({{$varNameSingular}}Query) { - queryMods := []qm.QueryMod{ - qm.Where("{{.ForeignTable.ColumnName}}=$1", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), - } + queryMods := []qm.QueryMod{ + qm.Where("{{.ForeignTable.ColumnName}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), + } - queryMods = append(queryMods, mods...) + queryMods = append(queryMods, mods...) - query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, "{{.ForeignTable.Name}}") + query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) + queries.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}") - return query + return query } + {{- end -}}{{/* end with */}} +{{end -}}{{/* end define */}} -{{end -}} +{{- /* Begin execution of template for one-to-one relationship */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} -{{- template "relationship_to_one_helper" $rel -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} +{{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}} {{- end -}} {{- end -}} diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index e886258ef..0cf62d8f6 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -1,46 +1,51 @@ +{{- /* Begin execution of template for many-to-one or many-to-many relationship helper */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- $table := .Table -}} - {{- range .Table.ToManyRelationships -}} - {{- $varNameSingular := .ForeignTable | singular | camelCase -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} -{{- template "relationship_to_one_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} - {{- else -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- $dot := . -}} + {{- $table := .Table -}} + {{- range .Table.ToManyRelationships -}} + {{- $varNameSingular := .ForeignTable | singular | camelCase -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- /* Begin execution of template for many-to-one relationship. */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} + {{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}} + {{- else -}} + {{- /* Begin execution of template for many-to-many relationship. */ -}} + {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} // {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { - return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...) + return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...) } // {{$rel.Function.Name}} retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} with an executor {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { - queryMods := []qm.QueryMod{ - qm.Select(`"{{id 0}}".*`), - } + queryMods := []qm.QueryMod{ + qm.Select("{{id 0 | $dot.Quotes}}.*"), + } - if len(mods) != 0 { - queryMods = append(queryMods, mods...) - } + if len(mods) != 0 { + queryMods = append(queryMods, mods...) + } - {{if .ToJoinTable -}} - queryMods = append(queryMods, - qm.InnerJoin(`"{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), - qm.Where(`"{{id 1}}"."{{.JoinLocalColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), - ) - {{else -}} - queryMods = append(queryMods, - qm.Where(`"{{id 0}}"."{{.ForeignColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), - ) - {{end}} + {{if .ToJoinTable -}} + queryMods = append(queryMods, + qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"), + qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + ) + {{else -}} + queryMods = append(queryMods, + qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + ) + {{end}} - query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, `"{{.ForeignTable}}" as "{{id 0}}"`) - return query + query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) + queries.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}") + return query } {{end -}}{{- /* if unique foreign key */ -}} {{- end -}}{{- /* range relationships */ -}} -{{- end -}}{{- /* outer if join table */ -}} +{{- end -}}{{- /* if isJoinTable */ -}} diff --git a/templates/06_relationship_to_one_eager.tpl b/templates/06_relationship_to_one_eager.tpl index 8d6a514ee..0911029a6 100644 --- a/templates/06_relationship_to_one_eager.tpl +++ b/templates/06_relationship_to_one_eager.tpl @@ -1,91 +1,93 @@ {{- define "relationship_to_one_eager_helper" -}} - {{- $varNameSingular := .Dot.Table.Name | singular | camelCase -}} - {{- $noHooks := .Dot.NoHooks -}} - {{- with .Rel -}} - {{- $arg := printf "maybe%s" .LocalTable.NameGo -}} - {{- $slice := printf "%sSlice" .LocalTable.NameGo -}} + {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} + {{- with .Rel -}} + {{- $arg := printf "maybe%s" .LocalTable.NameGo -}} + {{- $slice := printf "%sSlice" .LocalTable.NameGo -}} // Load{{.Function.Name}} allows an eager lookup of values, cached into the // loaded structs of the objects. func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { - var slice []*{{.LocalTable.NameGo}} - var object *{{.LocalTable.NameGo}} + var slice []*{{.LocalTable.NameGo}} + var object *{{.LocalTable.NameGo}} - count := 1 - if singular { - object = {{$arg}}.(*{{.LocalTable.NameGo}}) - } else { - slice = *{{$arg}}.(*{{$slice}}) - count = len(slice) - } + count := 1 + if singular { + object = {{$arg}}.(*{{.LocalTable.NameGo}}) + } else { + slice = *{{$arg}}.(*{{$slice}}) + count = len(slice) + } - args := make([]interface{}, count) - if singular { - args[0] = object.{{.LocalTable.ColumnNameGo}} - } else { - for i, obj := range slice { - args[i] = obj.{{.LocalTable.ColumnNameGo}} - } - } + args := make([]interface{}, count) + if singular { + args[0] = object.{{.LocalTable.ColumnNameGo}} + } else { + for i, obj := range slice { + args[i] = obj.{{.LocalTable.ColumnNameGo}} + } + } - query := fmt.Sprintf( - `select * from "{{.ForeignKey.ForeignTable}}" where "{{.ForeignKey.ForeignColumn}}" in (%s)`, - strmangle.Placeholders(count, 1, 1), - ) + query := fmt.Sprintf( + "select * from {{.ForeignKey.ForeignTable | $dot.SchemaTable}} where {{.ForeignKey.ForeignColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), + ) - if boil.DebugMode { - fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) - } + if boil.DebugMode { + fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) + } - results, err := e.Query(query, args...) - if err != nil { - return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}") - } - defer results.Close() + results, err := e.Query(query, args...) + if err != nil { + return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}") + } + defer results.Close() - var resultSlice []*{{.ForeignTable.NameGo}} - if err = boil.Bind(results, &resultSlice); err != nil { - return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") - } + var resultSlice []*{{.ForeignTable.NameGo}} + if err = queries.Bind(results, &resultSlice); err != nil { + return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") + } - {{if not $noHooks -}} - if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 { - for _, obj := range resultSlice { - if err := obj.doAfterSelectHooks(e); err != nil { - return err - } - } - } - {{- end}} + {{if not $dot.NoHooks -}} + if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 { + for _, obj := range resultSlice { + if err := obj.doAfterSelectHooks(e); err != nil { + return err + } + } + } + {{- end}} - if singular && len(resultSlice) != 0 { - if object.R == nil { - object.R = &{{$varNameSingular}}R{} - } - object.R.{{.Function.Name}} = resultSlice[0] - return nil - } + if singular && len(resultSlice) != 0 { + if object.R == nil { + object.R = &{{$varNameSingular}}R{} + } + object.R.{{.Function.Name}} = resultSlice[0] + return nil + } - for _, foreign := range resultSlice { - for _, local := range slice { - if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} { - if local.R == nil { - local.R = &{{$varNameSingular}}R{} - } - local.R.{{.Function.Name}} = foreign - break - } - } - } + for _, foreign := range resultSlice { + for _, local := range slice { + if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} { + if local.R == nil { + local.R = &{{$varNameSingular}}R{} + } + local.R.{{.Function.Name}} = foreign + break + } + } + } - return nil + return nil } - {{- end -}} -{{end -}} + {{- end -}}{{- /* end with */ -}} +{{end -}}{{- /* end define */ -}} + +{{- /* Begin execution of template for one-to-one eager load */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} -{{- template "relationship_to_one_eager_helper" (preserveDot $dot $rel) -}} -{{- end -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} + {{- end -}} {{end}} diff --git a/templates/07_relationship_to_many_eager.tpl b/templates/07_relationship_to_many_eager.tpl index b8276c61c..aae8c4e3c 100644 --- a/templates/07_relationship_to_many_eager.tpl +++ b/templates/07_relationship_to_many_eager.tpl @@ -1,136 +1,141 @@ +{{- /* Begin execution of template for many-to-one or many-to-many eager load */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} -{{- $dot := . -}} -{{- range .Table.ToManyRelationships -}} -{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}} - {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} -{{- else -}} - {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} - {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} - {{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} - {{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} + {{- $dot := . -}} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- /* Begin execution of template for many-to-one eager load */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} + {{- else -}} + {{- /* Begin execution of template for many-to-many eager load */ -}} + {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} + {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} + {{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} + {{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} + {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} // Load{{$txt.Function.Name}} allows an eager lookup of values, cached into the // loaded structs of the objects. func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { - var slice []*{{$txt.LocalTable.NameGo}} - var object *{{$txt.LocalTable.NameGo}} + var slice []*{{$txt.LocalTable.NameGo}} + var object *{{$txt.LocalTable.NameGo}} - count := 1 - if singular { - object = {{$arg}}.(*{{$txt.LocalTable.NameGo}}) - } else { - slice = *{{$arg}}.(*{{$slice}}) - count = len(slice) - } + count := 1 + if singular { + object = {{$arg}}.(*{{$txt.LocalTable.NameGo}}) + } else { + slice = *{{$arg}}.(*{{$slice}}) + count = len(slice) + } - args := make([]interface{}, count) - if singular { - args[0] = object.{{.Column | titleCase}} - } else { - for i, obj := range slice { - args[i] = obj.{{.Column | titleCase}} - } - } + args := make([]interface{}, count) + if singular { + args[0] = object.{{.Column | titleCase}} + } else { + for i, obj := range slice { + args[i] = obj.{{.Column | titleCase}} + } + } - {{if .ToJoinTable -}} - query := fmt.Sprintf( - `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from "{{.ForeignTable}}" as "{{id 0}}" inner join "{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, - strmangle.Placeholders(count, 1, 1), - ) - {{else -}} - query := fmt.Sprintf( - `select * from "{{.ForeignTable}}" where "{{.ForeignColumn}}" in (%s)`, - strmangle.Placeholders(count, 1, 1), - ) - {{end -}} + {{if .ToJoinTable -}} + {{- $schemaJoinTable := .JoinTable | $dot.SchemaTable -}} + query := fmt.Sprintf( + "select {{id 0 | $dot.Quotes}}.*, {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} from {{$schemaForeignTable}} as {{id 0 | $dot.Quotes}} inner join {{$schemaJoinTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}} where {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), + ) + {{else -}} + query := fmt.Sprintf( + "select * from {{$schemaForeignTable}} where {{.ForeignColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), + ) + {{end -}} - if boil.DebugMode { - fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) - } + if boil.DebugMode { + fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) + } - results, err := e.Query(query, args...) - if err != nil { - return errors.Wrap(err, "failed to eager load {{.ForeignTable}}") - } - defer results.Close() + results, err := e.Query(query, args...) + if err != nil { + return errors.Wrap(err, "failed to eager load {{.ForeignTable}}") + } + defer results.Close() - var resultSlice []*{{$txt.ForeignTable.NameGo}} - {{if .ToJoinTable -}} - {{- $foreignTable := getTable $dot.Tables .ForeignTable -}} - {{- $joinTable := getTable $dot.Tables .JoinTable -}} - {{- $localCol := $joinTable.GetColumn .JoinLocalColumn}} - var localJoinCols []{{$localCol.Type}} - for results.Next() { - one := new({{$txt.ForeignTable.NameGo}}) - var localJoinCol {{$localCol.Type}} + var resultSlice []*{{$txt.ForeignTable.NameGo}} + {{if .ToJoinTable -}} + {{- $foreignTable := getTable $dot.Tables .ForeignTable -}} + {{- $joinTable := getTable $dot.Tables .JoinTable -}} + {{- $localCol := $joinTable.GetColumn .JoinLocalColumn}} + var localJoinCols []{{$localCol.Type}} + for results.Next() { + one := new({{$txt.ForeignTable.NameGo}}) + var localJoinCol {{$localCol.Type}} - err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol) - if err = results.Err(); err != nil { - return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") - } + err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol) + if err = results.Err(); err != nil { + return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") + } - resultSlice = append(resultSlice, one) - localJoinCols = append(localJoinCols, localJoinCol) - } + resultSlice = append(resultSlice, one) + localJoinCols = append(localJoinCols, localJoinCol) + } - if err = results.Err(); err != nil { - return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") - } - {{else -}} - if err = boil.Bind(results, &resultSlice); err != nil { - return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") - } - {{end}} + if err = results.Err(); err != nil { + return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") + } + {{else -}} + if err = queries.Bind(results, &resultSlice); err != nil { + return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") + } + {{end}} - {{if not $dot.NoHooks -}} - if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 { - for _, obj := range resultSlice { - if err := obj.doAfterSelectHooks(e); err != nil { - return err - } - } - } + {{if not $dot.NoHooks -}} + if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 { + for _, obj := range resultSlice { + if err := obj.doAfterSelectHooks(e); err != nil { + return err + } + } + } - {{- end}} - if singular { - if object.R == nil { - object.R = &{{$varNameSingular}}R{} - } - object.R.{{$txt.Function.Name}} = resultSlice - return nil - } + {{- end}} + if singular { + if object.R == nil { + object.R = &{{$varNameSingular}}R{} + } + object.R.{{$txt.Function.Name}} = resultSlice + return nil + } - {{if .ToJoinTable -}} - for i, foreign := range resultSlice { - localJoinCol := localJoinCols[i] - for _, local := range slice { - if local.{{$txt.Function.LocalAssignment}} == localJoinCol { - if local.R == nil { - local.R = &{{$varNameSingular}}R{} - } - local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) - break - } - } - } - {{else -}} - for _, foreign := range resultSlice { - for _, local := range slice { - if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} { - if local.R == nil { - local.R = &{{$varNameSingular}}R{} - } - local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) - break - } - } - } - {{end}} + {{if .ToJoinTable -}} + for i, foreign := range resultSlice { + localJoinCol := localJoinCols[i] + for _, local := range slice { + if local.{{$txt.Function.LocalAssignment}} == localJoinCol { + if local.R == nil { + local.R = &{{$varNameSingular}}R{} + } + local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) + break + } + } + } + {{else -}} + for _, foreign := range resultSlice { + for _, local := range slice { + if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} { + if local.R == nil { + local.R = &{{$varNameSingular}}R{} + } + local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) + break + } + } + } + {{end}} - return nil + return nil } {{end -}}{{/* if ForeignColumnUnique */}} {{- end -}}{{/* range tomany */}} -{{- end -}}{{/* if isjointable */}} +{{- end -}}{{/* if IsJoinTable */}} diff --git a/templates/08_relationship_to_one_setops.tpl b/templates/08_relationship_to_one_setops.tpl index 8f8457fe1..e4e93d6fa 100644 --- a/templates/08_relationship_to_one_setops.tpl +++ b/templates/08_relationship_to_one_setops.tpl @@ -1,101 +1,105 @@ {{- define "relationship_to_one_setops_helper" -}} -{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} -{{- $localNameSingular := .ForeignKey.Table | singular | camelCase}} - + {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- with .Rel -}} + {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} + {{- $localNameSingular := .ForeignKey.Table | singular | camelCase}} // Set{{.Function.Name}} of the {{.ForeignKey.Table | singular}} to the related item. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to related. // Adds {{.Function.Receiver}} to related.R.{{.Function.ForeignName}}. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Set{{.Function.Name}}(exec boil.Executor, insert bool, related *{{.ForeignTable.NameGo}}) error { - var err error - if insert { - if err = related.Insert(exec); err != nil { - return errors.Wrap(err, "failed to insert into foreign table") - } - } + var err error + if insert { + if err = related.Insert(exec); err != nil { + return errors.Wrap(err, "failed to insert into foreign table") + } + } - oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} - {{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}} - if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal - return errors.Wrap(err, "failed to update local table") - } + oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} + {{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}} + if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal + return errors.Wrap(err, "failed to update local table") + } - if {{.Function.Receiver}}.R == nil { - {{.Function.Receiver}}.R = &{{$localNameSingular}}R{ - {{.Function.Name}}: related, - } - } else { - {{.Function.Receiver}}.R.{{.Function.Name}} = related - } + if {{.Function.Receiver}}.R == nil { + {{.Function.Receiver}}.R = &{{$localNameSingular}}R{ + {{.Function.Name}}: related, + } + } else { + {{.Function.Receiver}}.R.{{.Function.Name}} = related + } - {{if (or .ForeignKey.Unique .Function.OneToOne) -}} - if related.R == nil { - related.R = &{{$varNameSingular}}R{ - {{.Function.ForeignName}}: {{.Function.Receiver}}, - } - } else { - related.R.{{.Function.ForeignName}} = {{.Function.Receiver}} - } - {{else -}} - if related.R == nil { - related.R = &{{$varNameSingular}}R{ - {{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}}, - } - } else { - related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}}) - } - {{end -}} + {{if (or .ForeignKey.Unique .Function.OneToOne) -}} + if related.R == nil { + related.R = &{{$varNameSingular}}R{ + {{.Function.ForeignName}}: {{.Function.Receiver}}, + } + } else { + related.R.{{.Function.ForeignName}} = {{.Function.Receiver}} + } + {{else -}} + if related.R == nil { + related.R = &{{$varNameSingular}}R{ + {{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}}, + } + } else { + related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}}) + } + {{end -}} - {{if .ForeignKey.Nullable}} - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true - {{end -}} - return nil + {{if .ForeignKey.Nullable}} + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true + {{end -}} + return nil } -{{- if .ForeignKey.Nullable}} + {{- if .ForeignKey.Nullable}} // Remove{{.Function.Name}} relationship. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to nil. // Removes {{.Function.Receiver}} from all passed in related items' relationships struct (Optional). func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Remove{{.Function.Name}}(exec boil.Executor, related *{{.ForeignTable.NameGo}}) error { - var err error + var err error - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false - if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true - return errors.Wrap(err, "failed to update local table") - } + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false + if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true + return errors.Wrap(err, "failed to update local table") + } - {{.Function.Receiver}}.R.{{.Function.Name}} = nil - if related == nil || related.R == nil { - return nil - } + {{.Function.Receiver}}.R.{{.Function.Name}} = nil + if related == nil || related.R == nil { + return nil + } - {{if .ForeignKey.Unique -}} - related.R.{{.Function.ForeignName}} = nil - {{else -}} - for i, ri := range related.R.{{.Function.ForeignName}} { - if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} { - continue - } + {{if .ForeignKey.Unique -}} + related.R.{{.Function.ForeignName}} = nil + {{else -}} + for i, ri := range related.R.{{.Function.ForeignName}} { + if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} { + continue + } - ln := len(related.R.{{.Function.ForeignName}}) - if ln > 1 && i < ln-1 { - related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1] - } - related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1] - break - } - {{end -}} + ln := len(related.R.{{.Function.ForeignName}}) + if ln > 1 && i < ln-1 { + related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1] + } + related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1] + break + } + {{end -}} - return nil + return nil } -{{end -}} -{{- end -}} + {{- end -}}{{/* if foreignkey nullable */}} + {{end -}}{{/* end with */}} +{{- end -}}{{/* end define */}} + +{{- /* Begin execution of template for one-to-one setops */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} -{{- template "relationship_to_one_setops_helper" $rel -}} -{{- end -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} + {{- end -}} {{- end -}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index d2d292202..b42842874 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -1,91 +1,93 @@ +{{- /* Begin execution of template for many-to-one or many-to-many setops */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- $table := .Table -}} - {{- range .Table.ToManyRelationships -}} - {{- $varNameSingular := .ForeignTable | singular | camelCase -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} -{{- template "relationship_to_one_setops_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} - {{- else -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} - {{- $localNameSingular := .Table | singular | camelCase -}} - {{- $foreignNameSingular := .ForeignTable | singular | camelCase}} - + {{- $dot := . -}} + {{- $table := .Table -}} + {{- range .Table.ToManyRelationships -}} + {{- $varNameSingular := .ForeignTable | singular | camelCase -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- /* Begin execution of template for many-to-one setops */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} + {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} + {{- else -}} + {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- $localNameSingular := .Table | singular | camelCase -}} + {{- $foreignNameSingular := .ForeignTable | singular | camelCase}} // Add{{$rel.Function.Name}} adds the given related objects to the existing relationships // of the {{$table.Name | singular}}, optionally inserting them as new records. // Appends related to {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}. // Sets related.R.{{$rel.Function.ForeignName}} appropriately. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { - var err error - for _, rel := range related { - {{if not .ToJoinTable -}} - rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} - {{if .ForeignColumnNullable -}} - rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true - {{end -}} - {{end -}} - if insert { - if err = rel.Insert(exec); err != nil { - return errors.Wrap(err, "failed to insert into foreign table") - } - }{{if not .ToJoinTable}} else { - if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { - return errors.Wrap(err, "failed to update foreign table") - } - }{{end -}} - } - - {{if .ToJoinTable -}} - for _, rel := range related { - query := `insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, values) - } - - _, err = exec.Exec(query, values...) - if err != nil { - return errors.Wrap(err, "failed to insert into join table") - } - } - {{end -}} - - if {{$rel.Function.Receiver}}.R == nil { - {{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{ - {{$rel.Function.Name}}: related, - } - } else { - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...) - } - - {{if .ToJoinTable -}} - for _, rel := range related { - if rel.R == nil { - rel.R = &{{$foreignNameSingular}}R{ - {{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}}, - } - } else { - rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}}) - } - } - {{else -}} - for _, rel := range related { - if rel.R == nil { - rel.R = &{{$foreignNameSingular}}R{ - {{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}}, - } - } else { - rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}} - } - } - {{end -}} - - return nil + var err error + for _, rel := range related { + {{if not .ToJoinTable -}} + rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} + {{if .ForeignColumnNullable -}} + rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true + {{end -}} + {{end -}} + if insert { + if err = rel.Insert(exec); err != nil { + return errors.Wrap(err, "failed to insert into foreign table") + } + }{{if not .ToJoinTable}} else { + if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { + return errors.Wrap(err, "failed to update foreign table") + } + }{{end -}} + } + + {{if .ToJoinTable -}} + for _, rel := range related { + query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}" + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, values) + } + + _, err = exec.Exec(query, values...) + if err != nil { + return errors.Wrap(err, "failed to insert into join table") + } + } + {{end -}} + + if {{$rel.Function.Receiver}}.R == nil { + {{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{ + {{$rel.Function.Name}}: related, + } + } else { + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...) + } + + {{if .ToJoinTable -}} + for _, rel := range related { + if rel.R == nil { + rel.R = &{{$foreignNameSingular}}R{ + {{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}}, + } + } else { + rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}}) + } + } + {{else -}} + for _, rel := range related { + if rel.R == nil { + rel.R = &{{$foreignNameSingular}}R{ + {{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}}, + } + } else { + rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}} + } + } + {{end -}} + + return nil } -{{- if (or .ForeignColumnNullable .ToJoinTable)}} + {{- if (or .ForeignColumnNullable .ToJoinTable)}} // Set{{$rel.Function.Name}} removes all previously related items of the // {{$table.Name | singular}} replacing them completely with the passed // in related items, optionally inserting them as new records. @@ -93,126 +95,126 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function // Replaces {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} with related. // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { - {{if .ToJoinTable -}} - query := `delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1` - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} - {{else -}} - query := `update "{{.ForeignTable}}" set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1` - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} - {{end -}} - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, values) - } - - _, err := exec.Exec(query, values...) - if err != nil { - return errors.Wrap(err, "failed to remove relationships before set") - } - - {{if .ToJoinTable -}} - remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil - {{else -}} - if {{$rel.Function.Receiver}}.R != nil { - for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { - rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false - if rel.R == nil { - continue - } - - rel.R.{{$rel.Function.ForeignName}} = nil - } - - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil - } - {{end -}} - - return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...) + {{if .ToJoinTable -}} + query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} + {{else -}} + query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} + {{end -}} + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, values) + } + + _, err := exec.Exec(query, values...) + if err != nil { + return errors.Wrap(err, "failed to remove relationships before set") + } + + {{if .ToJoinTable -}} + remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil + {{else -}} + if {{$rel.Function.Receiver}}.R != nil { + for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { + rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false + if rel.R == nil { + continue + } + + rel.R.{{$rel.Function.ForeignName}} = nil + } + + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil + } + {{end -}} + + return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...) } // Remove{{$rel.Function.Name}} relationships from objects passed in. // Removes related items from R.{{$rel.Function.Name}} (uses pointer comparison, removal does not keep order) // Sets related.R.{{$rel.Function.ForeignName}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Function.Name}}(exec boil.Executor, related ...*{{$rel.ForeignTable.NameGo}}) error { - var err error - {{if .ToJoinTable -}} - query := fmt.Sprintf( - `delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, - strmangle.Placeholders(len(related), 1, 1), - ) - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, values) - } - - _, err = exec.Exec(query, values...) - if err != nil { - return errors.Wrap(err, "failed to remove relationships before set") - } - {{else -}} - for _, rel := range related { - rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false - {{if not .ToJoinTable -}} - if rel.R != nil { - rel.R.{{$rel.Function.ForeignName}} = nil - } - {{end -}} - if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { - return err - } - } - {{end -}} - - {{if .ToJoinTable -}} - remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) - {{end -}} - if {{$rel.Function.Receiver}}.R == nil { - return nil - } - - for _, rel := range related { - for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { - if rel != ri { - continue - } - - ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}) - if ln > 1 && i < ln-1 { - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1] - } - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1] - break - } - } - - return nil + var err error + {{if .ToJoinTable -}} + query := fmt.Sprintf( + "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}} and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1), + ) + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, values) + } + + _, err = exec.Exec(query, values...) + if err != nil { + return errors.Wrap(err, "failed to remove relationships before set") + } + {{else -}} + for _, rel := range related { + rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false + {{if not .ToJoinTable -}} + if rel.R != nil { + rel.R.{{$rel.Function.ForeignName}} = nil + } + {{end -}} + if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { + return err + } + } + {{end -}} + + {{if .ToJoinTable -}} + remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) + {{end -}} + if {{$rel.Function.Receiver}}.R == nil { + return nil + } + + for _, rel := range related { + for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { + if rel != ri { + continue + } + + ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}) + if ln > 1 && i < ln-1 { + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1] + } + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1] + break + } + } + + return nil } -{{if .ToJoinTable -}} + {{if .ToJoinTable -}} func remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}, related []*{{$rel.ForeignTable.NameGo}}) { - for _, rel := range related { - if rel.R == nil { - continue - } - for i, ri := range rel.R.{{$rel.Function.ForeignName}} { - if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} { - continue - } - - ln := len(rel.R.{{$rel.Function.ForeignName}}) - if ln > 1 && i < ln-1 { - rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1] - } - rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1] - break - } - } + for _, rel := range related { + if rel.R == nil { + continue + } + for i, ri := range rel.R.{{$rel.Function.ForeignName}} { + if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} { + continue + } + + ln := len(rel.R.{{$rel.Function.ForeignName}}) + if ln > 1 && i < ln-1 { + rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1] + } + rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1] + break + } + } } -{{end -}}{{- /* if join table */ -}} -{{- end -}}{{- /* if nullable foreign key */ -}} -{{- end -}}{{- /* if unique foreign key */ -}} -{{- end -}}{{- /* range relationships */ -}} -{{- end -}}{{- /* outer if join table */ -}} + {{end -}}{{- /* if ToJoinTable */ -}} + {{- end -}}{{- /* if nullable foreign key */ -}} + {{- end -}}{{- /* if unique foreign key */ -}} + {{- end -}}{{- /* range relationships */ -}} +{{- end -}}{{- /* if IsJoinTable */ -}} diff --git a/templates/10_all.tpl b/templates/10_all.tpl index a9470b18a..e1ed9ddd5 100644 --- a/templates/10_all.tpl +++ b/templates/10_all.tpl @@ -3,11 +3,11 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} // {{$tableNamePlural}}G retrieves all records. func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { - return {{$tableNamePlural}}(boil.GetDB(), mods...) + return {{$tableNamePlural}}(boil.GetDB(), mods...) } // {{$tableNamePlural}} retrieves all the records using an executor. func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { - mods = append(mods, qm.From("{{.Table.Name}}")) - return {{$varNameSingular}}Query{NewQuery(exec, mods...)} + mods = append(mods, qm.From("{{.Table.Name | .SchemaTable}}")) + return {{$varNameSingular}}Query{NewQuery(exec, mods...)} } diff --git a/templates/11_find.tpl b/templates/11_find.tpl index 987afcf99..5fb5ffba9 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -4,53 +4,53 @@ {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} -// {{$tableNameSingular}}FindG retrieves a single record by ID. +// Find{{$tableNameSingular}}G retrieves a single record by ID. func Find{{$tableNameSingular}}G({{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { - return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) + return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) } -// {{$tableNameSingular}}FindGP retrieves a single record by ID, and panics on error. +// Find{{$tableNameSingular}}GP retrieves a single record by ID, and panics on error. func Find{{$tableNameSingular}}GP({{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { - retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) - if err != nil { - panic(boil.WrapErr(err)) - } + retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) + if err != nil { + panic(boil.WrapErr(err)) + } - return retobj + return retobj } -// {{$tableNameSingular}}Find retrieves a single record by ID with an executor. +// Find{{$tableNameSingular}} retrieves a single record by ID with an executor. // If selectCols is empty Find will return all columns. func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { - {{$varNameSingular}}Obj := &{{$tableNameSingular}}{} - - sel := "*" - if len(selectCols) > 0 { - sel = strings.Join(strmangle.IdentQuoteSlice(selectCols), ",") - } - query := fmt.Sprintf( - `select %s from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}}`, sel, - ) - - q := boil.SQL(exec, query, {{$pkNames | join ", "}}) - - err := q.Bind({{$varNameSingular}}Obj) - if err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, sql.ErrNoRows - } - return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}") - } - - return {{$varNameSingular}}Obj, nil + {{$varNameSingular}}Obj := &{{$tableNameSingular}}{} + + sel := "*" + if len(selectCols) > 0 { + sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") + } + query := fmt.Sprintf( + "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel, + ) + + q := queries.Raw(exec, query, {{$pkNames | join ", "}}) + + err := q.Bind({{$varNameSingular}}Obj) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, sql.ErrNoRows + } + return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}") + } + + return {{$varNameSingular}}Obj, nil } -// {{$tableNameSingular}}FindP retrieves a single record by ID with an executor, and panics on error. +// Find{{$tableNameSingular}}P retrieves a single record by ID with an executor, and panics on error. func Find{{$tableNameSingular}}P(exec boil.Executor, {{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { - retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...) - if err != nil { - panic(boil.WrapErr(err)) - } + retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...) + if err != nil { + panic(boil.WrapErr(err)) + } - return retobj + return retobj } diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index aa7958dca..c22c18245 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -1,24 +1,25 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // InsertG a single record. See Insert for whitelist behavior description. func (o *{{$tableNameSingular}}) InsertG(whitelist ... string) error { - return o.Insert(boil.GetDB(), whitelist...) + return o.Insert(boil.GetDB(), whitelist...) } // InsertGP a single record, and panics on error. See Insert for whitelist // behavior description. func (o *{{$tableNameSingular}}) InsertGP(whitelist ... string) { - if err := o.Insert(boil.GetDB(), whitelist...); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Insert(boil.GetDB(), whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // InsertP a single record using an executor, and panics on error. See Insert // for whitelist behavior description. func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... string) { - if err := o.Insert(exec, whitelist...); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Insert(exec, whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // Insert a single record using an executor. @@ -27,115 +28,131 @@ func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... strin // - All columns without a default value are included (i.e. name, age) // - All columns with a default, but non-zero are included (i.e. health = 75) func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") - } - - var err error - {{- template "timestamp_insert_helper" . }} - - {{if not .NoHooks -}} - if err := o.doBeforeInsertHooks(exec); err != nil { - return err - } - {{- end}} - - nzDefaults := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) - - key := makeCacheKey(whitelist, nzDefaults) - {{$varNameSingular}}InsertCacheMut.RLock() - cache, cached := {{$varNameSingular}}InsertCache[key] - {{$varNameSingular}}InsertCacheMut.RUnlock() - - if !cached { - wl, returnColumns := strmangle.InsertColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}ColumnsWithDefault, - {{$varNameSingular}}ColumnsWithoutDefault, - nzDefaults, - whitelist, - ) - - cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl) - if err != nil { - return err - } - cache.retMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns) - if err != nil { - return err - } - cache.query = fmt.Sprintf(`INSERT INTO {{.Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(len(wl), 1, 1)) - - if len(cache.retMapping) != 0 { - {{if .UseLastInsertID -}} - cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)) - {{else -}} - cache.query += fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) - {{end -}} - } - } - - value := reflect.Indirect(reflect.ValueOf(o)) - vals := boil.ValuesFromMapping(value, cache.valueMapping) - {{if .UseLastInsertID}} - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.query) - fmt.Fprintln(boil.DebugWriter, vals) - } - - result, err := exec.Exec(ins, vals...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") - } - - if len(cache.retMapping) == 0 { - {{if not .NoHooks -}} - return o.doAfterInsertHooks(exec) - {{else -}} - return nil - {{end -}} - } - - lastID, err := result.LastInsertId() - if err != nil || lastID == 0 || len({{$varNameSingular}}PrimaryKeyColumns) != 1 { - return ErrSyncFail - } - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.retQuery) - fmt.Fprintln(boil.DebugWriter, lastID) - } - - err = exec.QueryRow(cache.retQuery, lastID).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") - } - {{else}} - if len(cache.retMapping) != 0 { - err = exec.QueryRow(cache.query, vals...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) - } else { - _, err = exec.Exec(cache.query, vals...) - } - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.query) - fmt.Fprintln(boil.DebugWriter, vals) - } - - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") - } - {{end}} - - if !cached { - {{$varNameSingular}}InsertCacheMut.Lock() - {{$varNameSingular}}InsertCache[key] = cache - {{$varNameSingular}}InsertCacheMut.Unlock() - } - - {{if not .NoHooks -}} - return o.doAfterInsertHooks(exec) - {{- else -}} - return nil - {{- end}} + if o == nil { + return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") + } + + var err error + {{- template "timestamp_insert_helper" . }} + + {{if not .NoHooks -}} + if err := o.doBeforeInsertHooks(exec); err != nil { + return err + } + {{- end}} + + nzDefaults := queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) + + key := makeCacheKey(whitelist, nzDefaults) + {{$varNameSingular}}InsertCacheMut.RLock() + cache, cached := {{$varNameSingular}}InsertCache[key] + {{$varNameSingular}}InsertCacheMut.RUnlock() + + if !cached { + wl, returnColumns := strmangle.InsertColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}ColumnsWithoutDefault, + nzDefaults, + whitelist, + ) + + cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl) + if err != nil { + return err + } + cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns) + if err != nil { + return err + } + cache.query = fmt.Sprintf("INSERT INTO {{$schemaTable}} ({{.LQ}}%s{{.RQ}}) VALUES (%s)", strings.Join(wl, "{{.LQ}},{{.RQ}}"), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1)) + + if len(cache.retMapping) != 0 { + {{if .UseLastInsertID -}} + cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns)) + {{else -}} + cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}")) + {{end -}} + } + } + + value := reflect.Indirect(reflect.ValueOf(o)) + vals := queries.ValuesFromMapping(value, cache.valueMapping) + {{if .UseLastInsertID}} + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, vals) + } + + result, err := exec.Exec(cache.query, vals...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") + } + + var lastID int64 + var identifierCols []interface{} + + if len(cache.retMapping) == 0 { + goto CacheNoHooks + } + + lastID, err = result.LastInsertId() + if err != nil { + return ErrSyncFail + } + + if lastID != 0 { + {{- $colName := index .Table.PKey.Columns 0 -}} + {{- $col := .Table.GetColumn $colName -}} + o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID) + identifierCols = []interface{}{lastID} + } else { + identifierCols = []interface{}{ + {{range .Table.PKey.Columns -}} + o.{{. | singular | titleCase}}, + {{end -}} + } + } + + if lastID != 0 && len(cache.retMapping) == 1 { + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.retQuery) + fmt.Fprintln(boil.DebugWriter, identifierCols...) + } + + err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") + } + } + {{else}} + if len(cache.retMapping) != 0 { + err = exec.QueryRow(cache.query, vals...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...) + } else { + _, err = exec.Exec(cache.query, vals...) + } + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, vals) + } + + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") + } + {{end}} +{{if .UseLastInsertID -}} +CacheNoHooks: +{{- end}} + if !cached { + {{$varNameSingular}}InsertCacheMut.Lock() + {{$varNameSingular}}InsertCache[key] = cache + {{$varNameSingular}}InsertCacheMut.Unlock() + } + + {{if not .NoHooks -}} + return o.doAfterInsertHooks(exec) + {{- else -}} + return nil + {{- end}} } diff --git a/templates/13_update.tpl b/templates/13_update.tpl index 8483d113e..7581b4566 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -3,28 +3,29 @@ {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // UpdateG a single {{$tableNameSingular}} record. See Update for // whitelist behavior description. func (o *{{$tableNameSingular}}) UpdateG(whitelist ...string) error { - return o.Update(boil.GetDB(), whitelist...) + return o.Update(boil.GetDB(), whitelist...) } // UpdateGP a single {{$tableNameSingular}} record. // UpdateGP takes a whitelist of column names that should be updated. // Panics on error. See Update for whitelist behavior description. func (o *{{$tableNameSingular}}) UpdateGP(whitelist ...string) { - if err := o.Update(boil.GetDB(), whitelist...); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Update(boil.GetDB(), whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateP uses an executor to update the {{$tableNameSingular}}, and panics on error. // See Update for whitelist behavior description. func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... string) { - err := o.Update(exec, whitelist...) - if err != nil { - panic(boil.WrapErr(err)) - } + err := o.Update(exec, whitelist...) + if err != nil { + panic(boil.WrapErr(err)) + } } // Update uses an executor to update the {{$tableNameSingular}}. @@ -35,146 +36,147 @@ func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... strin // Update does not automatically update the record in case of default values. Use .Reload() // to refresh the records. func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string) error { - {{- template "timestamp_update_helper" . -}} - - var err error - {{if not .NoHooks -}} - if err = o.doBeforeUpdateHooks(exec); err != nil { - return err - } - {{end -}} - - key := makeCacheKey(whitelist, nil) - {{$varNameSingular}}UpdateCacheMut.RLock() - cache, cached := {{$varNameSingular}}UpdateCache[key] - {{$varNameSingular}}UpdateCacheMut.RUnlock() - - if !cached { - wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - - cache.query = fmt.Sprintf(`UPDATE "{{.Table.Name}}" SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) - cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) - if err != nil { - return err - } - } - - if len(cache.valueMapping) == 0 { - return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") - } - - values := boil.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping) - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.query) - fmt.Fprintln(boil.DebugWriter, values) - } - - result, err := exec.Exec(cache.query, values...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row") - } - - if r, err := result.RowsAffected(); err == nil && r != 1 { - return errors.Errorf("failed to update single row, updated %d rows", r) - } - - if !cached { - {{$varNameSingular}}UpdateCacheMut.Lock() - {{$varNameSingular}}UpdateCache[key] = cache - {{$varNameSingular}}UpdateCacheMut.Unlock() - } - - {{if not .NoHooks -}} - return o.doAfterUpdateHooks(exec) - {{- else -}} - return nil - {{- end}} + {{- template "timestamp_update_helper" . -}} + + var err error + {{if not .NoHooks -}} + if err = o.doBeforeUpdateHooks(exec); err != nil { + return err + } + {{end -}} + + key := makeCacheKey(whitelist, nil) + {{$varNameSingular}}UpdateCacheMut.RLock() + cache, cached := {{$varNameSingular}}UpdateCache[key] + {{$varNameSingular}}UpdateCacheMut.RUnlock() + + if !cached { + wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) + + cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", + strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl), + strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns), + ) + cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) + if err != nil { + return err + } + } + + if len(cache.valueMapping) == 0 { + return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") + } + + values := queries.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping) + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, values) + } + + result, err := exec.Exec(cache.query, values...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row") + } + + if r, err := result.RowsAffected(); err == nil && r != 1 { + return errors.Errorf("failed to update single row, updated %d rows", r) + } + + if !cached { + {{$varNameSingular}}UpdateCacheMut.Lock() + {{$varNameSingular}}UpdateCache[key] = cache + {{$varNameSingular}}UpdateCacheMut.Unlock() + } + + {{if not .NoHooks -}} + return o.doAfterUpdateHooks(exec) + {{- else -}} + return nil + {{- end}} } // UpdateAllP updates all rows with matching column names, and panics on error. func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { - if err := q.UpdateAll(cols); err != nil { - panic(boil.WrapErr(err)) - } + if err := q.UpdateAll(cols); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateAll updates all rows with the specified column values. func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { - boil.SetUpdate(q.Query, cols) + queries.SetUpdate(q.Query, cols) - _, err := boil.ExecQuery(q.Query) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}") - } + _, err := q.Query.Exec() + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}") + } - return nil + return nil } // UpdateAllG updates all rows with the specified column values. func (o {{$tableNameSingular}}Slice) UpdateAllG(cols M) error { - return o.UpdateAll(boil.GetDB(), cols) + return o.UpdateAll(boil.GetDB(), cols) } // UpdateAllGP updates all rows with the specified column values, and panics on error. func (o {{$tableNameSingular}}Slice) UpdateAllGP(cols M) { - if err := o.UpdateAll(boil.GetDB(), cols); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.UpdateAll(boil.GetDB(), cols); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateAllP updates all rows with the specified column values, and panics on error. func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) { - if err := o.UpdateAll(exec, cols); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.UpdateAll(exec, cols); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateAll updates all rows with the specified column values, using an executor. func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error { - ln := int64(len(o)) - if ln == 0 { - return nil - } - - if len(cols) == 0 { - return errors.New("{{.PkgName}}: update all requires at least one column argument") - } - - colNames := make([]string, len(cols)) - args := make([]interface{}, len(cols)) - - i := 0 - for name, value := range cols { - colNames[i] = strmangle.IdentQuote(name) - args[i] = value - i++ - } - - // Append all of the primary key values for each column - args = append(args, o.inPrimaryKeyArgs()...) - - sql := fmt.Sprintf( - `UPDATE {{.Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`, - strings.Join(colNames, ", "), - strmangle.Placeholders(len(colNames), 1, 1), - strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), - ) - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, args...) - } - - result, err := exec.Exec(sql, args...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice") - } - - if r, err := result.RowsAffected(); err == nil && r != ln { - return errors.Errorf("failed to update %d rows, only affected %d", ln, r) - } - - return nil + ln := int64(len(o)) + if ln == 0 { + return nil + } + + if len(cols) == 0 { + return errors.New("{{.PkgName}}: update all requires at least one column argument") + } + + colNames := make([]string, len(cols)) + args := make([]interface{}, len(cols)) + + i := 0 + for name, value := range cols { + colNames[i] = name + args[i] = value + i++ + } + + // Append all of the primary key values for each column + args = append(args, o.inPrimaryKeyArgs()...) + + sql := fmt.Sprintf( + "UPDATE {{$schemaTable}} SET %s WHERE ({{.LQ}}{{.Table.PKey.Columns | join (printf "%s,%s" .LQ .RQ)}}{{.RQ}}) IN (%s)", + strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, colNames), + strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), + ) + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, args...) + } + + result, err := exec.Exec(sql, args...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice") + } + + if r, err := result.RowsAffected(); err == nil && r != ln { + return errors.Errorf("failed to update %d rows, only affected %d", ln, r) + } + + return nil } diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index 55a50c126..6505f9d5b 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -1,85 +1,188 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // UpsertG attempts an insert, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) UpsertG(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { - return o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...) +func (o *{{$tableNameSingular}}) UpsertG({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { + return o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...) } // UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error. -func (o *{{$tableNameSingular}}) UpsertGP(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) { - if err := o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil { - panic(boil.WrapErr(err)) - } +func (o *{{$tableNameSingular}}) UpsertGP({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // UpsertP attempts an insert using an executor, and does an update or ignore on conflict. // UpsertP panics on error. -func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) { - if err := o.Upsert(exec, updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil { - panic(boil.WrapErr(err)) - } +func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(exec, {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // Upsert attempts an insert using an executor, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") - } - - {{- template "timestamp_upsert_helper" . }} - - {{if not .NoHooks -}} - if err := o.doBeforeUpsertHooks(exec); err != nil { - return err - } - {{- end}} - - var err error - var ret []string - whitelist, ret = strmangle.InsertColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}ColumnsWithDefault, - {{$varNameSingular}}ColumnsWithoutDefault, - boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), - whitelist, - ) - update := strmangle.UpdateColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}PrimaryKeyColumns, - updateColumns, - ) - conflict := conflictColumns - if len(conflict) == 0 { - conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) - copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) - } - - query := boil.BuildUpsertQuery("{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...)) - } - - {{- if .UseLastInsertID}} - return errors.New("don't know how to do this yet") - {{- else}} - if len(ret) != 0 { - err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) - } else { - _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) - } - {{- end}} - - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") - } - - {{if not .NoHooks -}} - if err := o.doAfterUpsertHooks(exec); err != nil { - return err - } - {{- end}} - - return nil +func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { + if o == nil { + return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") + } + + {{- template "timestamp_upsert_helper" . }} + + {{if not .NoHooks -}} + if err := o.doBeforeUpsertHooks(exec); err != nil { + return err + } + {{- end}} + + // Build cache key in-line uglily - mysql vs postgres problems + buf := strmangle.GetBuffer() + {{if ne .DriverName "mysql" -}} + if updateOnConflict { + buf.WriteByte('t') + } else { + buf.WriteByte('f') + } + buf.WriteByte('.') + for _, c := range conflictColumns { + buf.WriteString(c) + } + buf.WriteByte('.') + {{end -}} + for _, c := range updateColumns { + buf.WriteString(c) + } + buf.WriteByte('.') + for _, c := range whitelist { + buf.WriteString(c) + } + key := buf.String() + strmangle.PutBuffer(buf) + + {{$varNameSingular}}UpsertCacheMut.RLock() + cache, cached := {{$varNameSingular}}UpsertCache[key] + {{$varNameSingular}}UpsertCacheMut.RUnlock() + + var err error + + if !cached { + var ret []string + whitelist, ret = strmangle.InsertColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}ColumnsWithoutDefault, + queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), + whitelist, + ) + update := strmangle.UpdateColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + updateColumns, + ) + + {{if ne .DriverName "mysql" -}} + var conflict []string + if len(conflictColumns) == 0 { + conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) + copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) + } + cache.query = queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist) + {{- else -}} + cache.query = queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) + cache.retQuery = fmt.Sprintf( + "SELECT %s FROM {{.LQ}}{{.Table.Name}}{{.RQ}} WHERE {{whereClause .LQ .RQ 0 .Table.PKey.Columns}}", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, ret), ","), + ) + {{- end}} + + cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, whitelist) + if err != nil { + return err + } + if len(ret) != 0 { + cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, ret) + if err != nil { + return err + } + } + } + + value := reflect.Indirect(reflect.ValueOf(o)) + values := queries.ValuesFromMapping(value, cache.valueMapping) + var returns []interface{} + if len(cache.retMapping) != 0 { + returns = queries.PtrsFromMapping(value, cache.retMapping) + } + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, values) + } + + {{- if .UseLastInsertID}} + result, err := exec.Exec(cache.query, values...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") + } + + if len(cache.retMapping) == 0 { + {{if not .NoHooks -}} + return o.doAfterUpsertHooks(exec) + {{else -}} + return nil + {{end -}} + } + + lastID, err := result.LastInsertId() + if err != nil { + return ErrSyncFail + } + + var identifierCols []interface{} + if lastID != 0 { + {{- $colName := index .Table.PKey.Columns 0 -}} + {{- $col := .Table.GetColumn $colName -}} + o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID) + identifierCols = []interface{}{lastID} + } else { + identifierCols = []interface{}{ + {{range .Table.PKey.Columns -}} + o.{{. | singular | titleCase}}, + {{end -}} + } + } + + if lastID != 0 && len(cache.retMapping) == 1 { + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.retQuery) + fmt.Fprintln(boil.DebugWriter, identifierCols...) + } + + err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(returns...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") + } + } + {{- else}} + if len(cache.retMapping) != 0 { + err = exec.QueryRow(cache.query, values...).Scan(returns...) + } else { + _, err = exec.Exec(cache.query, values...) + } + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") + } + {{- end}} + + if !cached { + {{$varNameSingular}}UpsertCacheMut.Lock() + {{$varNameSingular}}UpsertCache[key] = cache + {{$varNameSingular}}UpsertCacheMut.Unlock() + } + + {{if not .NoHooks -}} + return o.doAfterUpsertHooks(exec) + {{- else -}} + return nil + {{- end}} } diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index a3c8b5145..18205bdd1 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -1,161 +1,162 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // DeleteP deletes a single {{$tableNameSingular}} record with an executor. // DeleteP will match against the primary key column to find the record to delete. // Panics on error. func (o *{{$tableNameSingular}}) DeleteP(exec boil.Executor) { - if err := o.Delete(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Delete(exec); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteG deletes a single {{$tableNameSingular}} record. // DeleteG will match against the primary key column to find the record to delete. func (o *{{$tableNameSingular}}) DeleteG() error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion") - } + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion") + } - return o.Delete(boil.GetDB()) + return o.Delete(boil.GetDB()) } // DeleteGP deletes a single {{$tableNameSingular}} record. // DeleteGP will match against the primary key column to find the record to delete. // Panics on error. func (o *{{$tableNameSingular}}) DeleteGP() { - if err := o.DeleteG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.DeleteG(); err != nil { + panic(boil.WrapErr(err)) + } } // Delete deletes a single {{$tableNameSingular}} record with an executor. // Delete will match against the primary key column to find the record to delete. func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete") - } + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete") + } - {{if not .NoHooks -}} - if err := o.doBeforeDeleteHooks(exec); err != nil { - return err - } - {{- end}} + {{if not .NoHooks -}} + if err := o.doBeforeDeleteHooks(exec); err != nil { + return err + } + {{- end}} - args := o.inPrimaryKeyArgs() + args := o.inPrimaryKeyArgs() - sql := `DELETE FROM {{.Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}` + sql := "DELETE FROM {{$schemaTable}} WHERE {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}" - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, args...) - } + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, args...) + } - _, err := exec.Exec(sql, args...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}") - } + _, err := exec.Exec(sql, args...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}") + } - {{if not .NoHooks -}} - if err := o.doAfterDeleteHooks(exec); err != nil { - return err - } - {{- end}} + {{if not .NoHooks -}} + if err := o.doAfterDeleteHooks(exec); err != nil { + return err + } + {{- end}} - return nil + return nil } // DeleteAllP deletes all rows, and panics on error. func (q {{$varNameSingular}}Query) DeleteAllP() { - if err := q.DeleteAll(); err != nil { - panic(boil.WrapErr(err)) - } + if err := q.DeleteAll(); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteAll deletes all matching rows. func (q {{$varNameSingular}}Query) DeleteAll() error { - if q.Query == nil { - return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all") - } + if q.Query == nil { + return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all") + } - boil.SetDelete(q.Query) + queries.SetDelete(q.Query) - _, err := boil.ExecQuery(q.Query) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}") - } + _, err := q.Query.Exec() + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}") + } - return nil + return nil } -// DeleteAll deletes all rows in the slice, and panics on error. +// DeleteAllGP deletes all rows in the slice, and panics on error. func (o {{$tableNameSingular}}Slice) DeleteAllGP() { - if err := o.DeleteAllG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.DeleteAllG(); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteAllG deletes all rows in the slice. func (o {{$tableNameSingular}}Slice) DeleteAllG() error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") - } - return o.DeleteAll(boil.GetDB()) + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") + } + return o.DeleteAll(boil.GetDB()) } // DeleteAllP deletes all rows in the slice, using an executor, and panics on error. func (o {{$tableNameSingular}}Slice) DeleteAllP(exec boil.Executor) { - if err := o.DeleteAll(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.DeleteAll(exec); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteAll deletes all rows in the slice, using an executor. func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") - } - - if len(o) == 0 { - return nil - } - - {{if not .NoHooks -}} - if len({{$varNameSingular}}BeforeDeleteHooks) != 0 { - for _, obj := range o { - if err := obj.doBeforeDeleteHooks(exec); err != nil { - return err - } - } - } - {{- end}} - - args := o.inPrimaryKeyArgs() - - sql := fmt.Sprintf( - `DELETE FROM {{.Table.Name}} WHERE (%s) IN (%s)`, - strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), - ) - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, args) - } - - _, err := exec.Exec(sql, args...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice") - } - - {{if not .NoHooks -}} - if len({{$varNameSingular}}AfterDeleteHooks) != 0 { - for _, obj := range o { - if err := obj.doAfterDeleteHooks(exec); err != nil { - return err - } - } - } - {{- end}} - - return nil + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") + } + + if len(o) == 0 { + return nil + } + + {{if not .NoHooks -}} + if len({{$varNameSingular}}BeforeDeleteHooks) != 0 { + for _, obj := range o { + if err := obj.doBeforeDeleteHooks(exec); err != nil { + return err + } + } + } + {{- end}} + + args := o.inPrimaryKeyArgs() + + sql := fmt.Sprintf( + "DELETE FROM {{$schemaTable}} WHERE (%s) IN (%s)", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), + ) + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, args) + } + + _, err := exec.Exec(sql, args...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice") + } + + {{if not .NoHooks -}} + if len({{$varNameSingular}}AfterDeleteHooks) != 0 { + for _, obj := range o { + if err := obj.doAfterDeleteHooks(exec); err != nil { + return err + } + } + } + {{- end}} + + return nil } diff --git a/templates/16_reload.tpl b/templates/16_reload.tpl index 38e83067c..60a1f369c 100644 --- a/templates/16_reload.tpl +++ b/templates/16_reload.tpl @@ -1,85 +1,94 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // ReloadGP refetches the object from the database and panics on error. func (o *{{$tableNameSingular}}) ReloadGP() { - if err := o.ReloadG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.ReloadG(); err != nil { + panic(boil.WrapErr(err)) + } } // ReloadP refetches the object from the database with an executor. Panics on error. func (o *{{$tableNameSingular}}) ReloadP(exec boil.Executor) { - if err := o.Reload(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Reload(exec); err != nil { + panic(boil.WrapErr(err)) + } } // ReloadG refetches the object from the database using the primary keys. func (o *{{$tableNameSingular}}) ReloadG() error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload") - } + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload") + } - return o.Reload(boil.GetDB()) + return o.Reload(boil.GetDB()) } // Reload refetches the object from the database // using the primary keys with an executor. func (o *{{$tableNameSingular}}) Reload(exec boil.Executor) error { - ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) - if err != nil { - return err - } + ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) + if err != nil { + return err + } - *o = *ret - return nil + *o = *ret + return nil } +// ReloadAllGP refetches every row with matching primary key column values +// and overwrites the original object slice with the newly updated slice. +// Panics on error. func (o *{{$tableNameSingular}}Slice) ReloadAllGP() { - if err := o.ReloadAllG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.ReloadAllG(); err != nil { + panic(boil.WrapErr(err)) + } } +// ReloadAllP refetches every row with matching primary key column values +// and overwrites the original object slice with the newly updated slice. +// Panics on error. func (o *{{$tableNameSingular}}Slice) ReloadAllP(exec boil.Executor) { - if err := o.ReloadAll(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.ReloadAll(exec); err != nil { + panic(boil.WrapErr(err)) + } } +// ReloadAllG refetches every row with matching primary key column values +// and overwrites the original object slice with the newly updated slice. func (o *{{$tableNameSingular}}Slice) ReloadAllG() error { - if o == nil { - return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all") - } + if o == nil { + return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all") + } - return o.ReloadAll(boil.GetDB()) + return o.ReloadAll(boil.GetDB()) } // ReloadAll refetches every row with matching primary key column values // and overwrites the original object slice with the newly updated slice. func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error { - if o == nil || len(*o) == 0 { - return nil - } + if o == nil || len(*o) == 0 { + return nil + } - {{$varNamePlural}} := {{$tableNameSingular}}Slice{} - args := o.inPrimaryKeyArgs() + {{$varNamePlural}} := {{$tableNameSingular}}Slice{} + args := o.inPrimaryKeyArgs() - sql := fmt.Sprintf( - `SELECT {{.Table.Name}}.* FROM {{.Table.Name}} WHERE (%s) IN (%s)`, - strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), - ) + sql := fmt.Sprintf( + "SELECT {{$schemaTable}}.* FROM {{$schemaTable}} WHERE (%s) IN (%s)", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), + ) - q := boil.SQL(exec, sql, args...) + q := queries.Raw(exec, sql, args...) - err := q.Bind(&{{$varNamePlural}}) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice") - } + err := q.Bind(&{{$varNamePlural}}) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice") + } - *o = {{$varNamePlural}} + *o = {{$varNamePlural}} - return nil + return nil } diff --git a/templates/17_exists.tpl b/templates/17_exists.tpl index 899bec096..48709aec8 100644 --- a/templates/17_exists.tpl +++ b/templates/17_exists.tpl @@ -2,48 +2,49 @@ {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // {{$tableNameSingular}}Exists checks if the {{$tableNameSingular}} row exists. func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { - var exists bool + var exists bool - sql := `select exists(select 1 from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}} limit 1)` + sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}} limit 1)" - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}}) - } + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}}) + } - row := exec.QueryRow(sql, {{$pkNames | join ", "}}) + row := exec.QueryRow(sql, {{$pkNames | join ", "}}) - err := row.Scan(&exists) - if err != nil { - return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists") - } + err := row.Scan(&exists) + if err != nil { + return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists") + } - return exists, nil + return exists, nil } // {{$tableNameSingular}}ExistsG checks if the {{$tableNameSingular}} row exists. func {{$tableNameSingular}}ExistsG({{$pkArgs}}) (bool, error) { - return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) + return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) } // {{$tableNameSingular}}ExistsGP checks if the {{$tableNameSingular}} row exists. Panics on error. func {{$tableNameSingular}}ExistsGP({{$pkArgs}}) bool { - e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) - if err != nil { - panic(boil.WrapErr(err)) - } + e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) + if err != nil { + panic(boil.WrapErr(err)) + } - return e + return e } // {{$tableNameSingular}}ExistsP checks if the {{$tableNameSingular}} row exists. Panics on error. func {{$tableNameSingular}}ExistsP(exec boil.Executor, {{$pkArgs}}) bool { - e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}}) - if err != nil { - panic(boil.WrapErr(err)) - } + e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}}) + if err != nil { + panic(boil.WrapErr(err)) + } - return e + return e } diff --git a/templates/18_helpers.tpl b/templates/18_helpers.tpl index a9dd023d0..36618441a 100644 --- a/templates/18_helpers.tpl +++ b/templates/18_helpers.tpl @@ -1,23 +1,23 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}} func (o {{$tableNameSingular}}) inPrimaryKeyArgs() []interface{} { - var args []interface{} + var args []interface{} - {{- range $key, $value := .Table.PKey.Columns }} - args = append(args, o.{{titleCase $value}}) - {{ end -}} + {{- range $key, $value := .Table.PKey.Columns }} + args = append(args, o.{{titleCase $value}}) + {{ end -}} - return args + return args } func (o {{$tableNameSingular}}Slice) inPrimaryKeyArgs() []interface{} { - var args []interface{} + var args []interface{} - for i := 0; i < len(o); i++ { - {{- range $key, $value := .Table.PKey.Columns }} - args = append(args, o[i].{{titleCase $value}}) - {{ end -}} - } + for i := 0; i < len(o); i++ { + {{- range $key, $value := .Table.PKey.Columns }} + args = append(args, o[i].{{titleCase $value}}) + {{ end -}} + } - return args + return args } diff --git a/templates/19_auto_timestamps.tpl b/templates/19_auto_timestamps.tpl index d600ccb77..fcb00d1cd 100644 --- a/templates/19_auto_timestamps.tpl +++ b/templates/19_auto_timestamps.tpl @@ -1,82 +1,82 @@ {{- define "timestamp_insert_helper" -}} - {{- if not .NoAutoTimestamps -}} - {{- $colNames := .Table.Columns | columnNames -}} - {{if containsAny $colNames "created_at" "updated_at"}} - currTime := time.Now().In(boil.GetLocation()) - {{range $ind, $col := .Table.Columns}} - {{- if eq $col.Name "created_at" -}} - {{- if $col.Nullable}} - if o.CreatedAt.Time.IsZero() { - o.CreatedAt.Time = currTime - o.CreatedAt.Valid = true - } - {{- else}} - if o.CreatedAt.IsZero() { - o.CreatedAt = currTime - } - {{- end -}} - {{- end -}} - {{- if eq $col.Name "updated_at" -}} - {{- if $col.Nullable}} - if o.UpdatedAt.Time.IsZero() { - o.UpdatedAt.Time = currTime - o.UpdatedAt.Valid = true - } - {{- else}} - if o.UpdatedAt.IsZero() { - o.UpdatedAt = currTime - } - {{- end -}} - {{- end -}} - {{end}} - {{end}} - {{- end}} + {{- if not .NoAutoTimestamps -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "created_at" "updated_at"}} + currTime := time.Now().In(boil.GetLocation()) + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "created_at" -}} + {{- if $col.Nullable}} + if o.CreatedAt.Time.IsZero() { + o.CreatedAt.Time = currTime + o.CreatedAt.Valid = true + } + {{- else}} + if o.CreatedAt.IsZero() { + o.CreatedAt = currTime + } + {{- end -}} + {{- end -}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + if o.UpdatedAt.Time.IsZero() { + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + } + {{- else}} + if o.UpdatedAt.IsZero() { + o.UpdatedAt = currTime + } + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} {{- end -}} {{- define "timestamp_update_helper" -}} - {{- if not .NoAutoTimestamps -}} - {{- $colNames := .Table.Columns | columnNames -}} - {{if containsAny $colNames "updated_at"}} - currTime := time.Now().In(boil.GetLocation()) - {{range $ind, $col := .Table.Columns}} - {{- if eq $col.Name "updated_at" -}} - {{- if $col.Nullable}} - o.UpdatedAt.Time = currTime - o.UpdatedAt.Valid = true - {{- else}} - o.UpdatedAt = currTime - {{- end -}} - {{- end -}} - {{end}} - {{end}} - {{- end}} + {{- if not .NoAutoTimestamps -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "updated_at"}} + currTime := time.Now().In(boil.GetLocation()) + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + {{- else}} + o.UpdatedAt = currTime + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} {{end -}} {{- define "timestamp_upsert_helper" -}} - {{- if not .NoAutoTimestamps -}} - {{- $colNames := .Table.Columns | columnNames -}} - {{if containsAny $colNames "created_at" "updated_at"}} - currTime := time.Now().In(boil.GetLocation()) - {{range $ind, $col := .Table.Columns}} - {{- if eq $col.Name "created_at" -}} - {{- if $col.Nullable}} - if o.CreatedAt.Time.IsZero() { - o.CreatedAt.Time = currTime - o.CreatedAt.Valid = true - } - {{- else}} - if o.CreatedAt.IsZero() { - o.CreatedAt = currTime - } - {{- end -}} - {{- end -}} - {{- if eq $col.Name "updated_at" -}} - {{- if $col.Nullable}} - o.UpdatedAt.Time = currTime - o.UpdatedAt.Valid = true - {{- else}} - o.UpdatedAt = currTime - {{- end -}} - {{- end -}} - {{end}} - {{end}} - {{- end}} + {{- if not .NoAutoTimestamps -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "created_at" "updated_at"}} + currTime := time.Now().In(boil.GetLocation()) + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "created_at" -}} + {{- if $col.Nullable}} + if o.CreatedAt.Time.IsZero() { + o.CreatedAt.Time = currTime + o.CreatedAt.Valid = true + } + {{- else}} + if o.CreatedAt.IsZero() { + o.CreatedAt = currTime + } + {{- end -}} + {{- end -}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + {{- else}} + o.UpdatedAt = currTime + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} {{end -}} diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index 6cb607df2..d0879cb96 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -1,12 +1,19 @@ +var dialect = queries.Dialect{ + LQ: 0x{{printf "%x" .Dialect.LQ}}, + RQ: 0x{{printf "%x" .Dialect.RQ}}, + IndexPlaceholders: {{.Dialect.IndexPlaceholders}}, +} + // NewQueryG initializes a new Query using the passed in QueryMods -func NewQueryG(mods ...qm.QueryMod) *boil.Query { +func NewQueryG(mods ...qm.QueryMod) *queries.Query { return NewQuery(boil.GetDB(), mods...) } // NewQuery initializes a new Query using the passed in QueryMods -func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *boil.Query { - q := &boil.Query{} - boil.SetExecutor(q, exec) +func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *queries.Query { + q := &queries.Query{} + queries.SetExecutor(q, exec) + queries.SetDialect(q, &dialect) qm.Apply(q, mods...) return q diff --git a/templates/singleton/boil_types.tpl b/templates/singleton/boil_types.tpl index 143a18c39..ebb88918f 100644 --- a/templates/singleton/boil_types.tpl +++ b/templates/singleton/boil_types.tpl @@ -6,33 +6,32 @@ type M map[string]interface{} // fails or there was a primary key configuration that was not resolvable. var ErrSyncFail = errors.New("{{.PkgName}}: failed to synchronize data after insert") -type insertCache struct{ - query string - retQuery string - valueMapping []uint64 - retMapping []uint64 +type insertCache struct { + query string + retQuery string + valueMapping []uint64 + retMapping []uint64 } -type updateCache struct{ - query string - valueMapping []uint64 +type updateCache struct { + query string + valueMapping []uint64 } func makeCacheKey(wl, nzDefaults []string) string { - buf := strmangle.GetBuffer() + buf := strmangle.GetBuffer() - for _, w := range wl { - buf.WriteString(w) - } - if len(nzDefaults) != 0 { - buf.WriteByte('.') - } - for _, nz := range nzDefaults { - buf.WriteString(nz) - } + for _, w := range wl { + buf.WriteString(w) + } + if len(nzDefaults) != 0 { + buf.WriteByte('.') + } + for _, nz := range nzDefaults { + buf.WriteString(nz) + } - str := buf.String() - strmangle.PutBuffer(buf) - return str + str := buf.String() + strmangle.PutBuffer(buf) + return str } - diff --git a/templates_test/all.tpl b/templates_test/all.tpl index 001c32299..532801395 100644 --- a/templates_test/all.tpl +++ b/templates_test/all.tpl @@ -3,11 +3,11 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}(t *testing.T) { - t.Parallel() + t.Parallel() - query := {{$tableNamePlural}}(nil) + query := {{$tableNamePlural}}(nil) - if query.Query == nil { - t.Error("expected a query, got nothing") - } + if query.Query == nil { + t.Error("expected a query, got nothing") + } } diff --git a/templates_test/delete.tpl b/templates_test/delete.tpl index f6dfaca4d..f745ea48d 100644 --- a/templates_test/delete.tpl +++ b/templates_test/delete.tpl @@ -3,93 +3,93 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Delete(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$varNameSingular}}.Delete(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 0 { - t.Error("want zero records, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$varNameSingular}}.Delete(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 0 { + t.Error("want zero records, got:", count) + } } func test{{$tableNamePlural}}QueryDeleteAll(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 0 { - t.Error("want zero records, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 0 { + t.Error("want zero records, got:", count) + } } func test{{$tableNamePlural}}SliceDeleteAll(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} - - if err = slice.DeleteAll(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 0 { - t.Error("want zero records, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} + + if err = slice.DeleteAll(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 0 { + t.Error("want zero records, got:", count) + } } diff --git a/templates_test/exists.tpl b/templates_test/exists.tpl index 36089214d..30bbcfbf3 100644 --- a/templates_test/exists.tpl +++ b/templates_test/exists.tpl @@ -3,27 +3,27 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Exists(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - {{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}} - e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}}) - if err != nil { - t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err) - } - if e != true { - t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.") - } + {{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}} + e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}}) + if err != nil { + t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err) + } + if e != true { + t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.") + } } diff --git a/templates_test/find.tpl b/templates_test/find.tpl index 2da3fdadc..cd3ea9677 100644 --- a/templates_test/find.tpl +++ b/templates_test/find.tpl @@ -3,27 +3,27 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Find(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - {{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}}) - if err != nil { - t.Error(err) - } + {{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}}) + if err != nil { + t.Error(err) + } - if {{$varNameSingular}}Found == nil { - t.Error("want a record, got nil") - } + if {{$varNameSingular}}Found == nil { + t.Error("want a record, got nil") + } } diff --git a/templates_test/finishers.tpl b/templates_test/finishers.tpl index f7cff0d3a..fa8b129d2 100644 --- a/templates_test/finishers.tpl +++ b/templates_test/finishers.tpl @@ -3,111 +3,111 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Bind(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil { - t.Error(err) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil { + t.Error(err) + } } func test{{$tableNamePlural}}One(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if x, err := {{$tableNamePlural}}(tx).One(); err != nil { - t.Error(err) - } else if x == nil { - t.Error("expected to get a non nil record") - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if x, err := {{$tableNamePlural}}(tx).One(); err != nil { + t.Error(err) + } else if x == nil { + t.Error("expected to get a non nil record") + } } func test{{$tableNamePlural}}All(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}}One := &{{$tableNameSingular}}{} - {{$varNameSingular}}Two := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}One.Insert(tx); err != nil { - t.Error(err) - } - if err = {{$varNameSingular}}Two.Insert(tx); err != nil { - t.Error(err) - } - - slice, err := {{$tableNamePlural}}(tx).All() - if err != nil { - t.Error(err) - } - - if len(slice) != 2 { - t.Error("want 2 records, got:", len(slice)) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}}One := &{{$tableNameSingular}}{} + {{$varNameSingular}}Two := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}One.Insert(tx); err != nil { + t.Error(err) + } + if err = {{$varNameSingular}}Two.Insert(tx); err != nil { + t.Error(err) + } + + slice, err := {{$tableNamePlural}}(tx).All() + if err != nil { + t.Error(err) + } + + if len(slice) != 2 { + t.Error("want 2 records, got:", len(slice)) + } } func test{{$tableNamePlural}}Count(t *testing.T) { - t.Parallel() - - var err error - seed := randomize.NewSeed() - {{$varNameSingular}}One := &{{$tableNameSingular}}{} - {{$varNameSingular}}Two := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}One.Insert(tx); err != nil { - t.Error(err) - } - if err = {{$varNameSingular}}Two.Insert(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 2 { - t.Error("want 2 records, got:", count) - } + t.Parallel() + + var err error + seed := randomize.NewSeed() + {{$varNameSingular}}One := &{{$tableNameSingular}}{} + {{$varNameSingular}}Two := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}One.Insert(tx); err != nil { + t.Error(err) + } + if err = {{$varNameSingular}}Two.Insert(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 2 { + t.Error("want 2 records, got:", count) + } } diff --git a/templates_test/helpers.tpl b/templates_test/helpers.tpl index c20dbd0d9..09f9bff86 100644 --- a/templates_test/helpers.tpl +++ b/templates_test/helpers.tpl @@ -5,57 +5,57 @@ var {{$varNameSingular}}DBTypes = map[string]string{{"{"}}{{.Table.Columns | columnDBTypes | makeStringMap}}{{"}"}} func test{{$tableNamePlural}}InPrimaryKeyArgs(t *testing.T) { - t.Parallel() + t.Parallel() - var err error - var o {{$tableNameSingular}} - o = {{$tableNameSingular}}{} + var err error + var o {{$tableNameSingular}} + o = {{$tableNameSingular}}{} - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Could not randomize struct: %s", err) - } + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Could not randomize struct: %s", err) + } - args := o.inPrimaryKeyArgs() + args := o.inPrimaryKeyArgs() - if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) { - t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args)) - } + if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) { + t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args)) + } - {{range $key, $value := .Table.PKey.Columns}} - if o.{{titleCase $value}} != args[{{$key}}] { - t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}]) - } - {{- end}} + {{range $key, $value := .Table.PKey.Columns}} + if o.{{titleCase $value}} != args[{{$key}}] { + t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}]) + } + {{- end}} } func test{{$tableNamePlural}}SliceInPrimaryKeyArgs(t *testing.T) { - t.Parallel() - - var err error - o := make({{$tableNameSingular}}Slice, 3) - - seed := randomize.NewSeed() - for i := range o { - o[i] = &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Could not randomize struct: %s", err) - } - } - - args := o.inPrimaryKeyArgs() - - if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 { - t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args)) - } - - argC := 0 - for i := 0; i < 3; i++ { - {{range $key, $value := .Table.PKey.Columns}} - if o[i].{{titleCase $value}} != args[argC] { - t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i]) - } - argC++ - {{- end}} - } + t.Parallel() + + var err error + o := make({{$tableNameSingular}}Slice, 3) + + seed := randomize.NewSeed() + for i := range o { + o[i] = &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Could not randomize struct: %s", err) + } + } + + args := o.inPrimaryKeyArgs() + + if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 { + t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args)) + } + + argC := 0 + for i := 0; i < 3; i++ { + {{range $key, $value := .Table.PKey.Columns}} + if o[i].{{titleCase $value}} != args[argC] { + t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i]) + } + argC++ + {{- end}} + } } diff --git a/templates_test/hooks.tpl b/templates_test/hooks.tpl index dd6a4bb64..22dc84c7e 100644 --- a/templates_test/hooks.tpl +++ b/templates_test/hooks.tpl @@ -4,142 +4,142 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func {{$varNameSingular}}BeforeInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterSelectHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}BeforeUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}BeforeDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}BeforeUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func test{{$tableNamePlural}}Hooks(t *testing.T) { - t.Parallel() - - var err error - - empty := &{{$tableNameSingular}}{} - o := &{{$tableNameSingular}}{} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err) - } - - Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook) - if err = o.doBeforeInsertHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeInsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook) - if err = o.doAfterInsertHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterInsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook) - if err = o.doAfterSelectHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterSelectHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook) - if err = o.doBeforeUpdateHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook) - if err = o.doAfterUpdateHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterUpdateHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook) - if err = o.doBeforeDeleteHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook) - if err = o.doAfterDeleteHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterDeleteHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook) - if err = o.doBeforeUpsertHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook) - if err = o.doAfterUpsertHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterUpsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{} + t.Parallel() + + var err error + + empty := &{{$tableNameSingular}}{} + o := &{{$tableNameSingular}}{} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err) + } + + Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook) + if err = o.doBeforeInsertHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeInsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook) + if err = o.doAfterInsertHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterInsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook) + if err = o.doAfterSelectHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterSelectHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook) + if err = o.doBeforeUpdateHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook) + if err = o.doAfterUpdateHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterUpdateHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook) + if err = o.doBeforeDeleteHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook) + if err = o.doAfterDeleteHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterDeleteHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook) + if err = o.doBeforeUpsertHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook) + if err = o.doAfterUpsertHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterUpsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{} } {{- end}} diff --git a/templates_test/insert.tpl b/templates_test/insert.tpl index 63898dac1..d14a0c827 100644 --- a/templates_test/insert.tpl +++ b/templates_test/insert.tpl @@ -4,53 +4,53 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $parent := . -}} func test{{$tableNamePlural}}Insert(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 1 { - t.Error("want one record, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 1 { + t.Error("want one record, got:", count) + } } func test{{$tableNamePlural}}InsertWhitelist(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 1 { - t.Error("want one record, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 1 { + t.Error("want one record, got:", count) + } } diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl new file mode 100644 index 000000000..fc43d3d8f --- /dev/null +++ b/templates_test/main_test/mysql_main.tpl @@ -0,0 +1,166 @@ +type mysqlTester struct { + dbConn *sql.DB + + dbName string + host string + user string + pass string + sslmode string + port int + + optionFile string + + testDBName string +} + +func init() { + dbMain = &mysqlTester{} +} + +func (m *mysqlTester) setup() error { + var err error + + m.dbName = viper.GetString("mysql.dbname") + m.host = viper.GetString("mysql.host") + m.user = viper.GetString("mysql.user") + m.pass = viper.GetString("mysql.pass") + m.port = viper.GetInt("mysql.port") + m.sslmode = viper.GetString("mysql.sslmode") + // Create a randomized db name. + m.testDBName = randomize.StableDBName(m.dbName) + + if err = m.makeOptionFile(); err != nil { + return errors.Wrap(err, "couldn't make option file") + } + + if err = m.dropTestDB(); err != nil { + return err + } + if err = m.createTestDB(); err != nil { + return err + } + + dumpCmd := exec.Command("mysqldump", m.defaultsFile(), "--no-data", m.dbName) + createCmd := exec.Command("mysql", m.defaultsFile(), "--database", m.testDBName) + + r, w := io.Pipe() + dumpCmd.Stdout = w + createCmd.Stdin = newFKeyDestroyer(rgxMySQLkey, r) + + if err = dumpCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start mysqldump command") + } + if err = createCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start mysql command") + } + + if err = dumpCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for mysqldump command") + } + + w.Close() // After dumpCmd is done, close the write end of the pipe + + if err = createCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for mysql command") + } + + return nil +} + +func (m *mysqlTester) sslMode(mode string) string { + switch mode { + case "true": + return "REQUIRED" + case "false": + return "DISABLED" + default: + return "PREFERRED" + } +} + +func (m *mysqlTester) defaultsFile() string { + return fmt.Sprintf("--defaults-file=%s", m.optionFile) +} + +func (m *mysqlTester) makeOptionFile() error { + tmp, err := ioutil.TempFile("", "optionfile") + if err != nil { + return errors.Wrap(err, "failed to create option file") + } + + fmt.Fprintln(tmp, "[client]") + fmt.Fprintf(tmp, "host=%s\n", m.host) + fmt.Fprintf(tmp, "port=%d\n", m.port) + fmt.Fprintf(tmp, "user=%s\n", m.user) + fmt.Fprintf(tmp, "password=%s\n", m.pass) + fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode)) + + fmt.Fprintln(tmp, "[mysqldump]") + fmt.Fprintf(tmp, "host=%s\n", m.host) + fmt.Fprintf(tmp, "port=%d\n", m.port) + fmt.Fprintf(tmp, "user=%s\n", m.user) + fmt.Fprintf(tmp, "password=%s\n", m.pass) + fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode)) + + m.optionFile = tmp.Name() + + return tmp.Close() +} + +func (m *mysqlTester) createTestDB() error { + sql := fmt.Sprintf("create database %s;", m.testDBName) + return m.runCmd(sql, "mysql") +} + +func (m *mysqlTester) dropTestDB() error { + sql := fmt.Sprintf("drop database if exists %s;", m.testDBName) + return m.runCmd(sql, "mysql") +} + +func (m *mysqlTester) teardown() error { + if m.dbConn != nil { + m.dbConn.Close() + } + + if err := m.dropTestDB(); err != nil { + return err + } + + return os.Remove(m.optionFile) +} + +func (m *mysqlTester) runCmd(stdin, command string, args ...string) error { + args = append([]string{m.defaultsFile()}, args...) + + cmd := exec.Command(command, args...) + cmd.Stdin = strings.NewReader(stdin) + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd.Stdout = stdout + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + fmt.Println("failed running:", command, args) + fmt.Println(stdout.String()) + fmt.Println(stderr.String()) + return err + } + + return nil +} + +func (m *mysqlTester) conn() (*sql.DB, error) { + if m.dbConn != nil { + return m.dbConn, nil + } + + var err error + m.dbConn, err = sql.Open("mysql", drivers.MySQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode)) + if err != nil { + return nil, err + } + + return m.dbConn, nil +} diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index ee3ce6985..d1951438a 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -1,275 +1,166 @@ -type PostgresCfg struct { - User string `toml:"user"` - Pass string `toml:"pass"` - Host string `toml:"host"` - Port int `toml:"port"` - DBName string `toml:"dbname"` - SSLMode string `toml:"sslmode"` -} - -type Config struct { - Postgres PostgresCfg `toml:"postgres"` -} - -var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements") - -func TestMain(m *testing.M) { - rand.Seed(time.Now().UnixNano()) +type pgTester struct { + dbConn *sql.DB - // Set DebugMode so we can see generated sql statements - flag.Parse() - boil.DebugMode = *flagDebugMode + dbName string + host string + user string + pass string + sslmode string + port int - var err error - if err = setup(); err != nil { - fmt.Println("Unable to execute setup:", err) - os.Exit(-2) - } + pgPassFile string - var code int - if err = disableTriggers(); err != nil { - fmt.Println("Unable to disable triggers:", err) - } else { - boil.SetDB(dbConn) - code = m.Run() - } + testDBName string +} - if err = teardown(); err != nil { - fmt.Println("Unable to execute teardown:", err) - os.Exit(-3) - } +func init() { + dbMain = &pgTester{} +} - os.Exit(code) +// setup dumps the database schema and imports it into a temporary randomly +// generated test database so that tests can be run against it using the +// generated sqlboiler ORM package. +func (p *pgTester) setup() error { + var err error + + p.dbName = viper.GetString("postgres.dbname") + p.host = viper.GetString("postgres.host") + p.user = viper.GetString("postgres.user") + p.pass = viper.GetString("postgres.pass") + p.port = viper.GetInt("postgres.port") + p.sslmode = viper.GetString("postgres.sslmode") + // Create a randomized db name. + p.testDBName = randomize.StableDBName(p.dbName) + + if err = p.makePGPassFile(); err != nil { + return err + } + + if err = p.dropTestDB(); err != nil { + return err + } + if err = p.createTestDB(); err != nil { + return err + } + + dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName) + dumpCmd.Env = append(os.Environ(), p.pgEnv()...) + createCmd := exec.Command("psql", p.testDBName) + createCmd.Env = append(os.Environ(), p.pgEnv()...) + + r, w := io.Pipe() + dumpCmd.Stdout = w + createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r) + + if err = dumpCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start pg_dump command") + } + if err = createCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start psql command") + } + + if err = dumpCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for pg_dump command") + } + + w.Close() // After dumpCmd is done, close the write end of the pipe + + if err = createCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for psql command") + } + + return nil } -// disableTriggers is used to disable foreign key constraints for every table. -// If this is not used we cannot test inserts due to foreign key constraint errors. -func disableTriggers() error { - var stmts []string +func (p *pgTester) runCmd(stdin, command string, args ...string) error { + cmd := exec.Command(command, args...) + cmd.Env = append(os.Environ(), p.pgEnv()...) + + if len(stdin) != 0 { + cmd.Stdin = strings.NewReader(stdin) + } + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd.Stdout = stdout + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + fmt.Println("failed running:", command, args) + fmt.Println(stdout.String()) + fmt.Println(stderr.String()) + return err + } + + return nil +} - {{range .Tables}} - stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`) - {{- end}} +func (p *pgTester) pgEnv() []string { + return []string{ + fmt.Sprintf("PGHOST=%s", p.host), + fmt.Sprintf("PGPORT=%d", p.port), + fmt.Sprintf("PGUSER=%s", p.user), + fmt.Sprintf("PGPASS=%s", p.pgPassFile), + } +} - if len(stmts) == 0 { - return nil - } +func (p *pgTester) makePGPassFile() error { + tmp, err := ioutil.TempFile("", "pgpass") + if err != nil { + return errors.Wrap(err, "failed to create option file") + } + + fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user) + if len(p.pass) != 0 { + fmt.Fprintf(tmp, ":%s", p.pass) + } + fmt.Fprintln(tmp) + + fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user) + if len(p.pass) != 0 { + fmt.Fprintf(tmp, ":%s", p.pass) + } + fmt.Fprintln(tmp) + + p.pgPassFile = tmp.Name() + return tmp.Close() +} - var err error - for _, s := range stmts { - _, err = dbConn.Exec(s) - if err != nil { - return err - } - } +func (p *pgTester) createTestDB() error { + return p.runCmd("", "createdb", p.testDBName) +} - return nil +func (p *pgTester) dropTestDB() error { + return p.runCmd("", "dropdb", "--if-exists", p.testDBName) } // teardown executes cleanup tasks when the tests finish running -func teardown() error { - err := dropTestDB() - return err +func (p *pgTester) teardown() error { + var err error + if err = p.dbConn.Close(); err != nil { + return err + } + p.dbConn = nil + + if err = p.dropTestDB(); err != nil { + return err + } + + return os.Remove(p.pgPassFile) } -// dropTestDB switches its connection to the template1 database temporarily -// so that it can drop the test database without causing "in use" conflicts. -// The template1 database should be present on all default postgres installations. -func dropTestDB() error { - var err error - if dbConn != nil { - if err = dbConn.Close(); err != nil { - return err - } - } - - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) - if err != nil { - return err - } - - _, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) - if err != nil { - return err - } - - return dbConn.Close() -} +func (p *pgTester) conn() (*sql.DB, error) { + if p.dbConn != nil { + return p.dbConn, nil + } -// DBConnect connects to a database and returns the handle. -func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) { - connStr := drivers.BuildQueryString(user, pass, dbname, host, port, sslmode) + var err error + p.dbConn, err = sql.Open("postgres", drivers.PostgresBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)) + if err != nil { + return nil, err + } - return sql.Open("postgres", connStr) + return p.dbConn, nil } -// setup dumps the database schema and imports it into a temporary randomly -// generated test database so that tests can be run against it using the -// generated sqlboiler ORM package. -func setup() error { - var err error - - // Initialize Viper and load the config file - err = InitViper() - if err != nil { - return errors.Wrap(err, "Unable to load config file") - } - - viper.SetDefault("postgres.sslmode", "require") - viper.SetDefault("postgres.port", "5432") - - // Create a randomized test configuration object. - testCfg.Postgres.Host = viper.GetString("postgres.host") - testCfg.Postgres.Port = viper.GetInt("postgres.port") - testCfg.Postgres.User = viper.GetString("postgres.user") - testCfg.Postgres.Pass = viper.GetString("postgres.pass") - testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname")) - testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") - - // Set the default SSLMode value - if testCfg.Postgres.SSLMode == "" { - viper.Set("postgres.sslmode", "require") - testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") - } - - err = vala.BeginValidation().Validate( - vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"), - vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"), - vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")), - vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"), - vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"), - ).Check() - - if err != nil { - return errors.Wrap(err, "Unable to load testCfg") - } - - err = dropTestDB() - if err != nil { - fmt.Printf("%#v\n", err) - return err - } - - fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema") - if err != nil { - return errors.Wrap(err, "Unable to create sqlboiler schema tmp file") - } - defer os.Remove(fhSchema.Name()) - - passDir, err := ioutil.TempDir(os.TempDir(), "sqlboiler") - if err != nil { - return errors.Wrap(err, "Unable to create sqlboiler tmp dir for postgres pw file") - } - defer os.RemoveAll(passDir) - - // Write the postgres user password to a tmp file for pg_dump - pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", - viper.GetString("postgres.host"), - viper.GetInt("postgres.port"), - viper.GetString("postgres.dbname"), - viper.GetString("postgres.user"), - )) - - if pw := viper.GetString("postgres.pass"); len(pw) > 0 { - pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw)) - } - - passFilePath := filepath.Join(passDir, "pwfile") - - err = ioutil.WriteFile(passFilePath, pwBytes, 0600) - if err != nil { - return errors.Wrap(err, "Unable to create pwfile in passDir") - } - - // The params for the pg_dump command to dump the database schema - params := []string{ - fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")), - fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")), - fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")), - "--schema-only", - viper.GetString("postgres.dbname"), - } - - // Dump the database schema into the sqlboilerschema tmp file - errBuf := bytes.Buffer{} - cmd := exec.Command("pg_dump", params...) - cmd.Stderr = &errBuf - cmd.Stdout = fhSchema - cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, passFilePath)) - - if err := cmd.Run(); err != nil { - fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String()) - return err - } - - dbConn, err = DBConnect( - viper.GetString("postgres.user"), - viper.GetString("postgres.pass"), - viper.GetString("postgres.dbname"), - viper.GetString("postgres.host"), - viper.GetInt("postgres.port"), - viper.GetString("postgres.sslmode"), - ) - if err != nil { - return err - } - - // Create the randomly generated database - _, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName)) - if err != nil { - return err - } - - // Close the old connection so we can reconnect to the test database - if err = dbConn.Close(); err != nil { - return err - } - - // Connect to the generated test db - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) - if err != nil { - return err - } - - // Write the test config credentials to a tmp file for pg_dump - testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", - testCfg.Postgres.Host, - testCfg.Postgres.Port, - testCfg.Postgres.DBName, - testCfg.Postgres.User, - )) - - if len(testCfg.Postgres.Pass) > 0 { - testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass)) - } - - testPassFilePath := passDir + "/testpwfile" - - err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600) - if err != nil { - return errors.Wrapf(err, "Unable to create testpwfile in passDir") - } - - // The params for the psql schema import command - params = []string{ - fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName), - fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host), - fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port), - fmt.Sprintf(`--username=%s`, testCfg.Postgres.User), - fmt.Sprintf(`--file=%s`, fhSchema.Name()), - } - - // Import the database schema into the generated database. - // It is now ready to be used by the generated ORM package for testing. - outBuf := bytes.Buffer{} - cmd = exec.Command("psql", params...) - cmd.Stderr = &errBuf - cmd.Stdout = &outBuf - cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, testPassFilePath)) - - if err = cmd.Run(); err != nil { - fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String()) - } - - return nil -} diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index cb9ee364a..cade6cd0a 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -1,98 +1,98 @@ {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . }} - {{- $table := .Table }} - {{- range .Table.ToManyRelationships -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- $dot := . }} + {{- $table := .Table }} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- template "relationship_to_one_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} - {{- else -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- else -}} + {{- $rel := textsFromRelationship $dot.Tables $table . -}} func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { - var err error - tx := MustTx(boil.Begin()) - defer tx.Rollback() + var err error + tx := MustTx(boil.Begin()) + defer tx.Rollback() - var a {{$rel.LocalTable.NameGo}} - var b, c {{$rel.ForeignTable.NameGo}} + var a {{$rel.LocalTable.NameGo}} + var b, c {{$rel.ForeignTable.NameGo}} - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } - seed := randomize.NewSeed() - randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") - randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") - {{if .Nullable -}} - a.{{.Column | titleCase}}.Valid = true - {{- end}} - {{- if .ForeignColumnNullable -}} - b.{{.ForeignColumn | titleCase}}.Valid = true - c.{{.ForeignColumn | titleCase}}.Valid = true - {{- end}} - {{if not .ToJoinTable -}} - b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} - c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} - {{- end}} - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - if err = c.Insert(tx); err != nil { - t.Fatal(err) - } + seed := randomize.NewSeed() + randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") + randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") + {{if .Nullable -}} + a.{{.Column | titleCase}}.Valid = true + {{- end}} + {{- if .ForeignColumnNullable -}} + b.{{.ForeignColumn | titleCase}}.Valid = true + c.{{.ForeignColumn | titleCase}}.Valid = true + {{- end}} + {{if not .ToJoinTable -}} + b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} + c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} + {{- end}} + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + if err = c.Insert(tx); err != nil { + t.Fatal(err) + } - {{if .ToJoinTable -}} - _, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) - if err != nil { - t.Fatal(err) - } - _, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) - if err != nil { - t.Fatal(err) - } - {{end}} + {{if .ToJoinTable -}} + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + if err != nil { + t.Fatal(err) + } + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + if err != nil { + t.Fatal(err) + } + {{end}} - {{$varname := .ForeignTable | singular | camelCase -}} - {{$varname}}, err := a.{{$rel.Function.Name}}(tx).All() - if err != nil { - t.Fatal(err) - } + {{$varname := .ForeignTable | singular | camelCase -}} + {{$varname}}, err := a.{{$rel.Function.Name}}(tx).All() + if err != nil { + t.Fatal(err) + } - bFound, cFound := false, false - for _, v := range {{$varname}} { - if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} { - bFound = true - } - if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} { - cFound = true - } - } + bFound, cFound := false, false + for _, v := range {{$varname}} { + if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} { + bFound = true + } + if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} { + cFound = true + } + } - if !bFound { - t.Error("expected to find b") - } - if !cFound { - t.Error("expected to find c") - } + if !bFound { + t.Error("expected to find b") + } + if !cFound { + t.Error("expected to find c") + } - slice := {{$rel.LocalTable.NameGo}}Slice{&a} - if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil { - t.Fatal(err) - } - if got := len(a.R.{{$rel.Function.Name}}); got != 2 { - t.Error("number of eager loaded records wrong, got:", got) - } + slice := {{$rel.LocalTable.NameGo}}Slice{&a} + if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil { + t.Fatal(err) + } + if got := len(a.R.{{$rel.Function.Name}}); got != 2 { + t.Error("number of eager loaded records wrong, got:", got) + } - a.R.{{$rel.Function.Name}} = nil - if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil { - t.Fatal(err) - } - if got := len(a.R.{{$rel.Function.Name}}); got != 2 { - t.Error("number of eager loaded records wrong, got:", got) - } + a.R.{{$rel.Function.Name}} = nil + if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil { + t.Fatal(err) + } + if got := len(a.R.{{$rel.Function.Name}}); got != 2 { + t.Error("number of eager loaded records wrong, got:", got) + } - if t.Failed() { - t.Logf("%#v", {{$varname}}) - } + if t.Failed() { + t.Logf("%#v", {{$varname}}) + } } {{end -}}{{- /* if unique */ -}} diff --git a/templates_test/relationship_to_many_setops.tpl b/templates_test/relationship_to_many_setops.tpl index 22e1ea792..e653d3672 100644 --- a/templates_test/relationship_to_many_setops.tpl +++ b/templates_test/relationship_to_many_setops.tpl @@ -1,306 +1,306 @@ {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- $table := .Table -}} - {{- range .Table.ToManyRelationships -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- $dot := . -}} + {{- $table := .Table -}} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- template "relationship_to_one_setops_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} - {{- else -}} - {{- $varNameSingular := .Table | singular | camelCase -}} - {{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}} - {{- $rel := textsFromRelationship $dot.Tables $table .}} + {{- else -}} + {{- $varNameSingular := .Table | singular | camelCase -}} + {{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}} + {{- $rel := textsFromRelationship $dot.Tables $table .}} func test{{$rel.LocalTable.NameGo}}ToManyAddOp{{$rel.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{$rel.LocalTable.NameGo}} - var b, c, d, e {{$rel.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} - for _, x := range foreigners { - if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - } - - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - if err = c.Insert(tx); err != nil { - t.Fatal(err) - } - - foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{ - {&b, &c}, - {&d, &e}, - } - - for i, x := range foreignersSplitByInsertion { - err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...) - if err != nil { - t.Fatal(err) - } - - first := x[0] - second := x[1] - {{- if .ToJoinTable}} - - if first.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - if second.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - {{- else}} - - if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}}) - } - if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}}) - } - - if first.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign slice") - } - if second.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign slice") - } - {{- end}} - - if a.R.{{$rel.Function.Name}}[i*2] != first { - t.Error("relationship struct slice not set to correct value") - } - if a.R.{{$rel.Function.Name}}[i*2+1] != second { - t.Error("relationship struct slice not set to correct value") - } - - count, err := a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if want := int64((i+1)*2); count != want { - t.Error("want", want, "got", count) - } - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{$rel.LocalTable.NameGo}} + var b, c, d, e {{$rel.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} + for _, x := range foreigners { + if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + } + + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + if err = c.Insert(tx); err != nil { + t.Fatal(err) + } + + foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{ + {&b, &c}, + {&d, &e}, + } + + for i, x := range foreignersSplitByInsertion { + err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...) + if err != nil { + t.Fatal(err) + } + + first := x[0] + second := x[1] + {{- if .ToJoinTable}} + + if first.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + if second.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + {{- else}} + + if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}}) + } + if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}}) + } + + if first.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign slice") + } + if second.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign slice") + } + {{- end}} + + if a.R.{{$rel.Function.Name}}[i*2] != first { + t.Error("relationship struct slice not set to correct value") + } + if a.R.{{$rel.Function.Name}}[i*2+1] != second { + t.Error("relationship struct slice not set to correct value") + } + + count, err := a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if want := int64((i+1)*2); count != want { + t.Error("want", want, "got", count) + } + } } {{- if (or .ForeignColumnNullable .ToJoinTable)}} func test{{$rel.LocalTable.NameGo}}ToManySetOp{{$rel.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{$rel.LocalTable.NameGo}} - var b, c, d, e {{$rel.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} - for _, x := range foreigners { - if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - } - - if err = a.Insert(tx); err != nil { - t.Fatal(err) - } - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - if err = c.Insert(tx); err != nil { - t.Fatal(err) - } - - err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c) - if err != nil { - t.Fatal(err) - } - - count, err := a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 2 { - t.Error("count was wrong:", count) - } - - err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e) - if err != nil { - t.Fatal(err) - } - - count, err = a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 2 { - t.Error("count was wrong:", count) - } - - {{- if .ToJoinTable}} - - if len(b.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if len(c.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if d.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - if e.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - {{- else}} - - if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want b's foreign key value to be nil") - } - if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want c's foreign key value to be nil") - } - if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}}) - } - if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}}) - } - - if b.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if c.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if d.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign struct") - } - if e.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign struct") - } - {{- end}} - - if a.R.{{$rel.Function.Name}}[0] != &d { - t.Error("relationship struct slice not set to correct value") - } - if a.R.{{$rel.Function.Name}}[1] != &e { - t.Error("relationship struct slice not set to correct value") - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{$rel.LocalTable.NameGo}} + var b, c, d, e {{$rel.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} + for _, x := range foreigners { + if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + } + + if err = a.Insert(tx); err != nil { + t.Fatal(err) + } + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + if err = c.Insert(tx); err != nil { + t.Fatal(err) + } + + err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c) + if err != nil { + t.Fatal(err) + } + + count, err := a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 2 { + t.Error("count was wrong:", count) + } + + err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e) + if err != nil { + t.Fatal(err) + } + + count, err = a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 2 { + t.Error("count was wrong:", count) + } + + {{- if .ToJoinTable}} + + if len(b.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if len(c.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if d.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + if e.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + {{- else}} + + if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want b's foreign key value to be nil") + } + if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want c's foreign key value to be nil") + } + if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}}) + } + if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}}) + } + + if b.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if c.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if d.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign struct") + } + if e.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign struct") + } + {{- end}} + + if a.R.{{$rel.Function.Name}}[0] != &d { + t.Error("relationship struct slice not set to correct value") + } + if a.R.{{$rel.Function.Name}}[1] != &e { + t.Error("relationship struct slice not set to correct value") + } } func test{{$rel.LocalTable.NameGo}}ToManyRemoveOp{{$rel.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{$rel.LocalTable.NameGo}} - var b, c, d, e {{$rel.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} - for _, x := range foreigners { - if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - } - - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } - - err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...) - if err != nil { - t.Fatal(err) - } - - count, err := a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 4 { - t.Error("count was wrong:", count) - } - - err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...) - if err != nil { - t.Fatal(err) - } - - count, err = a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 2 { - t.Error("count was wrong:", count) - } - - {{- if .ToJoinTable}} - - if len(b.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if len(c.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if d.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the foreign struct") - } - if e.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the foreign struct") - } - {{- else}} - - if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want b's foreign key value to be nil") - } - if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want c's foreign key value to be nil") - } - - if b.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if c.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if d.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship to a should have been preserved") - } - if e.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship to a should have been preserved") - } - {{- end}} - - if len(a.R.{{$rel.Function.Name}}) != 2 { - t.Error("should have preserved two relationships") - } - - // Removal doesn't do a stable deletion for performance so we have to flip the order - if a.R.{{$rel.Function.Name}}[1] != &d { - t.Error("relationship to d should have been preserved") - } - if a.R.{{$rel.Function.Name}}[0] != &e { - t.Error("relationship to e should have been preserved") - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{$rel.LocalTable.NameGo}} + var b, c, d, e {{$rel.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} + for _, x := range foreigners { + if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + } + + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } + + err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...) + if err != nil { + t.Fatal(err) + } + + count, err := a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 4 { + t.Error("count was wrong:", count) + } + + err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...) + if err != nil { + t.Fatal(err) + } + + count, err = a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 2 { + t.Error("count was wrong:", count) + } + + {{- if .ToJoinTable}} + + if len(b.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if len(c.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if d.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the foreign struct") + } + if e.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the foreign struct") + } + {{- else}} + + if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want b's foreign key value to be nil") + } + if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want c's foreign key value to be nil") + } + + if b.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if c.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if d.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship to a should have been preserved") + } + if e.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship to a should have been preserved") + } + {{- end}} + + if len(a.R.{{$rel.Function.Name}}) != 2 { + t.Error("should have preserved two relationships") + } + + // Removal doesn't do a stable deletion for performance so we have to flip the order + if a.R.{{$rel.Function.Name}}[1] != &d { + t.Error("relationship to d should have been preserved") + } + if a.R.{{$rel.Function.Name}}[0] != &e { + t.Error("relationship to e should have been preserved") + } } {{end -}} {{- end -}}{{- /* if unique foreign key */ -}} diff --git a/templates_test/relationship_to_one.tpl b/templates_test/relationship_to_one.tpl index 602f9b953..f3a60b4c9 100644 --- a/templates_test/relationship_to_one.tpl +++ b/templates_test/relationship_to_one.tpl @@ -1,69 +1,69 @@ {{- define "relationship_to_one_test_helper"}} func test{{.LocalTable.NameGo}}ToOne{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { - tx := MustTx(boil.Begin()) - defer tx.Rollback() + tx := MustTx(boil.Begin()) + defer tx.Rollback() - var foreign {{.ForeignTable.NameGo}} - var local {{.LocalTable.NameGo}} - {{if .ForeignKey.Nullable -}} - local.{{.ForeignKey.Column | titleCase}}.Valid = true - {{end}} - {{- if .ForeignKey.ForeignColumnNullable -}} - foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true - {{end}} + var foreign {{.ForeignTable.NameGo}} + var local {{.LocalTable.NameGo}} + {{if .ForeignKey.Nullable -}} + local.{{.ForeignKey.Column | titleCase}}.Valid = true + {{end}} + {{- if .ForeignKey.ForeignColumnNullable -}} + foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true + {{end}} - {{if not .Function.OneToOne -}} - if err := foreign.Insert(tx); err != nil { - t.Fatal(err) - } + {{if not .Function.OneToOne -}} + if err := foreign.Insert(tx); err != nil { + t.Fatal(err) + } - local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}} - if err := local.Insert(tx); err != nil { - t.Fatal(err) - } - {{else -}} - if err := local.Insert(tx); err != nil { - t.Fatal(err) - } + local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}} + if err := local.Insert(tx); err != nil { + t.Fatal(err) + } + {{else -}} + if err := local.Insert(tx); err != nil { + t.Fatal(err) + } - foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}} - if err := foreign.Insert(tx); err != nil { - t.Fatal(err) - } - {{end -}} + foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}} + if err := foreign.Insert(tx); err != nil { + t.Fatal(err) + } + {{end -}} - check, err := local.{{.Function.Name}}(tx).One() - if err != nil { - t.Fatal(err) - } + check, err := local.{{.Function.Name}}(tx).One() + if err != nil { + t.Fatal(err) + } - if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} { - t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}}) - } + if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} { + t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}}) + } - slice := {{.LocalTable.NameGo}}Slice{&local} - if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil { - t.Fatal(err) - } - if local.R.{{.Function.Name}} == nil { - t.Error("struct should have been eager loaded") - } + slice := {{.LocalTable.NameGo}}Slice{&local} + if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil { + t.Fatal(err) + } + if local.R.{{.Function.Name}} == nil { + t.Error("struct should have been eager loaded") + } - local.R.{{.Function.Name}} = nil - if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil { - t.Fatal(err) - } - if local.R.{{.Function.Name}} == nil { - t.Error("struct should have been eager loaded") - } + local.R.{{.Function.Name}} = nil + if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil { + t.Fatal(err) + } + if local.R.{{.Function.Name}} == nil { + t.Error("struct should have been eager loaded") + } } {{end -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- template "relationship_to_one_test_helper" $rel -}} {{end -}} {{- end -}} diff --git a/templates_test/relationship_to_one_setops.tpl b/templates_test/relationship_to_one_setops.tpl index ab27db590..2ff2ca76f 100644 --- a/templates_test/relationship_to_one_setops.tpl +++ b/templates_test/relationship_to_one_setops.tpl @@ -2,131 +2,131 @@ {{- $varNameSingular := .ForeignKey.Table | singular | camelCase -}} {{- $foreignVarNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} func test{{.LocalTable.NameGo}}ToOneSetOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{.LocalTable.NameGo}} - var b, c {{.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - - for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} { - err = a.Set{{.Function.Name}}(tx, i != 0, x) - if err != nil { - t.Fatal(err) - } - - if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}) - } - if a.R.{{.Function.Name}} != x { - t.Error("relationship struct not set to correct value") - } - - zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}})) - reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero) - - if err = a.Reload(tx); err != nil { - t.Fatal("failed to reload", err) - } - - if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}}) - } - - {{if .ForeignKey.Unique -}} - if x.R.{{.Function.ForeignName}} != &a { - t.Error("failed to append to foreign relationship struct") - } - {{else -}} - if x.R.{{.Function.ForeignName}}[0] != &a { - t.Error("failed to append to foreign relationship struct") - } - {{end -}} - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{.LocalTable.NameGo}} + var b, c {{.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + + for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} { + err = a.Set{{.Function.Name}}(tx, i != 0, x) + if err != nil { + t.Fatal(err) + } + + if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}) + } + if a.R.{{.Function.Name}} != x { + t.Error("relationship struct not set to correct value") + } + + zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}})) + reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero) + + if err = a.Reload(tx); err != nil { + t.Fatal("failed to reload", err) + } + + if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}}) + } + + {{if .ForeignKey.Unique -}} + if x.R.{{.Function.ForeignName}} != &a { + t.Error("failed to append to foreign relationship struct") + } + {{else -}} + if x.R.{{.Function.ForeignName}}[0] != &a { + t.Error("failed to append to foreign relationship struct") + } + {{end -}} + } } {{- if .ForeignKey.Nullable}} func test{{.LocalTable.NameGo}}ToOneRemoveOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{.LocalTable.NameGo}} - var b {{.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - - if err = a.Insert(tx); err != nil { - t.Fatal(err) - } - - if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil { - t.Fatal(err) - } - - if err = a.Remove{{.Function.Name}}(tx, &b); err != nil { - t.Error("failed to remove relationship") - } - - count, err := a.{{.Function.Name}}(tx).Count() - if err != nil { - t.Error(err) - } - if count != 0 { - t.Error("want no relationships remaining") - } - - if a.R.{{.Function.Name}} != nil { - t.Error("R struct entry should be nil") - } - - if a.{{.LocalTable.ColumnNameGo}}.Valid { - t.Error("R struct entry should be nil") - } - - {{if .ForeignKey.Unique -}} - if b.R.{{.Function.ForeignName}} != nil { - t.Error("failed to remove a from b's relationships") - } - {{else -}} - if len(b.R.{{.Function.ForeignName}}) != 0 { - t.Error("failed to remove a from b's relationships") - } - {{end -}} + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{.LocalTable.NameGo}} + var b {{.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + + if err = a.Insert(tx); err != nil { + t.Fatal(err) + } + + if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil { + t.Fatal(err) + } + + if err = a.Remove{{.Function.Name}}(tx, &b); err != nil { + t.Error("failed to remove relationship") + } + + count, err := a.{{.Function.Name}}(tx).Count() + if err != nil { + t.Error(err) + } + if count != 0 { + t.Error("want no relationships remaining") + } + + if a.R.{{.Function.Name}} != nil { + t.Error("R struct entry should be nil") + } + + if a.{{.LocalTable.ColumnNameGo}}.Valid { + t.Error("R struct entry should be nil") + } + + {{if .ForeignKey.Unique -}} + if b.R.{{.Function.ForeignName}} != nil { + t.Error("failed to remove a from b's relationships") + } + {{else -}} + if len(b.R.{{.Function.ForeignName}}) != 0 { + t.Error("failed to remove a from b's relationships") + } + {{end -}} } {{end -}} {{- end -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}} {{template "relationship_to_one_setops_test_helper" $rel -}} {{- end -}} diff --git a/templates_test/reload.tpl b/templates_test/reload.tpl index da236bdb8..7879b40b9 100644 --- a/templates_test/reload.tpl +++ b/templates_test/reload.tpl @@ -3,45 +3,45 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Reload(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$varNameSingular}}.Reload(tx); err != nil { - t.Error(err) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$varNameSingular}}.Reload(tx); err != nil { + t.Error(err) + } } func test{{$tableNamePlural}}ReloadAll(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} - - if err = slice.ReloadAll(tx); err != nil { - t.Error(err) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} + + if err = slice.ReloadAll(tx); err != nil { + t.Error(err) + } } diff --git a/templates_test/select.tpl b/templates_test/select.tpl index 1b3aefbdf..0e5692fc0 100644 --- a/templates_test/select.tpl +++ b/templates_test/select.tpl @@ -3,27 +3,27 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Select(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - slice, err := {{$tableNamePlural}}(tx).All() - if err != nil { - t.Error(err) - } + slice, err := {{$tableNamePlural}}(tx).All() + if err != nil { + t.Error(err) + } - if len(slice) != 1 { - t.Error("want one record, got:", len(slice)) - } + if len(slice) != 1 { + t.Error("want one record, got:", len(slice)) + } } diff --git a/templates_test/singleton/boil_main_test.tpl b/templates_test/singleton/boil_main_test.tpl new file mode 100644 index 000000000..0014a1e4b --- /dev/null +++ b/templates_test/singleton/boil_main_test.tpl @@ -0,0 +1,131 @@ +var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements") + +var ( + dbMain tester +) + +type tester interface { + setup() error + conn() (*sql.DB, error) + teardown() error +} + +func TestMain(m *testing.M) { + if dbMain == nil { + fmt.Println("no dbMain tester interface was ready") + os.Exit(-1) + } + + rand.Seed(time.Now().UnixNano()) + var err error + + // Load configuration + err = initViper() + if err != nil { + fmt.Println("unable to load config file") + os.Exit(-2) + } + + setConfigDefaults() + if err := validateConfig("{{.DriverName}}"); err != nil { + fmt.Println("failed to validate config", err) + os.Exit(-3) + } + + // Set DebugMode so we can see generated sql statements + flag.Parse() + boil.DebugMode = *flagDebugMode + + if err = dbMain.setup(); err != nil { + fmt.Println("Unable to execute setup:", err) + os.Exit(-4) + } + + conn, err := dbMain.conn() + if err != nil { + fmt.Println("failed to get connection:", err) + } + + var code int + boil.SetDB(conn) + code = m.Run() + + if err = dbMain.teardown(); err != nil { + fmt.Println("Unable to execute teardown:", err) + os.Exit(-5) + } + + os.Exit(code) +} + +func initViper() error { + var err error + + viper.SetConfigName("sqlboiler") + + configHome := os.Getenv("XDG_CONFIG_HOME") + homePath := os.Getenv("HOME") + wd, err := os.Getwd() + if err != nil { + wd = "../" + } else { + wd = wd + "/.." + } + + configPaths := []string{wd} + if len(configHome) > 0 { + configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) + } else { + configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) + } + + for _, p := range configPaths { + viper.AddConfigPath(p) + } + + // Ignore errors here, fall back to defaults and validation to provide errs + _ = viper.ReadInConfig() + viper.AutomaticEnv() + + return nil +} + +// setConfigDefaults is only necessary because of bugs in viper, noted in main +func setConfigDefaults() { + if viper.GetString("postgres.sslmode") == "" { + viper.Set("postgres.sslmode", "require") + } + if viper.GetInt("postgres.port") == 0 { + viper.Set("postgres.port", 5432) + } + if viper.GetString("mysql.sslmode") == "" { + viper.Set("mysql.sslmode", "true") + } + if viper.GetInt("mysql.port") == 0 { + viper.Set("mysql.port", 3306) + } +} + +func validateConfig(driverName string) error { + if driverName == "postgres" { + return vala.BeginValidation().Validate( + vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"), + vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"), + vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")), + vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"), + vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"), + ).Check() + } + + if driverName == "mysql" { + return vala.BeginValidation().Validate( + vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"), + vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"), + vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")), + vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"), + vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"), + ).Check() + } + + return errors.New("not a valid driver name") +} diff --git a/templates_test/singleton/boil_queries_test.tpl b/templates_test/singleton/boil_queries_test.tpl index 26f09d3c5..90419be6c 100644 --- a/templates_test/singleton/boil_queries_test.tpl +++ b/templates_test/singleton/boil_queries_test.tpl @@ -7,53 +7,31 @@ func MustTx(transactor boil.Transactor, err error) boil.Transactor { return transactor } -func initDBNameRand(input string) { - sum := md5.Sum([]byte(input)) +var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE ONLY .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`) +var rgxMySQLkey = regexp.MustCompile(`(?m)((,\n)?\s+CONSTRAINT.*?FOREIGN KEY.*?\n)+`) - var sumInt string - for _, v := range sum { - sumInt = sumInt + strconv.Itoa(int(v)) - } - - // Cut integer to 18 digits to ensure no int64 overflow. - sumInt = sumInt[:18] - - sumTmp := sumInt - for i, v := range sumInt { - if v == '0' { - sumTmp = sumInt[i+1:] - continue - } - break - } - - sumInt = sumTmp - - randSeed, err := strconv.ParseInt(sumInt, 0, 64) - if err != nil { - fmt.Printf("Unable to parse sumInt: %s", err) - os.Exit(-1) +func newFKeyDestroyer(regex *regexp.Regexp, reader io.Reader) io.Reader { + return &fKeyDestroyer{ + reader: reader, + rgx: regex, } +} - dbNameRand = rand.New(rand.NewSource(randSeed)) +type fKeyDestroyer struct { + reader io.Reader + buf *bytes.Buffer + rgx *regexp.Regexp } -var alphabetChars = "abcdefghijklmnopqrstuvwxyz" -func randStr(length int) string { - c := len(alphabetChars) +func (f *fKeyDestroyer) Read(b []byte) (int, error) { + if f.buf == nil { + all, err := ioutil.ReadAll(f.reader) + if err != nil { + return 0, err + } - output := make([]rune, length) - for i := 0; i < length; i++ { - output[i] = rune(alphabetChars[dbNameRand.Intn(c)]) + f.buf = bytes.NewBuffer(f.rgx.ReplaceAll(all, []byte{})) } - return string(output) -} - -// getDBNameHash takes a database name in, and generates -// a random string using the database name as the rand Seed. -// getDBNameHash is used to generate unique test database names. -func getDBNameHash(input string) string { - initDBNameRand(input) - return randStr(40) + return f.buf.Read(b) } diff --git a/templates_test/singleton/boil_viper_test.tpl b/templates_test/singleton/boil_viper_test.tpl deleted file mode 100644 index d05a20a7b..000000000 --- a/templates_test/singleton/boil_viper_test.tpl +++ /dev/null @@ -1,37 +0,0 @@ -var ( - testCfg *Config - dbConn *sql.DB -) - -func InitViper() error { - var err error - testCfg = &Config{} - - viper.SetConfigName("sqlboiler") - - configHome := os.Getenv("XDG_CONFIG_HOME") - homePath := os.Getenv("HOME") - wd, err := os.Getwd() - if err != nil { - wd = "../" - } else { - wd = wd + "/.." - } - - configPaths := []string{wd} - if len(configHome) > 0 { - configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) - } else { - configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) - } - - for _, p := range configPaths { - viper.AddConfigPath(p) - } - - // Ignore errors here, fall back to defaults and validation to provide errs - _ = viper.ReadInConfig() - viper.AutomaticEnv() - - return nil -} diff --git a/templates_test/update.tpl b/templates_test/update.tpl index 5cd15bedb..98421c9d7 100644 --- a/templates_test/update.tpl +++ b/templates_test/update.tpl @@ -3,97 +3,97 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Update(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } - if count != 1 { - t.Error("want one record, got:", count) - } + if count != 1 { + t.Error("want one record, got:", count) + } - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - // If table only contains primary key columns, we need to pass - // them into a whitelist to get a valid test result, - // otherwise the Update method will error because it will not be able to - // generate a whitelist (due to it excluding primary key columns). - if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { - if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Error(err) - } - } else { - if err = {{$varNameSingular}}.Update(tx); err != nil { - t.Error(err) - } - } + // If table only contains primary key columns, we need to pass + // them into a whitelist to get a valid test result, + // otherwise the Update method will error because it will not be able to + // generate a whitelist (due to it excluding primary key columns). + if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { + if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Error(err) + } + } else { + if err = {{$varNameSingular}}.Update(tx); err != nil { + t.Error(err) + } + } } func test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } - if count != 1 { - t.Error("want one record, got:", count) - } + if count != 1 { + t.Error("want one record, got:", count) + } - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - // Remove Primary keys and unique columns from what we plan to update - var fields []string - if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { - fields = {{$varNameSingular}}Columns - } else { - fields = strmangle.SetComplement( - {{$varNameSingular}}Columns, - {{$varNameSingular}}PrimaryKeyColumns, - ) - } + // Remove Primary keys and unique columns from what we plan to update + var fields []string + if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { + fields = {{$varNameSingular}}Columns + } else { + fields = strmangle.SetComplement( + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + ) + } value := reflect.Indirect(reflect.ValueOf({{$varNameSingular}})) - updateMap := M{} - for _, col := range fields { - updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface() - } + updateMap := M{} + for _, col := range fields { + updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface() + } - slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} - if err = slice.UpdateAll(tx, updateMap); err != nil { - t.Error(err) - } + slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} + if err = slice.UpdateAll(tx, updateMap); err != nil { + t.Error(err) + } } diff --git a/templates_test/upsert.tpl b/templates_test/upsert.tpl index cb48b87a9..de5623b0c 100644 --- a/templates_test/upsert.tpl +++ b/templates_test/upsert.tpl @@ -3,44 +3,47 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Upsert(t *testing.T) { - t.Parallel() + {{if not (eq .DriverName "postgres") -}} + t.Skip("not implemented for {{.DriverName}}") + {{end -}} + t.Parallel() - seed := randomize.NewSeed() - var err error - // Attempt the INSERT side of an UPSERT - {{$varNameSingular}} := {{$tableNameSingular}}{} - if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + // Attempt the INSERT side of an UPSERT + {{$varNameSingular}} := {{$tableNameSingular}}{} + if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Upsert(tx, false, nil, nil); err != nil { - t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}false, nil, {{end}}nil); err != nil { + t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) + } - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - if count != 1 { - t.Error("want one record, got:", count) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + if count != 1 { + t.Error("want one record, got:", count) + } - // Attempt the UPDATE side of an UPSERT - if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + // Attempt the UPDATE side of an UPSERT + if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - if err = {{$varNameSingular}}.Upsert(tx, true, nil, nil); err != nil { - t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) - } + if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}true, nil, {{end}}nil); err != nil { + t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) + } - count, err = {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - if count != 1 { - t.Error("want one record, got:", count) - } + count, err = {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + if count != 1 { + t.Error("want one record, got:", count) + } } diff --git a/testdata/test_schema.sql b/testdata/test_schema.sql index 65b786d73..14beeca1d 100644 --- a/testdata/test_schema.sql +++ b/testdata/test_schema.sql @@ -93,7 +93,46 @@ CREATE TABLE magic ( strange_three timestamp without time zone default (now() at time zone 'utc'), strange_four timestamp with time zone default (now() at time zone 'utc'), strange_five interval NOT NULL DEFAULT '21 days', - strange_six interval NULL DEFAULT '23 hours' + strange_six interval NULL DEFAULT '23 hours', + + aa json NULL, + bb json NOT NULL, + cc jsonb NULL, + dd jsonb NOT NULL, + ee box NULL, + ff box NOT NULL, + gg cidr NULL, + hh cidr NOT NULL, + ii circle NULL, + jj circle NOT NULL, + kk double precision NULL, + ll double precision NOT NULL, + mm inet NULL, + nn inet NOT NULL, + oo line NULL, + pp line NOT NULL, + qq lseg NULL, + rr lseg NOT NULL, + ss macaddr NULL, + tt macaddr NOT NULL, + uu money NULL, + vv money NOT NULL, + ww path NULL, + xx path NOT NULL, + yy pg_lsn NULL, + zz pg_lsn NOT NULL, + aaa point NULL, + bbb point NOT NULL, + ccc polygon NULL, + ddd polygon NOT NULL, + eee tsquery NULL, + fff tsquery NOT NULL, + ggg tsvector NULL, + hhh tsvector NOT NULL, + iii txid_snapshot NULL, + jjj txid_snapshot NOT NULL, + kkk xml NULL, + lll xml NOT NULL ); create table owner ( @@ -136,12 +175,6 @@ create table spider_toys ( primary key (spider_id) ); -/* - Test: - * Variations of capitalization - * Single value columns - * Primary key as only value -*/ create table pals ( pal character varying, primary key (pal) @@ -161,3 +194,22 @@ create table enemies ( enemies character varying, primary key (enemies) ); + +create table fun_arrays ( + id serial, + fun_one integer[] null, + fun_two integer[] not null, + fun_three boolean[] null, + fun_four boolean[] not null, + fun_five varchar[] null, + fun_six varchar[] not null, + fun_seven decimal[] null, + fun_eight decimal[] not null, + fun_nine bytea[] null, + fun_ten bytea[] not null, + fun_eleven jsonb[] null, + fun_twelve jsonb[] not null, + fun_thirteen json[] null, + fun_fourteen json[] not null, + primary key (id) +) diff --git a/text_helpers_test.go b/text_helpers_test.go index 65ab7efc9..5a09e8a5f 100644 --- a/text_helpers_test.go +++ b/text_helpers_test.go @@ -12,7 +12,7 @@ import ( func TestTextsFromForeignKey(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}) + tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil) if err != nil { t.Fatal(err) } @@ -81,7 +81,7 @@ func TestTextsFromForeignKey(t *testing.T) { func TestTextsFromOneToOneRelationship(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}) + tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil) if err != nil { t.Fatal(err) } @@ -130,7 +130,7 @@ func TestTextsFromOneToOneRelationship(t *testing.T) { func TestTextsFromRelationship(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}) + tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil) if err != nil { t.Fatal(err) } diff --git a/types/array.go b/types/array.go new file mode 100644 index 000000000..2924fa20e --- /dev/null +++ b/types/array.go @@ -0,0 +1,719 @@ +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package types + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/hex" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +var typeByteSlice = reflect.TypeOf([]byte{}) +var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func encode(x interface{}) []byte { + switch v := x.(type) { + case int64: + return strconv.AppendInt(nil, v, 10) + case float64: + return strconv.AppendFloat(nil, v, 'f', -1, 64) + case []byte: + return encodeBytes(v) + case string: + return []byte(v) + case bool: + return strconv.AppendBool(nil, v) + case time.Time: + return formatTimestamp(v) + + default: + panic(fmt.Errorf("encode: unknown type for %T", v)) + } +} + +// FormatTimestamp formats t into Postgres' text format for timestamps. +func formatTimestamp(t time.Time) []byte { + // Need to send dates before 0001 A.D. with " BC" suffix, instead of the + // minus sign preferred by Go. + // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on + bc := false + if t.Year() <= 0 { + // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" + t = t.AddDate((-t.Year())*2+1, 0, 0) + bc = true + } + b := []byte(t.Format(time.RFC3339Nano)) + + _, offset := t.Zone() + offset = offset % 60 + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) + } + + if bc { + b = append(b, " BC"...) + } + return b +} + +func encodeBytes(v []byte) (result []byte) { + for _, b := range v { + if b == '\\' { + result = append(result, '\\', '\\') + } else if b < 0x20 || b > 0x7e { + result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) + } else { + result = append(result, b) + } + } + + return result +} + +// Parse a bytea value received from the server. Both "hex" and the legacy +// "escape" format are supported. +func parseBytes(s []byte) (result []byte, err error) { + if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { + // bytea_output = hex + s = s[2:] // trim off leading "\\x" + result = make([]byte, hex.DecodedLen(len(s))) + _, err := hex.Decode(result, s) + if err != nil { + return nil, err + } + } else { + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } + + // '\\' followed by an octal number + if len(s) < 4 { + return nil, fmt.Errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseInt(string(s[1:4]), 8, 9) + if err != nil { + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break + } + result = append(result, s[:i]...) + s = s[i:] + } + } + } + + return result, nil +} + +// Array returns the optimal driver.Valuer and sql.Scanner for an array or +// slice of any dimension. +// +// For example: +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) +// +// Scanning multi-dimensional arrays is not supported. Arrays where the lower +// bound is not one (such as `[0:0]={1}') are not supported. +func Array(a interface{}) interface { + driver.Valuer + sql.Scanner +} { + switch a := a.(type) { + case []bool: + return (*BoolArray)(&a) + case []float64: + return (*Float64Array)(&a) + case []int64: + return (*Int64Array)(&a) + case []string: + return (*StringArray)(&a) + + case *[]bool: + return (*BoolArray)(a) + case *[]float64: + return (*Float64Array)(a) + case *[]int64: + return (*Int64Array)(a) + case *[]string: + return (*StringArray)(a) + + default: + panic(fmt.Sprintf("boil: invalid type received %T", a)) + } +} + +// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner +// to override the array delimiter used by GenericArray. +type ArrayDelimiter interface { + // ArrayDelimiter returns the delimiter character(s) for this element's type. + ArrayDelimiter() string +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray []bool + +// Scan implements the sql.Scanner interface. +func (a *BoolArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("boil: cannot convert %T to BoolArray", src) +} + +func (a *BoolArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BoolArray, len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a BoolArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' + } + } + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// BytesArray represents a one-dimensional array of the PostgreSQL bytea type. +type BytesArray [][]byte + +// Scan implements the sql.Scanner interface. +func (a *BytesArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("boil: cannot convert %T to BytesArray", src) +} + +func (a *BytesArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BytesArray") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BytesArray, len(elems)) + for i, v := range elems { + b[i], err = parseBytes(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. It uses the "hex" format which +// is only supported on PostgreSQL 9.0 or newer. +func (a BytesArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // 3*N bytes of hex formatting, and N-1 bytes of delimiters. + size := 1 + 6*n + for _, x := range a { + size += hex.EncodedLen(len(x)) + } + + b := make([]byte, size) + + for i, s := 0, b; i < n; i++ { + o := copy(s, `,"\\x`) + o += hex.Encode(s[o:], a[i]) + s[o] = '"' + s = s[o+1:] + } + + b[0] = '{' + b[size-1] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array []float64 + +// Scan implements the sql.Scanner interface. +func (a *Float64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("boil: cannot convert %T to Float64Array", src) +} + +func (a *Float64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { + return fmt.Errorf("boil: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, a[0], 'f', -1, 64) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, a[i], 'f', -1, 64) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +type Int64Array []int64 + +// Scan implements the sql.Scanner interface. +func (a *Int64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("boil: cannot convert %T to Int64Array", src) +} + +func (a *Int64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int64Array") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { + return fmt.Errorf("boil: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray []string + +// Scan implements the sql.Scanner interface. +func (a *StringArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("boil: cannot convert %T to StringArray", src) +} + +func (a *StringArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(StringArray, len(elems)) + for i, v := range elems { + if b[i] = string(v); v == nil { + return fmt.Errorf("boil: parsing array element index %d: cannot convert nil to string", i) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// appendArray appends rv to the buffer, returning the extended buffer and +// the delimiter used between elements. +// +// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. +func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { + var del string + var err error + + b = append(b, '{') + + if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { + return b, del, err + } + + for i := 1; i < n; i++ { + b = append(b, del...) + if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { + return b, del, err + } + } + + return append(b, '}'), del, nil +} + +// appendArrayElement appends rv to the buffer, returning the extended buffer +// and the delimiter to use before the next element. +// +// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted +// using driver.DefaultParameterConverter and the resulting []byte or string +// is double-quoted. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { + if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { + if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { + if n := rv.Len(); n > 0 { + return appendArray(b, rv, n) + } + + return b, "", nil + } + } + + var del = "," + var err error + var iv interface{} = rv.Interface() + + if ad, ok := iv.(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { + return b, del, err + } + + switch v := iv.(type) { + case nil: + return append(b, "NULL"...), del, nil + case []byte: + return appendArrayQuotedBytes(b, v), del, nil + case string: + return appendArrayQuotedBytes(b, []byte(v)), del, nil + } + + b, err = appendValue(b, iv) + return b, del, err +} + +func appendArrayQuotedBytes(b, v []byte) []byte { + b = append(b, '"') + for { + i := bytes.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +func appendValue(b []byte, v driver.Value) ([]byte, error) { + return append(b, encode(v)...), nil +} + +// parseArray extracts the dimensions and elements of an array represented in +// text format. Only representations emitted by the backend are supported. +// Notably, whitespace around brackets and delimiters is significant, and NULL +// is case-sensitive. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { + var depth, i int + + if len(src) < 1 || src[0] != '{' { + return nil, nil, fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '{', 0) + } + +Open: + for i < len(src) { + switch src[i] { + case '{': + depth++ + i++ + case '}': + elems = make([][]byte, 0) + goto Close + default: + break Open + } + } + dims = make([]int, i) + +Element: + for i < len(src) { + switch src[i] { + case '{': + depth++ + dims[depth-1] = 0 + i++ + case '"': + var elem = []byte{} + var escape bool + for i++; i < len(src); i++ { + if escape { + elem = append(elem, src[i]) + escape = false + } else { + switch src[i] { + default: + elem = append(elem, src[i]) + case '\\': + escape = true + case '"': + elems = append(elems, elem) + i++ + break Element + } + } + } + default: + for start := i; i < len(src); i++ { + if bytes.HasPrefix(src[i:], del) || src[i] == '}' { + elem := src[start:i] + if len(elem) == 0 { + return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) + } + if bytes.Equal(elem, []byte("NULL")) { + elem = nil + } + elems = append(elems, elem) + break Element + } + } + } + } + + for i < len(src) { + if bytes.HasPrefix(src[i:], del) { + dims[depth-1]++ + i += len(del) + goto Element + } else if src[i] == '}' { + dims[depth-1]++ + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + +Close: + for i < len(src) { + if src[i] == '}' && depth > 0 { + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + if depth > 0 { + err = fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '}', i) + } + if err == nil { + for _, d := range dims { + if (len(elems) % d) != 0 { + err = fmt.Errorf("boil: multidimensional arrays must have elements with matching dimensions") + } + } + } + return +} + +func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { + dims, elems, err := parseArray(src, del) + if err != nil { + return nil, err + } + if len(dims) > 1 { + return nil, fmt.Errorf("boil: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + } + return elems, err +} diff --git a/types/array_test.go b/types/array_test.go new file mode 100644 index 000000000..27e68cf1a --- /dev/null +++ b/types/array_test.go @@ -0,0 +1,800 @@ +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package types + +import ( + "database/sql" + "database/sql/driver" + "math/rand" + "reflect" + "strings" + "testing" +) + +func TestParseArray(t *testing.T) { + for _, tt := range []struct { + input string + delim string + dims []int + elems [][]byte + }{ + {`{}`, `,`, nil, [][]byte{}}, + {`{NULL}`, `,`, []int{1}, [][]byte{nil}}, + {`{a}`, `,`, []int{1}, [][]byte{{'a'}}}, + {`{a,b}`, `,`, []int{2}, [][]byte{{'a'}, {'b'}}}, + {`{{a,b}}`, `,`, []int{1, 2}, [][]byte{{'a'}, {'b'}}}, + {`{{a},{b}}`, `,`, []int{2, 1}, [][]byte{{'a'}, {'b'}}}, + {`{{{a,b},{c,d},{e,f}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + }}, + {`{""}`, `,`, []int{1}, [][]byte{{}}}, + {`{","}`, `,`, []int{1}, [][]byte{{','}}}, + {`{",",","}`, `,`, []int{2}, [][]byte{{','}, {','}}}, + {`{{",",","}}`, `,`, []int{1, 2}, [][]byte{{','}, {','}}}, + {`{{","},{","}}`, `,`, []int{2, 1}, [][]byte{{','}, {','}}}, + {`{{{",",","},{",",","},{",",","}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {','}, {','}, {','}, {','}, {','}, {','}, + }}, + {`{"\"}"}`, `,`, []int{1}, [][]byte{{'"', '}'}}}, + {`{"\"","\""}`, `,`, []int{2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\"","\""}}`, `,`, []int{1, 2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\""},{"\""}}`, `,`, []int{2, 1}, [][]byte{{'"'}, {'"'}}}, + {`{{{"\"","\""},{"\"","\""},{"\"","\""}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, + }}, + {`{axyzb}`, `xyz`, []int{2}, [][]byte{{'a'}, {'b'}}}, + } { + dims, elems, err := parseArray([]byte(tt.input), []byte(tt.delim)) + + if err != nil { + t.Fatalf("Expected no error for %q, got %q", tt.input, err) + } + if !reflect.DeepEqual(dims, tt.dims) { + t.Errorf("Expected %v dimensions for %q, got %v", tt.dims, tt.input, dims) + } + if !reflect.DeepEqual(elems, tt.elems) { + t.Errorf("Expected %v elements for %q, got %v", tt.elems, tt.input, elems) + } + } +} + +func TestParseArrayError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "expected '{' at offset 0"}, + {`x`, "expected '{' at offset 0"}, + {`}`, "expected '{' at offset 0"}, + {`{`, "expected '}' at offset 1"}, + {`{{}`, "expected '}' at offset 3"}, + {`{}}`, "unexpected '}' at offset 2"}, + {`{,}`, "unexpected ',' at offset 1"}, + {`{,x}`, "unexpected ',' at offset 1"}, + {`{x,}`, "unexpected '}' at offset 3"}, + {`{""x}`, "unexpected 'x' at offset 3"}, + {`{{a},{b,c}}`, "multidimensional arrays must have elements with matching dimensions"}, + } { + _, _, err := parseArray([]byte(tt.input), []byte{','}) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + } +} + +func TestArrayScanner(t *testing.T) { + var s sql.Scanner + + s = Array(&[]bool{}) + if _, ok := s.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", s) + } + + s = Array(&[]float64{}) + if _, ok := s.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", s) + } + + s = Array(&[]int64{}) + if _, ok := s.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", s) + } + + s = Array(&[]string{}) + if _, ok := s.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", s) + } +} + +func TestArrayValuer(t *testing.T) { + var v driver.Valuer + + v = Array([]bool{}) + if _, ok := v.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", v) + } + + v = Array([]float64{}) + if _, ok := v.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", v) + } + + v = Array([]int64{}) + if _, ok := v.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", v) + } + + v = Array([]string{}) + if _, ok := v.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", v) + } +} + +func TestBoolArrayScanUnsupported(t *testing.T) { + var arr BoolArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to BoolArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var BoolArrayStringTests = []struct { + str string + arr BoolArray +}{ + {`{}`, BoolArray{}}, + {`{t}`, BoolArray{true}}, + {`{f,t}`, BoolArray{false, true}}, +} + +func TestBoolArrayScanBytes(t *testing.T) { + for _, tt := range BoolArrayStringTests { + bytes := []byte(tt.str) + arr := BoolArray{true, true, true} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkBoolArrayScanBytes(b *testing.B) { + var a BoolArray + var x interface{} = []byte(`{t,f,t,f,t,f,t,f,t,f}`) + + for i := 0; i < b.N; i++ { + a = BoolArray{} + a.Scan(x) + } +} + +func TestBoolArrayScanString(t *testing.T) { + for _, tt := range BoolArrayStringTests { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestBoolArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{t},{f}}`, "cannot convert ARRAY[2][1] to BoolArray"}, + {`{NULL}`, `could not parse boolean array index 0: invalid boolean ""`}, + {`{a}`, `could not parse boolean array index 0: invalid boolean "a"`}, + {`{t,b}`, `could not parse boolean array index 1: invalid boolean "b"`}, + {`{t,f,cd}`, `could not parse boolean array index 2: invalid boolean "cd"`}, + } { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, BoolArray{true, true, true}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestBoolArrayValue(t *testing.T) { + result, err := BoolArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = BoolArray([]bool{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = BoolArray([]bool{false, true, false}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{f,t,f}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkBoolArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]bool, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Intn(2) == 0 + } + a := BoolArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestBytesArrayScanUnsupported(t *testing.T) { + var arr BytesArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to BytesArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var BytesArrayStringTests = []struct { + str string + arr BytesArray +}{ + {`{}`, BytesArray{}}, + {`{NULL}`, BytesArray{nil}}, + {`{"\\xfeff"}`, BytesArray{{'\xFE', '\xFF'}}}, + {`{"\\xdead","\\xbeef"}`, BytesArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, +} + +func TestBytesArrayScanBytes(t *testing.T) { + for _, tt := range BytesArrayStringTests { + bytes := []byte(tt.str) + arr := BytesArray{{2}, {6}, {0, 0}} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkBytesArrayScanBytes(b *testing.B) { + var a BytesArray + var x interface{} = []byte(`{"\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff"}`) + + for i := 0; i < b.N; i++ { + a = BytesArray{} + a.Scan(x) + } +} + +func TestBytesArrayScanString(t *testing.T) { + for _, tt := range BytesArrayStringTests { + arr := BytesArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestBytesArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{"\\xfeff"},{"\\xbeef"}}`, "cannot convert ARRAY[2][1] to BytesArray"}, + {`{"\\abc"}`, "could not parse bytea array index 0: could not parse bytea value"}, + } { + arr := BytesArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, BytesArray{{2}, {6}, {0, 0}}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestBytesArrayValue(t *testing.T) { + result, err := BytesArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = BytesArray([][]byte{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = BytesArray([][]byte{{'\xDE', '\xAD', '\xBE', '\xEF'}, {'\xFE', '\xFF'}, {}}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"\\xdeadbeef","\\xfeff","\\x"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkBytesArrayValue(b *testing.B) { + rand.Seed(1) + x := make([][]byte, 10) + for i := 0; i < len(x); i++ { + x[i] = make([]byte, len(x)) + for j := 0; j < len(x); j++ { + x[i][j] = byte(rand.Int()) + } + } + a := BytesArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestFloat64ArrayScanUnsupported(t *testing.T) { + var arr Float64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var Float64ArrayStringTests = []struct { + str string + arr Float64Array +}{ + {`{}`, Float64Array{}}, + {`{1.2}`, Float64Array{1.2}}, + {`{3.456,7.89}`, Float64Array{3.456, 7.89}}, + {`{3,1,2}`, Float64Array{3, 1, 2}}, +} + +func TestFloat64ArrayScanBytes(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + bytes := []byte(tt.str) + arr := Float64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat64ArrayScanBytes(b *testing.B) { + var a Float64Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float64Array{} + a.Scan(x) + } +} + +func TestFloat64ArrayScanString(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat64ArrayValue(t *testing.T) { + result, err := Float64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float64Array([]float64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float64Array([]float64{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.NormFloat64() + } + a := Float64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt64ArrayScanUnsupported(t *testing.T) { + var arr Int64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var Int64ArrayStringTests = []struct { + str string + arr Int64Array +}{ + {`{}`, Int64Array{}}, + {`{12}`, Int64Array{12}}, + {`{345,678}`, Int64Array{345, 678}}, +} + +func TestInt64ArrayScanBytes(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + bytes := []byte(tt.str) + arr := Int64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt64ArrayScanBytes(b *testing.B) { + var a Int64Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int64Array{} + a.Scan(x) + } +} + +func TestInt64ArrayScanString(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt64ArrayValue(t *testing.T) { + result, err := Int64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int64Array([]int64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int64Array([]int64{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int63() + } + a := Int64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestStringArrayScanUnsupported(t *testing.T) { + var arr StringArray + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to StringArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var StringArrayStringTests = []struct { + str string + arr StringArray +}{ + {`{}`, StringArray{}}, + {`{t}`, StringArray{"t"}}, + {`{f,1}`, StringArray{"f", "1"}}, + {`{"a\\b","c d",","}`, StringArray{"a\\b", "c d", ","}}, +} + +func TestStringArrayScanBytes(t *testing.T) { + for _, tt := range StringArrayStringTests { + bytes := []byte(tt.str) + arr := StringArray{"x", "x", "x"} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkStringArrayScanBytes(b *testing.B) { + var a StringArray + var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) + var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) + + for i := 0; i < b.N; i++ { + a = StringArray{} + a.Scan(x) + a = StringArray{} + a.Scan(y) + } +} + +func TestStringArrayScanString(t *testing.T) { + for _, tt := range StringArrayStringTests { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestStringArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{a},{b}}`, "cannot convert ARRAY[2][1] to StringArray"}, + {`{NULL}`, "parsing array element index 0: cannot convert nil to string"}, + {`{a,NULL}`, "parsing array element index 1: cannot convert nil to string"}, + {`{a,b,NULL}`, "parsing array element index 2: cannot convert nil to string"}, + } { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, StringArray{"x", "x", "x"}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestStringArrayValue(t *testing.T) { + result, err := StringArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = StringArray([]string{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = StringArray([]string{`a`, `\b`, `c"`, `d,e`}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"a","\\b","c\"","d,e"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkStringArrayValue(b *testing.B) { + x := make([]string, 10) + for i := 0; i < len(x); i++ { + x[i] = strings.Repeat(`abc"def\ghi`, 5) + } + a := StringArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} diff --git a/types/hstore.go b/types/hstore.go new file mode 100644 index 000000000..101a4f111 --- /dev/null +++ b/types/hstore.go @@ -0,0 +1,135 @@ +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package types + +import ( + "database/sql" + "database/sql/driver" + "strings" +) + +// HStore is a wrapper for transferring HStore values back and forth easily. +type HStore map[string]sql.NullString + +// escapes and quotes hstore keys/values +// s should be a sql.NullString or string +func hQuote(s interface{}) string { + var str string + switch v := s.(type) { + case sql.NullString: + if !v.Valid { + return "NULL" + } + str = v.String + case string: + str = v + default: + panic("not a string or sql.NullString") + } + + str = strings.Replace(str, "\\", "\\\\", -1) + return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` +} + +// Scan implements the Scanner interface. +// +// Note h is reallocated before the scan to clear existing values. If the +// hstore column's database value is NULL, then h is set to nil instead. +func (h *HStore) Scan(value interface{}) error { + if value == nil { + h = nil + return nil + } + *h = make(map[string]sql.NullString) + var b byte + pair := [][]byte{{}, {}} + pi := 0 + inQuote := false + didQuote := false + sawSlash := false + bindex := 0 + for bindex, b = range value.([]byte) { + if sawSlash { + pair[pi] = append(pair[pi], b) + sawSlash = false + continue + } + + switch b { + case '\\': + sawSlash = true + continue + case '"': + inQuote = !inQuote + if !didQuote { + didQuote = true + } + continue + default: + if !inQuote { + switch b { + case ' ', '\t', '\n', '\r': + continue + case '=': + continue + case '>': + pi = 1 + didQuote = false + continue + case ',': + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + (*h)[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + (*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + pair[0] = []byte{} + pair[1] = []byte{} + pi = 0 + continue + } + } + } + pair[pi] = append(pair[pi], b) + } + if bindex > 0 { + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + (*h)[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + (*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + } + return nil +} + +// Value implements the driver Valuer interface. Note if h is nil, the +// database column value will be set to NULL. +func (h HStore) Value() (driver.Value, error) { + if h == nil { + return nil, nil + } + parts := []string{} + for key, val := range h { + thispart := hQuote(key) + "=>" + hQuote(val) + parts = append(parts, thispart) + } + return []byte(strings.Join(parts, ",")), nil +} diff --git a/types/json.go b/types/json.go new file mode 100644 index 000000000..b42e694b3 --- /dev/null +++ b/types/json.go @@ -0,0 +1,77 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "errors" +) + +// JSON is an alias for json.RawMessage, which is +// a []byte underneath. +// JSON implements Marshal and Unmarshal. +type JSON json.RawMessage + +// String output your JSON. +func (j JSON) String() string { + return string(j) +} + +// Unmarshal your JSON variable into dest. +func (j JSON) Unmarshal(dest interface{}) error { + return json.Unmarshal(j, dest) +} + +// Marshal obj into your JSON variable. +func (j *JSON) Marshal(obj interface{}) error { + res, err := json.Marshal(obj) + if err != nil { + return err + } + + *j = res + return nil +} + +// UnmarshalJSON sets *j to a copy of data. +func (j *JSON) UnmarshalJSON(data []byte) error { + if j == nil { + return errors.New("JSON: UnmarshalJSON on nil pointer") + } + + *j = append((*j)[0:0], data...) + return nil +} + +// MarshalJSON returns j as the JSON encoding of j. +func (j JSON) MarshalJSON() ([]byte, error) { + return j, nil +} + +// Value returns j as a value. +// Unmarshal into RawMessage for validation. +func (j JSON) Value() (driver.Value, error) { + var r json.RawMessage + if err := j.Unmarshal(&r); err != nil { + return nil, err + } + + return []byte(r), nil +} + +// Scan stores the src in *j. +func (j *JSON) Scan(src interface{}) error { + var source []byte + + switch src.(type) { + case string: + source = []byte(src.(string)) + case []byte: + source = src.([]byte) + default: + return errors.New("Incompatible type for JSON") + } + + *j = JSON(append((*j)[0:0], source...)) + + return nil +} diff --git a/types/json_test.go b/types/json_test.go new file mode 100644 index 000000000..9ba232711 --- /dev/null +++ b/types/json_test.go @@ -0,0 +1,119 @@ +package types + +import ( + "bytes" + "testing" +) + +func TestJSONString(t *testing.T) { + t.Parallel() + + j := JSON("hello") + if j.String() != "hello" { + t.Errorf("Expected %q, got %s", "hello", j.String()) + } +} + +func TestJSONUnmarshal(t *testing.T) { + t.Parallel() + + type JSONTest struct { + Name string + Age int + } + var jt JSONTest + + j := JSON(`{"Name":"hi","Age":15}`) + err := j.Unmarshal(&jt) + if err != nil { + t.Error(err) + } + + if jt.Name != "hi" { + t.Errorf("Expected %q, got %s", "hi", jt.Name) + } + if jt.Age != 15 { + t.Errorf("Expected %v, got %v", 15, jt.Age) + } +} + +func TestJSONMarshal(t *testing.T) { + t.Parallel() + + type JSONTest struct { + Name string + Age int + } + jt := JSONTest{ + Name: "hi", + Age: 15, + } + + var j JSON + err := j.Marshal(jt) + if err != nil { + t.Error(err) + } + + if j.String() != `{"Name":"hi","Age":15}` { + t.Errorf("expected %s, got %s", `{"Name":"hi","Age":15}`, j.String()) + } +} + +func TestJSONUnmarshalJSON(t *testing.T) { + t.Parallel() + + j := JSON(nil) + + err := j.UnmarshalJSON(JSON(`"hi"`)) + if err != nil { + t.Error(err) + } + + if j.String() != `"hi"` { + t.Errorf("Expected %q, got %s", "hi", j.String()) + } +} + +func TestJSONMarshalJSON(t *testing.T) { + t.Parallel() + + j := JSON(`"hi"`) + res, err := j.MarshalJSON() + if err != nil { + t.Error(err) + } + + if !bytes.Equal(res, []byte(`"hi"`)) { + t.Errorf("Expected %q, got %v", `"hi"`, res) + } +} + +func TestJSONValue(t *testing.T) { + t.Parallel() + + j := JSON(`{"Name":"hi","Age":15}`) + v, err := j.Value() + if err != nil { + t.Error(err) + } + + if !bytes.Equal(j, v.([]byte)) { + t.Errorf("byte mismatch, %v %v", j, v) + } +} + +func TestJSONScan(t *testing.T) { + t.Parallel() + + j := JSON{} + + err := j.Scan(`"hello"`) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(j, []byte(`"hello"`)) { + t.Errorf("bad []byte: %#v ≠ %#v\n", j, string([]byte(`"hello"`))) + } +}