Skip to content

Commit

Permalink
contrib/database/sql: Close DB Stats goroutine on db.Close() (#3025)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtoffl01 authored Feb 3, 2025
1 parent 0a41ffd commit dbfb8f2
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 15 deletions.
33 changes: 20 additions & 13 deletions contrib/database/sql/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,27 @@ var interval = 10 * time.Second

// pollDBStats calls (*DB).Stats on the db at a predetermined interval. It pushes the DBStats off to the statsd client.
// the caller should always ensure that db & statsd are non-nil
func pollDBStats(statsd internal.StatsdClient, db *sql.DB) {
func pollDBStats(statsd internal.StatsdClient, db *sql.DB, stop chan struct{}) {
log.Debug("DB stats will be gathered and sent every %v.", interval)
for range time.NewTicker(interval).C {
log.Debug("Reporting DB.Stats metrics...")
stat := db.Stats()
statsd.Gauge(MaxOpenConnections, float64(stat.MaxOpenConnections), []string{}, 1)
statsd.Gauge(OpenConnections, float64(stat.OpenConnections), []string{}, 1)
statsd.Gauge(InUse, float64(stat.InUse), []string{}, 1)
statsd.Gauge(Idle, float64(stat.Idle), []string{}, 1)
statsd.Gauge(WaitCount, float64(stat.WaitCount), []string{}, 1)
statsd.Timing(WaitDuration, stat.WaitDuration, []string{}, 1)
statsd.Gauge(MaxIdleClosed, float64(stat.MaxIdleClosed), []string{}, 1)
statsd.Gauge(MaxIdleTimeClosed, float64(stat.MaxIdleTimeClosed), []string{}, 1)
statsd.Gauge(MaxLifetimeClosed, float64(stat.MaxLifetimeClosed), []string{}, 1)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
log.Debug("Reporting DB.Stats metrics...")
stat := db.Stats()
statsd.Gauge(MaxOpenConnections, float64(stat.MaxOpenConnections), []string{}, 1)
statsd.Gauge(OpenConnections, float64(stat.OpenConnections), []string{}, 1)
statsd.Gauge(InUse, float64(stat.InUse), []string{}, 1)
statsd.Gauge(Idle, float64(stat.Idle), []string{}, 1)
statsd.Gauge(WaitCount, float64(stat.WaitCount), []string{}, 1)
statsd.Timing(WaitDuration, stat.WaitDuration, []string{}, 1)
statsd.Gauge(MaxIdleClosed, float64(stat.MaxIdleClosed), []string{}, 1)
statsd.Gauge(MaxIdleTimeClosed, float64(stat.MaxIdleTimeClosed), []string{}, 1)
statsd.Gauge(MaxLifetimeClosed, float64(stat.MaxLifetimeClosed), []string{}, 1)
case <-stop:
return
}
}
}

Expand Down
23 changes: 23 additions & 0 deletions contrib/database/sql/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
package sql

import (
"sync"
"testing"

"github.com/DataDog/datadog-go/v5/statsd"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig"
)

Expand Down Expand Up @@ -64,3 +68,22 @@ func TestStatsTags(t *testing.T) {
})
resetGlobalConfig()
}

func TestPollDBStatsStop(t *testing.T) {
driverName := "postgres"
Register(driverName, &pq.Driver{}, WithServiceName("postgres-test"), WithAnalyticsRate(0.2))
defer unregister(driverName)
db, err := Open(driverName, "postgres://postgres:[email protected]:5432/postgres?sslmode=disable")
require.NoError(t, err)
defer db.Close()

var wg sync.WaitGroup
stop := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
pollDBStats(&statsd.NoOpClientDirect{}, db, stop)
}()
close(stop)
wg.Wait()
}
11 changes: 10 additions & 1 deletion contrib/database/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ type tracedConnector struct {
connector driver.Connector
driverName string
cfg *config
dbClose chan struct{}
}

func (t *tracedConnector) Connect(ctx context.Context) (driver.Conn, error) {
Expand Down Expand Up @@ -171,6 +172,13 @@ func (t *tracedConnector) Driver() driver.Driver {
return t.connector.Driver()
}

// Close closes the dbClose channel
// This method will be invoked when DB.Close() is called, which we expect to occur only once: https://cs.opensource.google/go/go/+/refs/tags/go1.23.4:src/database/sql/sql.go;l=918-950
func (t *tracedConnector) Close() error {
close(t.dbClose)
return nil
}

// from Go stdlib implementation of sql.Open
type dsnConnector struct {
dsn string
Expand Down Expand Up @@ -208,10 +216,11 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB {
connector: c,
driverName: driverName,
cfg: cfg,
dbClose: make(chan struct{}),
}
db := sql.OpenDB(tc)
if cfg.dbStats && cfg.statsdClient != nil {
go pollDBStats(cfg.statsdClient, db)
go pollDBStats(cfg.statsdClient, db, tc.dbClose)
}
return db
}
Expand Down
8 changes: 7 additions & 1 deletion contrib/database/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,13 @@ func TestOpenOptions(t *testing.T) {
var tg statsdtest.TestStatsdClient
Register(driverName, &pq.Driver{})
defer unregister(driverName)
_, err := Open(driverName, dsn, withStatsdClient(&tg), WithDBStats())
db, err := Open(driverName, dsn, withStatsdClient(&tg), WithDBStats())
require.NoError(t, err)

// The polling interval has been reduced to 500ms for the sake of this test, so at least one round of `pollDBStats` should be complete in 1s
deadline := time.Now().Add(1 * time.Second)
wantStats := []string{MaxOpenConnections, OpenConnections, InUse, Idle, WaitCount, WaitDuration, MaxIdleClosed, MaxIdleTimeClosed, MaxLifetimeClosed}
var calls1 []string
for {
if time.Now().After(deadline) {
t.Fatalf("Stats not collected in expected interval of %v", interval)
Expand All @@ -300,11 +301,16 @@ func TestOpenOptions(t *testing.T) {
}
}
// all expected stats have been collected; exit out of loop, test should pass
calls1 = calls
break
}
// not all stats have been collected yet, try again in 50ms
time.Sleep(50 * time.Millisecond)
}
// Close DB and assert the no further stats have been collected; db.Close should stop the pollDBStats goroutine.
db.Close()
time.Sleep(50 * time.Millisecond)
assert.Equal(t, calls1, tg.CallNames())
})
}

Expand Down
1 change: 1 addition & 0 deletions contrib/jackc/pgx.v5/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var interval = 10 * time.Second

// pollPoolStats calls (*pgxpool).Stats on the pool at a predetermined interval. It pushes the pool Stats off to the statsd client.
func pollPoolStats(statsd internal.StatsdClient, pool *pgxpool.Pool) {
// TODO: Create stop condition for pgx on db.Close
log.Debug("contrib/jackc/pgx.v5: Traced pool connection found: Pool stats will be gathered and sent every %v.", interval)
for range time.NewTicker(interval).C {
log.Debug("contrib/jackc/pgx.v5: Reporting pgxpool.Stat metrics...")
Expand Down

0 comments on commit dbfb8f2

Please sign in to comment.