Skip to content

Commit

Permalink
keep listener after erroring with invalid upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
pires committed Oct 8, 2024
1 parent 9814f02 commit 10741be
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 51 deletions.
102 changes: 59 additions & 43 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@ package proxyproto

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)

// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
var DefaultReadHeaderTimeout = 10 * time.Second
var (
// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
DefaultReadHeaderTimeout = 10 * time.Second

// ErrInvalidUpstream should be returned when an upstream connection address
// is not trusted, and therefore is invalid.
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
)

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
Expand Down Expand Up @@ -72,53 +80,61 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
}
}

// Accept waits for and returns the next connection to the listener.
// Accept waits for and returns the next valid connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
for {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()

newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)
if errors.Is(err, ErrInvalidUpstream) {
// keep listening for other connections
continue
}

// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil
}
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout
newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)

return newConn, nil
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout

return newConn, nil
}
}

// Close closes the underlying listener.
Expand Down
86 changes: 78 additions & 8 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"io"
"io/ioutil"

Check failure on line 14 in protocol_test.go

View workflow job for this annotation

GitHub Actions / lint (1.19)

SA1019: "io/ioutil" has been deprecated since Go 1.19: As of Go 1.16, the same functionality is now provided by package io or package os, and those implementations should be preferred in new code. See the specific function documentation for details. (staticcheck)

Check failure on line 14 in protocol_test.go

View workflow job for this annotation

GitHub Actions / lint (1.20)

SA1019: "io/ioutil" has been deprecated since Go 1.19: As of Go 1.16, the same functionality is now provided by package io or package os, and those implementations should be preferred in new code. See the specific function documentation for details. (staticcheck)
"net"
"net/http"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -83,7 +85,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
start := time.Now()

l, err := net.Listen("tcp", "127.0.0.1:0")

if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -138,7 +139,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
start := time.Now()

l, err := net.Listen("tcp", "127.0.0.1:0")

if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -848,6 +848,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
t.Fatalf("client error: %v", err)
}
}

func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
Expand Down Expand Up @@ -1275,6 +1276,67 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
}
}

func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8080")
if err != nil {
t.Fatalf("error creating listener: %v", err)
}

var connectionCounter atomic.Int32

newLn := &Listener{
Listener: l,
ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) {
// Return the invalid upstream error on the first call, the listener
// should remain open and accepting.
times := connectionCounter.Load()
if times == 0 {
connectionCounter.Store(times + 1)
return REJECT, ErrInvalidUpstream
}

return REJECT, ErrNoProxyProtocol
},
}

// Kick off the listener and return any error via the chanel.
errCh := make(chan error)
defer close(errCh)
go func(t *testing.T) {
_, err := newLn.Accept()
errCh <- err
}(t)

// Make two calls to trigger the listener's accept, the first should experience
// the ErrInvalidUpstream and keep the listener open, the second should experience
// a different error which will cause the listener to close.
_, _ = http.Get("http://localhost:8080")
// Wait a few seconds to ensure we didn't get anything back on our channel.
select {
case err := <-errCh:
if err != nil {
t.Fatalf("invalid upstream shouldn't return an error: %v", err)
}
case <-time.After(2 * time.Second):
// No error returned (as expected, we're still listening though)
}

_, _ = http.Get("http://localhost:8080")
// Wait a few seconds before we fail the test as we should have received an
// error that was not invalid upstream.
select {
case err := <-errCh:
if err == nil {
t.Fatalf("errors other than invalid upstream should error")
}
if !errors.Is(ErrNoProxyProtocol, err) {
t.Fatalf("unexpected error type: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("timed out waiting for listener")
}
}

type TestTLSServer struct {
Listener net.Listener

Expand Down Expand Up @@ -1483,9 +1545,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
b, err := ioutil.ReadAll(r)
return int64(len(b)), err
}

func (c *testConn) Write(p []byte) (int, error) {
return len(p), nil
}

func (c *testConn) Read(p []byte) (int, error) {
if c.reads == 0 {
return 0, io.EOF
Expand Down Expand Up @@ -1534,7 +1598,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
}

func benchmarkTCPProxy(size int, b *testing.B) {
//create and start the echo backend
// create and start the echo backend
backend, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
Expand All @@ -1555,7 +1619,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
}
}()

//start the proxyprotocol enabled tcp proxy
// start the proxyprotocol enabled tcp proxy
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
Expand Down Expand Up @@ -1604,7 +1668,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
},
}

//now for the actual benchmark
// now for the actual benchmark
b.ResetTimer()
for n := 0; n < b.N; n++ {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand All @@ -1615,16 +1679,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
if _, err := header.WriteTo(conn); err != nil {
b.Fatalf("err: %v", err)
}
//send data
// send data
go func() {
_, err = conn.Write(data)
_ = conn.(*net.TCPConn).CloseWrite()
if err != nil {
panic(fmt.Sprintf("Failed to write data: %v", err))
}

}()
//receive data
// receive data
n, err := io.Copy(ioutil.Discard, conn)
if n != int64(len(data)) {
b.Fatalf("Expected to receive %d bytes, got %d", len(data), n)
Expand All @@ -1639,24 +1702,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
func BenchmarkTCPProxy16KB(b *testing.B) {
benchmarkTCPProxy(16*1024, b)
}

func BenchmarkTCPProxy32KB(b *testing.B) {
benchmarkTCPProxy(32*1024, b)
}

func BenchmarkTCPProxy64KB(b *testing.B) {
benchmarkTCPProxy(64*1024, b)
}

func BenchmarkTCPProxy128KB(b *testing.B) {
benchmarkTCPProxy(128*1024, b)
}

func BenchmarkTCPProxy256KB(b *testing.B) {
benchmarkTCPProxy(256*1024, b)
}

func BenchmarkTCPProxy512KB(b *testing.B) {
benchmarkTCPProxy(512*1024, b)
}

func BenchmarkTCPProxy1024KB(b *testing.B) {
benchmarkTCPProxy(1024*1024, b)
}

func BenchmarkTCPProxy2048KB(b *testing.B) {
benchmarkTCPProxy(2048*1024, b)
}
Expand Down

0 comments on commit 10741be

Please sign in to comment.