diff --git a/protocol.go b/protocol.go index 6e4fdd3..270b90d 100644 --- a/protocol.go +++ b/protocol.go @@ -2,6 +2,8 @@ package proxyproto import ( "bufio" + "errors" + "fmt" "io" "net" "sync" @@ -9,11 +11,17 @@ import ( "time" ) -// DefaultReadHeaderTimeout is how long header processing waits for header to -// be read from the wire, if Listener.ReaderHeaderTimeout is not set. -// It's kept as a global variable so to make it easier to find and override, -// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" -var DefaultReadHeaderTimeout = 10 * time.Second +var ( + // DefaultReadHeaderTimeout is how long header processing waits for header to + // be read from the wire, if Listener.ReaderHeaderTimeout is not set. + // It's kept as a global variable so to make it easier to find and override, + // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" + DefaultReadHeaderTimeout = 10 * time.Second + + // ErrInvalidUpstream should be returned when an upstream connection address + // is not trusted, and therefore is invalid. + ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information") +) // Listener is used to wrap an underlying listener, // whose connections may be using the HAProxy Proxy Protocol. @@ -73,53 +81,61 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) { } } -// Accept waits for and returns the next connection to the listener. +// Accept waits for and returns the next valid connection to the listener. func (p *Listener) Accept() (net.Conn, error) { - // Get the underlying connection - conn, err := p.Listener.Accept() - if err != nil { - return nil, err - } - - proxyHeaderPolicy := USE - if p.Policy != nil && p.ConnPolicy != nil { - panic("only one of policy or connpolicy must be provided.") - } - if p.Policy != nil || p.ConnPolicy != nil { - if p.Policy != nil { - proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) - } else { - proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{ - Upstream: conn.RemoteAddr(), - Downstream: conn.LocalAddr(), - }) - } + for { + // Get the underlying connection + conn, err := p.Listener.Accept() if err != nil { - // can't decide the policy, we can't accept the connection - conn.Close() return nil, err } - // Handle a connection as a regular one - if proxyHeaderPolicy == SKIP { - return conn, nil + + proxyHeaderPolicy := USE + if p.Policy != nil && p.ConnPolicy != nil { + panic("only one of policy or connpolicy must be provided.") } - } + if p.Policy != nil || p.ConnPolicy != nil { + if p.Policy != nil { + proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) + } else { + proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{ + Upstream: conn.RemoteAddr(), + Downstream: conn.LocalAddr(), + }) + } + if err != nil { + // can't decide the policy, we can't accept the connection + conn.Close() - newConn := NewConn( - conn, - WithPolicy(proxyHeaderPolicy), - ValidateHeader(p.ValidateHeader), - ) + if errors.Is(err, ErrInvalidUpstream) { + // keep listening for other connections + continue + } - // If the ReadHeaderTimeout for the listener is unset, use the default timeout. - if p.ReadHeaderTimeout == 0 { - p.ReadHeaderTimeout = DefaultReadHeaderTimeout - } + return nil, err + } + // Handle a connection as a regular one + if proxyHeaderPolicy == SKIP { + return conn, nil + } + } - // Set the readHeaderTimeout of the new conn to the value of the listener - newConn.readHeaderTimeout = p.ReadHeaderTimeout + newConn := NewConn( + conn, + WithPolicy(proxyHeaderPolicy), + ValidateHeader(p.ValidateHeader), + ) - return newConn, nil + // If the ReadHeaderTimeout for the listener is unset, use the default timeout. + if p.ReadHeaderTimeout == 0 { + p.ReadHeaderTimeout = DefaultReadHeaderTimeout + } + + // Set the readHeaderTimeout of the new conn to the value of the listener + newConn.readHeaderTimeout = p.ReadHeaderTimeout + + return newConn, nil + } } // Close closes the underlying listener. diff --git a/protocol_test.go b/protocol_test.go index 847601f..4221149 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -12,6 +12,8 @@ import ( "fmt" "io" "net" + "net/http" + "sync/atomic" "testing" "time" ) @@ -82,7 +84,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) { start := time.Now() l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { t.Fatalf("err: %v", err) } @@ -137,7 +138,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) { start := time.Now() l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { t.Fatalf("err: %v", err) } @@ -847,6 +847,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) { t.Fatalf("client error: %v", err) } } + func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -1274,6 +1275,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) { } } +func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8080") + if err != nil { + t.Fatalf("error creating listener: %v", err) + } + + var connectionCounter atomic.Int32 + + newLn := &Listener{ + Listener: l, + ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) { + // Return the invalid upstream error on the first call, the listener + // should remain open and accepting. + times := connectionCounter.Load() + if times == 0 { + connectionCounter.Store(times + 1) + return REJECT, ErrInvalidUpstream + } + + return REJECT, ErrNoProxyProtocol + }, + } + + // Kick off the listener and return any error via the chanel. + errCh := make(chan error) + defer close(errCh) + go func(t *testing.T) { + _, err := newLn.Accept() + errCh <- err + }(t) + + // Make two calls to trigger the listener's accept, the first should experience + // the ErrInvalidUpstream and keep the listener open, the second should experience + // a different error which will cause the listener to close. + _, _ = http.Get("http://localhost:8080") + // Wait a few seconds to ensure we didn't get anything back on our channel. + select { + case err := <-errCh: + if err != nil { + t.Fatalf("invalid upstream shouldn't return an error: %v", err) + } + case <-time.After(2 * time.Second): + // No error returned (as expected, we're still listening though) + } + + _, _ = http.Get("http://localhost:8080") + // Wait a few seconds before we fail the test as we should have received an + // error that was not invalid upstream. + select { + case err := <-errCh: + if err == nil { + t.Fatalf("errors other than invalid upstream should error") + } + if !errors.Is(ErrNoProxyProtocol, err) { + t.Fatalf("unexpected error type: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for listener") + } +} + type TestTLSServer struct { Listener net.Listener @@ -1482,9 +1544,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) { b, err := io.ReadAll(r) return int64(len(b)), err } + func (c *testConn) Write(p []byte) (int, error) { return len(p), nil } + func (c *testConn) Read(p []byte) (int, error) { if c.reads == 0 { return 0, io.EOF @@ -1533,7 +1597,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { } func benchmarkTCPProxy(size int, b *testing.B) { - //create and start the echo backend + // create and start the echo backend backend, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatalf("err: %v", err) @@ -1554,7 +1618,7 @@ func benchmarkTCPProxy(size int, b *testing.B) { } }() - //start the proxyprotocol enabled tcp proxy + // start the proxyprotocol enabled tcp proxy l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatalf("err: %v", err) @@ -1603,7 +1667,7 @@ func benchmarkTCPProxy(size int, b *testing.B) { }, } - //now for the actual benchmark + // now for the actual benchmark b.ResetTimer() for n := 0; n < b.N; n++ { conn, err := net.Dial("tcp", pl.Addr().String()) @@ -1614,16 +1678,15 @@ func benchmarkTCPProxy(size int, b *testing.B) { if _, err := header.WriteTo(conn); err != nil { b.Fatalf("err: %v", err) } - //send data + // send data go func() { _, err = conn.Write(data) _ = conn.(*net.TCPConn).CloseWrite() if err != nil { panic(fmt.Sprintf("Failed to write data: %v", err)) } - }() - //receive data + // receive data n, err := io.Copy(io.Discard, conn) if n != int64(len(data)) { b.Fatalf("Expected to receive %d bytes, got %d", len(data), n) @@ -1638,24 +1701,31 @@ func benchmarkTCPProxy(size int, b *testing.B) { func BenchmarkTCPProxy16KB(b *testing.B) { benchmarkTCPProxy(16*1024, b) } + func BenchmarkTCPProxy32KB(b *testing.B) { benchmarkTCPProxy(32*1024, b) } + func BenchmarkTCPProxy64KB(b *testing.B) { benchmarkTCPProxy(64*1024, b) } + func BenchmarkTCPProxy128KB(b *testing.B) { benchmarkTCPProxy(128*1024, b) } + func BenchmarkTCPProxy256KB(b *testing.B) { benchmarkTCPProxy(256*1024, b) } + func BenchmarkTCPProxy512KB(b *testing.B) { benchmarkTCPProxy(512*1024, b) } + func BenchmarkTCPProxy1024KB(b *testing.B) { benchmarkTCPProxy(1024*1024, b) } + func BenchmarkTCPProxy2048KB(b *testing.B) { benchmarkTCPProxy(2048*1024, b) }