Skip to content

Commit 22bc614

Browse files
authored
Merge pull request #61 from pires/feature/proxy_unknown
Support v1 UNKNOWN and v2 UNSPEC when command is LOCAL
2 parents adbbabe + 4b450a5 commit 22bc614

File tree

7 files changed

+412
-282
lines changed

7 files changed

+412
-282
lines changed

addr_proto.go

-9
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@ const (
1313
UnixDatagram AddressFamilyAndProtocol = '\x32'
1414
)
1515

16-
var supportedTransportProtocol = map[AddressFamilyAndProtocol]bool{
17-
TCPv4: true,
18-
UDPv4: true,
19-
TCPv6: true,
20-
UDPv6: true,
21-
UnixStream: true,
22-
UnixDatagram: true,
23-
}
24-
2516
// IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise.
2617
func (ap AddressFamilyAndProtocol) IsIPv4() bool {
2718
return 0x10 == ap&0xF0

header.go

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ var (
1616
SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'}
1717
SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'}
1818

19+
ErrLineMustEndWithCrlf = errors.New("proxyproto: header is invalid, must end with \\r\\n")
1920
ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command")
2021
ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol")
2122
ErrCantReadLength = errors.New("proxyproto: can't read length")

protocol.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,15 @@ func (p *Conn) RemoteAddr() net.Addr {
157157
// Raw returns the underlying connection which can be casted to
158158
// a concrete type, allowing access to specialized functions.
159159
//
160-
// Use this ONLY if you know exactly what you are doing.
160+
// Use this ONLY if you know exactly what you are doing.
161161
func (p *Conn) Raw() net.Conn {
162162
return p.conn
163163
}
164164

165165
// TCPConn returns the underlying TCP connection,
166166
// allowing access to specialized functions.
167167
//
168-
// Use this ONLY if you know exactly what you are doing.
168+
// Use this ONLY if you know exactly what you are doing.
169169
func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) {
170170
conn, ok = p.conn.(*net.TCPConn)
171171
return
@@ -174,7 +174,7 @@ func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) {
174174
// UnixConn returns the underlying Unix socket connection,
175175
// allowing access to specialized functions.
176176
//
177-
// Use this ONLY if you know exactly what you are doing.
177+
// Use this ONLY if you know exactly what you are doing.
178178
func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) {
179179
conn, ok = p.conn.(*net.UnixConn)
180180
return
@@ -183,7 +183,7 @@ func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) {
183183
// UDPConn returns the underlying UDP connection,
184184
// allowing access to specialized functions.
185185
//
186-
// Use this ONLY if you know exactly what you are doing.
186+
// Use this ONLY if you know exactly what you are doing.
187187
func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) {
188188
conn, ok = p.conn.(*net.UDPConn)
189189
return

v1.go

+59-38
Original file line numberDiff line numberDiff line change
@@ -22,52 +22,73 @@ func initVersion1() *Header {
2222
}
2323

2424
func parseVersion1(reader *bufio.Reader) (*Header, error) {
25-
// Make sure we have a v1 header
25+
// Read until LF shows up, otherwise fail.
26+
// At this point, can't be sure CR precedes LF which will be validated next.
2627
line, err := reader.ReadString('\n')
28+
if err != nil {
29+
return nil, ErrLineMustEndWithCrlf
30+
}
2731
if !strings.HasSuffix(line, crlf) {
28-
return nil, ErrCantReadProtocolVersionAndCommand
32+
return nil, ErrLineMustEndWithCrlf
2933
}
34+
// Check full signature.
3035
tokens := strings.Split(line[:len(line)-2], separator)
31-
if len(tokens) < 6 {
32-
return nil, ErrCantReadProtocolVersionAndCommand
36+
transportProtocol := UNSPEC // doesn't exist in v1 but fits UNKNOWN.
37+
if len(tokens) > 0 {
38+
// Read address family and protocol
39+
switch tokens[1] {
40+
case "TCP4":
41+
transportProtocol = TCPv4
42+
case "TCP6":
43+
transportProtocol = TCPv6
44+
case "UNKNOWN": // no-op as UNSPEC is set already
45+
default:
46+
return nil, ErrCantReadAddressFamilyAndProtocol
47+
}
48+
49+
// Expect 6 tokens only when UNKNOWN is not present.
50+
if !transportProtocol.IsUnspec() && len(tokens) < 6 {
51+
return nil, ErrCantReadAddressFamilyAndProtocol
52+
}
3353
}
3454

55+
// Allocation only happens when a signature is found.
3556
header := initVersion1()
57+
// If UNKNOWN is present, set Command to LOCAL.
58+
// Command is not present in v1 but set it for other parts of
59+
// this library to rely on it for determining connection details.
60+
header.Command = LOCAL
3661

37-
// Read address family and protocol
38-
switch tokens[1] {
39-
case "TCP4":
40-
header.TransportProtocol = TCPv4
41-
case "TCP6":
42-
header.TransportProtocol = TCPv6
43-
default:
44-
header.TransportProtocol = UNSPEC
45-
}
62+
// Transport protocol has been processed already.
63+
header.TransportProtocol = transportProtocol
4664

47-
// Read addresses and ports
48-
sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2])
49-
if err != nil {
50-
return nil, err
51-
}
52-
destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3])
53-
if err != nil {
54-
return nil, err
55-
}
56-
sourcePort, err := parseV1PortNumber(tokens[4])
57-
if err != nil {
58-
return nil, err
59-
}
60-
destPort, err := parseV1PortNumber(tokens[5])
61-
if err != nil {
62-
return nil, err
63-
}
64-
header.SourceAddr = &net.TCPAddr{
65-
IP: sourceIP,
66-
Port: sourcePort,
67-
}
68-
header.DestinationAddr = &net.TCPAddr{
69-
IP: destIP,
70-
Port: destPort,
65+
// Only process further if UNKNOWN is not present.
66+
if header.TransportProtocol != UNSPEC {
67+
// Read addresses and ports
68+
sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2])
69+
if err != nil {
70+
return nil, err
71+
}
72+
destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3])
73+
if err != nil {
74+
return nil, err
75+
}
76+
sourcePort, err := parseV1PortNumber(tokens[4])
77+
if err != nil {
78+
return nil, err
79+
}
80+
destPort, err := parseV1PortNumber(tokens[5])
81+
if err != nil {
82+
return nil, err
83+
}
84+
header.SourceAddr = &net.TCPAddr{
85+
IP: sourceIP,
86+
Port: sourcePort,
87+
}
88+
header.DestinationAddr = &net.TCPAddr{
89+
IP: destIP,
90+
Port: destPort,
91+
}
7192
}
7293

