Sfoglia il codice sorgente

tsnet: add test for packet filter generation from netmap

This is an integration test that covers all the code in Direct, Auto, and
LocalBackend that processes NetMaps and creates a Filter. The test uses
tsnet as a convenient proxy for setting up all the client pieces correctly,
but is not actually a test specific to tsnet.

Updates tailscale/corp#20514

Signed-off-by: James Sanderson <[email protected]>
James Sanderson 11 mesi fa
parent
commit
85a7abef0c
3 ha cambiato i file con 254 aggiunte e 2 eliminazioni
  1. 5 1
      ipn/ipnlocal/local.go
  2. 1 1
      ipn/ipnlocal/local_test.go
  3. 248 0
      tsnet/packet_filter_test.go

+ 5 - 1
ipn/ipnlocal/local.go

@@ -1510,7 +1510,11 @@ func (nb *nodeBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap {
 	return nil
 }
 
-func (nb *nodeBackend) GetFilterForTest() *filter.Filter {
+func (b *LocalBackend) GetFilterForTest() *filter.Filter {
+	if !testenv.InTest() {
+		panic("GetFilterForTest called outside of test")
+	}
+	nb := b.currentNode()
 	return nb.filterAtomic.Load()
 }
 

+ 1 - 1
ipn/ipnlocal/local_test.go

@@ -5328,7 +5328,7 @@ func TestSrcCapPacketFilter(t *testing.T) {
 		}},
 	})
 
