From 1fbc1c15180db4046d2f35a2384ff6fc62db8fa7 Mon Sep 17 00:00:00 2001 From: ChienNM3 Date: Sat, 22 Jul 2023 11:07:48 +0700 Subject: [PATCH] Allow drivers read and combine config and db foreign keys --- boilingcore/config.go | 9 +- drivers/config.go | 62 +++++++ drivers/config_test.go | 171 ++++++++++++++++++++ drivers/interface.go | 4 +- drivers/sqlboiler-mssql/driver/mssql.go | 14 +- drivers/sqlboiler-mysql/driver/mysql.go | 23 ++- drivers/sqlboiler-psql/driver/psql.go | 15 +- drivers/sqlboiler-sqlite3/driver/sqlite3.go | 17 +- main.go | 3 + 9 files changed, 301 insertions(+), 17 deletions(-) diff --git a/boilingcore/config.go b/boilingcore/config.go index dc942df3f..54f78b4b4 100644 --- a/boilingcore/config.go +++ b/boilingcore/config.go @@ -48,10 +48,11 @@ type Config struct { DefaultTemplates fs.FS `toml:"-" json:"-"` CustomTemplateFuncs template.FuncMap `toml:"-" json:"-"` - Aliases Aliases `toml:"aliases,omitempty" json:"aliases,omitempty"` - TypeReplaces []TypeReplace `toml:"type_replaces,omitempty" json:"type_replaces,omitempty"` - AutoColumns AutoColumns `toml:"auto_columns,omitempty" json:"auto_columns,omitempty"` - Inflections Inflections `toml:"inflections,omitempty" json:"inflections,omitempty"` + Aliases Aliases `toml:"aliases,omitempty" json:"aliases,omitempty"` + TypeReplaces []TypeReplace `toml:"type_replaces,omitempty" json:"type_replaces,omitempty"` + AutoColumns AutoColumns `toml:"auto_columns,omitempty" json:"auto_columns,omitempty"` + Inflections Inflections `toml:"inflections,omitempty" json:"inflections,omitempty"` + ForeignKeys []drivers.ForeignKey `toml:"foreign_keys,omitempty" json:"foreign_keys,omitempty" ` Version string `toml:"version" json:"version"` } diff --git a/drivers/config.go b/drivers/config.go index 7751cbfab..2b4ab8a3c 100644 --- a/drivers/config.go +++ b/drivers/config.go @@ -160,6 +160,41 @@ func (c Config) StringSlice(key string) ([]string, bool) { return slice, true } +func (c Config) MustForeignKeys(key string) []ForeignKey { + rawValue, ok := c[key] + if !ok { + return nil + } + + switch v := rawValue.(type) { + case nil: + return nil + case []ForeignKey: + return v + case []interface{}: // in case binary, config is pass to driver in json format, so this key will be []interface{} + fks := make([]ForeignKey, 0, len(v)) + for _, item := range v { + fk, ok := item.(map[string]interface{}) + if !ok { + panic(errors.Errorf("found item of foreign keys, but it was not a map[string]interface{} (%T)", v)) + } + + configFK := Config(fk) + + fks = append(fks, ForeignKey{ + Name: configFK.MustString("name"), + Table: configFK.MustString("table"), + Column: configFK.MustString("column"), + ForeignTable: configFK.MustString("foreign_table"), + ForeignColumn: configFK.MustString("foreign_column"), + }) + } + return fks + default: + panic(errors.Errorf("found key %s in config, but it was invalid (%T)", key, v)) + } +} + // DefaultEnv grabs a value from the environment or a default. // This is shared by drivers to get config for testing. func DefaultEnv(key, def string) string { @@ -211,3 +246,30 @@ func ColumnsFromList(list []string, tablename string) []string { return columns } + +// CombineConfigAndDBForeignKeys takes foreign keys from both config and db, filter by tableName and +// deduplicate by column name. If a foreign key is found in both config and db, the one in config will be used. +func CombineConfigAndDBForeignKeys(configForeignKeys []ForeignKey, tableName string, dbForeignKeys []ForeignKey) []ForeignKey { + combinedForeignKeys := make([]ForeignKey, 0, len(configForeignKeys)+len(dbForeignKeys)) + appearedColumns := make(map[string]bool) + + for _, fk := range configForeignKeys { + // need check table name here cause configForeignKeys contains all foreign keys of all tables + if fk.Table != tableName { + continue + } + + combinedForeignKeys = append(combinedForeignKeys, fk) + appearedColumns[fk.Column] = true + } + + for _, fk := range dbForeignKeys { + // no need check table here, because dbForeignKeys are already filtered by table name + if appearedColumns[fk.Column] { + continue + } + combinedForeignKeys = append(combinedForeignKeys, fk) + } + + return combinedForeignKeys +} diff --git a/drivers/config_test.go b/drivers/config_test.go index cceda9c7b..5a0608d3e 100644 --- a/drivers/config_test.go +++ b/drivers/config_test.go @@ -274,3 +274,174 @@ func TestColumnsFromList(t *testing.T) { t.Error("list was wrong:", got) } } + +func TestConfig_MustForeignKeys(t *testing.T) { + tests := []struct { + name string + c Config + want []ForeignKey + panic bool + }{ + { + name: "no foreign keys", + c: Config{}, + want: nil, + panic: false, + }, + { + name: "nil foreign keys", + c: Config{ + "foreign-keys": nil, + }, + want: nil, + panic: false, + }, + { + name: "have foreign keys", + c: Config{ + "foreign-keys": []ForeignKey{ + { + Name: "test_fk", + Table: "table_name", + Column: "column_name", + ForeignColumn: "foreign_column_name", + ForeignTable: "foreign_table_name", + }, + }, + }, + want: []ForeignKey{ + { + Name: "test_fk", + Table: "table_name", + Column: "column_name", + ForeignColumn: "foreign_column_name", + ForeignTable: "foreign_table_name", + }, + }, + panic: false, + }, + { + name: "invalid foreign keys", + c: Config{ + "foreign-keys": 1, + }, + panic: true, + }, + { + name: "foreign keys in []interface{} format", + c: Config{ + "foreign-keys": []interface{}{ + map[string]interface{}{ + "name": "test_fk", + "table": "table_name", + "column": "column_name", + "foreign_column": "foreign_column_name", + "foreign_table": "foreign_table_name", + }, + }, + }, + want: []ForeignKey{ + { + Name: "test_fk", + Table: "table_name", + Column: "column_name", + ForeignColumn: "foreign_column_name", + ForeignTable: "foreign_table_name", + }, + }, + panic: false, + }, + { + name: "invalid foreign keys in []interface{} format", + c: Config{ + "foreign-keys": []interface{}{ + "123", + }, + }, + panic: true, + }, + { + name: "foreign keys in []map[string]string format but missing fields", + c: Config{ + "foreign-keys": []interface{}{ + map[string]interface{}{ + "name": "test_fk", + }, + }, + }, + want: nil, + panic: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got []ForeignKey + var paniced interface{} + func() { + defer func() { + if r := recover(); r != nil { + paniced = r + } + }() + got = tt.c.MustForeignKeys(ConfigForeignKeys) + }() + + if tt.panic && paniced == nil { + t.Errorf("MustForeignKeys() should have panicked") + } + if !tt.panic && paniced != nil { + t.Errorf("MustForeignKeys() should not have panicked") + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MustForeignKeys() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCombineConfigAndDBForeignKeys(t *testing.T) { + configForeignKeys := []ForeignKey{ + { + Name: "config_fk1", + Table: "table_A", + Column: "column_A1", + ForeignColumn: "column_B1", + ForeignTable: "table_B", + }, + { + Name: "config_fk2", + Table: "table_C", + Column: "column_C1", + ForeignColumn: "column_B1", + ForeignTable: "table_B", + }, + { + Name: "config_fk3", + Table: "table_A", + Column: "column_A2", + ForeignColumn: "column_D2", + ForeignTable: "table_D", + }, + } + tableName := "table_A" + dbForeignKeys := []ForeignKey{ + { + Name: "db_fk1", + Table: "table_A", + Column: "column_A1", + ForeignColumn: "column_E1", + ForeignTable: "table_E", + }, + } + + expected := []ForeignKey{ + configForeignKeys[0], + configForeignKeys[2], + } + + got := CombineConfigAndDBForeignKeys(configForeignKeys, tableName, dbForeignKeys) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("CombineConfigAndDBForeignKeys() = %v, want %v", got, expected) + } +} diff --git a/drivers/interface.go b/drivers/interface.go index 8a5d57e98..7cc397168 100644 --- a/drivers/interface.go +++ b/drivers/interface.go @@ -7,8 +7,9 @@ import ( "sync" "github.com/friendsofgo/errors" - "github.com/volatiletech/sqlboiler/v4/importers" "github.com/volatiletech/strmangle" + + "github.com/volatiletech/sqlboiler/v4/importers" ) // These constants are used in the config map passed into the driver @@ -19,6 +20,7 @@ const ( ConfigAddEnumTypes = "add-enum-types" ConfigEnumNullPrefix = "enum-null-prefix" ConfigConcurrency = "concurrency" + ConfigForeignKeys = "foreign-keys" ConfigUser = "user" ConfigPass = "pass" diff --git a/drivers/sqlboiler-mssql/driver/mssql.go b/drivers/sqlboiler-mssql/driver/mssql.go index da4b1bb6d..1156b7c87 100644 --- a/drivers/sqlboiler-mssql/driver/mssql.go +++ b/drivers/sqlboiler-mssql/driver/mssql.go @@ -12,9 +12,10 @@ import ( // Side effect import go-mssqldb "github.com/friendsofgo/errors" _ "github.com/microsoft/go-mssqldb" + "github.com/volatiletech/strmangle" + "github.com/volatiletech/sqlboiler/v4/drivers" "github.com/volatiletech/sqlboiler/v4/importers" - "github.com/volatiletech/strmangle" ) //go:embed override @@ -36,6 +37,8 @@ func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { type MSSQLDriver struct { connStr string conn *sql.DB + + configForeignKeys []drivers.ForeignKey } // Templates that should be added/overridden @@ -84,6 +87,7 @@ func (m *MSSQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency) m.connStr = MSSQLBuildQueryString(user, pass, dbname, host, port, sslmode) + m.configForeignKeys = config.MustForeignKeys(drivers.ConfigForeignKeys) m.conn, err = sql.Open("mssql", m.connStr) if err != nil { return nil, errors.Wrap(err, "sqlboiler-mssql failed to connect to database") @@ -409,6 +413,14 @@ func (m *MSSQLDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.Primary // ForeignKeyInfo retrieves the foreign keys for a given table name. func (m *MSSQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { + dbForeignKeys, err := m.foreignKeyInfoFromDB(schema, tableName) + if err != nil { + return nil, errors.Wrap(err, "read foreign keys info from db") + } + + return drivers.CombineConfigAndDBForeignKeys(m.configForeignKeys, tableName, dbForeignKeys), nil +} +func (m *MSSQLDriver) foreignKeyInfoFromDB(schema, tableName string) ([]drivers.ForeignKey, error) { var fkeys []drivers.ForeignKey query := ` diff --git a/drivers/sqlboiler-mysql/driver/mysql.go b/drivers/sqlboiler-mysql/driver/mysql.go index 7103c96e3..76298210f 100644 --- a/drivers/sqlboiler-mysql/driver/mysql.go +++ b/drivers/sqlboiler-mysql/driver/mysql.go @@ -11,9 +11,10 @@ import ( "github.com/friendsofgo/errors" "github.com/go-sql-driver/mysql" + "github.com/volatiletech/strmangle" + "github.com/volatiletech/sqlboiler/v4/drivers" "github.com/volatiletech/sqlboiler/v4/importers" - "github.com/volatiletech/strmangle" ) //go:embed override @@ -33,11 +34,12 @@ func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { // MySQLDriver holds the database connection string and a handle // to the database connection. type MySQLDriver struct { - connStr string - conn *sql.DB - addEnumTypes bool - enumNullPrefix string - tinyIntAsInt bool + connStr string + conn *sql.DB + addEnumTypes bool + enumNullPrefix string + tinyIntAsInt bool + configForeignKeys []drivers.ForeignKey } // Templates that should be added/overridden @@ -95,6 +97,7 @@ func (m *MySQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e m.addEnumTypes, _ = config[drivers.ConfigAddEnumTypes].(bool) m.enumNullPrefix = strmangle.TitleCase(config.DefaultString(drivers.ConfigEnumNullPrefix, "Null")) m.connStr = MySQLBuildQueryString(user, pass, dbname, host, port, sslmode) + m.configForeignKeys = config.MustForeignKeys(drivers.ConfigForeignKeys) m.conn, err = sql.Open("mysql", m.connStr) if err != nil { return nil, errors.Wrap(err, "sqlboiler-mysql failed to connect to database") @@ -405,6 +408,14 @@ func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.Primary // ForeignKeyInfo retrieves the foreign keys for a given table name. func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { + dbForeignKeys, err := m.foreignKeyInfoFromDB(schema, tableName) + if err != nil { + return nil, errors.Wrap(err, "read foreign keys info from db") + } + + return drivers.CombineConfigAndDBForeignKeys(m.configForeignKeys, tableName, dbForeignKeys), nil +} +func (m *MySQLDriver) foreignKeyInfoFromDB(schema, tableName string) ([]drivers.ForeignKey, error) { var fkeys []drivers.ForeignKey query := ` diff --git a/drivers/sqlboiler-psql/driver/psql.go b/drivers/sqlboiler-psql/driver/psql.go index 9c326efdc..a41bfb7c4 100644 --- a/drivers/sqlboiler-psql/driver/psql.go +++ b/drivers/sqlboiler-psql/driver/psql.go @@ -15,9 +15,10 @@ import ( "github.com/volatiletech/sqlboiler/v4/importers" "github.com/friendsofgo/errors" - "github.com/volatiletech/sqlboiler/v4/drivers" "github.com/volatiletech/strmangle" + "github.com/volatiletech/sqlboiler/v4/drivers" + // Side-effect import sql driver _ "github.com/lib/pq" ) @@ -45,7 +46,8 @@ type PostgresDriver struct { addEnumTypes bool enumNullPrefix string - uniqueColumns map[columnIdentifier]struct{} + uniqueColumns map[columnIdentifier]struct{} + configForeignKeys []drivers.ForeignKey } type columnIdentifier struct { @@ -103,6 +105,7 @@ func (p *PostgresDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo p.addEnumTypes, _ = config[drivers.ConfigAddEnumTypes].(bool) p.enumNullPrefix = strmangle.TitleCase(config.DefaultString(drivers.ConfigEnumNullPrefix, "Null")) p.connStr = PSQLBuildQueryString(user, pass, dbname, host, port, sslmode) + p.configForeignKeys = config.MustForeignKeys(drivers.ConfigForeignKeys) p.conn, err = sql.Open("postgres", p.connStr) if err != nil { return nil, errors.Wrap(err, "sqlboiler-psql failed to connect to database") @@ -703,6 +706,14 @@ func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.Prim // ForeignKeyInfo retrieves the foreign keys for a given table name. func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { + dbForeignKeys, err := p.foreignKeyInfoFromDB(schema, tableName) + if err != nil { + return nil, errors.Wrap(err, "read foreign keys info from db") + } + + return drivers.CombineConfigAndDBForeignKeys(p.configForeignKeys, tableName, dbForeignKeys), nil +} +func (p *PostgresDriver) foreignKeyInfoFromDB(schema, tableName string) ([]drivers.ForeignKey, error) { var fkeys []drivers.ForeignKey whereConditions := []string{"pgn.nspname = $2", "pgc.relname = $1", "pgcon.contype = 'f'"} diff --git a/drivers/sqlboiler-sqlite3/driver/sqlite3.go b/drivers/sqlboiler-sqlite3/driver/sqlite3.go index 0b6b67059..f124a2207 100644 --- a/drivers/sqlboiler-sqlite3/driver/sqlite3.go +++ b/drivers/sqlboiler-sqlite3/driver/sqlite3.go @@ -9,9 +9,10 @@ import ( "io/fs" "strings" + _ "modernc.org/sqlite" + "github.com/volatiletech/sqlboiler/v4/drivers" "github.com/volatiletech/sqlboiler/v4/importers" - _ "modernc.org/sqlite" ) //go:embed override @@ -30,8 +31,9 @@ func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { // SQLiteDriver holds the database connection string and a handle // to the database connection. type SQLiteDriver struct { - connStr string - dbConn *sql.DB + connStr string + dbConn *sql.DB + configForeignKeys []drivers.ForeignKey } // Templates that should be added/overridden @@ -73,6 +75,7 @@ func (s SQLiteDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency) s.connStr = SQLiteBuildQueryString(dbname) + s.configForeignKeys = config.MustForeignKeys(drivers.ConfigForeignKeys) s.dbConn, err = sql.Open("sqlite", s.connStr) if err != nil { return nil, fmt.Errorf("sqlboiler-sqlite failed to connect to database: %w", err) @@ -451,6 +454,14 @@ func (s SQLiteDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.Primary // ForeignKeyInfo retrieves the foreign keys for a given table name. func (s SQLiteDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { + dbForeignKeys, err := s.foreignKeyInfoFromDB(schema, tableName) + if err != nil { + return nil, fmt.Errorf("read foreign keys info from db: %w", err) + } + + return drivers.CombineConfigAndDBForeignKeys(s.configForeignKeys, tableName, dbForeignKeys), nil +} +func (s SQLiteDriver) foreignKeyInfoFromDB(schema, tableName string) ([]drivers.ForeignKey, error) { var fkeys []drivers.ForeignKey query := fmt.Sprintf("PRAGMA foreign_key_list('%s')", tableName) diff --git a/main.go b/main.go index b123545be..db674f44a 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "github.com/friendsofgo/errors" "github.com/spf13/cobra" "github.com/spf13/viper" + "github.com/volatiletech/sqlboiler/v4/boilingcore" "github.com/volatiletech/sqlboiler/v4/drivers" "github.com/volatiletech/sqlboiler/v4/importers" @@ -190,6 +191,7 @@ func preRun(cmd *cobra.Command, args []string) error { SingularExact: viper.GetStringMapString("inflections.singular_exact"), Irregular: viper.GetStringMapString("inflections.irregular"), }, + ForeignKeys: boilingcore.ConvertForeignKeys(viper.Get("foreign_keys")), Version: sqlBoilerVersion, } @@ -204,6 +206,7 @@ func preRun(cmd *cobra.Command, args []string) error { "blacklist": viper.GetStringSlice(driverName + ".blacklist"), "add-enum-types": cmdConfig.AddEnumTypes, "enum-null-prefix": cmdConfig.EnumNullPrefix, + "foreign-keys": cmdConfig.ForeignKeys, } keys := allKeys(driverName)