Skip to content

Commit dbfb8f2

Browse files
authored
contrib/database/sql: Close DB Stats goroutine on db.Close() (#3025)
1 parent 0a41ffd commit dbfb8f2

File tree

5 files changed

+61
-15
lines changed

5 files changed

+61
-15
lines changed

contrib/database/sql/metrics.go

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,27 @@ var interval = 10 * time.Second
3333

3434
// pollDBStats calls (*DB).Stats on the db at a predetermined interval. It pushes the DBStats off to the statsd client.
3535
// the caller should always ensure that db & statsd are non-nil
36-
func pollDBStats(statsd internal.StatsdClient, db *sql.DB) {
36+
func pollDBStats(statsd internal.StatsdClient, db *sql.DB, stop chan struct{}) {
3737
log.Debug("DB stats will be gathered and sent every %v.", interval)
38-
for range time.NewTicker(interval).C {
39-
log.Debug("Reporting DB.Stats metrics...")
40-
stat := db.Stats()
41-
statsd.Gauge(MaxOpenConnections, float64(stat.MaxOpenConnections), []string{}, 1)
42-
statsd.Gauge(OpenConnections, float64(stat.OpenConnections), []string{}, 1)
43-
statsd.Gauge(InUse, float64(stat.InUse), []string{}, 1)
44-
statsd.Gauge(Idle, float64(stat.Idle), []string{}, 1)
45-
statsd.Gauge(WaitCount, float64(stat.WaitCount), []string{}, 1)
46-
statsd.Timing(WaitDuration, stat.WaitDuration, []string{}, 1)
47-
statsd.Gauge(MaxIdleClosed, float64(stat.MaxIdleClosed), []string{}, 1)
48-
statsd.Gauge(MaxIdleTimeClosed, float64(stat.MaxIdleTimeClosed), []string{}, 1)
49-
statsd.Gauge(MaxLifetimeClosed, float64(stat.MaxLifetimeClosed), []string{}, 1)
38+
ticker := time.NewTicker(interval)
39+
defer ticker.Stop()
40+
for {
41+
select {
42+
case <-ticker.C:
43+
log.Debug("Reporting DB.Stats metrics...")
44+
stat := db.Stats()
45+
statsd.Gauge(MaxOpenConnections, float64(stat.MaxOpenConnections), []string{}, 1)
46+
statsd.Gauge(OpenConnections, float64(stat.OpenConnections), []string{}, 1)
47+
statsd.Gauge(InUse, float64(stat.InUse), []string{}, 1)
48+
statsd.Gauge(Idle, float64(stat.Idle), []string{}, 1)
49+
statsd.Gauge(WaitCount, float64(stat.WaitCount), []string{}, 1)
50+
statsd.Timing(WaitDuration, stat.WaitDuration, []string{}, 1)
51+
statsd.Gauge(MaxIdleClosed, float64(stat.MaxIdleClosed), []string{}, 1)
52+
statsd.Gauge(MaxIdleTimeClosed, float64(stat.MaxIdleTimeClosed), []string{}, 1)
53+
statsd.Gauge(MaxLifetimeClosed, float64(stat.MaxLifetimeClosed), []string{}, 1)
54+
case <-stop:
55+
return
56+
}
5057
}
5158
}
5259

contrib/database/sql/metrics_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
package sql
77

88
import (
9+
"sync"
910
"testing"
1011

12+
"github.com/DataDog/datadog-go/v5/statsd"
13+
"github.com/lib/pq"
1114
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
1216
"gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig"
1317
)
1418

@@ -64,3 +68,22 @@ func TestStatsTags(t *testing.T) {
6468
})
6569
resetGlobalConfig()
6670
}
71+
72+
func TestPollDBStatsStop(t *testing.T) {
73+
driverName := "postgres"
74+
Register(driverName, &pq.Driver{}, WithServiceName("postgres-test"), WithAnalyticsRate(0.2))
75+
defer unregister(driverName)
76+
db, err := Open(driverName, "postgres://postgres:[email protected]:5432/postgres?sslmode=disable")
77+
require.NoError(t, err)
78+
defer db.Close()
79+
80+
var wg sync.WaitGroup
81+
stop := make(chan struct{})
82+
wg.Add(1)
83+
go func() {
84+
defer wg.Done()
85+
pollDBStats(&statsd.NoOpClientDirect{}, db, stop)
86+
}()
87+
close(stop)
88+
wg.Wait()
89+
}

contrib/database/sql/sql.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ type tracedConnector struct {
139139
connector driver.Connector
140140
driverName string
141141
cfg *config
142+
dbClose chan struct{}
142143
}
143144

144145
func (t *tracedConnector) Connect(ctx context.Context) (driver.Conn, error) {
@@ -171,6 +172,13 @@ func (t *tracedConnector) Driver() driver.Driver {
171172
return t.connector.Driver()
172173
}
173174

175+
// Close closes the dbClose channel
176+
// This method will be invoked when DB.Close() is called, which we expect to occur only once: https://cs.opensource.google/go/go/+/refs/tags/go1.23.4:src/database/sql/sql.go;l=918-950
177+
func (t *tracedConnector) Close() error {
178+
close(t.dbClose)
179+
return nil
180+
}
181+
174182
// from Go stdlib implementation of sql.Open
175183
type dsnConnector struct {
176184
dsn string
@@ -208,10 +216,11 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB {
208216
connector: c,
209217
driverName: driverName,
210218
cfg: cfg,
219+
dbClose: make(chan struct{}),
211220
}
212221
db := sql.OpenDB(tc)
213222
if cfg.dbStats && cfg.statsdClient != nil {
214-
go pollDBStats(cfg.statsdClient, db)
223+
go pollDBStats(cfg.statsdClient, db, tc.dbClose)
215224
}
216225
return db
217226
}

contrib/database/sql/sql_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,13 @@ func TestOpenOptions(t *testing.T) {
281281
var tg statsdtest.TestStatsdClient
282282
Register(driverName, &pq.Driver{})
283283
defer unregister(driverName)
284-
_, err := Open(driverName, dsn, withStatsdClient(&tg), WithDBStats())
284+
db, err := Open(driverName, dsn, withStatsdClient(&tg), WithDBStats())
285285
require.NoError(t, err)
286286

287287
// The polling interval has been reduced to 500ms for the sake of this test, so at least one round of `pollDBStats` should be complete in 1s
288288
deadline := time.Now().Add(1 * time.Second)
289289
wantStats := []string{MaxOpenConnections, OpenConnections, InUse, Idle, WaitCount, WaitDuration, MaxIdleClosed, MaxIdleTimeClosed, MaxLifetimeClosed}
290+
var calls1 []string
290291
for {
291292
if time.Now().After(deadline) {
292293
t.Fatalf("Stats not collected in expected interval of %v", interval)
@@ -300,11 +301,16 @@ func TestOpenOptions(t *testing.T) {
300301
}
301302
}
302303
// all expected stats have been collected; exit out of loop, test should pass
304+
calls1 = calls
303305
break
304306
}
305307
// not all stats have been collected yet, try again in 50ms
306308
time.Sleep(50 * time.Millisecond)
307309
}
310+
// Close DB and assert the no further stats have been collected; db.Close should stop the pollDBStats goroutine.
311+
db.Close()
312+
time.Sleep(50 * time.Millisecond)
313+
assert.Equal(t, calls1, tg.CallNames())
308314
})
309315
}
310316

contrib/jackc/pgx.v5/metrics.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ var interval = 10 * time.Second
3535

3636
// pollPoolStats calls (*pgxpool).Stats on the pool at a predetermined interval. It pushes the pool Stats off to the statsd client.
3737
func pollPoolStats(statsd internal.StatsdClient, pool *pgxpool.Pool) {
38+
// TODO: Create stop condition for pgx on db.Close
3839
log.Debug("contrib/jackc/pgx.v5: Traced pool connection found: Pool stats will be gathered and sent every %v.", interval)
3940
for range time.NewTicker(interval).C {
4041
log.Debug("contrib/jackc/pgx.v5: Reporting pgxpool.Stat metrics...")

0 commit comments

Comments
 (0)