Browse Source

Firewall types and cross-stack subnet stuff (#1509)

* firewall can distinguish if the host connecting has an overlapping network, is a VPN peer without an overlapping network, or is a unsafe network

* Cross stack subnet stuff (#1512)

* experiment with not filtering out non-common addresses in hostinfo.networks

* allow handshakes without overlaps

* unsafe network test

* change HostInfo.buildNetworks argument to reference the cert
Jack Doan 3 weeks ago
parent
commit
a89f95182c
11 changed files with 580 additions and 114 deletions
  1. 4 0
      control_tester.go
  2. 170 0
      e2e/handshakes_test.go
  3. 52 10
      e2e/helpers_test.go
  4. 47 0
      e2e/tunnels_test.go
  5. 22 7
      firewall.go
  6. 192 11
      firewall_test.go
  7. 52 66
      handshake_ix.go
  8. 2 2
      handshake_manager.go
  9. 29 11
      hostmap.go
  10. 6 5
      inside.go
  11. 4 2
      lighthouse.go

+ 4 - 0
control_tester.go

@@ -174,6 +174,10 @@ func (c *Control) GetHostmap() *HostMap {
 	return c.f.hostMap
 }
 
+func (c *Control) GetF() *Interface {
+	return c.f
+}
+
 func (c *Control) GetCertState() *CertState {
 	return c.f.pki.getCertState()
 }

+ 170 - 0
e2e/handshakes_test.go

@@ -97,6 +97,41 @@ func TestGoodHandshake(t *testing.T) {
 	theirControl.Stop()
 }
 
+func TestGoodHandshakeNoOverlap(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
+
+	// Put their info in our lighthouse
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	empty := []byte{}
+	t.Log("do something to cause a handshake")
+	myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
+
+	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
+	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
+
+	t.Log("Get their stage 1 packet")
+	stage1Packet := theirControl.GetFromUDP(true)
+
+	t.Log("Have me consume their stage 1 packet. I have a tunnel now")
+	myControl.InjectUDPPacket(stage1Packet)
+
+	t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
+	myControl.WaitForType(header.Test, 0, theirControl)
+
+	t.Log("Make sure our host infos are correct")
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
 func TestWrongResponderHandshake(t *testing.T) {
 	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 
@@ -464,6 +499,35 @@ func TestRelays(t *testing.T) {
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 }
 
+func TestRelaysDontCareAboutIps(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay  ", "2001::9999/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
+	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
+}
+
 func TestReestablishRelays(t *testing.T) {
 	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -1227,3 +1291,109 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh  ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
+
+	o := m{
+		"static_host_map": m{
+			lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
+		},
+		"lighthouse": m{
+			"hosts": []string{lhVpnIpNet[0].Addr().String()},
+			"local_allow_list": m{
+				// Try and block our lighthouse updates from using the actual addresses assigned to this computer
+				// If we start discovering addresses the test router doesn't know about then test traffic cant flow
+				"10.0.0.0/24": true,
+				"::/0":        false,
+			},
+		},
+	}
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me  ", "10.128.0.2/24, ff::2/64", o)
+	theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, lhControl, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	lhControl.Start()
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Stand up an ipv6 tunnel between me and them")
+	assert.True(t, myVpnIpNet[1].Addr().Is6())
+	assert.True(t, theirVpnIpNet[1].Addr().Is6())
+	assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
+
+	lhControl.Stop()
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestGoodHandshakeUnsafeDest(t *testing.T) {
+	unsafePrefix := "192.168.6.0/24"
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
+	route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
+	myCfg := m{
+		"tun": m{
+			"unsafe_routes": []m{route},
+		},
+	}
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
+	t.Logf("my config %v", myConfig)
+	// Put their info in our lighthouse
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	spookyDest := netip.MustParseAddr("192.168.6.4")
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
+	myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+
+	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
+	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
+
+	t.Log("Get their stage 1 packet so that we can play with it")
+	stage1Packet := theirControl.GetFromUDP(true)
+
+	t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
+	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
+	badPacket := stage1Packet.Copy()
+	badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
+	myControl.InjectUDPPacket(badPacket)
+
+	t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
+	myControl.InjectUDPPacket(stage1Packet)
+
+	t.Log("Wait until we see my cached packet come through")
+	myControl.WaitForType(1, 0, theirControl)
+
+	t.Log("Make sure our host infos are correct")
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
+
+	t.Log("Get that cached packet and make sure it looks right")
+	myCachedPacket := theirControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
+
+	//reply
+	theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
+	//wait for reply
+	theirControl.WaitForType(1, 0, myControl)
+	theirCachedPacket := myControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
+
+	t.Log("Do a bidirectional tunnel test")
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+	myControl.Stop()
+	theirControl.Stop()
+}

+ 52 - 10
e2e/helpers_test.go

@@ -22,6 +22,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"go.yaml.in/yaml/v3"
 )
 
@@ -29,8 +30,6 @@ type m = map[string]any
 
 // newSimpleServer creates a nebula instance with many assumptions
 func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
-	l := NewTestLogger()
-
 	var vpnNetworks []netip.Prefix
 	for _, sn := range strings.Split(sVpnNetworks, ",") {
 		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -56,7 +55,54 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
-	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
+	return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
+}
+
+func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
+}
+
+func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	l := NewTestLogger()
+
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
+	}
+
+	firewallInbound := []m{{
+		"proto": "any",
+		"port":  "any",
+		"host":  "any",
+	}}
+
+	var unsafeNetworks []netip.Prefix
+	if sUnsafeNetworks != "" {
+		firewallInbound = []m{{
+			"proto":      "any",
+			"port":       "any",
+			"host":       "any",
+			"local_cidr": "0.0.0.0/0",
+		}}
+
+		for _, sn := range strings.Split(sUnsafeNetworks, ",") {
+			x, err := netip.ParsePrefix(strings.TrimSpace(sn))
+			if err != nil {
+				panic(err)
+			}
+			unsafeNetworks = append(unsafeNetworks, x)
+		}
+	}
+
+	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
 
 	caB, err := caCrt.MarshalPEM()
 	if err != nil {
@@ -76,11 +122,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 				"port":  "any",
 				"host":  "any",
 			}},
-			"inbound": []m{{
-				"proto": "any",
-				"port":  "any",
-				"host":  "any",
-			}},
+			"inbound": firewallInbound,
 		},
 		//"handshakes": m{
 		//	"try_interval": "1s",
