Skip to content

Commit 88e698b

Browse files
benjirewisprestonvasquez
authored andcommitted
GODRIVER-2658 Better guard against nil pinned connections. (#1153)
1 parent dba8e42 commit 88e698b

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,10 @@ func (c initConnection) SupportsStreaming() bool {
577577
}
578578

579579
// Connection implements the driver.Connection interface to allow reading and writing wire
580-
// messages and the driver.Expirable interface to allow expiring.
580+
// messages and the driver.Expirable interface to allow expiring. It wraps an underlying
581+
// topology.connection to make it more goroutine-safe and nil-safe.
581582
type Connection struct {
582-
*connection
583+
connection *connection
583584
refCount int
584585
cleanupPoolFn func()
585586

@@ -601,7 +602,7 @@ func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
601602
if c.connection == nil {
602603
return ErrConnectionClosed
603604
}
604-
return c.writeWireMessage(ctx, wm)
605+
return c.connection.writeWireMessage(ctx, wm)
605606
}
606607

607608
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
@@ -612,7 +613,7 @@ func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, e
612613
if c.connection == nil {
613614
return dst, ErrConnectionClosed
614615
}
615-
return c.readWireMessage(ctx, dst)
616+
return c.connection.readWireMessage(ctx, dst)
616617
}
617618

618619
// CompressWireMessage handles compressing the provided wire message using the underlying
@@ -658,7 +659,7 @@ func (c *Connection) Description() description.Server {
658659
if c.connection == nil {
659660
return description.Server{}
660661
}
661-
return c.desc
662+
return c.connection.desc
662663
}
663664

664665
// Close returns this connection to the connection pool. This method may not closeConnection the underlying
@@ -681,12 +682,12 @@ func (c *Connection) Expire() error {
681682
return nil
682683
}
683684

684-
_ = c.close()
685+
_ = c.connection.close()
685686
return c.cleanupReferences()
686687
}
687688

688689
func (c *Connection) cleanupReferences() error {
689-
err := c.pool.checkIn(c.connection)
690+
err := c.connection.pool.checkIn(c.connection)
690691
if c.cleanupPoolFn != nil {
691692
c.cleanupPoolFn()
692693
c.cleanupPoolFn = nil
@@ -711,14 +712,22 @@ func (c *Connection) ID() string {
711712
if c.connection == nil {
712713
return "<closed>"
713714
}
714-
return c.id
715+
return c.connection.id
716+
}
717+
718+
// ServerConnectionID returns the server connection ID of this connection.
719+
func (c *Connection) ServerConnectionID() *int32 {
720+
if c.connection == nil {
721+
return nil
722+
}
723+
return c.connection.serverConnectionID
715724
}
716725

717726
// Stale returns if the connection is stale.
718727
func (c *Connection) Stale() bool {
719728
c.mu.RLock()
720729
defer c.mu.RUnlock()
721-
return c.pool.stale(c.connection)
730+
return c.connection.pool.stale(c.connection)
722731
}
723732

724733
// Address returns the address of this connection.
@@ -728,27 +737,27 @@ func (c *Connection) Address() address.Address {
728737
if c.connection == nil {
729738
return address.Address("0.0.0.0")
730739
}
731-
return c.addr
740+
return c.connection.addr
732741
}
733742

734743
// LocalAddress returns the local address of the connection
735744
func (c *Connection) LocalAddress() address.Address {
736745
c.mu.RLock()
737746
defer c.mu.RUnlock()
738-
if c.connection == nil || c.nc == nil {
747+
if c.connection == nil || c.connection.nc == nil {
739748
return address.Address("0.0.0.0")
740749
}
741-
return address.Address(c.nc.LocalAddr().String())
750+
return address.Address(c.connection.nc.LocalAddr().String())
742751
}
743752

744753
// PinToCursor updates this connection to reflect that it is pinned to a cursor.
745754
func (c *Connection) PinToCursor() error {
746-
return c.pin("cursor", c.pool.pinConnectionToCursor, c.pool.unpinConnectionFromCursor)
755+
return c.pin("cursor", c.connection.pool.pinConnectionToCursor, c.connection.pool.unpinConnectionFromCursor)
747756
}
748757

749758
// PinToTransaction updates this connection to reflect that it is pinned to a transaction.
750759
func (c *Connection) PinToTransaction() error {
751-
return c.pin("transaction", c.pool.pinConnectionToTransaction, c.pool.unpinConnectionFromTransaction)
760+
return c.pin("transaction", c.connection.pool.pinConnectionToTransaction, c.connection.pool.unpinConnectionFromTransaction)
752761
}
753762

754763
func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {

x/mongo/driver/topology/connection_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,12 @@ func TestConnection(t *testing.T) {
762762
if !cmp.Equal(got, want) {
763763
t.Errorf("LocalAddresses do not match. got %v; want %v", got, want)
764764
}
765+
766+
want = (*int32)(nil)
767+
got = conn.ServerConnectionID()
768+
if !cmp.Equal(got, want) {
769+
t.Errorf("ServerConnectionIDs do not match. got %v; want %v", got, want)
770+
}
765771
})
766772

767773
t.Run("pinning", func(t *testing.T) {

0 commit comments

Comments
 (0)