Selaa lähdekoodia

firewall can distinguish if the host connecting has an overlapping network, is a VPN peer without an overlapping network, or is a unsafe network

JackDoan 9 tuntia sitten
vanhempi
sitoutus
f597aa71e3
4 muutettua tiedostoa jossa 74 lisäystä ja 26 poistoa
  1. 20 6
      firewall.go
  2. 26 11
      firewall_test.go
  3. 2 2
      handshake_ix.go
  4. 26 7
      hostmap.go

+ 20 - 6
firewall.go

@@ -417,6 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 	return nil
 }
 
+var ErrUnknownNetworkType = errors.New("unknown network type")
+var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle")
 var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
 var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
 var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
@@ -429,19 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 		return nil
 	}
 
-	// TODO if we don't have a network in common with this packet's source IP, (and it's not for an unsafe_network), do we reject it?
-	// Make sure remote address matches nebula certificate
-	if h.networks != nil {
-		if !h.networks.Contains(fp.RemoteAddr) {
+	// Make sure remote address matches nebula certificate, and determine how to treat it
+	if h.networks == nil {
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	} else {
-		// Simple case: Certificate has one address and no unsafe networks
-		if h.vpnAddrs[0] != fp.RemoteAddr {
+		nwType, ok := h.networks.Lookup(fp.RemoteAddr)
+		if !ok {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
+		switch nwType {
+		case NetworkTypeVPN:
+			break // nothing special
+		case NetworkTypeVPNPeer:
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
+			return ErrPeerRejected // reject for now, one day this may have different FW rules
+		case NetworkTypeUnsafe:
+			break // nothing special, one day this may have different FW rules
+		default:
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
+			return ErrUnknownNetworkType //should never happen
+		}
 	}
 
 	// Make sure we are supposed to be handling this local ip address

+ 26 - 11
firewall_test.go

@@ -8,6 +8,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
@@ -149,7 +150,8 @@ func TestFirewall_Drop(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
-
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
 		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -174,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
-	h.buildNetworks(c.networks, c.unsafeNetworks)
+	h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -226,6 +228,9 @@ func TestFirewall_DropV6(t *testing.T) {
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
+
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("fd12::34"),
 		RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -250,7 +255,7 @@ func TestFirewall_DropV6(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
 	}
-	h.buildNetworks(c.networks, c.unsafeNetworks)
+	h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -453,6 +458,8 @@ func TestFirewall_Drop2(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -478,7 +485,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	c1 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -493,7 +500,7 @@ func TestFirewall_Drop2(t *testing.T) {
 			peerCert: &c1,
 		},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -510,6 +517,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -541,7 +550,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	c2 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -556,7 +565,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
+	h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
 
 	c3 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -571,7 +580,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
+	h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -597,6 +606,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("fd12::34"),
@@ -620,7 +631,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	// Test a remote address match
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@@ -633,6 +644,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -659,7 +672,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -696,6 +709,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
 
 	c := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -717,7 +732,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 

+ 2 - 2
handshake_ix.go

@@ -323,7 +323,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
-	hostinfo.buildNetworks(remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
+	hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
@@ -632,7 +632,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	// Build up the radix for the firewall if we have subnets in the cert
 	hostinfo.vpnAddrs = vpnAddrs
-	hostinfo.buildNetworks(remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
+	hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)

+ 26 - 7
hostmap.go

@@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.relayForByIdx[idx] = r
 }
 
+type NetworkType uint8
+
+const (
+	NetworkTypeUnknown NetworkType = iota
+	// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
+	NetworkTypeVPN
+	// NetworkTypeVPNPeer is a network that does not overlap one of our networks
+	NetworkTypeVPNPeer
+	// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
+	NetworkTypeUnsafe
+)
+
 type HostInfo struct {
 	remote          netip.AddrPort
 	remotes         *RemoteList
@@ -224,7 +236,7 @@ type HostInfo struct {
 	vpnAddrs []netip.Addr
 
 	// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
-	networks   *bart.Lite
+	networks   *bart.Table[NetworkType]
 	relayState RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
@@ -728,20 +740,27 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 	return false
 }
 
-func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
+func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) {
 	if len(networks) == 1 && len(unsafeNetworks) == 0 {
-		// Simple case, no CIDRTree needed
-		return
+		if myVpnNetworksTable.Contains(networks[0].Addr()) {
+			return // Simple case, no CIDRTree needed
+		}
 	}
 
-	i.networks = new(bart.Lite)
+	i.networks = new(bart.Table[NetworkType])
 	for _, network := range networks {
+		var nwType NetworkType
+		if myVpnNetworksTable.Contains(network.Addr()) {
+			nwType = NetworkTypeVPN
+		} else {
+			nwType = NetworkTypeVPNPeer
+		}
 		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
-		i.networks.Insert(nprefix)
+		i.networks.Insert(nprefix, nwType)
 	}
 
 	for _, network := range unsafeNetworks {
-		i.networks.Insert(network)
+		i.networks.Insert(network, NetworkTypeUnsafe)
 	}
 }