diff --git a/bin/memos/main.go b/bin/memos/main.go index 6450a73ad9855..6c0fdb2584fdc 100644 --- a/bin/memos/main.go +++ b/bin/memos/main.go @@ -36,19 +36,21 @@ var ( Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`, Run: func(_ *cobra.Command, _ []string) { instanceProfile := &profile.Profile{ - Mode: viper.GetString("mode"), - Addr: viper.GetString("addr"), - Port: viper.GetInt("port"), - Data: viper.GetString("data"), - Driver: viper.GetString("driver"), - DSN: viper.GetString("dsn"), - InstanceURL: viper.GetString("instance-url"), - Version: version.GetCurrentVersion(viper.GetString("mode")), + Mode: viper.GetString("mode"), + Addr: viper.GetString("addr"), + Port: viper.GetInt("port"), + Data: viper.GetString("data"), + Driver: viper.GetString("driver"), + DSN: viper.GetString("dsn"), + InstanceURL: viper.GetString("instance-url"), + Version: version.GetCurrentVersion(viper.GetString("mode")), + DBMaxOpenConns: viper.GetInt("max-open-conns"), + DBMaxIdleConns: viper.GetInt("max-idle-conns"), + DBConnMaxLifetime: viper.GetDuration("conn-max-lifetime"), } if err := instanceProfile.Validate(); err != nil { panic(err) } - ctx, cancel := context.WithCancel(context.Background()) dbDriver, err := db.NewDBDriver(instanceProfile) if err != nil { @@ -110,6 +112,9 @@ func init() { rootCmd.PersistentFlags().String("driver", "sqlite", "database driver") rootCmd.PersistentFlags().String("dsn", "", "database source name(aka. DSN)") rootCmd.PersistentFlags().String("instance-url", "", "the url of your memos instance") + rootCmd.PersistentFlags().Int("max-open-conns", 0, "maximum number of open database connections") + rootCmd.PersistentFlags().Int("max-idle-conns", 2, "maximum number of connections in the idle connection pool") + rootCmd.PersistentFlags().Duration("conn-max-lifetime", 0, "maximum amount of time a connection may be reused") if err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")); err != nil { panic(err) @@ -132,12 +137,30 @@ func init() { if err := viper.BindPFlag("instance-url", rootCmd.PersistentFlags().Lookup("instance-url")); err != nil { panic(err) } + if err := viper.BindPFlag("max-open-conns", rootCmd.PersistentFlags().Lookup("max-open-conns")); err != nil { + panic(err) + } + if err := viper.BindPFlag("max-idle-conns", rootCmd.PersistentFlags().Lookup("max-idle-conns")); err != nil { + panic(err) + } + if err := viper.BindPFlag("conn-max-lifetime", rootCmd.PersistentFlags().Lookup("conn-max-lifetime")); err != nil { + panic(err) + } viper.SetEnvPrefix("memos") viper.AutomaticEnv() if err := viper.BindEnv("instance-url", "MEMOS_INSTANCE_URL"); err != nil { panic(err) } + if err := viper.BindEnv("max-open-conns", "MEMOS_MAX_OPEN_CONNS"); err != nil { + panic(err) + } + if err := viper.BindEnv("max-idle-conns", "MEMOS_MAX_IDLE_CONNS"); err != nil { + panic(err) + } + if err := viper.BindEnv("conn-max-lifetime", "MEMOS_CONN_MAX_LIFETIME"); err != nil { + panic(err) + } } func printGreetings(profile *profile.Profile) { @@ -153,8 +176,11 @@ addr: %s port: %d mode: %s driver: %s +max-open-conns: %d +max-idle-conns: %d +conn-max-lifetime: %s --- -`, profile.Version, profile.Data, profile.Addr, profile.Port, profile.Mode, profile.Driver) +`, profile.Version, profile.Data, profile.Addr, profile.Port, profile.Mode, profile.Driver, profile.DBMaxOpenConns, profile.DBMaxIdleConns, profile.DBConnMaxLifetime) print(greetingBanner) if len(profile.Addr) == 0 { diff --git a/server/profile/profile.go b/server/profile/profile.go index e6bbaf28b7457..13af60c8b6c20 100644 --- a/server/profile/profile.go +++ b/server/profile/profile.go @@ -7,6 +7,7 @@ import ( "path/filepath" "runtime" "strings" + "time" "github.com/pkg/errors" ) @@ -30,6 +31,12 @@ type Profile struct { Version string // InstanceURL is the url of your memos instance. InstanceURL string + // DBMaxOpenConns is the maximum number of open connections to the database. + DBMaxOpenConns int + // DBMaxIdleConns is the maximum number of idle connections to the database. + DBMaxIdleConns int + // DBConnMaxLifetime is the maximum amount of time a connection may be reused. + DBConnMaxLifetime time.Duration } func (p *Profile) IsDev() bool { @@ -76,7 +83,7 @@ func (p *Profile) Validate() error { dataDir, err := checkDataDir(p.Data) if err != nil { - slog.Error("failed to check dsn", slog.String("data", dataDir), slog.String("error", err.Error())) + slog.Error("failed to check data directory", slog.String("data", dataDir), slog.String("error", err.Error())) return err } diff --git a/store/db/db.go b/store/db/db.go index 47a369385e542..a9b099c114c29 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -28,5 +28,9 @@ func NewDBDriver(profile *profile.Profile) (store.Driver, error) { if err != nil { return nil, errors.Wrap(err, "failed to create db driver") } + + driver.GetDB().SetMaxOpenConns(profile.DBMaxOpenConns) + driver.GetDB().SetMaxIdleConns(profile.DBMaxIdleConns) + driver.GetDB().SetConnMaxLifetime(profile.DBConnMaxLifetime) return driver, nil } diff --git a/store/test/store_bench_test.go b/store/test/store_bench_test.go new file mode 100644 index 0000000000000..7de40138b642f --- /dev/null +++ b/store/test/store_bench_test.go @@ -0,0 +1,134 @@ +package teststore + +import ( + "context" + "runtime" + "testing" + "time" + + "github.com/lithammer/shortuuid/v4" + + "github.com/usememos/memos/store" +) + +// BenchmarkDB groups all database benchmarks. +func BenchmarkDB(b *testing.B) { + b.Run("BenchmarkDBConnPool", BenchmarkDBConnPool) +} + +// benchmarkConfig defines the configuration for benchmark testing. +type benchmarkConfig struct { + maxOpenConns int + maxIdleConns int + connMaxLifetime *time.Duration +} + +// benchmarkConnectionPool tests the performance of sql.DB connection pooling. +func BenchmarkDBConnPool(b *testing.B) { + cores := runtime.NumCPU() + lifeTime := time.Hour + cases := []struct { + name string + config benchmarkConfig + }{ + { + name: "default_unlimited", + config: benchmarkConfig{ + maxOpenConns: 0, // Use default value 0 (unlimited) + maxIdleConns: 2, // Use default value 2 + connMaxLifetime: nil, // Use default value 0 (unlimited) + }, + }, + { + name: "max_conns_equals_cores", + config: benchmarkConfig{ + maxOpenConns: cores, + maxIdleConns: cores / 2, + connMaxLifetime: &lifeTime, + }, + }, + { + name: "max_conns_double_cores", + config: benchmarkConfig{ + maxOpenConns: cores * 2, + maxIdleConns: cores, + connMaxLifetime: &lifeTime, + }, + }, + { + name: "max_conns_25", + config: benchmarkConfig{ + maxOpenConns: 25, + maxIdleConns: 10, + connMaxLifetime: &lifeTime, + }, + }, + { + name: "max_conns_50", + config: benchmarkConfig{ + maxOpenConns: 50, + maxIdleConns: 25, + connMaxLifetime: &lifeTime, + }, + }, + { + name: "max_conns_100", + config: benchmarkConfig{ + maxOpenConns: 100, + maxIdleConns: 50, + connMaxLifetime: &lifeTime, + }, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + ctx := context.Background() + ts := NewTestingStore(ctx, &testing.T{}) + db := ts.GetDriver().GetDB() + + db.SetMaxOpenConns(tc.config.maxOpenConns) + db.SetMaxIdleConns(tc.config.maxIdleConns) + + if tc.config.connMaxLifetime != nil { + db.SetConnMaxLifetime(*tc.config.connMaxLifetime) + } + + user, err := createTestingHostUser(ctx, ts) + if err != nil { + b.Logf("failed to create testing host user: %v", err) + } + + // Set concurrency level + b.SetParallelism(100) + b.ResetTimer() + + // Record initial stats + startStats := db.Stats() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Execute database operation + memoCreate := &store.Memo{ + UID: shortuuid.New(), + CreatorID: user.ID, + Content: "test_content", + Visibility: store.Public, + } + _, err := ts.CreateMemo(ctx, memoCreate) + if err != nil { + b.Fatal("failed to create memo:", err) + } + } + }) + + // Collect and report connection pool statistics + endStats := db.Stats() + // b.ReportMetric(float64(endStats.MaxOpenConnections), "max_open_conns") + // b.ReportMetric(float64(endStats.InUse), "conns_in_use") + // b.ReportMetric(float64(endStats.Idle), "idle_conns") + b.ReportMetric(float64(endStats.WaitCount-startStats.WaitCount), "wait_count") + b.ReportMetric(float64((endStats.WaitDuration - startStats.WaitDuration).Milliseconds()), "wait_duration_ms") + }) + } +}