Skip to content

Commit

Permalink
Allow drivers read and combine config and db foreign keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ChienNM3 committed Jul 22, 2023
1 parent bf6e572 commit 1fbc1c1
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 17 deletions.
9 changes: 5 additions & 4 deletions boilingcore/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ type Config struct {
DefaultTemplates fs.FS `toml:"-" json:"-"`
CustomTemplateFuncs template.FuncMap `toml:"-" json:"-"`

Aliases Aliases `toml:"aliases,omitempty" json:"aliases,omitempty"`
TypeReplaces []TypeReplace `toml:"type_replaces,omitempty" json:"type_replaces,omitempty"`
AutoColumns AutoColumns `toml:"auto_columns,omitempty" json:"auto_columns,omitempty"`
Inflections Inflections `toml:"inflections,omitempty" json:"inflections,omitempty"`
Aliases Aliases `toml:"aliases,omitempty" json:"aliases,omitempty"`
TypeReplaces []TypeReplace `toml:"type_replaces,omitempty" json:"type_replaces,omitempty"`
AutoColumns AutoColumns `toml:"auto_columns,omitempty" json:"auto_columns,omitempty"`
Inflections Inflections `toml:"inflections,omitempty" json:"inflections,omitempty"`
ForeignKeys []drivers.ForeignKey `toml:"foreign_keys,omitempty" json:"foreign_keys,omitempty" `

Version string `toml:"version" json:"version"`
}
Expand Down
62 changes: 62 additions & 0 deletions drivers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,41 @@ func (c Config) StringSlice(key string) ([]string, bool) {
return slice, true
}

func (c Config) MustForeignKeys(key string) []ForeignKey {
rawValue, ok := c[key]
if !ok {
return nil
}

switch v := rawValue.(type) {
case nil:
return nil
case []ForeignKey:
return v
case []interface{}: // in case binary, config is pass to driver in json format, so this key will be []interface{}
fks := make([]ForeignKey, 0, len(v))
for _, item := range v {
fk, ok := item.(map[string]interface{})
if !ok {
panic(errors.Errorf("found item of foreign keys, but it was not a map[string]interface{} (%T)", v))
}

configFK := Config(fk)

fks = append(fks, ForeignKey{
Name: configFK.MustString("name"),
Table: configFK.MustString("table"),
Column: configFK.MustString("column"),
ForeignTable: configFK.MustString("foreign_table"),
ForeignColumn: configFK.MustString("foreign_column"),
})
}
return fks
default:
panic(errors.Errorf("found key %s in config, but it was invalid (%T)", key, v))
}
}

// DefaultEnv grabs a value from the environment or a default.
// This is shared by drivers to get config for testing.
func DefaultEnv(key, def string) string {
Expand Down Expand Up @@ -211,3 +246,30 @@ func ColumnsFromList(list []string, tablename string) []string {

return columns
}

// CombineConfigAndDBForeignKeys takes foreign keys from both config and db, filter by tableName and
// deduplicate by column name. If a foreign key is found in both config and db, the one in config will be used.
func CombineConfigAndDBForeignKeys(configForeignKeys []ForeignKey, tableName string, dbForeignKeys []ForeignKey) []ForeignKey {
combinedForeignKeys := make([]ForeignKey, 0, len(configForeignKeys)+len(dbForeignKeys))
appearedColumns := make(map[string]bool)

for _, fk := range configForeignKeys {
// need check table name here cause configForeignKeys contains all foreign keys of all tables
if fk.Table != tableName {
continue
}

combinedForeignKeys = append(combinedForeignKeys, fk)
appearedColumns[fk.Column] = true
}

for _, fk := range dbForeignKeys {
// no need check table here, because dbForeignKeys are already filtered by table name
if appearedColumns[fk.Column] {
continue
}
combinedForeignKeys = append(combinedForeignKeys, fk)
}

return combinedForeignKeys
}
171 changes: 171 additions & 0 deletions drivers/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,174 @@ func TestColumnsFromList(t *testing.T) {
t.Error("list was wrong:", got)
}
}

