From 8a8a4d7100ce058f461f4b111f97a49fffa4f256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Tue, 15 Oct 2024 12:07:13 +0200 Subject: [PATCH] protocol: limit use of buffered reader Fix bug introduced in #116 where io.MultiReader only reads from buffered reader. Move buffer reader management to readHeader(). Remove Conn.bufReader, and make Conn.reader nil until readHeader() is called. --- protocol.go | 54 ++++++++++++++++++++++------------------------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/protocol.go b/protocol.go index 270b90d..170e12b 100644 --- a/protocol.go +++ b/protocol.go @@ -2,6 +2,7 @@ package proxyproto import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -51,7 +52,6 @@ type Conn struct { once sync.Once readErr error conn net.Conn - bufReader *bufio.Reader reader io.Reader header *Header ProxyHeaderPolicy Policy @@ -151,16 +151,8 @@ func (p *Listener) Addr() net.Addr { // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { - // For v1 the header length is at most 108 bytes. - // For v2 the header length is at most 52 bytes plus the length of the TLVs. - // We use 256 bytes to be safe. - const bufSize = 256 - br := bufio.NewReaderSize(conn, bufSize) - pConn := &Conn{ - bufReader: br, - reader: io.MultiReader(br, conn), - conn: conn, + conn: conn, } for _, opt := range opts { @@ -297,7 +289,23 @@ func (p *Conn) readHeader() error { } } - header, err := Read(p.bufReader) + // For v1 the header length is at most 108 bytes. + // For v2 the header length is at most 52 bytes plus the length of the TLVs. + // We use 256 bytes to be safe. + const bufSize = 256 + br := bufio.NewReaderSize(p.conn, bufSize) + + header, err := Read(br) + + if br.Buffered() != 0 { + buf := make([]byte, br.Buffered()) + if _, err := br.Read(buf); err != nil { + return err // this should never as we read buffered data + } + p.reader = io.MultiReader(bytes.NewReader(buf), p.conn) + } else { + p.reader = p.conn + } // If the connection's readHeaderTimeout is more than 0, undo the change to the // deadline that we made above. Because we retain the readDeadline as part of our @@ -364,26 +372,8 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) { return 0, p.readErr } - b := make([]byte, p.bufReader.Buffered()) - if _, err := p.bufReader.Read(b); err != nil { - return 0, err // this should never as we read buffered data - } - - var n int64 - { - nn, err := w.Write(b) - n += int64(nn) - if err != nil { - return n, err - } + if wt, ok := p.reader.(io.WriterTo); ok { + return wt.WriteTo(w) } - { - nn, err := io.Copy(w, p.conn) - n += nn - if err != nil { - return n, err - } - } - - return n, nil + return io.Copy(w, p.reader) }