Skip to content

Commit

Permalink
Add support for validating the downstream ip of the connection
Browse files Browse the repository at this point in the history
  • Loading branch information
kmala committed Apr 18, 2024
1 parent e5b291b commit f446ee1
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
13 changes: 13 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var DefaultReadHeaderTimeout = 10 * time.Second
type Listener struct {
Listener net.Listener
Policy PolicyFunc
DownstreamPolicy PolicyFunc
ValidateHeader Validator
ReadHeaderTimeout time.Duration
}
Expand Down Expand Up @@ -79,6 +80,18 @@ func (p *Listener) Accept() (net.Conn, error) {
return conn, nil
}
}
if p.DownstreamPolicy != nil {
proxyHeaderPolicy, err = p.DownstreamPolicy(conn.LocalAddr())
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil
}
}

newConn := NewConn(
conn,
Expand Down
85 changes: 85 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,91 @@ func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
}
}

func TestIgnoreUpstreamPolicyIgnoresIpFromProxyHeader(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

policyFunc := func(downstream net.Addr) (Policy, error) { return IGNORE, nil }

pl := &Listener{Listener: l, DownstreamPolicy: policyFunc}

cliResult := make(chan error)
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
cliResult <- err
return
}
defer conn.Close()

// Write out the header!
header := &Header{
Version: 2,
Command: PROXY,
TransportProtocol: TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP("20.2.2.2"),
Port: 2000,
},
}
if _, err := header.WriteTo(conn); err != nil {
cliResult <- err
return
}

if _, err := conn.Write([]byte("ping")); err != nil {
cliResult <- err
return
}

recv := make([]byte, 4)
if _, err = conn.Read(recv); err != nil {
cliResult <- err
return
}
if !bytes.Equal(recv, []byte("pong")) {
cliResult <- fmt.Errorf("bad: %v", recv)
return
}

close(cliResult)
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

recv := make([]byte, 4)
if _, err = conn.Read(recv); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}

if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}

// Check the remote addr
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}
}

func Test_AllOptionsAreRecognized(t *testing.T) {
recognizedOpt1 := false
opt1 := func(c *Conn) {
Expand Down

0 comments on commit f446ee1

Please sign in to comment.