Skip to content

Commit

Permalink
allowed to pass concurrency to driver over configuration file
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Krush committed Jun 25, 2022
1 parent 58f83d2 commit a5fd3f0
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 19 deletions.
26 changes: 12 additions & 14 deletions drivers/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@ import (

// These constants are used in the config map passed into the driver
const (
ConfigBlacklist = "blacklist"
ConfigWhitelist = "whitelist"
ConfigSchema = "schema"
ConfigAddEnumTypes = "add-enum-types"
ConfigEnumNullPrefix = "enum-null-prefix"
ConfigBlacklist = "blacklist"
ConfigWhitelist = "whitelist"
ConfigSchema = "schema"
ConfigAddEnumTypes = "add-enum-types"
ConfigEnumNullPrefix = "enum-null-prefix"
ConfigConcurrency = "concurrency"
ConfigDefaultConcurrency = 10

ConfigUser = "user"
ConfigPass = "pass"
ConfigHost = "host"
ConfigPort = "port"
ConfigDBName = "dbname"
ConfigSSLMode = "sslmode"

// number of threads while getting tables and views info
// TODO: allow override from config and cmdline
concurrency = 10
)

// Interface abstracts either a side-effect imported driver or a binary
Expand Down Expand Up @@ -106,17 +104,17 @@ type TableColumnTypeTranslator interface {

// Tables returns the metadata for all tables, minus the tables
// specified in the blacklist.
func Tables(c Constructor, schema string, whitelist, blacklist []string) ([]Table, error) {
func Tables(c Constructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) {
var err error
var ret []Table

ret, err = tables(c, schema, whitelist, blacklist)
ret, err = tables(c, schema, whitelist, blacklist, concurrency)
if err != nil {
return nil, errors.Wrap(err, "unable to load tables")
}

if vc, ok := c.(ViewConstructor); ok {
v, err := views(vc, schema, whitelist, blacklist)
v, err := views(vc, schema, whitelist, blacklist, concurrency)
if err != nil {
return nil, errors.Wrap(err, "unable to load views")
}
Expand All @@ -126,7 +124,7 @@ func Tables(c Constructor, schema string, whitelist, blacklist []string) ([]Tabl
return ret, nil
}

func tables(c Constructor, schema string, whitelist, blacklist []string) ([]Table, error) {
func tables(c Constructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) {
var err error

names, err := c.TableNames(schema, whitelist, blacklist)
Expand Down Expand Up @@ -216,7 +214,7 @@ func table(c Constructor, schema string, name string, whitelist, blacklist []str

// views returns the metadata for all views, minus the views
// specified in the blacklist.
func views(c ViewConstructor, schema string, whitelist, blacklist []string) ([]Table, error) {
func views(c ViewConstructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) {
var err error

names, err := c.ViewNames(schema, whitelist, blacklist)
Expand Down
2 changes: 1 addition & 1 deletion drivers/mocks/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (m *MockDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, er
whitelist, _ := config.StringSlice(drivers.ConfigWhitelist)
blacklist, _ := config.StringSlice(drivers.ConfigBlacklist)

dbinfo.Tables, err = drivers.Tables(m, schema, whitelist, blacklist)
dbinfo.Tables, err = drivers.Tables(m, schema, whitelist, blacklist, 1)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion drivers/sqlboiler-mssql/driver/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (m *MSSQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
schema := config.DefaultString(drivers.ConfigSchema, "dbo")
whitelist, _ := config.StringSlice(drivers.ConfigWhitelist)
blacklist, _ := config.StringSlice(drivers.ConfigBlacklist)
concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.ConfigDefaultConcurrency)

m.connStr = MSSQLBuildQueryString(user, pass, dbname, host, port, sslmode)
m.conn, err = sql.Open("mssql", m.connStr)
Expand Down Expand Up @@ -110,7 +111,7 @@ func (m *MSSQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
UseCaseWhenExistsClause: true,
},
}
dbinfo.Tables, err = drivers.Tables(m, schema, whitelist, blacklist)
dbinfo.Tables, err = drivers.Tables(m, schema, whitelist, blacklist, concurrency)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion drivers/sqlboiler-mysql/driver/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func (m *MySQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
schema := dbname
whitelist, _ := config.StringSlice(drivers.ConfigWhitelist)
blacklist, _ := config.StringSlice(drivers.ConfigBlacklist)
concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.ConfigDefaultConcurrency)

tinyIntAsIntIntf, ok := config["tinyint_as_int"]
if ok {
Expand Down Expand Up @@ -116,7 +117,7 @@ func (m *MySQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
},
}

dbinfo.Tables, err = drivers.Tables(m, schema, whitelist, blacklist)
dbinfo.Tables, err = drivers.Tables(m, schema, whitelist, blacklist, concurrency)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion drivers/sqlboiler-psql/driver/psql.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func (p *PostgresDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo
schema := config.DefaultString(drivers.ConfigSchema, "public")
whitelist, _ := config.StringSlice(drivers.ConfigWhitelist)
blacklist, _ := config.StringSlice(drivers.ConfigBlacklist)
concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.ConfigDefaultConcurrency)

useSchema := schema != "public"

Expand Down Expand Up @@ -122,7 +123,7 @@ func (p *PostgresDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo
UseDefaultKeyword: true,
},
}
dbinfo.Tables, err = drivers.Tables(p, schema, whitelist, blacklist)
dbinfo.Tables, err = drivers.Tables(p, schema, whitelist, blacklist, concurrency)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion drivers/sqlboiler-sqlite3/driver/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (s SQLiteDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
dbname := config.MustString(drivers.ConfigDBName)
whitelist, _ := config.StringSlice(drivers.ConfigWhitelist)
blacklist, _ := config.StringSlice(drivers.ConfigBlacklist)
concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.ConfigDefaultConcurrency)

s.connStr = SQLiteBuildQueryString(dbname)
s.dbConn, err = sql.Open("sqlite", s.connStr)
Expand All @@ -95,7 +96,7 @@ func (s SQLiteDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
},
}

dbinfo.Tables, err = drivers.Tables(s, "", whitelist, blacklist)
dbinfo.Tables, err = drivers.Tables(s, "", whitelist, blacklist, concurrency)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit a5fd3f0

Please sign in to comment.