Skip to content

Commit 6fd0f86

Browse files
committed
Allow separate specification of http/dial Hosts
Enable the Host: header to be customized in a Dialer.Dial call. This is needed to implement websocket proxies.
1 parent b2fa8f6 commit 6fd0f86

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

client.go

+16
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,22 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
228228
requestHeader = h
229229
}
230230

231+
if len(requestHeader["Host"]) > 0 {
232+
// This can be used to supply a Host: header which is different from
233+
// the dial address.
234+
u.Host = requestHeader.Get("Host")
235+
236+
// Drop "Host" header
237+
h := http.Header{}
238+
for k, v := range requestHeader {
239+
if k == "Host" {
240+
continue
241+
}
242+
h[k] = v
243+
}
244+
requestHeader = h
245+
}
246+
231247
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
232248

233249
if err != nil {

client_server_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,36 @@ func TestRespOnBadHandshake(t *testing.T) {
288288
t.Errorf("resp.Body=%s, want %s", p, expectedBody)
289289
}
290290
}
291+
292+
// If the Host header is specified in `Dial()`, the server must receive it as
293+
// the `Host:` header.
294+
func TestHostHeader(t *testing.T) {
295+
s := newServer(t)
296+
defer s.Close()
297+
298+
specifiedHost := make(chan string, 1)
299+
origHandler := s.Server.Config.Handler
300+
301+
// Capture the request Host header.
302+
s.Server.Config.Handler = http.HandlerFunc(
303+
func(w http.ResponseWriter, r *http.Request) {
304+
specifiedHost <- r.Host
305+
origHandler.ServeHTTP(w, r)
306+
})
307+
308+
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
309+
if err != nil {
310+
t.Fatalf("Dial: %v", err)
311+
}
312+
defer ws.Close()
313+
314+
if resp.StatusCode != http.StatusSwitchingProtocols {
315+
t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode)
316+
}
317+
318+
if gotHost := <-specifiedHost; gotHost != "testhost" {
319+
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
320+
}
321+
322+
sendRecv(t, ws)
323+
}

0 commit comments

Comments
 (0)