@@ -3,19 +3,25 @@ package proxyproto
3
3
import (
4
4
"bufio"
5
5
"bytes"
6
+ "io"
7
+ "net"
6
8
"strconv"
7
9
"strings"
8
10
"testing"
11
+ "time"
9
12
)
10
13
11
14
var (
12
15
IPv4AddressesAndPorts = strings .Join ([]string {IP4_ADDR , IP4_ADDR , strconv .Itoa (PORT ), strconv .Itoa (PORT )}, separator )
13
16
IPv4AddressesAndInvalidPorts = strings .Join ([]string {IP4_ADDR , IP4_ADDR , strconv .Itoa (INVALID_PORT ), strconv .Itoa (INVALID_PORT )}, separator )
14
17
IPv6AddressesAndPorts = strings .Join ([]string {IP6_ADDR , IP6_ADDR , strconv .Itoa (PORT ), strconv .Itoa (PORT )}, separator )
18
+ IPv6LongAddressesAndPorts = strings .Join ([]string {IP6_LONG_ADDR , IP6_LONG_ADDR , strconv .Itoa (PORT ), strconv .Itoa (PORT )}, separator )
15
19
16
20
fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /"
17
21
fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /"
18
22
23
+ fixtureTCP6V1Overflow = "PROXY TCP6 " + IPv6LongAddressesAndPorts
24
+
19
25
fixtureUnknown = "PROXY UNKNOWN" + crlf
20
26
fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf
21
27
)
@@ -58,7 +64,7 @@ var invalidParseV1Tests = []struct {
58
64
{
59
65
desc : "incomplete signature TCP4" ,
60
66
reader : newBufioReader ([]byte ("PROXY TCP4 " + IPv4AddressesAndPorts )),
61
- expectedError : ErrLineMustEndWithCrlf ,
67
+ expectedError : ErrCantReadVersion1Header ,
62
68
},
63
69
{
64
70
desc : "TCP6 with IPv4 addresses" ,
@@ -75,13 +81,18 @@ var invalidParseV1Tests = []struct {
75
81
reader : newBufioReader ([]byte ("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf )),
76
82
expectedError : ErrInvalidPortNumber ,
77
83
},
84
+ {
85
+ desc : "header too long" ,
86
+ reader : newBufioReader ([]byte ("PROXY UNKNOWN " + IPv6LongAddressesAndPorts + " " + crlf )),
87
+ expectedError : ErrVersion1HeaderTooLong ,
88
+ },
78
89
}
79
90
80
91
func TestReadV1Invalid (t * testing.T ) {
81
92
for _ , tt := range invalidParseV1Tests {
82
93
t .Run (tt .desc , func (t * testing.T ) {
83
94
if _ , err := Read (tt .reader ); err != tt .expectedError {
84
- t .Fatalf ("expected %s, actual %s " , tt .expectedError , err . Error () )
95
+ t .Fatalf ("expected %s, actual %v " , tt .expectedError , err )
85
96
}
86
97
})
87
98
}
@@ -175,3 +186,150 @@ func TestWriteV1Valid(t *testing.T) {
175
186
})
176
187
}
177
188
}
189
+
190
+ // Tests for parseVersion1 overflow - issue #69.
191
+
192
+ type dataSource struct {
193
+ NBytes int
194
+ NRead int
195
+ }
196
+
197
+ func (ds * dataSource ) Read (b []byte ) (int , error ) {
198
+ if ds .NRead >= ds .NBytes {
199
+ return 0 , io .EOF
200
+ }
201
+ avail := ds .NBytes - ds .NRead
202
+ if len (b ) < avail {
203
+ avail = len (b )
204
+ }
205
+ for i := 0 ; i < avail ; i ++ {
206
+ b [i ] = 0x20
207
+ }
208
+ ds .NRead += avail
209
+ return avail , nil
210
+ }
211
+
212
+ func TestParseVersion1Overflow (t * testing.T ) {
213
+ ds := & dataSource {}
214
+ reader := bufio .NewReader (ds )
215
+ bufSize := reader .Size ()
216
+ ds .NBytes = bufSize * 16
217
+ parseVersion1 (reader )
218
+ if ds .NRead > bufSize {
219
+ t .Fatalf ("read: expected max %d bytes, actual %d\n " , bufSize , ds .NRead )
220
+ }
221
+ }
222
+
223
+ func listen (t * testing.T ) * Listener {
224
+ l , err := net .Listen ("tcp" , "127.0.0.1:0" )
225
+ if err != nil {
226
+ t .Fatalf ("listen: %v" , err )
227
+ }
228
+ return & Listener {Listener : l }
229
+ }
230
+
231
+ func client (t * testing.T , addr , header string , length int , terminate bool , wait time.Duration , done chan struct {}) {
232
+ c , err := net .Dial ("tcp" , addr )
233
+ if err != nil {
234
+ t .Fatalf ("dial: %v" , err )
235
+ }
236
+ defer c .Close ()
237
+
238
+ if terminate && length < 2 {
239
+ length = 2
240
+ }
241
+
242
+ buf := make ([]byte , len (header )+ length )
243
+ copy (buf , []byte (header ))
244
+ for i := 0 ; i < length - 2 ; i ++ {
245
+ buf [i + len (header )] = 0x20
246
+ }
247
+ if terminate {
248
+ copy (buf [len (header )+ length - 2 :], []byte (crlf ))
249
+ }
250
+
251
+ n , err := c .Write (buf )
252
+ if err != nil {
253
+ t .Fatalf ("write: %v" , err )
254
+ }
255
+ if n != len (buf ) {
256
+ t .Fatalf ("write; short write" )
257
+ }
258
+
259
+ time .Sleep (wait )
260
+ close (done )
261
+ }
262
+
263
+ func TestVersion1Overflow (t * testing.T ) {
264
+ done := make (chan struct {})
265
+
266
+ l := listen (t )
267
+ go client (t , l .Addr ().String (), fixtureTCP6V1Overflow , 10240 , true , 10 * time .Second , done )
268
+
269
+ c , err := l .Accept ()
270
+ if err != nil {
271
+ t .Fatalf ("accept: %v" , err )
272
+ }
273
+
274
+ b := []byte {}
275
+ _ , err = c .Read (b )
276
+ if err == nil {
277
+ t .Fatalf ("net.Conn: no error reported for oversized header" )
278
+ }
279
+ }
280
+
281
+ func TestVersion1SlowLoris (t * testing.T ) {
282
+ done := make (chan struct {})
283
+ timeout := make (chan error )
284
+
285
+ l := listen (t )
286
+ go client (t , l .Addr ().String (), fixtureTCP6V1Overflow , 0 , false , 10 * time .Second , done )
287
+
288
+ c , err := l .Accept ()
289
+ if err != nil {
290
+ t .Fatalf ("accept: %v" , err )
291
+ }
292
+
293
+ go func () {
294
+ b := []byte {}
295
+ _ , err = c .Read (b )
296
+ timeout <- err
297
+ }()
298
+
299
+ select {
300
+ case <- done :
301
+ t .Fatalf ("net.Conn: reader still blocked after 10 seconds" )
302
+ case err := <- timeout :
303
+ if err == nil {
304
+ t .Fatalf ("net.Conn: no error reported for incomplete header" )
305
+ }
306
+ }
307
+ }
308
+
309
+ func TestVersion1SlowLorisOverflow (t * testing.T ) {
310
+ done := make (chan struct {})
311
+ timeout := make (chan error )
312
+
313
+ l := listen (t )
314
+ go client (t , l .Addr ().String (), fixtureTCP6V1Overflow , 10240 , false , 10 * time .Second , done )
315
+
316
+ c , err := l .Accept ()
317
+ if err != nil {
318
+ t .Fatalf ("accept: %v" , err )
319
+ }
320
+
321
+ go func () {
322
+ b := []byte {}
323
+ _ , err = c .Read (b )
324
+ timeout <- err
325
+ }()
326
+
327
+ select {
328
+ case <- done :
329
+ t .Fatalf ("net.Conn: reader still blocked after 10 seconds" )
330
+ case err := <- timeout :
331
+ if err == nil {
332
+ t .Fatalf ("net.Conn: no error reported for incomplete and overflowed header" )
333
+ }
334
+ }
335
+ }
0 commit comments