Skip to content

Commit 61a733d

Browse files
authored
Merge pull request #159 from tooolbox/153-auth-refactor
Refactored buildHTTP to fix Auth with GetBody.
2 parents c909caa + 9593686 commit 61a733d

File tree

1 file changed

+106
-70
lines changed

1 file changed

+106
-70
lines changed

client/request.go

Lines changed: 106 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"bytes"
1919
"fmt"
2020
"io"
21-
"io/ioutil"
2221
"log"
2322
"mime/multipart"
2423
"net/http"
@@ -42,6 +41,7 @@ func newRequest(method, pathPattern string, writer runtime.ClientRequestWriter)
4241
header: make(http.Header),
4342
query: make(url.Values),
4443
timeout: DefaultTimeout,
44+
getBody: getRequestBuffer,
4545
}, nil
4646
}
4747

@@ -67,6 +67,8 @@ type request struct {
6767
payload interface{}
6868
timeout time.Duration
6969
buf *bytes.Buffer
70+
71+
getBody func(r *request) []byte
7072
}
7173

7274
var (
@@ -93,60 +95,34 @@ func (r *request) buildHTTP(mediaType, basePath string, producers map[string]run
9395
return nil, err
9496
}
9597

96-
if auth != nil {
97-
if err := auth.AuthenticateRequest(r, registry); err != nil {
98-
return nil, err
99-
}
100-
}
101-
102-
// create http request
103-
var reinstateSlash bool
104-
if r.pathPattern != "" && r.pathPattern != "/" && r.pathPattern[len(r.pathPattern)-1] == '/' {
105-
reinstateSlash = true
106-
}
107-
urlPath := path.Join(basePath, r.pathPattern)
108-
for k, v := range r.pathParams {
109-
urlPath = strings.Replace(urlPath, "{"+k+"}", url.PathEscape(v), -1)
110-
}
111-
if reinstateSlash {
112-
urlPath = urlPath + "/"
113-
}
114-
115-
var body io.ReadCloser
98+
// Our body must be an io.Reader.
99+
// When we create the http.Request, if we pass it a
100+
// bytes.Buffer then it will wrap it in an io.ReadCloser
101+
// and set the content length automatically.
102+
var body io.Reader
116103
var pr *io.PipeReader
117104
var pw *io.PipeWriter
118105

119106
r.buf = bytes.NewBuffer(nil)
120107
if r.payload != nil || len(r.formFields) > 0 || len(r.fileFields) > 0 {
121-
body = ioutil.NopCloser(r.buf)
108+
body = r.buf
122109
if r.isMultipart(mediaType) {
123110
pr, pw = io.Pipe()
124111
body = pr
125112
}
126113
}
127-
req, err := http.NewRequest(r.method, urlPath, body)
128-
129-
if err != nil {
130-
return nil, err
131-
}
132-
133-
req.URL.RawQuery = r.query.Encode()
134-
req.Header = r.header
135114

136115
// check if this is a form type request
137116
if len(r.formFields) > 0 || len(r.fileFields) > 0 {
138117
if !r.isMultipart(mediaType) {
139-
req.Header.Set(runtime.HeaderContentType, mediaType)
118+
r.header.Set(runtime.HeaderContentType, mediaType)
140119
formString := r.formFields.Encode()
141-
// set content length before writing to the buffer
142-
req.ContentLength = int64(len(formString))
143-
// write the form values as the body
144120
r.buf.WriteString(formString)
145-
return req, nil
121+
goto DoneChoosingBodySource
146122
}
147123

148124
mp := multipart.NewWriter(pw)
149-
req.Header.Set(runtime.HeaderContentType, mangleContentType(mediaType, mp.Boundary()))
125+
r.header.Set(runtime.HeaderContentType, mangleContentType(mediaType, mp.Boundary()))
150126

151127
go func() {
152128
defer func() {
@@ -184,65 +160,121 @@ func (r *request) buildHTTP(mediaType, basePath string, producers map[string]run
184160
}
185161

186162
}()
187-
return req, nil
188163

164+
goto DoneChoosingBodySource
189165
}
190166

191167
// if there is payload, use the producer to write the payload, and then
192168
// set the header to the content-type appropriate for the payload produced
193169
if r.payload != nil {
194170
// TODO: infer most appropriate content type based on the producer used,
195171
// and the `consumers` section of the spec/operation
196-
req.Header.Set(runtime.HeaderContentType, mediaType)
172+
r.header.Set(runtime.HeaderContentType, mediaType)
197173
if rdr, ok := r.payload.(io.ReadCloser); ok {
198-
req.Body = rdr
199-
200-
return req, nil
174+
body = rdr
175+
goto DoneChoosingBodySource
201176
}
202177

203178
if rdr, ok := r.payload.(io.Reader); ok {
204-
req.Body = ioutil.NopCloser(rdr)
179+
body = rdr
180+
goto DoneChoosingBodySource
181+
}
205182

206-
return req, nil
183+
producer := producers[mediaType]
184+
if err := producer.Produce(r.buf, r.payload); err != nil {
185+
return nil, err
207186
}
187+
}
208188

209-
req.GetBody = func() (io.ReadCloser, error) {
210-
var b bytes.Buffer
211-
producer := producers[mediaType]
212-
if err := producer.Produce(&b, r.payload); err != nil {
213-
return nil, err
214-
}
189+
DoneChoosingBodySource:
215190

216-
return ioutil.NopCloser(&b), nil
217-
}
191+
if runtime.CanHaveBody(r.method) && body == nil && r.header.Get(runtime.HeaderContentType) == "" {
192+
r.header.Set(runtime.HeaderContentType, mediaType)
193+
}
218194

219-
// set the content length of the request or else a chunked transfer is
220-
// declared, and this corrupts outgoing JSON payloads. the content's
221-
// length must be set prior to the body being written per the spec at
222-
// https://golang.org/pkg/net/http
195+
if auth != nil {
196+
197+
// If we're not using r.buf as our http.Request's body,
198+
// either the payload is an io.Reader or io.ReadCloser,
199+
// or we're doing a multipart form/file.
223200
//
224-
// If Body is present, Content-Length is <= 0 and TransferEncoding
225-
// hasn't been set to "identity", Write adds
226-
// "Transfer-Encoding: chunked" to the header. Body is closed
227-
// after it is sent.
201+
// In those cases, if the AuthenticateRequest call asks for the body,
202+
// we must read it into a buffer and provide that, then use that buffer
203+
// as the body of our http.Request.
228204
//
229-
// to that end a temporary buffer, b, is created to produce the payload
230-
// body, and then its size is used to set the request's content length
231-
var b bytes.Buffer
232-
producer := producers[mediaType]
233-
if err := producer.Produce(&b, r.payload); err != nil {
234-
return nil, err
205+
// This is done in-line with the GetBody() request rather than ahead
206+
// of time, because there's no way to know if the AuthenticateRequest
207+
// will even ask for the body of the request.
208+
//
209+
// If for some reason the copy fails, there's no way to return that
210+
// error to the GetBody() call, so return it afterwards.
211+
//
212+
// An error from the copy action is prioritized over any error
213+
// from the AuthenticateRequest call, because the mis-read
214+
// body may have interfered with the auth.
215+
//
216+
var copyErr error
217+
if buf, ok := body.(*bytes.Buffer); body != nil && (!ok || buf != r.buf) {
218+
219+
var copied bool
220+
r.getBody = func(r *request) []byte {
221+
222+
if copied {
223+
return getRequestBuffer(r)
224+
}
225+
226+
defer func() {
227+
copied = true
228+
}()
229+
230+
if _, copyErr = io.Copy(r.buf, body); copyErr != nil {
231+
return nil
232+
}
233+
234+
if closer, ok := body.(io.ReadCloser); ok {
235+
if copyErr = closer.Close(); copyErr != nil {
236+
return nil
237+
}
238+
}
239+
240+
body = r.buf
241+
return getRequestBuffer(r)
242+
}
235243
}
236-
req.ContentLength = int64(b.Len())
237-
if _, err := r.buf.Write(b.Bytes()); err != nil {
238-
return nil, err
244+
245+
authErr := auth.AuthenticateRequest(r, registry)
246+
247+
if copyErr != nil {
248+
return nil, fmt.Errorf("error retrieving the response body: %v", copyErr)
249+
}
250+
251+
if authErr != nil {
252+
return nil, authErr
239253
}
254+
255+
}
256+
257+
// create http request
258+
var reinstateSlash bool
259+
if r.pathPattern != "" && r.pathPattern != "/" && r.pathPattern[len(r.pathPattern)-1] == '/' {
260+
reinstateSlash = true
261+
}
262+
urlPath := path.Join(basePath, r.pathPattern)
263+
for k, v := range r.pathParams {
264+
urlPath = strings.Replace(urlPath, "{"+k+"}", url.PathEscape(v), -1)
265+
}
266+
if reinstateSlash {
267+
urlPath = urlPath + "/"
240268
}
241269

242-
if runtime.CanHaveBody(req.Method) && req.Body == nil && req.Header.Get(runtime.HeaderContentType) == "" {
243-
req.Header.Set(runtime.HeaderContentType, mediaType)
270+
req, err := http.NewRequest(r.method, urlPath, body)
271+
if err != nil {
272+
return nil, err
244273
}
245274

275+
req.URL.RawQuery = r.query.Encode()
276+
req.Header = r.header
277+
246278
return req, nil
247279
}
248280

@@ -266,6 +298,10 @@ func (r *request) GetPath() string {
266298
}
267299

268300
func (r *request) GetBody() []byte {
301+
return r.getBody(r)
302+
}
303+
304+
func getRequestBuffer(r *request) []byte {
269305
if r.buf == nil {
270306
return nil
271307
}

0 commit comments

Comments
 (0)