diff --git a/protocol.go b/protocol.go index 7bbb8e5..6e4fdd3 100644 --- a/protocol.go +++ b/protocol.go @@ -44,6 +44,7 @@ type Conn struct { readErr error conn net.Conn bufReader *bufio.Reader + reader io.Reader header *Header ProxyHeaderPolicy Policy Validate Validator @@ -138,9 +139,11 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { // For v2 the header length is at most 52 bytes plus the length of the TLVs. // We use 256 bytes to be safe. const bufSize = 256 + br := bufio.NewReaderSize(conn, bufSize) pConn := &Conn{ - bufReader: bufio.NewReaderSize(conn, bufSize), + bufReader: br, + reader: io.MultiReader(br, conn), conn: conn, } @@ -162,7 +165,7 @@ func (p *Conn) Read(b []byte) (int, error) { return 0, p.readErr } - return p.bufReader.Read(b) + return p.reader.Read(b) } // Write wraps original conn.Write @@ -344,5 +347,27 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) { if p.readErr != nil { return 0, p.readErr } - return p.bufReader.WriteTo(w) + + b := make([]byte, p.bufReader.Buffered()) + if _, err := p.bufReader.Read(b); err != nil { + return 0, err // this should never as we read buffered data + } + + var n int64 + { + nn, err := w.Write(b) + n += int64(nn) + if err != nil { + return n, err + } + } + { + nn, err := io.Copy(w, p.conn) + n += nn + if err != nil { + return n, err + } + } + + return n, nil }