Selaa lähdekoodia

Fix incorrect CIDR construction in hostmap (#1493)

* Fix incorrect CIDR construction in hostmap

* Introduce a regression test for incorrect hostmap CIDR
Gary Guo 1 viikko sitten
vanhempi
sitoutus
634181ba66
2 muutettua tiedostoa jossa 46 lisäystä ja 1 poistoa
  1. 44 0
      firewall_test.go
  2. 2 1
      hostmap.go

+ 44 - 0
firewall_test.go

@@ -692,6 +692,50 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
 }
 
+func TestFirewall_DropIPSpoofing(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host-owner",
+			networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
+		},
+	}
+
+	c1 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:           "host",
+			networks:       []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
+			unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
+		},
+	}
+	h1 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c1,
+		},
+		vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
+	}
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
+
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	cp := cert.NewCAPool()
+
+	// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("192.0.2.1"),
+		RemoteAddr: netip.MustParseAddr("192.0.2.3"),
+		LocalPort:  1,
+		RemotePort: 1,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+	assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
+}
+
 func BenchmarkLookup(b *testing.B) {
 	ml := func(m map[string]struct{}, a [][]string) {
 		for n := 0; n < b.N; n++ {

+ 2 - 1
hostmap.go

@@ -738,7 +738,8 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
 
 	i.networks = new(bart.Lite)
 	for _, network := range networks {
-		i.networks.Insert(network)
+		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
+		i.networks.Insert(nprefix)
 	}
 
 	for _, network := range unsafeNetworks {