Skip to content

Commit c60eac6

Browse files
umegbewejuliusmh
authored andcommitted
fix: client prompt return on context cancellation (prometheus#1729)
* fix: client prompt return on context cancellation Signed-off-by: Umegbewe Nwebedu <[email protected]> * test: add context cancellation unit test Signed-off-by: Umegbewe Nwebedu <[email protected]> * fix/rid unused package Signed-off-by: Umegbewe Nwebedu <[email protected]> * fix/lint and formatting Signed-off-by: Umegbewe Nwebedu <[email protected]> --------- Signed-off-by: Umegbewe Nwebedu <[email protected]>
1 parent 88c5272 commit c60eac6

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

api/client.go

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"bytes"
1919
"context"
2020
"errors"
21-
"io"
2221
"net"
2322
"net/http"
2423
"net/url"
@@ -132,36 +131,26 @@ func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response,
132131
req = req.WithContext(ctx)
133132
}
134133
resp, err := c.client.Do(req)
135-
defer func() {
136-
if resp != nil {
137-
_, _ = io.Copy(io.Discard, resp.Body)
138-
_ = resp.Body.Close()
139-
}
140-
}()
141-
142134
if err != nil {
143135
return nil, nil, err
144136
}
145137

146138
var body []byte
147-
done := make(chan struct{})
139+
done := make(chan error, 1)
148140
go func() {
149141
var buf bytes.Buffer
150-
// TODO(bwplotka): Add LimitReader for too long err messages (e.g. limit by 1KB)
151-
_, err = buf.ReadFrom(resp.Body)
142+
_, err := buf.ReadFrom(resp.Body)
152143
body = buf.Bytes()
153-
close(done)
144+
done <- err
154145
}()
155146

156147
select {
157148
case <-ctx.Done():
149+
resp.Body.Close()
158150
<-done
159-
err = resp.Body.Close()
160-
if err == nil {
161-
err = ctx.Err()
162-
}
163-
case <-done:
151+
return resp, nil, ctx.Err()
152+
case err = <-done:
153+
resp.Body.Close()
154+
return resp, body, err
164155
}
165-
166-
return resp, body, err
167156
}

api/client_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ package api
1616
import (
1717
"bytes"
1818
"context"
19+
"errors"
1920
"fmt"
2021
"net/http"
2122
"net/http/httptest"
2223
"net/url"
2324
"testing"
25+
"time"
2426
)
2527

2628
func TestConfig(t *testing.T) {
@@ -116,6 +118,52 @@ func TestClientURL(t *testing.T) {
116118
}
117119
}
118120

121+
func TestDoContextCancellation(t *testing.T) {
122+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123+
_, _ = w.Write([]byte("partial"))
124+
if f, ok := w.(http.Flusher); ok {
125+
f.Flush()
126+
}
127+
128+
<-r.Context().Done()
129+
}))
130+
131+
defer ts.Close()
132+
133+
client, err := NewClient(Config{
134+
Address: ts.URL,
135+
})
136+
if err != nil {
137+
t.Fatalf("failed to create client: %v", err)
138+
}
139+
140+
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
141+
if err != nil {
142+
t.Fatalf("failed to create request: %v", err)
143+
}
144+
145+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
146+
defer cancel()
147+
148+
start := time.Now()
149+
resp, body, err := client.Do(ctx, req)
150+
elapsed := time.Since(start)
151+
152+
if !errors.Is(err, context.DeadlineExceeded) {
153+
t.Errorf("expected error %v, got: %v", context.DeadlineExceeded, err)
154+
}
155+
if body != nil {
156+
t.Errorf("expected no body due to cancellation, got: %q", string(body))
157+
}
158+
if elapsed > 200*time.Millisecond {
159+
t.Errorf("Do did not return promptly on cancellation: took %v", elapsed)
160+
}
161+
162+
if resp != nil && resp.Body != nil {
163+
resp.Body.Close()
164+
}
165+
}
166+
119167
// Serve any http request with a response of N KB of spaces.
120168
type serveSpaces struct {
121169
sizeKB int

0 commit comments

Comments
 (0)