Skip to content

Commit 5bf07c6

Browse files
authored
Merge pull request #66 from go-irc/exit-on-write-error
* Ensure all write errors cause the client to exit * Replace queuedWriteError with writeErrorChan to fix a race
2 parents 9106b7e + 4e3e991 commit 5bf07c6

File tree

3 files changed

+67
-30
lines changed

3 files changed

+67
-30
lines changed

client.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ func (c *Client) writeCallback(w *Writer, line string) error {
173173
}
174174

175175
_, err := w.writer.Write([]byte(line + "\r\n"))
176+
if err != nil {
177+
c.sendError(err)
178+
}
176179
return err
177180
}
178181

@@ -260,11 +263,7 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) {
260263
func (c *Client) handlePing(timestamp int64, pongChan chan struct{}, wg *sync.WaitGroup, exiting chan struct{}) {
261264
defer wg.Done()
262265

263-
err := c.Writef("PING :%d", timestamp)
264-
if err != nil {
265-
c.sendError(err)
266-
return
267-
}
266+
c.Writef("PING :%d", timestamp)
268267

269268
timer := time.NewTimer(c.config.PingTimeout)
270269
defer timer.Stop()

client_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,4 +403,15 @@ func TestPingLoop(t *testing.T) {
403403
SendLine("PONG :hello 6\r\n"),
404404
SendLine("PONG :hello 7\r\n"),
405405
})
406+
407+
// Successful ping with write error
408+
runClientTest(t, config, errors.New("test error"), nil, []TestAction{
409+
ExpectLine("PASS :test_pass\r\n"),
410+
ExpectLine("NICK :test_nick\r\n"),
411+
ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"),
412+
// We queue this up a line early because the next write will happen after the delay.
413+
QueueWriteError(errors.New("test error")),
414+
SendLine("001 :hello_world\r\n"),
415+
Delay(25 * time.Millisecond),
416+
})
406417
}

stream_test.go

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,36 @@ func Delay(delay time.Duration) TestAction {
8787
}
8888
}
8989

90+
func QueueReadError(err error) TestAction {
91+
return func(t *testing.T, rw *testReadWriter) {
92+
select {
93+
case rw.readErrorChan <- err:
94+
default:
95+
assert.Fail(t, "Tried to queue a second read error")
96+
}
97+
}
98+
}
99+
100+
func QueueWriteError(err error) TestAction {
101+
return func(t *testing.T, rw *testReadWriter) {
102+
select {
103+
case rw.writeErrorChan <- err:
104+
default:
105+
assert.Fail(t, "Tried to queue a second write error")
106+
}
107+
}
108+
}
109+
90110
type testReadWriter struct {
91-
actions []TestAction
92-
queuedWriteError error
93-
writeChan chan string
94-
queuedReadError error
95-
readChan chan string
96-
readEmptyChan chan struct{}
97-
exiting chan struct{}
98-
clientDone chan struct{}
99-
serverBuffer bytes.Buffer
111+
actions []TestAction
112+
writeErrorChan chan error
113+
writeChan chan string
114+
readErrorChan chan error
115+
readChan chan string
116+
readEmptyChan chan struct{}
117+
exiting chan struct{}
118+
clientDone chan struct{}
119+
serverBuffer bytes.Buffer
100120
}
101121

102122
func (rw *testReadWriter) maybeBroadcastEmpty() {
@@ -109,10 +129,11 @@ func (rw *testReadWriter) maybeBroadcastEmpty() {
109129
}
110130

111131
func (rw *testReadWriter) Read(buf []byte) (int, error) {
112-
if rw.queuedReadError != nil {
113-
err := rw.queuedReadError
114-
rw.queuedReadError = nil
132+
// Check for a read error first
133+
select {
134+
case err := <-rw.readErrorChan:
115135
return 0, err
136+
default:
116137
}
117138

118139
// If there's data left in the buffer, we want to use that first.
@@ -125,10 +146,12 @@ func (rw *testReadWriter) Read(buf []byte) (int, error) {
125146
return s, err
126147
}
127148

128-
// Read from server. We're either waiting for this whole test to
129-
// finish or for data to come in from the server buffer. We expect
130-
// only one read to be happening at once.
149+
// Read from server. We're waiting for this whole test to finish, data to
150+
// come in from the server buffer, or for an error. We expect only one read
151+
// to be happening at once.
131152
select {
153+
case err := <-rw.readErrorChan:
154+
return 0, err
132155
case data := <-rw.readChan:
133156
rw.serverBuffer.WriteString(data)
134157
s, err := rw.serverBuffer.Read(buf)
@@ -143,10 +166,10 @@ func (rw *testReadWriter) Read(buf []byte) (int, error) {
143166
}
144167

145168
func (rw *testReadWriter) Write(buf []byte) (int, error) {
146-
if rw.queuedWriteError != nil {
147-
err := rw.queuedWriteError
148-
rw.queuedWriteError = nil
169+
select {
170+
case err := <-rw.writeErrorChan:
149171
return 0, err
172+
default:
150173
}
151174

152175
// Write to server. We can cheat with this because we know things
@@ -161,12 +184,14 @@ func (rw *testReadWriter) Write(buf []byte) (int, error) {
161184

162185
func newTestReadWriter(actions []TestAction) *testReadWriter {
163186
return &testReadWriter{
164-
actions: actions,
165-
writeChan: make(chan string),
166-
readChan: make(chan string),
167-
readEmptyChan: make(chan struct{}, 1),
168-
exiting: make(chan struct{}),
169-
clientDone: make(chan struct{}),
187+
actions: actions,
188+
writeErrorChan: make(chan error, 1),
189+
writeChan: make(chan string),
190+
readErrorChan: make(chan error, 1),
191+
readChan: make(chan string),
192+
readEmptyChan: make(chan struct{}, 1),
193+
exiting: make(chan struct{}),
194+
clientDone: make(chan struct{}),
170195
}
171196
}
172197

@@ -197,8 +222,10 @@ func runTest(t *testing.T, rw *testReadWriter, actions []TestAction) {
197222

198223
// TODO: Make sure there are no more incoming messages
199224

200-
// Ask everything to shut down and wait for the client to stop.
225+
// Ask everything to shut down
201226
close(rw.exiting)
227+
228+
// Wait for the client to stop
202229
select {
203230
case <-rw.clientDone:
204231
case <-time.After(1 * time.Second):

0 commit comments

Comments
 (0)