@@ -50,6 +50,7 @@ type cap struct {
50
50
// much simpler.
51
51
type Client struct {
52
52
* Conn
53
+ rwc io.ReadWriteCloser
53
54
config ClientConfig
54
55
55
56
// Internal state
@@ -63,9 +64,10 @@ type Client struct {
63
64
}
64
65
65
66
// NewClient creates a client given an io stream and a client config.
66
- func NewClient (rw io.ReadWriter , config ClientConfig ) * Client {
67
+ func NewClient (rwc io.ReadWriteCloser , config ClientConfig ) * Client {
67
68
c := & Client {
68
- Conn : NewConn (rw ),
69
+ Conn : NewConn (rwc ),
70
+ rwc : rwc ,
69
71
config : config ,
70
72
errChan : make (chan error , 1 ),
71
73
caps : make (map [string ]cap ),
@@ -238,25 +240,30 @@ func (c *Client) sendError(err error) {
238
240
}
239
241
}
240
242
241
- func (c * Client ) startReadLoop (wg * sync.WaitGroup ) {
243
+ func (c * Client ) startReadLoop (wg * sync.WaitGroup , exiting chan struct {} ) {
242
244
wg .Add (1 )
243
245
244
246
go func () {
245
247
defer wg .Done ()
246
248
247
249
for {
248
- m , err := c .ReadMessage ()
249
- if err != nil {
250
- c .sendError (err )
251
- break
252
- }
250
+ select {
251
+ case <- exiting :
252
+ return
253
+ default :
254
+ m , err := c .ReadMessage ()
255
+ if err != nil {
256
+ c .sendError (err )
257
+ break
258
+ }
253
259
254
- if f , ok := clientFilters [m .Command ]; ok {
255
- f (c , m )
256
- }
260
+ if f , ok := clientFilters [m .Command ]; ok {
261
+ f (c , m )
262
+ }
257
263
258
- if c .config .Handler != nil {
259
- c .config .Handler .Handle (c , m )
264
+ if c .config .Handler != nil {
265
+ c .config .Handler .Handle (c , m )
266
+ }
260
267
}
261
268
}
262
269
@@ -296,7 +303,7 @@ func (c *Client) RunContext(ctx context.Context) error {
296
303
297
304
// Now that the handshake is pretty much done, we can start listening for
298
305
// messages.
299
- c .startReadLoop (& wg )
306
+ c .startReadLoop (& wg , exiting )
300
307
301
308
// Wait for an error from any goroutine or for the context to time out, then
302
309
// signal we're exiting and wait for the goroutines to exit.
@@ -307,6 +314,7 @@ func (c *Client) RunContext(ctx context.Context) error {
307
314
}
308
315
309
316
close (exiting )
317
+ c .rwc .Close ()
310
318
wg .Wait ()
311
319
312
320
return err
0 commit comments