diff --git a/boilingcore/config.go b/boilingcore/config.go index 1a2a908be..dc942df3f 100644 --- a/boilingcore/config.go +++ b/boilingcore/config.go @@ -1,11 +1,13 @@ package boilingcore import ( + "fmt" "io/fs" "path/filepath" "strings" "text/template" + "github.com/friendsofgo/errors" "github.com/spf13/cast" "github.com/volatiletech/sqlboiler/v4/drivers" @@ -91,27 +93,27 @@ func (c *Config) OutputDirDepth() int { // // It also supports two different syntaxes, because of viper: // -// [aliases.tables.table_name] -// fields... = "values" -// [aliases.tables.columns] -// colname = "alias" -// [aliases.tables.relationships.fkey_name] -// local = "x" -// foreign = "y" +// [aliases.tables.table_name] +// fields... = "values" +// [aliases.tables.columns] +// colname = "alias" +// [aliases.tables.relationships.fkey_name] +// local = "x" +// foreign = "y" // // Or alternatively (when toml key names or viper's // lowercasing of key names gets in the way): // -// [[aliases.tables]] -// name = "table_name" -// fields... = "values" -// [[aliases.tables.columns]] -// name = "colname" -// alias = "alias" -// [[aliases.tables.relationships]] -// name = "fkey_name" -// local = "x" -// foreign = "y" +// [[aliases.tables]] +// name = "table_name" +// fields... = "values" +// [[aliases.tables.columns]] +// name = "colname" +// alias = "alias" +// [[aliases.tables.relationships]] +// name = "fkey_name" +// local = "x" +// foreign = "y" func ConvertAliases(i interface{}) (a Aliases) { if i == nil { return a @@ -283,3 +285,81 @@ func columnFromInterface(i interface{}) (col drivers.Column) { return col } + +// ConvertForeignKeys is necessary because viper +// +// It also supports two different syntaxes, because of viper: +// +// [foreign_keys.fk_1] +// table = "table_name" +// column = "column_name" +// foreign_table = "foreign_table_name" +// foreign_column = "foreign_column_name" +// +// Or alternatively (when toml key names or viper's +// lowercasing of key names gets in the way): +// +// [[foreign_keys]] +// name = "fk_1" +// table = "table_name" +// column = "column_name" +// foreign_table = "foreign_table_name" +// foreign_column = "foreign_column_name" +func ConvertForeignKeys(i interface{}) (fks []drivers.ForeignKey) { + if i == nil { + return nil + } + + iterateMapOrSlice(i, func(name string, obj interface{}) { + t := cast.ToStringMap(obj) + + fk := drivers.ForeignKey{ + Table: cast.ToString(t["table"]), + Name: name, + Column: cast.ToString(t["column"]), + ForeignTable: cast.ToString(t["foreign_table"]), + ForeignColumn: cast.ToString(t["foreign_column"]), + } + if err := validateForeignKey(fk); err != nil { + panic(errors.Errorf("invalid foreign key %s: %s", name, err)) + } + fks = append(fks, fk) + }) + + if err := validateDuplicateForeignKeys(fks); err != nil { + panic(errors.Errorf("invalid foreign keys: %s", err)) + } + + return fks +} + +func validateForeignKey(fk drivers.ForeignKey) error { + if fk.Name == "" { + return errors.New("foreign key must have a name") + } + if fk.Table == "" { + return errors.New("foreign key must have a table") + } + if fk.Column == "" { + return errors.New("foreign key must have a column") + } + if fk.ForeignTable == "" { + return errors.New("foreign key must have a foreign table") + } + if fk.ForeignColumn == "" { + return errors.New("foreign key must have a foreign column") + } + return nil +} + +func validateDuplicateForeignKeys(fks []drivers.ForeignKey) error { + fkMap := make(map[string]drivers.ForeignKey) + for _, fk := range fks { + key := fmt.Sprintf("%s.%s", fk.Table, fk.Column) + if _, ok := fkMap[key]; ok { + return errors.Errorf("duplicate foreign key name: %s", fk.Name) + } + fkMap[key] = fk + } + return nil +} diff --git a/boilingcore/config_test.go b/boilingcore/config_test.go index 7d32a990c..01ebc15c4 100644 --- a/boilingcore/config_test.go +++ b/boilingcore/config_test.go @@ -254,3 +254,66 @@ func TestConvertTypeReplace(t *testing.T) { t.Error("tables in types.match wrong:", got) } } + +func TestConvertForeignKeys(t *testing.T) { + t.Parallel() + + var intf interface{} = map[string]interface{}{ + "fk_1": map[string]interface{}{ + "table": "table_name", + "column": "column_name", + "foreign_table": "foreign_table_name", + "foreign_column": "foreign_column_name", + }, + } + + fks := ConvertForeignKeys(intf) + if len(fks) != 1 { + t.Error("should have one entry") + } + + fk := fks[0] + expectedFK := drivers.ForeignKey{ + Name: "fk_1", + Table: "table_name", + Column: "column_name", + ForeignTable: "foreign_table_name", + ForeignColumn: "foreign_column_name", + } + + if fk != expectedFK { + t.Error("value was wrong:", fk) + } +} + +func TestConvertForeignKeysAltSyntax(t *testing.T) { + t.Parallel() + + var intf interface{} = []interface{}{ + map[string]interface{}{ + "name": "fk_1", + "table": "table_name", + "column": "column_name", + "foreign_table": "foreign_table_name", + "foreign_column": "foreign_column_name", + }, + } + + fks := ConvertForeignKeys(intf) + if len(fks) != 1 { + t.Error("should have one entry") + } + + fk := fks[0] + expectedFK := drivers.ForeignKey{ + Name: "fk_1", + Table: "table_name", + Column: "column_name", + ForeignTable: "foreign_table_name", + ForeignColumn: "foreign_column_name", + } + + if fk != expectedFK { + t.Error("value was wrong:", fk) + } +}