Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package database

import (
"fmt"
"sync/atomic"
"time"

"gorm.io/gorm"
gormLogger "gorm.io/gorm/logger"
"gorm.io/plugin/dbresolver"
)

type LogLevel string
Expand Down Expand Up @@ -55,6 +58,11 @@ type DBConfig struct {
// ConnPool is the connection pool settings for the database connection.
// This is optional and can be set to nil if the default connection pool settings are sufficient.
ConnPool *DBConnPool `mapstructure:"conn_pool"`

// ResolverPolicy is the policy for selecting database connection (ConnPool) from Write sources or Read replicas in gorm.
// this policy can be rewritten by custom policy with WithResolverPolicy option function, by default random policy is used
// available values are "default", "random", "round_robin"
ResolverPolicy PolicyName `mapstructure:"resolver_policy"`
}

var (
Expand All @@ -64,7 +72,53 @@ var (
defaultConnMaxLifetime = time.Duration(0)
)

func (cfg *DBConfig) applyDefaultValue() {
type options struct {
ResolverPolicy dbresolver.Policy
}
type PolicyName string

const (
DefaultPolicyName PolicyName = "default"
RandomPolicyName PolicyName = "random"
RoundRobinPolicyName PolicyName = "round_robin"
StrictRoundRobinPolicyName PolicyName = "strict_round_robin"
)

var (
RandomPolicy = dbresolver.RandomPolicy{}
RoundRobinPolicy = NewRoundRobinPolicy()
StrictRoundRobinPolicy = NewStrictRoundRobinPolicy()
)

var resolverPolicies = map[PolicyName]dbresolver.Policy{
DefaultPolicyName: RandomPolicy,
RandomPolicyName: RandomPolicy,
RoundRobinPolicyName: RoundRobinPolicy,
StrictRoundRobinPolicyName: StrictRoundRobinPolicy,
}

type PolicyFunc func([]gorm.ConnPool) gorm.ConnPool

func (f PolicyFunc) Resolve(connPools []gorm.ConnPool) gorm.ConnPool {
return f(connPools)
}

func NewRoundRobinPolicy() dbresolver.Policy {
var i int
return PolicyFunc(func(connPools []gorm.ConnPool) gorm.ConnPool {
i = (i + 1) % len(connPools)
return connPools[i]
})
}

func NewStrictRoundRobinPolicy() dbresolver.Policy {
var i int64
return PolicyFunc(func(connPools []gorm.ConnPool) gorm.ConnPool {
return connPools[int(atomic.AddInt64(&i, 1))%len(connPools)]
})
}

func (cfg *DBConfig) applyDefaultValue() options {
if cfg.LogLevel == "" {
cfg.LogLevel = LogLevelError
}
Expand All @@ -78,4 +132,17 @@ func (cfg *DBConfig) applyDefaultValue() {
ConnMaxLifetime: defaultConnMaxLifetime,
}
}
if cfg.ResolverPolicy == "" {
cfg.ResolverPolicy = DefaultPolicyName
}
var (
opt options
ok bool
)
opt.ResolverPolicy, ok = resolverPolicies[cfg.ResolverPolicy]
if !ok {
opt.ResolverPolicy = resolverPolicies[DefaultPolicyName]
}

return opt
}
54 changes: 54 additions & 0 deletions database/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package database

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestApplyDefaultValue(t *testing.T) {
defaultConnPool := &DBConnPool{
MaxIdleConns: defaultMaxIdleConns,
ConnMaxIdleTime: defaultConnMaxIdleTime,
MaxOpenConns: defaultMaxOpenConns,
ConnMaxLifetime: defaultConnMaxLifetime,
}
cases := map[string]struct {
cfg DBConfig
expected options
expectedCfg DBConfig
}{
"1. Default": {
cfg: DBConfig{},
expected: options{
ResolverPolicy: RandomPolicy,
},
expectedCfg: DBConfig{
LogLevel: LogLevelError,
ResolverPolicy: DefaultPolicyName,
ConnPool: defaultConnPool,
},
},
"2. Configured": {
cfg: DBConfig{
LogLevel: LogLevelInfo,
ResolverPolicy: RoundRobinPolicyName,
},
expectedCfg: DBConfig{
LogLevel: LogLevelInfo,
ResolverPolicy: RoundRobinPolicyName,
ConnPool: defaultConnPool,
},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
cfg := tc.cfg
opt := cfg.applyDefaultValue()

assert.NotNil(t, opt.ResolverPolicy)
assert.EqualValues(t, tc.expectedCfg, cfg)
})
}
}
3 changes: 2 additions & 1 deletion database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type DBGetter struct {
// following a write operation may not see the updated data if it is executed on a different read-only replica
// that has not yet been updated with the new data.
func NewDBGetter(cfg DBConfig) (*DBGetter, error) {
cfg.applyDefaultValue()
opt := cfg.applyDefaultValue()

logLevel, err := newLogLevelFromString(cfg.LogLevel)
if err != nil {
Expand Down Expand Up @@ -80,6 +80,7 @@ func NewDBGetter(cfg DBConfig) (*DBGetter, error) {
Sources: []gorm.Dialector{postgres.Open(cfg.Url)},
Replicas: replicas,
TraceResolverMode: logLevel == gormLogger.Info,
Policy: opt.ResolverPolicy,
}).SetConnMaxIdleTime(cfg.ConnPool.ConnMaxIdleTime).
SetConnMaxLifetime(cfg.ConnPool.ConnMaxLifetime).
SetMaxIdleConns(cfg.ConnPool.MaxIdleConns).
Expand Down
Loading