diff --git a/database/config.go b/database/config.go index dbe833b..6d9cfcb 100644 --- a/database/config.go +++ b/database/config.go @@ -2,9 +2,12 @@ package database import ( "fmt" + "sync/atomic" "time" + "gorm.io/gorm" gormLogger "gorm.io/gorm/logger" + "gorm.io/plugin/dbresolver" ) type LogLevel string @@ -55,6 +58,11 @@ type DBConfig struct { // ConnPool is the connection pool settings for the database connection. // This is optional and can be set to nil if the default connection pool settings are sufficient. ConnPool *DBConnPool `mapstructure:"conn_pool"` + + // ResolverPolicy is the policy for selecting database connection (ConnPool) from Write sources or Read replicas in gorm. + // this policy can be rewritten by custom policy with WithResolverPolicy option function, by default random policy is used + // available values are "default", "random", "round_robin" + ResolverPolicy PolicyName `mapstructure:"resolver_policy"` } var ( @@ -64,7 +72,53 @@ var ( defaultConnMaxLifetime = time.Duration(0) ) -func (cfg *DBConfig) applyDefaultValue() { +type options struct { + ResolverPolicy dbresolver.Policy +} +type PolicyName string + +const ( + DefaultPolicyName PolicyName = "default" + RandomPolicyName PolicyName = "random" + RoundRobinPolicyName PolicyName = "round_robin" + StrictRoundRobinPolicyName PolicyName = "strict_round_robin" +) + +var ( + RandomPolicy = dbresolver.RandomPolicy{} + RoundRobinPolicy = NewRoundRobinPolicy() + StrictRoundRobinPolicy = NewStrictRoundRobinPolicy() +) + +var resolverPolicies = map[PolicyName]dbresolver.Policy{ + DefaultPolicyName: RandomPolicy, + RandomPolicyName: RandomPolicy, + RoundRobinPolicyName: RoundRobinPolicy, + StrictRoundRobinPolicyName: StrictRoundRobinPolicy, +} + +type PolicyFunc func([]gorm.ConnPool) gorm.ConnPool + +func (f PolicyFunc) Resolve(connPools []gorm.ConnPool) gorm.ConnPool { + return f(connPools) +} + +func NewRoundRobinPolicy() dbresolver.Policy { + var i int + return PolicyFunc(func(connPools []gorm.ConnPool) gorm.ConnPool { + i = (i + 1) % len(connPools) + return connPools[i] + }) +} + +func NewStrictRoundRobinPolicy() dbresolver.Policy { + var i int64 + return PolicyFunc(func(connPools []gorm.ConnPool) gorm.ConnPool { + return connPools[int(atomic.AddInt64(&i, 1))%len(connPools)] + }) +} + +func (cfg *DBConfig) applyDefaultValue() options { if cfg.LogLevel == "" { cfg.LogLevel = LogLevelError } @@ -78,4 +132,17 @@ func (cfg *DBConfig) applyDefaultValue() { ConnMaxLifetime: defaultConnMaxLifetime, } } + if cfg.ResolverPolicy == "" { + cfg.ResolverPolicy = DefaultPolicyName + } + var ( + opt options + ok bool + ) + opt.ResolverPolicy, ok = resolverPolicies[cfg.ResolverPolicy] + if !ok { + opt.ResolverPolicy = resolverPolicies[DefaultPolicyName] + } + + return opt } diff --git a/database/config_test.go b/database/config_test.go new file mode 100644 index 0000000..b85bbfe --- /dev/null +++ b/database/config_test.go @@ -0,0 +1,54 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestApplyDefaultValue(t *testing.T) { + defaultConnPool := &DBConnPool{ + MaxIdleConns: defaultMaxIdleConns, + ConnMaxIdleTime: defaultConnMaxIdleTime, + MaxOpenConns: defaultMaxOpenConns, + ConnMaxLifetime: defaultConnMaxLifetime, + } + cases := map[string]struct { + cfg DBConfig + expected options + expectedCfg DBConfig + }{ + "1. Default": { + cfg: DBConfig{}, + expected: options{ + ResolverPolicy: RandomPolicy, + }, + expectedCfg: DBConfig{ + LogLevel: LogLevelError, + ResolverPolicy: DefaultPolicyName, + ConnPool: defaultConnPool, + }, + }, + "2. Configured": { + cfg: DBConfig{ + LogLevel: LogLevelInfo, + ResolverPolicy: RoundRobinPolicyName, + }, + expectedCfg: DBConfig{ + LogLevel: LogLevelInfo, + ResolverPolicy: RoundRobinPolicyName, + ConnPool: defaultConnPool, + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + cfg := tc.cfg + opt := cfg.applyDefaultValue() + + assert.NotNil(t, opt.ResolverPolicy) + assert.EqualValues(t, tc.expectedCfg, cfg) + }) + } +} diff --git a/database/db.go b/database/db.go index a0685b6..dad2ca7 100644 --- a/database/db.go +++ b/database/db.go @@ -48,7 +48,7 @@ type DBGetter struct { // following a write operation may not see the updated data if it is executed on a different read-only replica // that has not yet been updated with the new data. func NewDBGetter(cfg DBConfig) (*DBGetter, error) { - cfg.applyDefaultValue() + opt := cfg.applyDefaultValue() logLevel, err := newLogLevelFromString(cfg.LogLevel) if err != nil { @@ -80,6 +80,7 @@ func NewDBGetter(cfg DBConfig) (*DBGetter, error) { Sources: []gorm.Dialector{postgres.Open(cfg.Url)}, Replicas: replicas, TraceResolverMode: logLevel == gormLogger.Info, + Policy: opt.ResolverPolicy, }).SetConnMaxIdleTime(cfg.ConnPool.ConnMaxIdleTime). SetConnMaxLifetime(cfg.ConnPool.ConnMaxLifetime). SetMaxIdleConns(cfg.ConnPool.MaxIdleConns).