Skip to content

Commit

Permalink
keep listener after erroring with invalid upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
pires committed Oct 8, 2024
1 parent 2df67b4 commit b323cec
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 51 deletions.
102 changes: 59 additions & 43 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@ package proxyproto

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)

// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
var DefaultReadHeaderTimeout = 10 * time.Second
var (
// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
DefaultReadHeaderTimeout = 10 * time.Second

// ErrInvalidUpstream should be returned when an upstream connection address
// is not trusted, and therefore is invalid.
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
)

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
Expand Down Expand Up @@ -73,53 +81,61 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
}
}

// Accept waits for and returns the next connection to the listener.
// Accept waits for and returns the next valid connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
for {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()

newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)
if errors.Is(err, ErrInvalidUpstream) {
// keep listening for other connections
continue
}

// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil
}
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout
newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)

return newConn, nil
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout

return newConn, nil
}
}

// Close closes the underlying listener.
Expand Down
86 changes: 78 additions & 8 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"fmt"
"io"
"net"
"net/http"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -82,7 +84,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
start := time.Now()

l, err := net.Listen("tcp", "127.0.0.1:0")

if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -137,7 +138,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
start := time.Now()

l, err := net.Listen("tcp", "127.0.0.1:0")

if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -847,6 +847,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
t.Fatalf("client error: %v", err)
}
}

func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
Expand Down Expand Up @@ -1274,6 +1275,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
}
}

func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8080")
if err != nil {
t.Fatalf("error creating listener: %v", err)
}

var connectionCounter atomic.Int32

newLn := &Listener{
Listener: l,
ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) {
// Return the invalid upstream error on the first call, the listener
// should remain open and accepting.
times := connectionCounter.Load()
if times == 0 {
connectionCounter.Store(times + 1)
return REJECT, ErrInvalidUpstream
}

return REJECT, ErrNoProxyProtocol
},
}

// Kick off the listener and return any error via the chanel.
errCh := make(chan error)
defer close(errCh)
go func(t *testing.T) {
_, err := newLn.Accept()
errCh <- err
}(t)

// Make two calls to trigger the listener's accept, the first should experience
// the ErrInvalidUpstream and keep the listener open, the second should experience
// a different error which will cause the listener to close.
_, _ = http.Get("http://localhost:8080")
// Wait a few seconds to ensure we didn't get anything back on our channel.
select {
case err := <-errCh:
if err != nil {
t.Fatalf("invalid upstream shouldn't return an error: %v", err)
}
case <-time.After(2 * time.Second):
// No error returned (as expected, we're still listening though)
}

_, _ = http.Get("http://localhost:8080")
// Wait a few seconds before we fail the test as we should have received an
// error that was not invalid upstream.
select {
case err := <-errCh:
if err == nil {
t.Fatalf("errors other than invalid upstream should error")
}
if !errors.Is(ErrNoProxyProtocol, err) {
t.Fatalf("unexpected error type: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("timed out waiting for listener")
}
}

type TestTLSServer struct {
Listener net.Listener

Expand Down Expand Up @@ -1482,9 +1544,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
b, err := io.ReadAll(r)
return int64(len(b)), err
}

func (c *testConn) Write(p []byte) (int, error) {
return len(p), nil
}

func (c *testConn) Read(p []byte) (int, error) {
if c.reads == 0 {
return 0, io.EOF
Expand Down Expand Up @@ -1533,7 +1597,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
}

func benchmarkTCPProxy(size int, b *testing.B) {
//create and start the echo backend
// create and start the echo backend
backend, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
Expand All @@ -1554,7 +1618,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
}
}()

//start the proxyprotocol enabled tcp proxy
// start the proxyprotocol enabled tcp proxy
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
Expand Down Expand Up @@ -1603,7 +1667,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
},
}

//now for the actual benchmark
// now for the actual benchmark
b.ResetTimer()
for n := 0; n < b.N; n++ {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand All @@ -1614,16 +1678,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
if _, err := header.WriteTo(conn); err != nil {
b.Fatalf("err: %v", err)
}
//send data
// send data
go func() {
_, err = conn.Write(data)
_ = conn.(*net.TCPConn).CloseWrite()
if err != nil {
panic(fmt.Sprintf("Failed to write data: %v", err))
}

}()
//receive data
// receive data
n, err := io.Copy(io.Discard, conn)
if n != int64(len(data)) {
b.Fatalf("Expected to receive %d bytes, got %d", len(data), n)
Expand All @@ -1638,24 +1701,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
func BenchmarkTCPProxy16KB(b *testing.B) {
benchmarkTCPProxy(16*1024, b)
}

func BenchmarkTCPProxy32KB(b *testing.B) {
benchmarkTCPProxy(32*1024, b)
}

func BenchmarkTCPProxy64KB(b *testing.B) {
benchmarkTCPProxy(64*1024, b)
}

func BenchmarkTCPProxy128KB(b *testing.B) {
benchmarkTCPProxy(128*1024, b)
}

func BenchmarkTCPProxy256KB(b *testing.B) {
benchmarkTCPProxy(256*1024, b)
}

func BenchmarkTCPProxy512KB(b *testing.B) {
benchmarkTCPProxy(512*1024, b)
}

func BenchmarkTCPProxy1024KB(b *testing.B) {
benchmarkTCPProxy(1024*1024, b)
}

func BenchmarkTCPProxy2048KB(b *testing.B) {
benchmarkTCPProxy(2048*1024, b)
}
Expand Down

0 comments on commit b323cec

Please sign in to comment.