7394
return header, nil
@@ -84,7 +105,7 @@ func (header *Header) formatVersion1() ([]byte, error) {
84105
proto = "TCP6"
85106
default:
86107
// Unknown connection (short form)
87-
return []byte("PROXY UNKNOWN\r\n"), nil
108+
return []byte("PROXY UNKNOWN" + crlf), nil
88109
}
89110

90111
sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr)

v1_test.go

+101-49
Original file line numberDiff line numberDiff line change
@@ -9,63 +9,88 @@ import (
99
)
1010

1111
var (
12-
TCP4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
13-
TCP4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator)
14-
TCP6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
12+
IPv4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
13+
IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator)
14+
IPv6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator)
1515

16-
fixtureTCP4V1 = "PROXY TCP4 " + TCP4AddressesAndPorts + crlf + "GET /"
17-
fixtureTCP6V1 = "PROXY TCP6 " + TCP6AddressesAndPorts + crlf + "GET /"
16+
fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /"
17+
fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /"
18+
19+
fixtureUnknown = "PROXY UNKNOWN" + crlf
20+
fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf
1821
)
1922

2023
var invalidParseV1Tests = []struct {
24+
desc string
2125
reader *bufio.Reader
2226
expectedError error
2327
}{
2428
{
25-
newBufioReader([]byte("PROX")),
26-
ErrNoProxyProtocol,
29+
desc: "no signature",
30+
reader: newBufioReader([]byte(NO_PROTOCOL)),
31+
expectedError: ErrNoProxyProtocol,
32+
},
33+
{
34+
desc: "prox",
35+
reader: newBufioReader([]byte("PROX")),
36+
expectedError: ErrNoProxyProtocol,
37+
},
38+
{
39+
desc: "proxy lf",
40+
reader: newBufioReader([]byte("PROXY \n")),
41+
expectedError: ErrLineMustEndWithCrlf,
2742
},
2843
{
29-
newBufioReader([]byte(NO_PROTOCOL)),
30-
ErrNoProxyProtocol,
44+
desc: "proxy crlf",
45+
reader: newBufioReader([]byte("PROXY " + crlf)),
46+
expectedError: ErrCantReadAddressFamilyAndProtocol,
3147
},
3248
{
33-
newBufioReader([]byte("PROXY \r\n")),
34-
ErrCantReadProtocolVersionAndCommand,
49+
desc: "proxy something crlf",
50+
reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)),
51+
expectedError: ErrCantReadAddressFamilyAndProtocol,
3552
},
3653
{
37-
newBufioReader([]byte("PROXY TCP4 " + TCP4AddressesAndPorts)),
38-
ErrCantReadProtocolVersionAndCommand,
54+
desc: "incomplete signature TCP4",
55+
reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)),
56+
expectedError: ErrLineMustEndWithCrlf,
3957
},
4058
{
41-
newBufioReader([]byte("PROXY TCP6 " + TCP4AddressesAndPorts + crlf)),
42-
ErrInvalidAddress,
59+
desc: "TCP6 with IPv4 addresses",
60+
reader: newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)),
61+
expectedError: ErrInvalidAddress,
4362
},
4463
{
45-
newBufioReader([]byte("PROXY TCP4 " + TCP6AddressesAndPorts + crlf)),
46-
ErrInvalidAddress,
64+
desc: "TCP4 with IPv6 addresses",
65+
reader: newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)),
66+
expectedError: ErrInvalidAddress,
4767
},
48-
// PROXY TCP IPv4
49-
{newBufioReader([]byte("PROXY TCP4 " + TCP4AddressesAndInvalidPorts + crlf)),
50-
ErrInvalidPortNumber,
68+
{
69+
desc: "TCP4 with invalid port",
70+
reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)),
71+
expectedError: ErrInvalidPortNumber,
5172
},
5273
}
5374

