|
@@ -1042,9 +1042,13 @@ func TestQuotaScansRole(t *testing.T) {
|
|
|
|
|
|
func TestProxyPolicy(t *testing.T) {
|
|
|
addr := net.TCPAddr{}
|
|
|
+ downstream := net.TCPAddr{IP: net.ParseIP("1.1.1.1")}
|
|
|
p := getProxyPolicy(nil, nil, proxyproto.IGNORE)
|
|
|
- policy, err := p(&addr)
|
|
|
- assert.Error(t, err)
|
|
|
+ policy, err := p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &addr,
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
+ assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream)
|
|
|
assert.Equal(t, proxyproto.REJECT, policy)
|
|
|
ip1 := net.ParseIP("10.8.1.1")
|
|
|
ip2 := net.ParseIP("10.8.1.2")
|
|
@@ -1054,30 +1058,54 @@ func TestProxyPolicy(t *testing.T) {
|
|
|
skipped, err := util.ParseAllowedIPAndRanges([]string{ip2.String(), ip3.String()})
|
|
|
assert.NoError(t, err)
|
|
|
p = getProxyPolicy(allowed, skipped, proxyproto.IGNORE)
|
|
|
- policy, err = p(&net.TCPAddr{IP: ip1})
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: ip1},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
assert.NoError(t, err)
|
|
|
assert.Equal(t, proxyproto.USE, policy)
|
|
|
- policy, err = p(&net.TCPAddr{IP: ip2})
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: ip2},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
assert.NoError(t, err)
|
|
|
assert.Equal(t, proxyproto.SKIP, policy)
|
|
|
- policy, err = p(&net.TCPAddr{IP: ip3})
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: ip3},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
assert.NoError(t, err)
|
|
|
assert.Equal(t, proxyproto.SKIP, policy)
|
|
|
- policy, err = p(&net.TCPAddr{IP: net.ParseIP("10.8.1.4")})
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.4")},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
assert.NoError(t, err)
|
|
|
assert.Equal(t, proxyproto.IGNORE, policy)
|
|
|
p = getProxyPolicy(allowed, skipped, proxyproto.REQUIRE)
|
|
|
- policy, err = p(&net.TCPAddr{IP: ip1})
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: ip1},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
assert.NoError(t, err)
|
|
|
assert.Equal(t, proxyproto.REQUIRE, policy)
|
|
|
- policy, err = p(&net.TCPAddr{IP: ip2})
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: ip2},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
assert.NoError(t, err)
|
|
|
assert.Equal(t, proxyproto.SKIP, policy)
|
|
|
- policy, err = p(&net.TCPAddr{IP: ip3})
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: ip3},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
assert.NoError(t, err)
|
|
|
assert.Equal(t, proxyproto.SKIP, policy)
|
|
|
- policy, err = p(&net.TCPAddr{IP: net.ParseIP("10.8.1.5")})
|
|
|
- assert.NoError(t, err)
|
|
|
+ policy, err = p(proxyproto.ConnPolicyOptions{
|
|
|
+ Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.5")},
|
|
|
+ Downstream: &downstream,
|
|
|
+ })
|
|
|
+ assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream)
|
|
|
assert.Equal(t, proxyproto.REJECT, policy)
|
|
|
}
|
|
|
|
|
@@ -1094,14 +1122,14 @@ func TestProxyProtocolVersion(t *testing.T) {
|
|
|
assert.NoError(t, err)
|
|
|
proxyListener, ok := listener.(*proxyproto.Listener)
|
|
|
require.True(t, ok)
|
|
|
- assert.NotNil(t, proxyListener.Policy)
|
|
|
+ assert.NotNil(t, proxyListener.ConnPolicy)
|
|
|
|
|
|
c.ProxyProtocol = 2
|
|
|
listener, err = c.GetProxyListener(nil)
|
|
|
assert.NoError(t, err)
|
|
|
proxyListener, ok = listener.(*proxyproto.Listener)
|
|
|
require.True(t, ok)
|
|
|
- assert.NotNil(t, proxyListener.Policy)
|
|
|
+ assert.NotNil(t, proxyListener.ConnPolicy)
|
|
|
}
|
|
|
|
|
|
func TestStartupHook(t *testing.T) {
|