diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 162bb9c1af..1dd236dfdc 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -503,7 +503,19 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { } return nil, ErrPoolClosed case poolPaused: - err := poolClearedError{err: p.lastClearErr, address: p.address} + // Wrap poolCleared in a driver.Error so we can add the + // "TransientTransactionError" label. This will add + // "TransientTransactionError" to all poolClearedError instances, not + // just those that happened during transactions. While that behavior is + // different than other places we add "TransientTransactionError", it is + // consistent with the Transactions specification and simplifies the + // code. + pcErr := poolClearedError{err: p.lastClearErr, address: p.address} + err := driver.Error{ + Message: pcErr.Error(), + Labels: []string{driver.TransientTransactionError}, + Wrapped: pcErr, + } p.stateMu.RUnlock() duration := time.Since(start) diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index f58e1cf204..17e803ea49 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -21,6 +21,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/eventtest" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/address" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" ) @@ -1584,3 +1585,26 @@ func TestPool_PoolMonitor(t *testing.T) { "expected ConnectionCheckOutFailed Duration to be set") }) } + +func TestPool_Error(t *testing.T) { + t.Parallel() + + t.Run("should have TransientTransactionError", func(t *testing.T) { + t.Parallel() + + p := newPool(poolConfig{}) + assert.Equalf(t, poolPaused, p.getState(), "expected new pool to be paused") + + // Since new pool is paused, checkout should throw PoolClearedError. + _, err := p.checkOut(context.Background()) + var le driver.Error + if errors.As(err, &le) { + assert.ErrorIs(t, poolClearedError{}, le.Unwrap(), "expect error to be PoolClearedError") + assert.True(t, le.HasErrorLabel(driver.TransientTransactionError), `expected error to include the "TransientTransactionError" label`) + } else { + t.Errorf("expected labeled error, got %v", err) + } + + p.close(context.Background()) + }) +}