@@ -13,6 +13,7 @@ import (
1313 "io"
1414 "net"
1515 "net/http"
16+ "sync/atomic"
1617 "testing"
1718 "time"
1819)
@@ -1280,39 +1281,58 @@ func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
12801281 t .Fatalf ("error creating listener: %v" , err )
12811282 }
12821283
1283- times := 0
1284+ var connectionCounter atomic. Int32
12841285
12851286 newLn := & Listener {
12861287 Listener : l ,
12871288 ConnPolicy : func (_ ConnPolicyOptions ) (Policy , error ) {
12881289 // Return the invalid upstream error on the first call, the listener
12891290 // should remain open and accepting.
1291+ times := connectionCounter .Load ()
12901292 if times == 0 {
1291- times ++
1293+ connectionCounter . Store ( times + 1 )
12921294 return REJECT , ErrInvalidUpstream
12931295 }
12941296
12951297 return REJECT , ErrNoProxyProtocol
12961298 },
12971299 }
12981300
1299- // Kick off the listener and capture any error.
1300- var listenerErr error
1301+ // Kick off the listener and return any error via the chanel.
1302+ errCh := make (chan error )
1303+ defer close (errCh )
13011304 go func (t * testing.T ) {
1302- _ , listenerErr = newLn .Accept ()
1305+ _ , err := newLn .Accept ()
1306+ errCh <- err
13031307 }(t )
13041308
13051309 // Make two calls to trigger the listener's accept, the first should experience
13061310 // the ErrInvalidUpstream and keep the listener open, the second should experience
13071311 // a different error which will cause the listener to close.
13081312 _ , _ = http .Get ("http://localhost:8080" )
1309- if listenerErr != nil {
1310- t .Fatalf ("invalid upstream shouldn't return an error: %v" , listenerErr )
1313+ // Wait a few seconds to ensure we didn't get anything back on our channel.
1314+ select {
1315+ case err := <- errCh :
1316+ if err != nil {
1317+ t .Fatalf ("invalid upstream shouldn't return an error: %v" , err )
1318+ }
1319+ case <- time .After (2 * time .Second ):
1320+ // No error returned (as expected, we're still listening though)
13111321 }
13121322
13131323 _ , _ = http .Get ("http://localhost:8080" )
1314- if listenerErr == nil {
1315- t .Fatalf ("errors other than invalid upstream should error" )
1324+ // Wait a few seconds before we fail the test as we should have received an
1325+ // error that was not invalid upstream.
1326+ select {
1327+ case err := <- errCh :
1328+ if err == nil {
1329+ t .Fatalf ("errors other than invalid upstream should error" )
1330+ }
1331+ if ! errors .Is (ErrNoProxyProtocol , err ) {
1332+ t .Fatalf ("unexpected error type: %v" , err )
1333+ }
1334+ case <- time .After (2 * time .Second ):
1335+ t .Fatalf ("timed out waiting for listener" )
13161336 }
13171337}
13181338
0 commit comments