@@ -266,10 +308,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
 	// Get both host infos
 	//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
 	hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
+	require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
 
 	hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
+	require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
 
 	// Check that both vpn and real addr are correct
 	assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")

+ 47 - 0
e2e/tunnels_test.go

@@ -318,3 +318,50 @@ func TestCertMismatchCorrection(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestCrossStackRelaysWork(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me     ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay  ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
+	theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them   ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
+
+	//myVpnV4 := myVpnIpNet[0]
+	myVpnV6 := myVpnIpNet[1]
+	relayVpnV4 := relayVpnIpNet[0]
+	relayVpnV6 := relayVpnIpNet[1]
+	theirVpnV6 := theirVpnIpNet[0]
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
+
+	t.Log("reply?")
+	theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
+	p = r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
+
+	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
+	//t.Log("finish up")
+	//myControl.Stop()
+	//theirControl.Stop()
+	//relayControl.Stop()
+}

+ 22 - 7
firewall.go

@@ -417,8 +417,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 	return nil
 }
 
-var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
-var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
+var ErrUnknownNetworkType = errors.New("unknown network type")
+var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
+var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
+var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
 var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
@@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 		return nil
 	}
 
-	// Make sure remote address matches nebula certificate
-	if h.networks != nil {
-		if !h.networks.Contains(fp.RemoteAddr) {
+	// Make sure remote address matches nebula certificate, and determine how to treat it
+	if h.networks == nil {
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	} else {
-		// Simple case: Certificate has one address and no unsafe networks
-		if h.vpnAddrs[0] != fp.RemoteAddr {
+		nwType, ok := h.networks.Lookup(fp.RemoteAddr)
+		if !ok {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
+		switch nwType {
+		case NetworkTypeVPN:
+			break // nothing special
+		case NetworkTypeVPNPeer:
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
+			return ErrPeerRejected // reject for now, one day this may have different FW rules
+		case NetworkTypeUnsafe:
+			break // nothing special, one day this may have different FW rules
+		default:
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
+			return ErrUnknownNetworkType //should never happen
+		}
 	}
 
 	// Make sure we are supposed to be handling this local ip address

+ 192 - 11
firewall_test.go

@@ -8,6 +8,8 @@ import (
 	"testing"
 	"time"
 
+	"github.com/gaissmai/bart"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
@@ -149,7 +151,8 @@ func TestFirewall_Drop(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
-
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
 		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -174,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
-	h.buildNetworks(c.networks, c.unsafeNetworks)
+	h.buildNetworks(myVpnNetworksTable, &c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -226,6 +229,9 @@ func TestFirewall_DropV6(t *testing.T) {
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
+
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("fd12::34"),
 		RemoteAddr: netip.MustParseAddr("fd12::34"),
@@ -250,7 +256,7 @@ func TestFirewall_DropV6(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
 	}
-	h.buildNetworks(c.networks, c.unsafeNetworks)
+	h.buildNetworks(myVpnNetworksTable, &c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -453,6 +459,8 @@ func TestFirewall_Drop2(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -478,7 +486,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate)
 
 	c1 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -493,7 +501,7 @@ func TestFirewall_Drop2(t *testing.T) {
 			peerCert: &c1,
 		},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -510,6 +518,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -541,7 +551,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
 
 	c2 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -556,7 +566,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
+	h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
 
 	c3 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -571,7 +581,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
+	h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -597,6 +607,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("fd12::34"),
@@ -620,7 +632,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate)
 
 	// Test a remote address match
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@@ -633,6 +645,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -659,7 +673,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -696,6 +710,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
 
 	c := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -717,7 +733,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 
@@ -1047,6 +1063,171 @@ func TestFirewall_convertRule(t *testing.T) {
 	assert.Equal(t, "group1", r.Group)
 }
 
+type testcase struct {
+	h   *HostInfo
+	p   firewall.Packet
+	c   cert.Certificate
+	err error
+}
+
+func (c *testcase) Test(t *testing.T, fw *Firewall) {
+	t.Helper()
+	cp := cert.NewCAPool()
+	resetConntrack(fw)
+	err := fw.Drop(c.p, true, c.h, cp, nil)
+	if c.err == nil {
+		require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
+	} else {
+		require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
+	}
+}
+
+func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
+	c1 := dummyCert{
+		name:     "host1",
+		networks: theirPrefixes,
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &cert.CachedCertificate{
+				Certificate:    &c1,
+				InvertedGroups: map[string]struct{}{"default-group": {}},
+			},
+		},
+		vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
+	}
+	for i := range theirPrefixes {
+		h.vpnAddrs[i] = theirPrefixes[i].Addr()
+	}
+	h.buildNetworks(setup.myVpnNetworksTable, &c1)
+	p := firewall.Packet{
+		LocalAddr:  setup.c.Networks()[0].Addr(), //todo?
+		RemoteAddr: theirPrefixes[0].Addr(),
+		LocalPort:  10,
+		RemotePort: 90,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+	return testcase{
+		h:   &h,
+		p:   p,
+		c:   &c1,
+		err: err,
+	}
+}
+
+type testsetup struct {
+	c                  dummyCert
+	myVpnNetworksTable *bart.Lite
+	fw                 *Firewall
+}
+
+func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
+	c := dummyCert{
+		name:     "me",
+		networks: myPrefixes,
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
+	}
+
+	return newSetupFromCert(t, l, c)
+}
+
+func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
+	myVpnNetworksTable := new(bart.Lite)
+	for _, prefix := range c.Networks() {
+		myVpnNetworksTable.Insert(prefix)
+	}
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+
+	return testsetup{
+		c:                  c,
+		fw:                 fw,
+		myVpnNetworksTable: myVpnNetworksTable,
+	}
+}
+
+func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
+	t.Parallel()
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	myPrefix := netip.MustParsePrefix("1.1.1.1/8")
+	// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
+	t.Run("allow inbound all matching", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
+		tc.Test(t, setup.fw)
+	})
+	t.Run("allow inbound local matching", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
+		tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
+		tc.Test(t, setup.fw)
+	})
+	t.Run("block inbound remote mismatched", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
+		tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
+		tc.Test(t, setup.fw)
+	})
+	t.Run("Block a vpn peer packet", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
+		tc.Test(t, setup.fw)
+	})
+	twoPrefixes := []netip.Prefix{
+		netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
+	}
+	t.Run("allow inbound one matching", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, nil, twoPrefixes...)
+		tc.Test(t, setup.fw)
+	})
+	t.Run("block inbound multimismatch", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
+		tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
+		tc.Test(t, setup.fw)
+	})
+	t.Run("allow inbound 2nd one matching", func(t *testing.T) {
+		t.Parallel()
+		setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
+		tc := buildTestCase(setup2, nil, twoPrefixes...)
+		tc.p.RemoteAddr = twoPrefixes[1].Addr()
+		tc.Test(t, setup2.fw)
+	})
+	t.Run("allow inbound unsafe route", func(t *testing.T) {
+		t.Parallel()
+		unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
+		c := dummyCert{
+			name:           "me",
+			networks:       []netip.Prefix{myPrefix},
+			unsafeNetworks: []netip.Prefix{unsafePrefix},
+			groups:         []string{"default-group"},
+			issuer:         "signer-shasum",
+		}
+		unsafeSetup := newSetupFromCert(t, l, c)
+		tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
+		tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
+		tc.err = ErrNoMatchingRule
+		tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
+		require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
+		tc.err = nil
+		tc.Test(t, unsafeSetup.fw) //should pass
+	})
+}
+
 type addRuleCall struct {
 	incoming  bool
 	proto     uint8

+ 52 - 66
handshake_ix.go

@@ -2,7 +2,6 @@ package nebula
 
 import (
 	"net/netip"
-	"slices"
 	"time"
 
 	"github.com/flynn/noise"
@@ -192,17 +191,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 	}
 
-	var vpnAddrs []netip.Addr
-	var filteredNetworks []netip.Prefix
 	certName := remoteCert.Certificate.Name()
 	certVersion := remoteCert.Certificate.Version()
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
+	vpnNetworks := remoteCert.Certificate.Networks()
 
-	for _, network := range remoteCert.Certificate.Networks() {
-		vpnAddr := network.Addr()
-		if f.myVpnAddrsTable.Contains(vpnAddr) {
-			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
+	anyVpnAddrsInCommon := false
+	vpnAddrs := make([]netip.Addr, len(vpnNetworks))
+	for i, network := range vpnNetworks {
+		if f.myVpnAddrsTable.Contains(network.Addr()) {
+			f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
@@ -210,24 +209,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
 			return
 		}
-
-		// vpnAddrs outside our vpn networks are of no use to us, filter them out
-		if !f.myVpnNetworksTable.Contains(vpnAddr) {
-			continue
+		vpnAddrs[i] = network.Addr()
+		if f.myVpnNetworksTable.Contains(network.Addr()) {
+			anyVpnAddrsInCommon = true
 		}
-
-		filteredNetworks = append(filteredNetworks, network)
-		vpnAddrs = append(vpnAddrs, vpnAddr)
-	}
-
-	if len(vpnAddrs) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("certName", certName).
-			WithField("certVersion", certVersion).
-			WithField("fingerprint", fingerprint).
-			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
-		return
 	}
 
 	if addr.IsValid() {
@@ -264,26 +249,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		},
 	}
 
-	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-		WithField("certName", certName).
-		WithField("certVersion", certVersion).
-		WithField("fingerprint", fingerprint).
-		WithField("issuer", issuer).
-		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-		WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-		Info("Handshake message received")
+	msgRxL := f.l.WithFields(m{
+		"vpnAddrs":       vpnAddrs,
+		"udpAddr":        addr,
+		"certName":       certName,
+		"certVersion":    certVersion,
+		"fingerprint":    fingerprint,
+		"issuer":         issuer,
+		"initiatorIndex": hs.Details.InitiatorIndex,
+		"responderIndex": hs.Details.ResponderIndex,
+		"remoteIndex":    h.RemoteIndex,
+		"handshake":      m{"stage": 1, "style": "ix_psk0"},
+	})
+
+	if anyVpnAddrsInCommon {
+		msgRxL.Info("Handshake message received")
+	} else {
+		//todo warn if not lighthouse or relay?
+		msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
+	}
 
 	hs.Details.ResponderIndex = myIndex
 	hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
 	if hs.Details.Cert == nil {
-		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-			WithField("certName", certName).
-			WithField("certVersion", certVersion).
-			WithField("fingerprint", fingerprint).
-			WithField("issuer", issuer).
-			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			WithField("certVersion", ci.myCert.Version()).
+		msgRxL.WithField("myCertVersion", ci.myCert.Version()).
 			Error("Unable to handshake with host because no certificate handshake bytes is available")
 		return
 	}
@@ -341,7 +330,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
-	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
+	hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
@@ -582,31 +571,22 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 	}
 
-	var vpnAddrs []netip.Addr
-	var filteredNetworks []netip.Prefix
-	for _, network := range vpnNetworks {
-		// vpnAddrs outside our vpn networks are of no use to us, filter them out
-		vpnAddr := network.Addr()
-		if !f.myVpnNetworksTable.Contains(vpnAddr) {
-			continue
+	correctHostResponded := false
+	anyVpnAddrsInCommon := false
+	vpnAddrs := make([]netip.Addr, len(vpnNetworks))
+	for i, network := range vpnNetworks {
+		vpnAddrs[i] = network.Addr()
+		if f.myVpnNetworksTable.Contains(network.Addr()) {
+			anyVpnAddrsInCommon = true
+		}
+		if hostinfo.vpnAddrs[0] == network.Addr() {
+			// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
+			correctHostResponded = true
 		}
-
-		filteredNetworks = append(filteredNetworks, network)
-		vpnAddrs = append(vpnAddrs, vpnAddr)
-	}
-
-	if len(vpnAddrs) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("certName", certName).
-			WithField("certVersion", certVersion).
-			WithField("fingerprint", fingerprint).
-			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
-		return true
 	}
 
 	// Ensure the right host responded
-	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
+	if !correctHostResponded {
 		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
 			WithField("udpAddr", addr).
 			WithField("certName", certName).
@@ -618,6 +598,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
+		//TODO is hostinfo.vpnAddrs[0] always the address to use?
 		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
@@ -644,7 +625,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+	msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certVersion", certVersion).
 		WithField("fingerprint", fingerprint).
@@ -652,12 +633,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 		WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 		WithField("durationNs", duration).
-		WithField("sentCachedPackets", len(hh.packetStore)).
-		Info("Handshake message received")
+		WithField("sentCachedPackets", len(hh.packetStore))
+	if anyVpnAddrsInCommon {
+		msgRxL.Info("Handshake message received")
+	} else {
+		//todo warn if not lighthouse or relay?
+		msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
+	}
 
 	// Build up the radix for the firewall if we have subnets in the cert
 	hostinfo.vpnAddrs = vpnAddrs
-	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
+	hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)

+ 2 - 2
handshake_manager.go

@@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
-			// Don't relay to myself
+			// Don't relay through the host I'm trying to connect to
 			if relay == vpnIp {
 				continue
 			}
 
-			// Don't relay through the host I'm trying to connect to
+			// Don't relay to myself
 			if hm.f.myVpnAddrsTable.Contains(relay) {
 				continue
 			}

+ 29 - 11
hostmap.go

@@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.relayForByIdx[idx] = r
 }
 
+type NetworkType uint8
+
+const (
+	NetworkTypeUnknown NetworkType = iota
+	// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
+	NetworkTypeVPN
+	// NetworkTypeVPNPeer is a network that does not overlap one of our networks
+	NetworkTypeVPNPeer
+	// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
+	NetworkTypeUnsafe
+)
+
 type HostInfo struct {
 	remote          netip.AddrPort
 	remotes         *RemoteList
@@ -225,8 +237,8 @@ type HostInfo struct {
 	// vpn networks but were removed because they are not usable
 	vpnAddrs []netip.Addr
 
-	// networks are both all vpn and unsafe networks assigned to this host
-	networks   *bart.Lite
+	// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
+	networks   *bart.Table[NetworkType]
 	relayState RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
@@ -730,20 +742,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 	return false
 }
 
-func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
-	if len(networks) == 1 && len(unsafeNetworks) == 0 {
-		// Simple case, no CIDRTree needed
-		return
+// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
+func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
+	if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
+		if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
+			return // Simple case, no BART needed
+		}
 	}
 
-	i.networks = new(bart.Lite)
-	for _, network := range networks {
+	i.networks = new(bart.Table[NetworkType])
+	for _, network := range c.Networks() {
 		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
-		i.networks.Insert(nprefix)
+		if myVpnNetworksTable.Contains(network.Addr()) {
+			i.networks.Insert(nprefix, NetworkTypeVPN)
+		} else {
+			i.networks.Insert(nprefix, NetworkTypeVPNPeer)
+		}
 	}
 
-	for _, network := range unsafeNetworks {
-		i.networks.Insert(network)
+	for _, network := range c.UnsafeNetworks() {
+		i.networks.Insert(network, NetworkTypeUnsafe)
 	}
 }
 

+ 6 - 5
inside.go

@@ -120,9 +120,10 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 
-// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
+// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
+// it does not check if it is within our vpn networks!
 func (f *Interface) Handshake(vpnAddr netip.Addr) {
-	f.getOrHandshakeNoRouting(vpnAddr, nil)
+	f.handshakeManager.GetOrHandshake(vpnAddr, nil)
 }
 
 // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -138,7 +139,6 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
 // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
 func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-
 	destinationAddr := fwPacket.RemoteAddr
 
 	hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -231,9 +231,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 
-// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
+// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
+// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
 func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
+	hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 

+ 4 - 2
lighthouse.go

@@ -360,7 +360,8 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
 		}
 
 		if !lh.myVpnNetworksTable.Contains(addr) {
-			return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
+			lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
+				Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
 		}
 		out[i] = addr
 	}
@@ -431,7 +432,8 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
 		}
 
 		if !lh.myVpnNetworksTable.Contains(vpnAddr) {
-			return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
+			lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
+				Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
 		}
 
 		vals, ok := v.([]any)