-	f := lb.currentNode().GetFilterForTest()
+	f := lb.GetFilterForTest()
 	res := f.Check(netip.MustParseAddr("2.2.2.2"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP)
 	if res != filter.Accept {
 		t.Errorf("Check(2.2.2.2, ...) = %s, want %s", res, filter.Accept)

+ 248 - 0
tsnet/packet_filter_test.go

@@ -0,0 +1,248 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tsnet
+
+import (
+	"context"
+	"fmt"
+	"net/netip"
+	"testing"
+	"time"
+
+	"tailscale.com/ipn"
+	"tailscale.com/tailcfg"
+	"tailscale.com/types/ipproto"
+	"tailscale.com/types/key"
+	"tailscale.com/types/netmap"
+	"tailscale.com/util/must"
+	"tailscale.com/wgengine/filter"
+)
+
+// waitFor blocks until a NetMap is seen on the IPN bus that satisfies the given
+// function f. Note: has no timeout, should be called with a ctx that has an
+// appropriate timeout set.
+func waitFor(t testing.TB, ctx context.Context, s *Server, f func(*netmap.NetworkMap) bool) error {
+	t.Helper()
+	watcher, err := s.localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap)
+	if err != nil {
+		t.Fatalf("error watching IPN bus: %s", err)
+	}
+	defer watcher.Close()
+
+	for {
+		n, err := watcher.Next()
+		if err != nil {
+			return fmt.Errorf("getting next ipn.Notify from IPN bus: %w", err)
+		}
+		if n.NetMap != nil {
+			if f(n.NetMap) {
+				return nil
+			}
+		}
+	}
+}
+
+// TestPacketFilterFromNetmap tests all of the client code for processing
+// netmaps and turning them into packet filters together. Only the control-plane
+// side is mocked out.
+func TestPacketFilterFromNetmap(t *testing.T) {
+	t.Parallel()
+
+	var key key.NodePublic
+	must.Do(key.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261")))
+
+	type check struct {
+		src  string
+		dst  string
+		port uint16
+		want filter.Response
+	}
+
+	tests := []struct {
+		name        string
+		mapResponse *tailcfg.MapResponse
+		waitTest    func(*netmap.NetworkMap) bool
+
+		incrementalMapResponse *tailcfg.MapResponse          // optional
+		incrementalWaitTest    func(*netmap.NetworkMap) bool // optional
+
+		checks []check
+	}{
+		{
+			name: "IP_based_peers",
+			mapResponse: &tailcfg.MapResponse{
+				Node: &tailcfg.Node{
+					Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
+				},
+				Peers: []*tailcfg.Node{{
+					ID:        2,
+					Name:      "foo",
+					Key:       key,
+					Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
+					CapMap:    nil,
+				}},
+				PacketFilter: []tailcfg.FilterRule{{
+					SrcIPs: []string{"2.2.2.2/32"},
+					DstPorts: []tailcfg.NetPortRange{{
+						IP: "1.1.1.1/32",
+						Ports: tailcfg.PortRange{
+							First: 22,
+							Last:  22,
+						},
+					}},
+					IPProto: []int{int(ipproto.TCP)},
+				}},
+			},
+			waitTest: func(nm *netmap.NetworkMap) bool {
+				return len(nm.Peers) > 0
+			},
+			checks: []check{
+				{src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept},
+				{src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
+				{src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
+				{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
+			},
+		},
+		{
+			name: "capmap_based_peers",
+			mapResponse: &tailcfg.MapResponse{
+				Node: &tailcfg.Node{
+					Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
+				},
+				Peers: []*tailcfg.Node{{
+					ID:        2,
+					Name:      "foo",
+					Key:       key,
+					Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
+					CapMap:    tailcfg.NodeCapMap{"X": nil},
+				}},
+				PacketFilter: []tailcfg.FilterRule{{
+					SrcIPs: []string{"cap:X"},
+					DstPorts: []tailcfg.NetPortRange{{
+						IP: "1.1.1.1/32",
+						Ports: tailcfg.PortRange{
+							First: 22,
+							Last:  22,
+						},
+					}},
+					IPProto: []int{int(ipproto.TCP)},
+				}},
+			},
+			waitTest: func(nm *netmap.NetworkMap) bool {
+				return len(nm.Peers) > 0
+			},
+			checks: []check{
+				{src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept},
+				{src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
+				{src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
+				{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
+			},
+		},
+		{
+			name: "capmap_based_peers_changed",
+			mapResponse: &tailcfg.MapResponse{
+				Node: &tailcfg.Node{
+					Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
+					CapMap:    tailcfg.NodeCapMap{"X-sigil": nil},
+				},
+				PacketFilter: []tailcfg.FilterRule{{
+					SrcIPs: []string{"cap:label-1"},
+					DstPorts: []tailcfg.NetPortRange{{
+						IP: "1.1.1.1/32",
+						Ports: tailcfg.PortRange{
+							First: 22,
+							Last:  22,
+						},
+					}},
+					IPProto: []int{int(ipproto.TCP)},
+				}},
+			},
+			waitTest: func(nm *netmap.NetworkMap) bool {
+				return nm.SelfNode.HasCap("X-sigil")
+			},
+			incrementalMapResponse: &tailcfg.MapResponse{
+				PeersChanged: []*tailcfg.Node{{
+					ID:        2,
+					Name:      "foo",
+					Key:       key,
+					Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
+					CapMap:    tailcfg.NodeCapMap{"label-1": nil},
+				}},
+			},
+			incrementalWaitTest: func(nm *netmap.NetworkMap) bool {
+				return len(nm.Peers) > 0
+			},
+			checks: []check{
+				{src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept},
+				{src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
+				{src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
+				{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
+			},
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+			ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
+			defer cancel()
+
+			controlURL, c := startControl(t)
+			s, _, pubKey := startServer(t, ctx, controlURL, "node")
+
+			if test.waitTest(s.lb.NetMap()) {
+				t.Fatal("waitTest already passes before sending initial netmap: this will be flaky")
+			}
+
+			if !c.AddRawMapResponse(pubKey, test.mapResponse) {
+				t.Fatalf("could not send map response to %s", pubKey)
+			}
+
+			if err := waitFor(t, ctx, s, test.waitTest); err != nil {
+				t.Fatalf("waitFor: %s", err)
+			}
+
+			pf := s.lb.GetFilterForTest()
+
+			for _, check := range test.checks {
+				got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP)
+
+				want := check.want
+				if test.incrementalMapResponse != nil {
+					want = filter.Drop
+				}
+				if got != want {
+					t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, want)
+				}
+			}
+
+			if test.incrementalMapResponse != nil {
+				if test.incrementalWaitTest == nil {
+					t.Fatal("incrementalWaitTest must be set if incrementalMapResponse is set")
+				}
+
+				if test.incrementalWaitTest(s.lb.NetMap()) {
+					t.Fatal("incrementalWaitTest already passes before sending incremental netmap: this will be flaky")
+				}
+
+				if !c.AddRawMapResponse(pubKey, test.incrementalMapResponse) {
+					t.Fatalf("could not send map response to %s", pubKey)
+				}
+
+				if err := waitFor(t, ctx, s, test.incrementalWaitTest); err != nil {
+					t.Fatalf("waitFor: %s", err)
+				}
+
+				pf := s.lb.GetFilterForTest()
+
+				for _, check := range test.checks {
+					got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP)
+					if got != check.want {
+						t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, check.want)
+					}
+				}
+			}
+
+		})
+	}
+}