@@ -2,6 +2,7 @@ package proxyproto
22
33import (
44 "bufio"
5+ "bytes"
56 "errors"
67 "fmt"
78 "io"
@@ -51,7 +52,7 @@ type Conn struct {
5152 once sync.Once
5253 readErr error
5354 conn net.Conn
54- bufReader * bufio .Reader
55+ reader io .Reader
5556 header * Header
5657 ProxyHeaderPolicy Policy
5758 Validate Validator
@@ -150,14 +151,8 @@ func (p *Listener) Addr() net.Addr {
150151// NewConn is used to wrap a net.Conn that may be speaking
151152// the proxy protocol into a proxyproto.Conn
152153func NewConn (conn net.Conn , opts ... func (* Conn )) * Conn {
153- // For v1 the header length is at most 108 bytes.
154- // For v2 the header length is at most 52 bytes plus the length of the TLVs.
155- // We use 256 bytes to be safe.
156- const bufSize = 256
157-
158154 pConn := & Conn {
159- bufReader : bufio .NewReaderSize (conn , bufSize ),
160- conn : conn ,
155+ conn : conn ,
161156 }
162157
163158 for _ , opt := range opts {
@@ -178,7 +173,7 @@ func (p *Conn) Read(b []byte) (int, error) {
178173 return 0 , p .readErr
179174 }
180175
181- return p .bufReader .Read (b )
176+ return p .reader .Read (b )
182177}
183178
184179// Write wraps original conn.Write
@@ -294,7 +289,26 @@ func (p *Conn) readHeader() error {
294289 }
295290 }
296291
297- header , err := Read (p .bufReader )
292+ // For v1 the header length is at most 108 bytes.
293+ // For v2 the header length is at most 52 bytes plus the length of the TLVs.
294+ // We use 256 bytes to be safe.
295+ const bufSize = 256
296+
297+ bb := bytes .NewBuffer (make ([]byte , 0 , bufSize ))
298+ tr := io .TeeReader (p .conn , bb )
299+ br := bufio .NewReaderSize (tr , bufSize )
300+
301+ header , err := Read (br )
302+
303+ if err == nil {
304+ _ , err = io .CopyN (io .Discard , bb , int64 (header .length ))
305+ }
306+
307+ if bb .Len () == 0 {
308+ p .reader = p .conn
309+ } else {
310+ p .reader = io .MultiReader (bb , p .conn )
311+ }
298312
299313 // If the connection's readHeaderTimeout is more than 0, undo the change to the
300314 // deadline that we made above. Because we retain the readDeadline as part of our
@@ -360,5 +374,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
360374 if p .readErr != nil {
361375 return 0 , p .readErr
362376 }
363- return p . bufReader . WriteTo ( w )
377+ return io . Copy ( w , p . reader )
364378}
0 commit comments