5475
func TestReadV1Invalid(t *testing.T) {
5576
for _, tt := range invalidParseV1Tests {
56-
if _, err := Read(tt.reader); err != tt.expectedError {
57-
t.Fatalf("TestReadV1Invalid: expected %s, actual %s", tt.expectedError, err.Error())
58-
}
77+
t.Run(tt.desc, func(t *testing.T) {
78+
if _, err := Read(tt.reader); err != tt.expectedError {
79+
t.Fatalf("expected %s, actual %s", tt.expectedError, err.Error())
80+
}
81+
})
5982
}
6083
}
6184

6285
var validParseAndWriteV1Tests = []struct {
86+
desc string
6387
reader *bufio.Reader
6488
expectedHeader *Header
6589
}{
6690
{
67-
bufio.NewReader(strings.NewReader(fixtureTCP4V1)),
68-
&Header{
91+
desc: "TCP4",
92+
reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)),
93+
expectedHeader: &Header{
6994
Version: 1,
7095
Command: PROXY,
7196
TransportProtocol: TCPv4,
@@ -74,47 +99,74 @@ var validParseAndWriteV1Tests = []struct {
7499
},
75100
},
76101
{
77-
bufio.NewReader(strings.NewReader(fixtureTCP6V1)),
78-
&Header{
102+
desc: "TCP6",
103+
reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)),
104+
expectedHeader: &Header{
79105
Version: 1,
80106
Command: PROXY,
81107
TransportProtocol: TCPv6,
82108
SourceAddr: v6addr,
83109
DestinationAddr: v6addr,
84110
},
85111
},
112+
{
113+
desc: "unknown",
114+
reader: bufio.NewReader(strings.NewReader(fixtureUnknown)),
115+
expectedHeader: &Header{
116+
Version: 1,
117+
Command: PROXY,
118+
TransportProtocol: UNSPEC,
119+
SourceAddr: nil,
120+
DestinationAddr: nil,
121+
},
122+
},
123+
{
124+
desc: "unknown with addresses and ports",
125+
reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)),
126+
expectedHeader: &Header{
127+
Version: 1,
128+
Command: PROXY,
129+
TransportProtocol: UNSPEC,
130+
SourceAddr: nil,
131+
DestinationAddr: nil,
132+
},
133+
},
86134
}
87135

88136
func TestParseV1Valid(t *testing.T) {
89137
for _, tt := range validParseAndWriteV1Tests {
90-
header, err := Read(tt.reader)
91-
if err != nil {
92-
t.Fatal("TestParseV1Valid: unexpected error", err.Error())
93-
}
94-
if !header.EqualsTo(tt.expectedHeader) {
95-
t.Fatalf("TestParseV1Valid: expected %#v, actual %#v", tt.expectedHeader, header)
96-
}
138+
t.Run(tt.desc, func(t *testing.T) {
139+
header, err := Read(tt.reader)
140+
if err != nil {
141+
t.Fatal("unexpected error", err.Error())
142+
}
143+
if !header.EqualsTo(tt.expectedHeader) {
144+
t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header)
145+
}
146+
})
97147
}
98148
}
99149

100150
func TestWriteV1Valid(t *testing.T) {
101151
for _, tt := range validParseAndWriteV1Tests {
102-
var b bytes.Buffer
103-
w := bufio.NewWriter(&b)
104-
if _, err := tt.expectedHeader.WriteTo(w); err != nil {
105-
t.Fatal("TestWriteV1Valid: Unexpected error ", err)
106-
}
107-
w.Flush()
152+
t.Run(tt.desc, func(t *testing.T) {
153+
var b bytes.Buffer
154+
w := bufio.NewWriter(&b)
155+
if _, err := tt.expectedHeader.WriteTo(w); err != nil {
156+
t.Fatal("unexpected error ", err)
157+
}
158+
w.Flush()
108159

109-
// Read written bytes to validate written header
110-
r := bufio.NewReader(&b)
111-
newHeader, err := Read(r)
112-
if err != nil {
113-
t.Fatal("TestWriteV1Valid: Unexpected error ", err)
114-
}
160+
// Read written bytes to validate written header
161+
r := bufio.NewReader(&b)
162+
newHeader, err := Read(r)
163+
if err != nil {
164+
t.Fatal("unexpected error ", err)
165+
}
115166

116-
if !newHeader.EqualsTo(tt.expectedHeader) {
117-
t.Fatalf("TestWriteV1Valid: expected %#v, actual %#v", tt.expectedHeader, newHeader)
118-
}
167+
if !newHeader.EqualsTo(tt.expectedHeader) {
168+
t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader)
169+
}
170+
})
119171
}
120172
}

0 commit comments

Comments
 (0)