Skip to content

Commit 3aa7ea9

Browse files
authored
Merge pull request #74 from unmarshal/read_header_timeout
Add support for ReadHeaderTimeout
2 parents 7f48261 + cdc6386 commit 3aa7ea9

File tree

6 files changed

+201
-3
lines changed

6 files changed

+201
-3
lines changed

README.md

+32
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,38 @@ func main() {
119119
}
120120
```
121121

122+
### HTTP Server
123+
```go
124+
package main
125+
126+
import (
127+
"net"
128+
"net/http"
129+
"time"
130+
131+
"github.com/pires/go-proxyproto"
132+
)
133+
134+
func main() {
135+
server := http.Server{
136+
Addr: ":8080",
137+
}
138+
139+
ln, err := net.Listen("tcp", server.Addr)
140+
if err != nil {
141+
panic(err)
142+
}
143+
144+
proxyListener := &proxyproto.Listener{
145+
Listener: ln,
146+
ReadHeaderTimeout: 10 * time.Second,
147+
}
148+
defer proxyListener.Close()
149+
150+
server.Serve(proxyListener)
151+
}
152+
```
153+
122154
## Special notes
123155

124156
### AWS

examples/client/client.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package main
2+
3+
import (
4+
"io"
5+
"log"
6+
"net"
7+
8+
proxyproto "github.com/pires/go-proxyproto"
9+
)
10+
11+
func chkErr(err error) {
12+
if err != nil {
13+
log.Fatalf("Error: %s", err.Error())
14+
}
15+
}
16+
17+
func main() {
18+
// Dial some proxy listener e.g. https://github.com/mailgun/proxyproto
19+
target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9876")
20+
chkErr(err)
21+
22+
conn, err := net.DialTCP("tcp", nil, target)
23+
chkErr(err)
24+
25+
defer conn.Close()
26+
27+
// Create a proxyprotocol header or use HeaderProxyFromAddrs() if you
28+
// have two conn's
29+
header := &proxyproto.Header{
30+
Version: 1,
31+
Command: proxyproto.PROXY,
32+
TransportProtocol: proxyproto.TCPv4,
33+
SourceAddr: &net.TCPAddr{
34+
IP: net.ParseIP("10.1.1.1"),
35+
Port: 1000,
36+
},
37+
DestinationAddr: &net.TCPAddr{
38+
IP: net.ParseIP("20.2.2.2"),
39+
Port: 2000,
40+
},
41+
}
42+
// After the connection was created write the proxy headers first
43+
_, err = header.WriteTo(conn)
44+
chkErr(err)
45+
// Then your data... e.g.:
46+
_, err = io.WriteString(conn, "HELO")
47+
chkErr(err)
48+
}

examples/httpserver/httpserver.go

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package main
2+
3+
import (
4+
"log"
5+
"net"
6+
"net/http"
7+
"time"
8+
9+
"github.com/pires/go-proxyproto"
10+
)
11+
12+
// TODO: add httpclient example
13+
14+
func main() {
15+
server := http.Server{
16+
Addr: ":8080",
17+
ConnState: func(c net.Conn, s http.ConnState) {
18+
if s == http.StateNew {
19+
log.Printf("[ConnState] %s -> %s", c.LocalAddr().String(), c.RemoteAddr().String())
20+
}
21+
},
22+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23+
log.Printf("[Handler] remote ip %q", r.RemoteAddr)
24+
}),
25+
}
26+
27+
ln, err := net.Listen("tcp", server.Addr)
28+
if err != nil {
29+
panic(err)
30+
}
31+
32+
proxyListener := &proxyproto.Listener{
33+
Listener: ln,
34+
ReadHeaderTimeout: 10 * time.Second,
35+
}
36+
defer proxyListener.Close()
37+
38+
server.Serve(proxyListener)
39+
}

examples/server/server.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package main
2+
3+
import (
4+
"log"
5+
"net"
6+
7+
proxyproto "github.com/pires/go-proxyproto"
8+
)
9+
10+
func main() {
11+
// Create a listener
12+
addr := "localhost:9876"
13+
list, err := net.Listen("tcp", addr)
14+
if err != nil {
15+
log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error())
16+
}
17+
18+
// Wrap listener in a proxyproto listener
19+
proxyListener := &proxyproto.Listener{Listener: list}
20+
defer proxyListener.Close()
21+
22+
// Wait for a connection and accept it
23+
conn, err := proxyListener.Accept()
24+
defer conn.Close()
25+
26+
// Print connection details
27+
if conn.LocalAddr() == nil {
28+
log.Fatal("couldn't retrieve local address")
29+
}
30+
log.Printf("local address: %q", conn.LocalAddr().String())
31+
32+
if conn.RemoteAddr() == nil {
33+
log.Fatal("couldn't retrieve remote address")
34+
}
35+
log.Printf("remote address: %q", conn.RemoteAddr().String())
36+
}

protocol.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ import (
1313
// If the connection is using the protocol, the RemoteAddr() will return
1414
// the correct client address.
1515
type Listener struct {
16-
Listener net.Listener
17-
Policy PolicyFunc
18-
ValidateHeader Validator
16+
Listener net.Listener
17+
Policy PolicyFunc
18+
ValidateHeader Validator
19+
ReadHeaderTimeout time.Duration
1920
}
2021

2122
// Conn is used to wrap and underlying connection which
@@ -52,6 +53,10 @@ func (p *Listener) Accept() (net.Conn, error) {
5253
return nil, err
5354
}
5455

56+
if d := p.ReadHeaderTimeout; d != 0 {
57+
conn.SetReadDeadline(time.Now().Add(d))
58+
}
59+
5560
proxyHeaderPolicy := USE
5661
if p.Policy != nil {
5762
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())

protocol_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ package proxyproto
66

77
import (
88
"bytes"
9+
"context"
910
"crypto/tls"
1011
"crypto/x509"
1112
"fmt"
1213
"io"
1314
"io/ioutil"
1415
"net"
1516
"testing"
17+
"time"
1618
)
1719

1820
func TestPassthrough(t *testing.T) {
@@ -61,6 +63,42 @@ func TestPassthrough(t *testing.T) {
6163
}
6264
}
6365

66+
func TestReadHeaderTimeout(t *testing.T) {
67+
l, err := net.Listen("tcp", "127.0.0.1:0")
68+
if err != nil {
69+
t.Fatalf("err: %v", err)
70+
}
71+
72+
pl := &Listener{
73+
Listener: l,
74+
ReadHeaderTimeout: 1 * time.Millisecond,
75+
}
76+
77+
ctx, cancel := context.WithCancel(context.Background())
78+
defer cancel()
79+
80+
go func() {
81+
conn, err := net.Dial("tcp", pl.Addr().String())
82+
if err != nil {
83+
t.Fatalf("err: %v", err)
84+
}
85+
defer conn.Close()
86+
87+
<-ctx.Done()
88+
}()
89+
90+
conn, err := pl.Accept()
91+
if err != nil {
92+
t.Fatalf("err: %v", err)
93+
}
94+
defer conn.Close()
95+
96+
// Read blocks forever if there is no ReadHeaderTimeout
97+
recv := make([]byte, 4)
98+
_, err = conn.Read(recv)
99+
100+
}
101+
64102
func TestParse_ipv4(t *testing.T) {
65103
l, err := net.Listen("tcp", "127.0.0.1:0")
66104
if err != nil {

0 commit comments

Comments
 (0)