|
|
@@ -8,6 +8,8 @@ import (
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
+ "github.com/gaissmai/bart"
|
|
|
+ "github.com/sirupsen/logrus"
|
|
|
"github.com/slackhq/nebula/cert"
|
|
|
"github.com/slackhq/nebula/config"
|
|
|
"github.com/slackhq/nebula/firewall"
|
|
|
@@ -149,7 +151,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|
|
l := test.NewLogger()
|
|
|
ob := &bytes.Buffer{}
|
|
|
l.SetOutput(ob)
|
|
|
-
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
|
p := firewall.Packet{
|
|
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
|
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
|
|
@@ -174,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
|
|
}
|
|
|
- h.buildNetworks(c.networks, c.unsafeNetworks)
|
|
|
+ h.buildNetworks(myVpnNetworksTable, &c)
|
|
|
|
|
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
|
@@ -226,6 +229,9 @@ func TestFirewall_DropV6(t *testing.T) {
|
|
|
ob := &bytes.Buffer{}
|
|
|
l.SetOutput(ob)
|
|
|
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
|
+
|
|
|
p := firewall.Packet{
|
|
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
|
|
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
|
|
@@ -250,7 +256,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
|
|
}
|
|
|
- h.buildNetworks(c.networks, c.unsafeNetworks)
|
|
|
+ h.buildNetworks(myVpnNetworksTable, &c)
|
|
|
|
|
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
|
@@ -453,6 +459,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|
|
l := test.NewLogger()
|
|
|
ob := &bytes.Buffer{}
|
|
|
l.SetOutput(ob)
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
|
|
|
|
p := firewall.Packet{
|
|
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
|
|
@@ -478,7 +486,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{network.Addr()},
|
|
|
}
|
|
|
- h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
|
|
+ h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
|
|
|
|
|
c1 := cert.CachedCertificate{
|
|
|
Certificate: &dummyCert{
|
|
|
@@ -493,7 +501,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|
|
peerCert: &c1,
|
|
|
},
|
|
|
}
|
|
|
- h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
|
|
+ h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
|
|
|
|
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
|
@@ -510,6 +518,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|
|
l := test.NewLogger()
|
|
|
ob := &bytes.Buffer{}
|
|
|
l.SetOutput(ob)
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
|
|
|
|
p := firewall.Packet{
|
|
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
|
|
@@ -541,7 +551,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{network.Addr()},
|
|
|
}
|
|
|
- h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
|
|
+ h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
|
|
|
|
|
c2 := cert.CachedCertificate{
|
|
|
Certificate: &dummyCert{
|
|
|
@@ -556,7 +566,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{network.Addr()},
|
|
|
}
|
|
|
- h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
|
|
+ h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
|
|
|
|
|
|
c3 := cert.CachedCertificate{
|
|
|
Certificate: &dummyCert{
|
|
|
@@ -571,7 +581,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{network.Addr()},
|
|
|
}
|
|
|
- h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
|
|
+ h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
|
|
|
|
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
|
@@ -597,6 +607,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|
|
l := test.NewLogger()
|
|
|
ob := &bytes.Buffer{}
|
|
|
l.SetOutput(ob)
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
|
|
|
|
|
p := firewall.Packet{
|
|
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
|
|
@@ -620,7 +632,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{network.Addr()},
|
|
|
}
|
|
|
- h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
|
|
+ h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
|
|
|
|
|
// Test a remote address match
|
|
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
|
@@ -633,6 +645,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|
|
l := test.NewLogger()
|
|
|
ob := &bytes.Buffer{}
|
|
|
l.SetOutput(ob)
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
|
|
|
|
|
p := firewall.Packet{
|
|
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
|
|
@@ -659,7 +673,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{network.Addr()},
|
|
|
}
|
|
|
- h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
|
|
+ h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
|
|
|
|
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
|
@@ -696,6 +710,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|
|
l := test.NewLogger()
|
|
|
ob := &bytes.Buffer{}
|
|
|
l.SetOutput(ob)
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
|
|
|
|
|
c := cert.CachedCertificate{
|
|
|
Certificate: &dummyCert{
|
|
|
@@ -717,7 +733,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|
|
},
|
|
|
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
|
|
}
|
|
|
- h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
|
|
+ h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
|
|
|
|
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
|
|
|
|
|
@@ -1047,6 +1063,171 @@ func TestFirewall_convertRule(t *testing.T) {
|
|
|
assert.Equal(t, "group1", r.Group)
|
|
|
}
|
|
|
|
|
|
+type testcase struct {
|
|
|
+ h *HostInfo
|
|
|
+ p firewall.Packet
|
|
|
+ c cert.Certificate
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
+func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|
|
+ t.Helper()
|
|
|
+ cp := cert.NewCAPool()
|
|
|
+ resetConntrack(fw)
|
|
|
+ err := fw.Drop(c.p, true, c.h, cp, nil)
|
|
|
+ if c.err == nil {
|
|
|
+ require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
|
|
+ } else {
|
|
|
+ require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
|
|
+ c1 := dummyCert{
|
|
|
+ name: "host1",
|
|
|
+ networks: theirPrefixes,
|
|
|
+ groups: []string{"default-group"},
|
|
|
+ issuer: "signer-shasum",
|
|
|
+ }
|
|
|
+ h := HostInfo{
|
|
|
+ ConnectionState: &ConnectionState{
|
|
|
+ peerCert: &cert.CachedCertificate{
|
|
|
+ Certificate: &c1,
|
|
|
+ InvertedGroups: map[string]struct{}{"default-group": {}},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
|
|
|
+ }
|
|
|
+ for i := range theirPrefixes {
|
|
|
+ h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
|
|
+ }
|
|
|
+ h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
|
|
+ p := firewall.Packet{
|
|
|
+ LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
|
|
+ RemoteAddr: theirPrefixes[0].Addr(),
|
|
|
+ LocalPort: 10,
|
|
|
+ RemotePort: 90,
|
|
|
+ Protocol: firewall.ProtoUDP,
|
|
|
+ Fragment: false,
|
|
|
+ }
|
|
|
+ return testcase{
|
|
|
+ h: &h,
|
|
|
+ p: p,
|
|
|
+ c: &c1,
|
|
|
+ err: err,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type testsetup struct {
|
|
|
+ c dummyCert
|
|
|
+ myVpnNetworksTable *bart.Lite
|
|
|
+ fw *Firewall
|
|
|
+}
|
|
|
+
|
|
|
+func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
|
|
|
+ c := dummyCert{
|
|
|
+ name: "me",
|
|
|
+ networks: myPrefixes,
|
|
|
+ groups: []string{"default-group"},
|
|
|
+ issuer: "signer-shasum",
|
|
|
+ }
|
|
|
+
|
|
|
+ return newSetupFromCert(t, l, c)
|
|
|
+}
|
|
|
+
|
|
|
+func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
|
|
+ myVpnNetworksTable := new(bart.Lite)
|
|
|
+ for _, prefix := range c.Networks() {
|
|
|
+ myVpnNetworksTable.Insert(prefix)
|
|
|
+ }
|
|
|
+ fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
|
|
+ require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
|
|
+
|
|
|
+ return testsetup{
|
|
|
+ c: c,
|
|
|
+ fw: fw,
|
|
|
+ myVpnNetworksTable: myVpnNetworksTable,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ l := test.NewLogger()
|
|
|
+ ob := &bytes.Buffer{}
|
|
|
+ l.SetOutput(ob)
|
|
|
+
|
|
|
+ myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
|
|
+ // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
|
|
|
+ t.Run("allow inbound all matching", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ setup := newSetup(t, l, myPrefix)
|
|
|
+ tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
|
|
|
+ tc.Test(t, setup.fw)
|
|
|
+ })
|
|
|
+ t.Run("allow inbound local matching", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ setup := newSetup(t, l, myPrefix)
|
|
|
+ tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
|
+ tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
|
|
|
+ tc.Test(t, setup.fw)
|
|
|
+ })
|
|
|
+ t.Run("block inbound remote mismatched", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ setup := newSetup(t, l, myPrefix)
|
|
|
+ tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
|
|
|
+ tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
|
+ tc.Test(t, setup.fw)
|
|
|
+ })
|
|
|
+ t.Run("Block a vpn peer packet", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ setup := newSetup(t, l, myPrefix)
|
|
|
+ tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
|
|
|
+ tc.Test(t, setup.fw)
|
|
|
+ })
|
|
|
+ twoPrefixes := []netip.Prefix{
|
|
|
+ netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
|
|
|
+ }
|
|
|
+ t.Run("allow inbound one matching", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ setup := newSetup(t, l, myPrefix)
|
|
|
+ tc := buildTestCase(setup, nil, twoPrefixes...)
|
|
|
+ tc.Test(t, setup.fw)
|
|
|
+ })
|
|
|
+ t.Run("block inbound multimismatch", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ setup := newSetup(t, l, myPrefix)
|
|
|
+ tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
|
|
|
+ tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
|
|
|
+ tc.Test(t, setup.fw)
|
|
|
+ })
|
|
|
+ t.Run("allow inbound 2nd one matching", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
|
|
|
+ tc := buildTestCase(setup2, nil, twoPrefixes...)
|
|
|
+ tc.p.RemoteAddr = twoPrefixes[1].Addr()
|
|
|
+ tc.Test(t, setup2.fw)
|
|
|
+ })
|
|
|
+ t.Run("allow inbound unsafe route", func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
|
|
+ c := dummyCert{
|
|
|
+ name: "me",
|
|
|
+ networks: []netip.Prefix{myPrefix},
|
|
|
+ unsafeNetworks: []netip.Prefix{unsafePrefix},
|
|
|
+ groups: []string{"default-group"},
|
|
|
+ issuer: "signer-shasum",
|
|
|
+ }
|
|
|
+ unsafeSetup := newSetupFromCert(t, l, c)
|
|
|
+ tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
|
|
+ tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
|
|
+ tc.err = ErrNoMatchingRule
|
|
|
+ tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
|
|
+ require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
|
|
|
+ tc.err = nil
|
|
|
+ tc.Test(t, unsafeSetup.fw) //should pass
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
type addRuleCall struct {
|
|
|
incoming bool
|
|
|
proto uint8
|