Skip to content

feat: configure database connection pool settings #4633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions bin/memos/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,21 @@ var (
Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
Run: func(_ *cobra.Command, _ []string) {
instanceProfile := &profile.Profile{
Mode: viper.GetString("mode"),
Addr: viper.GetString("addr"),
Port: viper.GetInt("port"),
Data: viper.GetString("data"),
Driver: viper.GetString("driver"),
DSN: viper.GetString("dsn"),
InstanceURL: viper.GetString("instance-url"),
Version: version.GetCurrentVersion(viper.GetString("mode")),
Mode: viper.GetString("mode"),
Addr: viper.GetString("addr"),
Port: viper.GetInt("port"),
Data: viper.GetString("data"),
Driver: viper.GetString("driver"),
DSN: viper.GetString("dsn"),
InstanceURL: viper.GetString("instance-url"),
Version: version.GetCurrentVersion(viper.GetString("mode")),
DBMaxOpenConns: viper.GetInt("max-open-conns"),
DBMaxIdleConns: viper.GetInt("max-idle-conns"),
DBConnMaxLifetime: viper.GetDuration("conn-max-lifetime"),
}
if err := instanceProfile.Validate(); err != nil {
panic(err)
}

ctx, cancel := context.WithCancel(context.Background())
dbDriver, err := db.NewDBDriver(instanceProfile)
if err != nil {
Expand Down Expand Up @@ -110,6 +112,9 @@ func init() {
rootCmd.PersistentFlags().String("driver", "sqlite", "database driver")
rootCmd.PersistentFlags().String("dsn", "", "database source name(aka. DSN)")
rootCmd.PersistentFlags().String("instance-url", "", "the url of your memos instance")
rootCmd.PersistentFlags().Int("max-open-conns", 0, "maximum number of open database connections")
rootCmd.PersistentFlags().Int("max-idle-conns", 2, "maximum number of connections in the idle connection pool")
rootCmd.PersistentFlags().Duration("conn-max-lifetime", 0, "maximum amount of time a connection may be reused")

if err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")); err != nil {
panic(err)
Expand All @@ -132,12 +137,30 @@ func init() {
if err := viper.BindPFlag("instance-url", rootCmd.PersistentFlags().Lookup("instance-url")); err != nil {
panic(err)
}
if err := viper.BindPFlag("max-open-conns", rootCmd.PersistentFlags().Lookup("max-open-conns")); err != nil {
panic(err)
}
if err := viper.BindPFlag("max-idle-conns", rootCmd.PersistentFlags().Lookup("max-idle-conns")); err != nil {
panic(err)
}
if err := viper.BindPFlag("conn-max-lifetime", rootCmd.PersistentFlags().Lookup("conn-max-lifetime")); err != nil {
panic(err)
}

viper.SetEnvPrefix("memos")
viper.AutomaticEnv()
if err := viper.BindEnv("instance-url", "MEMOS_INSTANCE_URL"); err != nil {
panic(err)
}
if err := viper.BindEnv("max-open-conns", "MEMOS_MAX_OPEN_CONNS"); err != nil {
panic(err)
}
if err := viper.BindEnv("max-idle-conns", "MEMOS_MAX_IDLE_CONNS"); err != nil {
panic(err)
}
if err := viper.BindEnv("conn-max-lifetime", "MEMOS_CONN_MAX_LIFETIME"); err != nil {
panic(err)
}
}

func printGreetings(profile *profile.Profile) {
Expand All @@ -153,8 +176,11 @@ addr: %s
port: %d
mode: %s
driver: %s
max-open-conns: %d
max-idle-conns: %d
conn-max-lifetime: %s
---
`, profile.Version, profile.Data, profile.Addr, profile.Port, profile.Mode, profile.Driver)
`, profile.Version, profile.Data, profile.Addr, profile.Port, profile.Mode, profile.Driver, profile.DBMaxOpenConns, profile.DBMaxIdleConns, profile.DBConnMaxLifetime)

print(greetingBanner)
if len(profile.Addr) == 0 {
Expand Down
9 changes: 8 additions & 1 deletion server/profile/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"runtime"
"strings"
"time"

"github.com/pkg/errors"
)
Expand All @@ -30,6 +31,12 @@ type Profile struct {
Version string
// InstanceURL is the url of your memos instance.
InstanceURL string
// DBMaxOpenConns is the maximum number of open connections to the database.
DBMaxOpenConns int
// DBMaxIdleConns is the maximum number of idle connections to the database.
DBMaxIdleConns int
// DBConnMaxLifetime is the maximum amount of time a connection may be reused.
DBConnMaxLifetime time.Duration
}

func (p *Profile) IsDev() bool {
Expand Down Expand Up @@ -76,7 +83,7 @@ func (p *Profile) Validate() error {

dataDir, err := checkDataDir(p.Data)
if err != nil {
slog.Error("failed to check dsn", slog.String("data", dataDir), slog.String("error", err.Error()))
slog.Error("failed to check data directory", slog.String("data", dataDir), slog.String("error", err.Error()))
return err
}

Expand Down
4 changes: 4 additions & 0 deletions store/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,9 @@ func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
if err != nil {
return nil, errors.Wrap(err, "failed to create db driver")
}

driver.GetDB().SetMaxOpenConns(profile.DBMaxOpenConns)
driver.GetDB().SetMaxIdleConns(profile.DBMaxIdleConns)
driver.GetDB().SetConnMaxLifetime(profile.DBConnMaxLifetime)
return driver, nil
}
134 changes: 134 additions & 0 deletions store/test/store_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package teststore

import (
"context"
"runtime"
"testing"
"time"

"github.com/lithammer/shortuuid/v4"

"github.com/usememos/memos/store"
)

// BenchmarkDB groups all database benchmarks.
func BenchmarkDB(b *testing.B) {
b.Run("BenchmarkDBConnPool", BenchmarkDBConnPool)
}

// benchmarkConfig defines the configuration for benchmark testing.
type benchmarkConfig struct {
maxOpenConns int
maxIdleConns int
connMaxLifetime *time.Duration
}

// benchmarkConnectionPool tests the performance of sql.DB connection pooling.
func BenchmarkDBConnPool(b *testing.B) {
cores := runtime.NumCPU()
lifeTime := time.Hour
cases := []struct {
name string
config benchmarkConfig
}{
{
name: "default_unlimited",
config: benchmarkConfig{
maxOpenConns: 0, // Use default value 0 (unlimited)
maxIdleConns: 2, // Use default value 2
connMaxLifetime: nil, // Use default value 0 (unlimited)
},
},
{
name: "max_conns_equals_cores",
config: benchmarkConfig{
maxOpenConns: cores,
maxIdleConns: cores / 2,
connMaxLifetime: &lifeTime,
},
},
{
name: "max_conns_double_cores",
config: benchmarkConfig{
maxOpenConns: cores * 2,
maxIdleConns: cores,
connMaxLifetime: &lifeTime,
},
},
{
name: "max_conns_25",
config: benchmarkConfig{
maxOpenConns: 25,
maxIdleConns: 10,
connMaxLifetime: &lifeTime,
},
},
{
name: "max_conns_50",
config: benchmarkConfig{
maxOpenConns: 50,
maxIdleConns: 25,
connMaxLifetime: &lifeTime,
},
},
{
name: "max_conns_100",
config: benchmarkConfig{
maxOpenConns: 100,
maxIdleConns: 50,
connMaxLifetime: &lifeTime,
},
},
}

for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
ctx := context.Background()
ts := NewTestingStore(ctx, &testing.T{})
db := ts.GetDriver().GetDB()

db.SetMaxOpenConns(tc.config.maxOpenConns)
db.SetMaxIdleConns(tc.config.maxIdleConns)

if tc.config.connMaxLifetime != nil {
db.SetConnMaxLifetime(*tc.config.connMaxLifetime)
}

user, err := createTestingHostUser(ctx, ts)
if err != nil {
b.Logf("failed to create testing host user: %v", err)
}

// Set concurrency level
b.SetParallelism(100)
b.ResetTimer()

// Record initial stats
startStats := db.Stats()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// Execute database operation
memoCreate := &store.Memo{
UID: shortuuid.New(),
CreatorID: user.ID,
Content: "test_content",
Visibility: store.Public,
}
_, err := ts.CreateMemo(ctx, memoCreate)
if err != nil {
b.Fatal("failed to create memo:", err)
}
}
})

// Collect and report connection pool statistics
endStats := db.Stats()
// b.ReportMetric(float64(endStats.MaxOpenConnections), "max_open_conns")
// b.ReportMetric(float64(endStats.InUse), "conns_in_use")
// b.ReportMetric(float64(endStats.Idle), "idle_conns")
b.ReportMetric(float64(endStats.WaitCount-startStats.WaitCount), "wait_count")
b.ReportMetric(float64((endStats.WaitDuration - startStats.WaitDuration).Milliseconds()), "wait_duration_ms")
})
}
}