@@ -8,12 +8,14 @@ package topology
88
99import (
1010 "context"
11+ "math"
1112 "sync"
1213 "sync/atomic"
1314 "time"
1415
1516 "go.mongodb.org/mongo-driver/event"
1617 "go.mongodb.org/mongo-driver/x/mongo/driver/address"
18+ "golang.org/x/sync/semaphore"
1719)
1820
1921// ErrPoolConnected is returned from an attempt to connect an already connected pool
@@ -67,6 +69,7 @@ type pool struct {
6769 connected int32 // Must be accessed using the sync/atomic package.
6870 nextid uint64
6971 opened map [uint64 ]* connection // opened holds all of the currently open connections.
72+ sem * semaphore.Weighted
7073 sync.Mutex
7174}
7275
@@ -116,6 +119,7 @@ func connectionCloseFunc(v interface{}) {
116119 return
117120 }
118121
122+ _ = c .pool .removeConnection (c )
119123 go func () {
120124 _ = c .pool .closeConnection (c )
121125 }()
@@ -141,16 +145,23 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
141145 opts = append (opts , WithIdleTimeout (func (_ time.Duration ) time.Duration { return config .MaxIdleTime }))
142146 }
143147
148+ var maxConns = config .MaxPoolSize
149+ if maxConns == 0 {
150+ maxConns = math .MaxInt64
151+ }
152+
144153 pool := & pool {
145154 address : config .Address ,
146155 monitor : config .PoolMonitor ,
147156 connected : disconnected ,
148157 opened : make (map [uint64 ]* connection ),
149158 opts : opts ,
159+ sem : semaphore .NewWeighted (int64 (maxConns )),
150160 }
151161
152162 // we do not pass in config.MaxPoolSize because we manage the max size at this level rather than the resource pool level
153163 rpc := resourcePoolConfig {
164+ MaxSize : maxConns ,
154165 MinSize : config .MinPoolSize ,
155166 MaintainInterval : maintainInterval ,
156167 ExpiredFn : connectionExpiredFunc ,
@@ -162,7 +173,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
162173 pool .monitor .Event (& event.PoolEvent {
163174 Type : event .PoolCreated ,
164175 PoolOptions : & event.MonitorPoolOptions {
165- MaxPoolSize : config . MaxPoolSize ,
176+ MaxPoolSize : rpc . MaxSize ,
166177 MinPoolSize : rpc .MinSize ,
167178 WaitQueueTimeoutMS : uint64 (config .MaxIdleTime ) / uint64 (time .Millisecond ),
168179 },
@@ -249,9 +260,11 @@ func (p *pool) disconnect(ctx context.Context) error {
249260 Reason : event .ReasonPoolClosed ,
250261 })
251262 }
263+ _ = p .removeConnection (pc )
252264 _ = p .closeConnection (pc ) // We don't care about errors while closing the connection.
253265 }
254266 atomic .StoreInt32 (& p .connected , disconnected )
267+ p .conns .clearTotal ()
255268
256269 if p .monitor != nil {
257270 p .monitor .Event (& event.PoolEvent {
@@ -305,7 +318,6 @@ func (p *pool) makeNewConnection(ctx context.Context) (*connection, string, erro
305318
306319// Checkout returns a connection from the pool
307320func (p * pool ) get (ctx context.Context ) (* connection , error ) {
308-
309321 if ctx == nil {
310322 ctx = context .Background ()
311323 }
@@ -321,81 +333,120 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
321333 return nil , ErrPoolDisconnected
322334 }
323335
324- connVal := p .conns .Get ()
325- if c , ok := connVal .(* connection ); ok && connVal != nil {
326- // call connect if not connected
327- if atomic .LoadInt32 (& c .connected ) == initialized {
328- c .connect (ctx )
336+ err := p .sem .Acquire (ctx , 1 )
337+ if err != nil {
338+ if p .monitor != nil {
339+ p .monitor .Event (& event.PoolEvent {
340+ Type : event .GetFailed ,
341+ Address : p .address .String (),
342+ Reason : event .ReasonTimedOut ,
343+ })
329344 }
345+ return nil , ErrWaitQueueTimeout
346+ }
330347
331- err := c .wait ()
332- if err != nil {
348+ // This loop is so that we don't end up with more than maxPoolSize connections if p.conns.Maintain runs between
349+ // calling p.conns.Get() and making the new connection
350+ for {
351+ if atomic .LoadInt32 (& p .connected ) != connected {
333352 if p .monitor != nil {
334353 p .monitor .Event (& event.PoolEvent {
335354 Type : event .GetFailed ,
336355 Address : p .address .String (),
337- Reason : event .ReasonConnectionErrored ,
356+ Reason : event .ReasonPoolClosed ,
338357 })
339358 }
340- return nil , err
359+ p .sem .Release (1 )
360+ return nil , ErrPoolDisconnected
341361 }
342362
343- if p .monitor != nil {
344- p .monitor .Event (& event.PoolEvent {
345- Type : event .GetSucceeded ,
346- Address : p .address .String (),
347- ConnectionID : c .poolID ,
348- })
349- }
350- return c , nil
351- }
363+ connVal := p .conns .Get ()
364+ if c , ok := connVal .(* connection ); ok && connVal != nil {
365+ // call connect if not connected
366+ if atomic .LoadInt32 (& c .connected ) == initialized {
367+ c .connect (ctx )
368+ }
352369
353- select {
354- case <- ctx .Done ():
355- if p .monitor != nil {
356- p .monitor .Event (& event.PoolEvent {
357- Type : event .GetFailed ,
358- Address : p .address .String (),
359- Reason : event .ReasonTimedOut ,
360- })
361- }
362- return nil , ctx .Err ()
363- default :
364- c , reason , err := p .makeNewConnection (ctx )
370+ err := c .wait ()
371+ if err != nil {
372+ if p .monitor != nil {
373+ p .monitor .Event (& event.PoolEvent {
374+ Type : event .GetFailed ,
375+ Address : p .address .String (),
376+ Reason : event .ReasonConnectionErrored ,
377+ })
378+ }
379+ p .conns .decrementTotal ()
380+ p .sem .Release (1 )
381+ return nil , err
382+ }
365383
366- if err != nil {
367384 if p .monitor != nil {
368385 p .monitor .Event (& event.PoolEvent {
369- Type : event .GetFailed ,
370- Address : p .address .String (),
371- Reason : reason ,
386+ Type : event .GetSucceeded ,
387+ Address : p .address .String (),
388+ ConnectionID : c . poolID ,
372389 })
373390 }
374- return nil , err
391+ return c , nil
375392 }
376393
377- c .connect (ctx )
378- // wait for conn to be connected
379- err = c .wait ()
380- if err != nil {
394+ select {
395+ case <- ctx .Done ():
381396 if p .monitor != nil {
382397 p .monitor .Event (& event.PoolEvent {
383398 Type : event .GetFailed ,
384399 Address : p .address .String (),
385- Reason : reason ,
400+ Reason : event . ReasonTimedOut ,
386401 })
387402 }
388- return nil , err
389- }
403+ p .sem .Release (1 )
404+ return nil , ctx .Err ()
405+ default :
406+ made := p .conns .incrementTotal ()
407+ if ! made {
408+ continue
409+ }
410+ c , reason , err := p .makeNewConnection (ctx )
411+
412+ if err != nil {
413+ if p .monitor != nil {
414+ p .monitor .Event (& event.PoolEvent {
415+ Type : event .GetFailed ,
416+ Address : p .address .String (),
417+ Reason : reason ,
418+ })
419+ }
420+ p .sem .Release (1 )
421+ p .conns .decrementTotal ()
422+ return nil , err
423+ }
390424
391- if p .monitor != nil {
392- p .monitor .Event (& event.PoolEvent {
393- Type : event .GetSucceeded ,
394- Address : p .address .String (),
395- ConnectionID : c .poolID ,
396- })
425+ c .connect (ctx )
426+ // wait for conn to be connected
427+ err = c .wait ()
428+ if err != nil {
429+ if p .monitor != nil {
430+ p .monitor .Event (& event.PoolEvent {
431+ Type : event .GetFailed ,
432+ Address : p .address .String (),
433+ Reason : reason ,
434+ })
435+ }
436+ p .sem .Release (1 )
437+ p .conns .decrementTotal ()
438+ return nil , err
439+ }
440+
441+ if p .monitor != nil {
442+ p .monitor .Event (& event.PoolEvent {
443+ Type : event .GetSucceeded ,
444+ Address : p .address .String (),
445+ ConnectionID : c .poolID ,
446+ })
447+ }
448+ return c , nil
397449 }
398- return c , nil
399450 }
400451}
401452
@@ -405,9 +456,6 @@ func (p *pool) closeConnection(c *connection) error {
405456 if c .pool != p {
406457 return ErrWrongPool
407458 }
408- p .Lock ()
409- delete (p .opened , c .poolID )
410- p .Unlock ()
411459
412460 if atomic .LoadInt32 (& c .connected ) == connected {
413461 c .closeConnectContext ()
@@ -441,8 +489,10 @@ func (p *pool) removeConnection(c *connection) error {
441489}
442490
443491// put returns a connection to this pool. If the pool is connected, the connection is not
444- // stale, and there is space in the cache, the connection is returned to the cache.
492+ // stale, and there is space in the cache, the connection is returned to the cache. This
493+ // assumes that the connection has already been counted in p.conns.totalSize.
445494func (p * pool ) put (c * connection ) error {
495+ defer p .sem .Release (1 )
446496 if p .monitor != nil {
447497 var cid uint64
448498 var addr string
0 commit comments