func TestConfig_MustForeignKeys(t *testing.T) {
tests := []struct {
name string
c Config
want []ForeignKey
panic bool
}{
{
name: "no foreign keys",
c: Config{},
want: nil,
panic: false,
},
{
name: "nil foreign keys",
c: Config{
"foreign-keys": nil,
},
want: nil,
panic: false,
},
{
name: "have foreign keys",
c: Config{
"foreign-keys": []ForeignKey{
{
Name: "test_fk",
Table: "table_name",
Column: "column_name",
ForeignColumn: "foreign_column_name",
ForeignTable: "foreign_table_name",
},
},
},
want: []ForeignKey{
{
Name: "test_fk",
Table: "table_name",
Column: "column_name",
ForeignColumn: "foreign_column_name",
ForeignTable: "foreign_table_name",
},
},
panic: false,
},
{
name: "invalid foreign keys",
c: Config{
"foreign-keys": 1,
},
panic: true,
},
{
name: "foreign keys in []interface{} format",
c: Config{
"foreign-keys": []interface{}{
map[string]interface{}{
"name": "test_fk",
"table": "table_name",
"column": "column_name",
"foreign_column": "foreign_column_name",
"foreign_table": "foreign_table_name",
},
},
},
want: []ForeignKey{
{
Name: "test_fk",
Table: "table_name",
Column: "column_name",
ForeignColumn: "foreign_column_name",
ForeignTable: "foreign_table_name",
},
},
panic: false,
},
{
name: "invalid foreign keys in []interface{} format",
c: Config{
"foreign-keys": []interface{}{
"123",
},
},
panic: true,
},
{
name: "foreign keys in []map[string]string format but missing fields",
c: Config{
"foreign-keys": []interface{}{
map[string]interface{}{
"name": "test_fk",
},
},
},
want: nil,
panic: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got []ForeignKey
var paniced interface{}
func() {
defer func() {
if r := recover(); r != nil {
paniced = r
}
}()
got = tt.c.MustForeignKeys(ConfigForeignKeys)
}()

if tt.panic && paniced == nil {
t.Errorf("MustForeignKeys() should have panicked")
}
if !tt.panic && paniced != nil {
t.Errorf("MustForeignKeys() should not have panicked")
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("MustForeignKeys() = %v, want %v", got, tt.want)
}
})
}
}

func TestCombineConfigAndDBForeignKeys(t *testing.T) {
configForeignKeys := []ForeignKey{
{
Name: "config_fk1",
Table: "table_A",
Column: "column_A1",
ForeignColumn: "column_B1",
ForeignTable: "table_B",
},
{
Name: "config_fk2",
Table: "table_C",
Column: "column_C1",
ForeignColumn: "column_B1",
ForeignTable: "table_B",
},
{
Name: "config_fk3",
Table: "table_A",
Column: "column_A2",
ForeignColumn: "column_D2",
ForeignTable: "table_D",
},
}
tableName := "table_A"
dbForeignKeys := []ForeignKey{
{
Name: "db_fk1",
Table: "table_A",
Column: "column_A1",
ForeignColumn: "column_E1",
ForeignTable: "table_E",
},
}

expected := []ForeignKey{
configForeignKeys[0],
configForeignKeys[2],
}

got := CombineConfigAndDBForeignKeys(configForeignKeys, tableName, dbForeignKeys)

if !reflect.DeepEqual(got, expected) {
t.Errorf("CombineConfigAndDBForeignKeys() = %v, want %v", got, expected)
}
}
4 changes: 3 additions & 1 deletion drivers/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
"sync"

"github.com/friendsofgo/errors"
"github.com/volatiletech/sqlboiler/v4/importers"
"github.com/volatiletech/strmangle"

"github.com/volatiletech/sqlboiler/v4/importers"
)

