@@ -11,8 +11,12 @@ package topology
11
11
12
12
import (
13
13
"context"
14
+ "crypto/tls"
15
+ "crypto/x509"
14
16
"errors"
17
+ "io/ioutil"
15
18
"net"
19
+ "os"
16
20
"runtime"
17
21
"sync"
18
22
"sync/atomic"
@@ -49,6 +53,144 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n
49
53
return cnc , nil
50
54
}
51
55
56
+ type errorQueue struct {
57
+ errors []error
58
+ mutex sync.Mutex
59
+ }
60
+
61
+ func (eq * errorQueue ) head () error {
62
+ eq .mutex .Lock ()
63
+ defer eq .mutex .Unlock ()
64
+ if len (eq .errors ) > 0 {
65
+ return eq .errors [0 ]
66
+ }
67
+ return nil
68
+ }
69
+
70
+ func (eq * errorQueue ) dequeue () bool {
71
+ eq .mutex .Lock ()
72
+ defer eq .mutex .Unlock ()
73
+ if len (eq .errors ) > 0 {
74
+ eq .errors = eq .errors [1 :]
75
+ return true
76
+ }
77
+ return false
78
+ }
79
+
80
+ type timeoutConn struct {
81
+ net.Conn
82
+ errors * errorQueue
83
+ }
84
+
85
+ func (c * timeoutConn ) Read (b []byte ) (int , error ) {
86
+ n , err := 0 , c .errors .head ()
87
+ if err == nil {
88
+ n , err = c .Conn .Read (b )
89
+ }
90
+ return n , err
91
+ }
92
+
93
+ func (c * timeoutConn ) Write (b []byte ) (int , error ) {
94
+ n , err := 0 , c .errors .head ()
95
+ if err == nil {
96
+ n , err = c .Conn .Write (b )
97
+ }
98
+ return n , err
99
+ }
100
+
101
+ type timeoutDialer struct {
102
+ Dialer
103
+ errors * errorQueue
104
+ }
105
+
106
+ func (d * timeoutDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
107
+ c , e := d .Dialer .DialContext (ctx , network , address )
108
+
109
+ if caFile := os .Getenv ("MONGO_GO_DRIVER_CA_FILE" ); len (caFile ) > 0 {
110
+ pem , err := ioutil .ReadFile (caFile )
111
+ if err != nil {
112
+ return nil , err
113
+ }
114
+
115
+ ca := x509 .NewCertPool ()
116
+ if ! ca .AppendCertsFromPEM (pem ) {
117
+ return nil , errors .New ("unable to load CA file" )
118
+ }
119
+
120
+ config := & tls.Config {
121
+ InsecureSkipVerify : true ,
122
+ RootCAs : ca ,
123
+ }
124
+ c = tls .Client (c , config )
125
+ }
126
+ return & timeoutConn {c , d .errors }, e
127
+ }
128
+
129
+ // TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577.
130
+ func TestServerHeartbeatTimeout (t * testing.T ) {
131
+ networkTimeoutError := & net.DNSError {
132
+ IsTimeout : true ,
133
+ }
134
+
135
+ testCases := []struct {
136
+ desc string
137
+ ioErrors []error
138
+ expectPoolCleared bool
139
+ }{
140
+ {
141
+ desc : "one single timeout should not clear the pool" ,
142
+ ioErrors : []error {nil , networkTimeoutError , nil , networkTimeoutError , nil },
143
+ expectPoolCleared : false ,
144
+ },
145
+ {
146
+ desc : "continuous timeouts should clear the pool" ,
147
+ ioErrors : []error {nil , networkTimeoutError , networkTimeoutError , nil },
148
+ expectPoolCleared : true ,
149
+ },
150
+ }
151
+ for _ , tc := range testCases {
152
+ tc := tc
153
+ t .Run (tc .desc , func (t * testing.T ) {
154
+ t .Parallel ()
155
+
156
+ var wg sync.WaitGroup
157
+ wg .Add (1 )
158
+
159
+ errors := & errorQueue {errors : tc .ioErrors }
160
+ tpm := monitor .NewTestPoolMonitor ()
161
+ server := NewServer (
162
+ address .Address ("localhost:27017" ),
163
+ primitive .NewObjectID (),
164
+ WithConnectionPoolMonitor (func (* event.PoolMonitor ) * event.PoolMonitor {
165
+ return tpm .PoolMonitor
166
+ }),
167
+ WithConnectionOptions (func (opts ... ConnectionOption ) []ConnectionOption {
168
+ return append (opts ,
169
+ WithDialer (func (d Dialer ) Dialer {
170
+ var dialer net.Dialer
171
+ return & timeoutDialer {& dialer , errors }
172
+ }))
173
+ }),
174
+ WithServerMonitor (func (* event.ServerMonitor ) * event.ServerMonitor {
175
+ return & event.ServerMonitor {
176
+ ServerHeartbeatStarted : func (e * event.ServerHeartbeatStartedEvent ) {
177
+ if ! errors .dequeue () {
178
+ wg .Done ()
179
+ }
180
+ },
181
+ }
182
+ }),
183
+ WithHeartbeatInterval (func (time.Duration ) time.Duration {
184
+ return 200 * time .Millisecond
185
+ }),
186
+ )
187
+ require .NoError (t , server .Connect (nil ))
188
+ wg .Wait ()
189
+ assert .Equal (t , tc .expectPoolCleared , tpm .IsPoolCleared (), "expected pool cleared to be %v but was %v" , tc .expectPoolCleared , tpm .IsPoolCleared ())
190
+ })
191
+ }
192
+ }
193
+
52
194
// TestServerConnectionTimeout tests how different timeout errors are handled during connection
53
195
// creation and server handshake.
54
196
func TestServerConnectionTimeout (t * testing.T ) {
0 commit comments