Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep listener after erroring with invalid upstream #117

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 59 additions & 43 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@ package proxyproto

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)

// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
var DefaultReadHeaderTimeout = 10 * time.Second
var (
// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
DefaultReadHeaderTimeout = 10 * time.Second

// ErrInvalidUpstream should be returned when an upstream connection address
// is not trusted, and therefore is invalid.
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
)

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
Expand Down Expand Up @@ -73,53 +81,61 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
}
}

// Accept waits for and returns the next connection to the listener.
// Accept waits for and returns the next valid connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
for {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()

newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)
if errors.Is(err, ErrInvalidUpstream) {
// keep listening for other connections
continue
}

// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil
}
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout
newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)

return newConn, nil
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout

return newConn, nil
}
}

// Close closes the underlying listener.
Expand Down
86 changes: 78 additions & 8 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"fmt"
"io"
"net"
"net/http"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -82,7 +84,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
start := time.Now()

l, err := net.Listen("tcp", "127.0.0.1:0")

if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -137,7 +138,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
start := time.Now()

l, err := net.Listen("tcp", "127.0.0.1:0")

if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -847,6 +847,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
t.Fatalf("client error: %v", err)
}
}

func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
Expand Down Expand Up @@ -1274,6 +1275,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
}
}

func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8080")
if err != nil {
t.Fatalf("error creating listener: %v", err)
}

var connectionCounter atomic.Int32

newLn := &Listener{
Listener: l,
ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) {
// Return the invalid upstream error on the first call, the listener
// should remain open and accepting.
times := connectionCounter.Load()
if times == 0 {
connectionCounter.Store(times + 1)
return REJECT, ErrInvalidUpstream
}

return REJECT, ErrNoProxyProtocol
},
}

// Kick off the listener and return any error via the chanel.
errCh := make(chan error)
defer close(errCh)
go func(t *testing.T) {
_, err := newLn.Accept()
errCh <- err
}(t)

// Make two calls to trigger the listener's accept, the first should experience
// the ErrInvalidUpstream and keep the listener open, the second should experience
// a different error which will cause the listener to close.
_, _ = http.Get("http://localhost:8080")
// Wait a few seconds to ensure we didn't get anything back on our channel.
select {
case err := <-errCh:
if err != nil {
t.Fatalf("invalid upstream shouldn't return an error: %v", err)
}
case <-time.After(2 * time.Second):
// No error returned (as expected, we're still listening though)
}

_, _ = http.Get("http://localhost:8080")
// Wait a few seconds before we fail the test as we should have received an
// error that was not invalid upstream.
select {
case err := <-errCh:
if err == nil {
t.Fatalf("errors other than invalid upstream should error")
}
if !errors.Is(ErrNoProxyProtocol, err) {
t.Fatalf("unexpected error type: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("timed out waiting for listener")
}
}

type TestTLSServer struct {
Listener net.Listener

Expand Down Expand Up @@ -1482,9 +1544,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
b, err := io.ReadAll(r)
return int64(len(b)), err
}

func (c *testConn) Write(p []byte) (int, error) {
return len(p), nil
}

func (c *testConn) Read(p []byte) (int, error) {
if c.reads == 0 {
return 0, io.EOF
Expand Down Expand Up @@ -1533,7 +1597,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
}

func benchmarkTCPProxy(size int, b *testing.B) {
//create and start the echo backend
// create and start the echo backend
backend, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
Expand All @@ -1554,7 +1618,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
}
}()

//start the proxyprotocol enabled tcp proxy
// start the proxyprotocol enabled tcp proxy
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
Expand Down Expand Up @@ -1603,7 +1667,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
},
}

//now for the actual benchmark
// now for the actual benchmark
b.ResetTimer()
for n := 0; n < b.N; n++ {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand All @@ -1614,16 +1678,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
if _, err := header.WriteTo(conn); err != nil {
b.Fatalf("err: %v", err)
}
//send data
// send data
go func() {
_, err = conn.Write(data)
_ = conn.(*net.TCPConn).CloseWrite()
if err != nil {
panic(fmt.Sprintf("Failed to write data: %v", err))
}

}()
//receive data
// receive data
n, err := io.Copy(io.Discard, conn)
if n != int64(len(data)) {
b.Fatalf("Expected to receive %d bytes, got %d", len(data), n)
Expand All @@ -1638,24 +1701,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
func BenchmarkTCPProxy16KB(b *testing.B) {
benchmarkTCPProxy(16*1024, b)
}

func BenchmarkTCPProxy32KB(b *testing.B) {
benchmarkTCPProxy(32*1024, b)
}

func BenchmarkTCPProxy64KB(b *testing.B) {
benchmarkTCPProxy(64*1024, b)
}

func BenchmarkTCPProxy128KB(b *testing.B) {
benchmarkTCPProxy(128*1024, b)
}

func BenchmarkTCPProxy256KB(b *testing.B) {
benchmarkTCPProxy(256*1024, b)
}

func BenchmarkTCPProxy512KB(b *testing.B) {
benchmarkTCPProxy(512*1024, b)
}

func BenchmarkTCPProxy1024KB(b *testing.B) {
benchmarkTCPProxy(1024*1024, b)
}

func BenchmarkTCPProxy2048KB(b *testing.B) {
benchmarkTCPProxy(2048*1024, b)
}
Expand Down
Loading