// These constants are used in the config map passed into the driver
Expand All @@ -19,6 +20,7 @@ const (
ConfigAddEnumTypes = "add-enum-types"
ConfigEnumNullPrefix = "enum-null-prefix"
ConfigConcurrency = "concurrency"
ConfigForeignKeys = "foreign-keys"

ConfigUser = "user"
ConfigPass = "pass"
Expand Down
14 changes: 13 additions & 1 deletion drivers/sqlboiler-mssql/driver/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (
// Side effect import go-mssqldb
"github.com/friendsofgo/errors"
_ "github.com/microsoft/go-mssqldb"
"github.com/volatiletech/strmangle"

"github.com/volatiletech/sqlboiler/v4/drivers"
"github.com/volatiletech/sqlboiler/v4/importers"
"github.com/volatiletech/strmangle"
)

//go:embed override
Expand All @@ -36,6 +37,8 @@ func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) {
type MSSQLDriver struct {
connStr string
conn *sql.DB

configForeignKeys []drivers.ForeignKey
}

// Templates that should be added/overridden
Expand Down Expand Up @@ -84,6 +87,7 @@ func (m *MSSQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency)

m.connStr = MSSQLBuildQueryString(user, pass, dbname, host, port, sslmode)
m.configForeignKeys = config.MustForeignKeys(drivers.ConfigForeignKeys)
m.conn, err = sql.Open("mssql", m.connStr)
if err != nil {
return nil, errors.Wrap(err, "sqlboiler-mssql failed to connect to database")
Expand Down Expand Up @@ -409,6 +413,14 @@ func (m *MSSQLDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.Primary

// ForeignKeyInfo retrieves the foreign keys for a given table name.
func (m *MSSQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) {
dbForeignKeys, err := m.foreignKeyInfoFromDB(schema, tableName)
if err != nil {
return nil, errors.Wrap(err, "read foreign keys info from db")
}

return drivers.CombineConfigAndDBForeignKeys(m.configForeignKeys, tableName, dbForeignKeys), nil
}
func (m *MSSQLDriver) foreignKeyInfoFromDB(schema, tableName string) ([]drivers.ForeignKey, error) {
var fkeys []drivers.ForeignKey

query := `
Expand Down
23 changes: 17 additions & 6 deletions drivers/sqlboiler-mysql/driver/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (

"github.com/friendsofgo/errors"
"github.com/go-sql-driver/mysql"
"github.com/volatiletech/strmangle"

"github.com/volatiletech/sqlboiler/v4/drivers"
"github.com/volatiletech/sqlboiler/v4/importers"
"github.com/volatiletech/strmangle"
)

//go:embed override
Expand All @@ -33,11 +34,12 @@ func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) {
// MySQLDriver holds the database connection string and a handle
// to the database connection.
type MySQLDriver struct {
connStr string
conn *sql.DB
addEnumTypes bool
enumNullPrefix string
tinyIntAsInt bool
connStr string
conn *sql.DB
addEnumTypes bool
enumNullPrefix string
tinyIntAsInt bool
configForeignKeys []drivers.ForeignKey
}

// Templates that should be added/overridden
Expand Down Expand Up @@ -95,6 +97,7 @@ func (m *MySQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
m.addEnumTypes, _ = config[drivers.ConfigAddEnumTypes].(bool)
m.enumNullPrefix = strmangle.TitleCase(config.DefaultString(drivers.ConfigEnumNullPrefix, "Null"))
m.connStr = MySQLBuildQueryString(user, pass, dbname, host, port, sslmode)
m.configForeignKeys = config.MustForeignKeys(drivers.ConfigForeignKeys)
m.conn, err = sql.Open("mysql", m.connStr)
if err != nil {
return nil, errors.Wrap(err, "sqlboiler-mysql failed to connect to database")
Expand Down Expand Up @@ -405,6 +408,14 @@ func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.Primary

// ForeignKeyInfo retrieves the foreign keys for a given table name.
func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) {
dbForeignKeys, err := m.foreignKeyInfoFromDB(schema, tableName)
if err != nil {
return nil, errors.Wrap(err, "read foreign keys info from db")
}

return drivers.CombineConfigAndDBForeignKeys(m.configForeignKeys, tableName, dbForeignKeys), nil
}
func (m *MySQLDriver) foreignKeyInfoFromDB(schema, tableName string) ([]drivers.ForeignKey, error) {
var fkeys []drivers.ForeignKey

query := `
Expand Down
Loading

0 comments on commit 1fbc1c1

Please sign in to comment.