Skip to content

Commit

Permalink
Add support for validating the downstream ip of the connection
Browse files Browse the repository at this point in the history
  • Loading branch information
kmala authored and pires committed Apr 19, 2024
1 parent 8a2480a commit b455b79
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 25 deletions.
30 changes: 25 additions & 5 deletions policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ import (
"strings"
)

// PolicyFunc can be used to decide whether to trust the PROXY info from
// upstream. If set, the connecting address is passed in as an argument.
// PolicyFunc can be used to decide whether to trust the PROXY info based on
// upstream/downstream IP. If set, the connecting addresses(remote and local)
// are passed in as arguments.
//
// See below for the different policies.
//
// In case an error is returned the connection is denied.
type PolicyFunc func(upstream net.Addr) (Policy, error)
type PolicyFunc func(upstream net.Addr, downstream net.Addr) (Policy, error)

// Policy defines how a connection with a PROXY header address is treated.
type Policy int
Expand Down Expand Up @@ -43,7 +44,7 @@ const (
// Kubernetes pods local traffic. The def is a policy to use when an upstream
// address doesn't match the skipHeaderCIDR.
func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
return func(upstream net.Addr) (Policy, error) {
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
ip, err := ipFromAddr(upstream)
if err != nil {
return def, err
Expand All @@ -57,6 +58,25 @@ func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
}
}

// IgnoreProxyHeaderNotOnInterface retuns a PolicyFunc which can be used to
// decide whether to use or ignore PROXY headers depending on the connection
// being made on a specific interface. This policy can be used when the server
// is bound to multiple interfaces but wants to allow on only one interface.
func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) PolicyFunc {
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
ip, err := ipFromAddr(downstream)
if err != nil {
return REJECT, err
}

if allowedIP.Equal(ip) {
return USE, nil
}

return IGNORE, nil
}
}

// WithPolicy adds given policy to a connection when passed as option to NewConn()
func WithPolicy(p Policy) func(*Conn) {
return func(c *Conn) {
Expand Down Expand Up @@ -117,7 +137,7 @@ func MustStrictWhiteListPolicy(allowed []string) PolicyFunc {
}

func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc {
return func(upstream net.Addr) (Policy, error) {
return func(upstream net.Addr, downstream net.Addr) (Policy, error) {
upstreamIP, err := ipFromAddr(upstream)
if err != nil {
// something is wrong with the source IP, better reject the connection
Expand Down
50 changes: 43 additions & 7 deletions policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.policy(failingAddr{})
_, err := tc.policy(failingAddr{}, nil)
if err == nil {
t.Fatal("Expected error, got none")
}
Expand All @@ -37,7 +37,7 @@ func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *t
t.Fatalf("err: %v", err)
}

policy, err := p(upstream)
policy, err := p(upstream, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand All @@ -55,7 +55,7 @@ func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *test
t.Fatalf("err: %v", err)
}

policy, err := p(upstream)
policy, err := p(upstream, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand All @@ -81,7 +81,7 @@ func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelist(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
policy, err := tc.policy(upstream)
policy, err := tc.policy(upstream, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing.

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
policy, err := tc.policy(upstream)
policy, err := tc.policy(upstream, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
f := SkipProxyHeaderForCIDR(cidr, REJECT)

upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345")
policy, err := f(upstream)
policy, err := f(upstream, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand All @@ -203,11 +203,47 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
}

upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345")
policy, err = f(upstream)
policy, err = f(upstream, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if policy != REJECT {
t.Errorf("Expected a REJECT policy for the %s address", upstream)
}
}

func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
downstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738")
if err != nil {
t.Fatalf("err: %v", err)
}

var cases = []struct {
name string
policy PolicyFunc
downstreamAddress net.Addr
expectedPolicy Policy
expectError bool
}{
{"ignore header for requests non on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("192.0.2.1")), downstream, IGNORE, false},
{"use headers for requests on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), downstream, USE, false},
{"invalid address should return error", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
policy, err := tc.policy(nil, tc.downstreamAddress)
if !tc.expectError && err != nil {
t.Fatalf("err: %v", err)
}
if tc.expectError && err == nil {
t.Fatal("Expected error, got none")
}

if policy != tc.expectedPolicy {
t.Fatalf("Expected policy %v, got %v", tc.expectedPolicy, policy)
}
})
}

}
2 changes: 1 addition & 1 deletion protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (p *Listener) Accept() (net.Conn, error) {

proxyHeaderPolicy := USE
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr(), conn.LocalAddr())
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()
Expand Down
24 changes: 12 additions & 12 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
pl := &Listener{
Listener: l,
ReadHeaderTimeout: time.Millisecond * time.Duration(duration),
Policy: func(upstream net.Addr) (Policy, error) {
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
return REQUIRE, nil
},
}
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
pl := &Listener{
Listener: l,
ReadHeaderTimeout: time.Millisecond * time.Duration(duration),
Policy: func(upstream net.Addr) (Policy, error) {
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
return USE, nil
},
}
Expand Down Expand Up @@ -645,7 +645,7 @@ func TestAcceptReturnsErrorWhenPolicyFuncErrors(t *testing.T) {
}

expectedErr := fmt.Errorf("failure")
policyFunc := func(upstream net.Addr) (Policy, error) { return USE, expectedErr }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return USE, expectedErr }

pl := &Listener{Listener: l, Policy: policyFunc}

Expand Down Expand Up @@ -681,7 +681,7 @@ func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }

pl := &Listener{Listener: l, Policy: policyFunc}

Expand Down Expand Up @@ -724,7 +724,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REJECT, nil }

pl := &Listener{Listener: l, Policy: policyFunc}

Expand Down Expand Up @@ -778,7 +778,7 @@ func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return IGNORE, nil }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return IGNORE, nil }

pl := &Listener{Listener: l, Policy: policyFunc}

Expand Down Expand Up @@ -891,7 +891,7 @@ func TestReadingIsRefusedOnErrorWhenRemoteAddrRequestedFirst(t *testing.T) {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }

pl := &Listener{Listener: l, Policy: policyFunc}

Expand Down Expand Up @@ -935,7 +935,7 @@ func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }

pl := &Listener{Listener: l, Policy: policyFunc}

Expand Down Expand Up @@ -979,7 +979,7 @@ func TestSkipProxyProtocolPolicy(t *testing.T) {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return SKIP, nil }

pl := &Listener{
Listener: l,
Expand Down Expand Up @@ -1036,7 +1036,7 @@ func Test_ConnectionCasts(t *testing.T) {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
policyFunc := func(upstream net.Addr, downstream net.Addr) (Policy, error) { return REQUIRE, nil }

pl := &Listener{Listener: l, Policy: policyFunc}

Expand Down Expand Up @@ -1198,7 +1198,7 @@ func Test_TLSServer(t *testing.T) {
s := NewTestTLSServer(l)
s.Listener = &Listener{
Listener: s.Listener,
Policy: func(upstream net.Addr) (Policy, error) {
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
return REQUIRE, nil
},
}
Expand Down Expand Up @@ -1269,7 +1269,7 @@ func Test_MisconfiguredTLSServerRespondsWithUnderlyingError(t *testing.T) {
s := NewTestTLSServer(l)
s.Listener = &Listener{
Listener: s.Listener,
Policy: func(upstream net.Addr) (Policy, error) {
Policy: func(upstream net.Addr, downstream net.Addr) (Policy, error) {
return REQUIRE, nil
},
}
Expand Down

0 comments on commit b455b79

Please sign in to comment.