Browse Source

tsnet: add tests to TestListenService for user-supplied TUN devices

This resolves a gap in test coverage, ensuring Server.ListenService
functions as expected in combination with user-supplied TUN devices

Fixes tailscale/corp#36603

Co-authored-by: Harry Harpham <[email protected]>
Signed-off-by: Harry Harpham <[email protected]>
James Tucker 1 month ago
parent
commit
569caefeb5
1 changed files with 103 additions and 92 deletions
  1. 103 92
      tsnet/tsnet_test.go

+ 103 - 92
tsnet/tsnet_test.go

@@ -1141,83 +1141,91 @@ func TestListenService(t *testing.T) {
 		// This ends up also testing the Service forwarding logic in
 		// LocalBackend, but that's useful too.
 		t.Run(tt.name, func(t *testing.T) {
-			ctx := t.Context()
-
-			controlURL, control := startControl(t)
-			serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host")
-			serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client")
-
-			const serviceName = tailcfg.ServiceName("svc:foo")
-			const serviceVIP = "100.11.22.33"
-
-			// == Set up necessary state in our mock ==
-
-			// The Service host must have the 'service-host' capability, which
-			// is a mapping from the Service name to the Service VIP.
-			var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr]
-			mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)}))
-			j := must.Get(json.Marshal(serviceHostCaps))
-			cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap()
-			mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)})
-			control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm)
-
-			// The Service host must be allowed to advertise the Service VIP.
-			control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{
-				netip.MustParsePrefix(serviceVIP + `/32`),
-			})
-
-			// The Service host must be a tagged node (any tag will do).
-			serviceHostNode := control.Node(serviceHost.lb.NodeKey())
-			serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag")
-			control.UpdateNode(serviceHostNode)
-
-			// The service client must accept routes advertised by other nodes
-			// (RouteAll is equivalent to --accept-routes).
-			must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{
-				RouteAllSet: true,
-				Prefs: ipn.Prefs{
-					RouteAll: true,
-				},
-			}))
-
-			// Set up DNS for our Service.
-			control.AddDNSRecords(tailcfg.DNSRecord{
-				Name:  serviceName.WithoutPrefix() + "." + control.MagicDNSDomain,
-				Value: serviceVIP,
-			})
+			// We run each test with and without a TUN device ([Server.Tun]).
+			// Note that this TUN device is distinct from TUN mode for Services.
+			doTest := func(t *testing.T, withTUNDevice bool) {
+				ctx := t.Context()
+
+				lt := setupTwoClientTest(t, withTUNDevice)
+				serviceHost := lt.s2
+				serviceClient := lt.s1
+				control := lt.control
+
+				const serviceName = tailcfg.ServiceName("svc:foo")
+				const serviceVIP = "100.11.22.33"
+
+				// == Set up necessary state in our mock ==
+
+				// The Service host must have the 'service-host' capability, which
+				// is a mapping from the Service name to the Service VIP.
+				var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr]
+				mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)}))
+				j := must.Get(json.Marshal(serviceHostCaps))
+				cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap()
+				mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)})
+				control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm)
+
+				// The Service host must be allowed to advertise the Service VIP.
+				control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{
+					netip.MustParsePrefix(serviceVIP + `/32`),
+				})
 
-			if tt.extraSetup != nil {
-				tt.extraSetup(t, control)
-			}
+				// The Service host must be a tagged node (any tag will do).
+				serviceHostNode := control.Node(serviceHost.lb.NodeKey())
+				serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag")
+				control.UpdateNode(serviceHostNode)
+
+				// The service client must accept routes advertised by other nodes
+				// (RouteAll is equivalent to --accept-routes).
+				must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{
+					RouteAllSet: true,
+					Prefs: ipn.Prefs{
+						RouteAll: true,
+					},
+				}))
 
-			// Force netmap updates to avoid race conditions. The nodes need to
-			// see our control updates before we can start the test.
-			must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey()))
-			must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey()))
-			netmapUpToDate := func(s *Server) bool {
-				nm := s.lb.NetMap()
-				return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool {
-					return r.Value == serviceVIP
+				// Set up DNS for our Service.
+				control.AddDNSRecords(tailcfg.DNSRecord{
+					Name:  serviceName.WithoutPrefix() + "." + control.MagicDNSDomain,
+					Value: serviceVIP,
 				})
-			}
-			for !netmapUpToDate(serviceClient) {
-				time.Sleep(10 * time.Millisecond)
-			}
-			for !netmapUpToDate(serviceHost) {
-				time.Sleep(10 * time.Millisecond)
-			}
 
-			// == Done setting up mock state ==
+				if tt.extraSetup != nil {
+					tt.extraSetup(t, control)
+				}
+
+				// Force netmap updates to avoid race conditions. The nodes need to
+				// see our control updates before we can start the test.
+				must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey()))
+				must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey()))
+				netmapUpToDate := func(s *Server) bool {
+					nm := s.lb.NetMap()
+					return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool {
+						return r.Value == serviceVIP
+					})
+				}
+				for !netmapUpToDate(serviceClient) {
+					time.Sleep(10 * time.Millisecond)
+				}
+				for !netmapUpToDate(serviceHost) {
+					time.Sleep(10 * time.Millisecond)
+				}
 
-			// Start the Service listeners.
-			listeners := make([]*ServiceListener, 0, len(tt.modes))
-			for _, input := range tt.modes {
-				ln := must.Get(serviceHost.ListenService(serviceName.String(), input))
-				defer ln.Close()
-				listeners = append(listeners, ln)
+				// == Done setting up mock state ==
+
+				// Start the Service listeners.
+				listeners := make([]*ServiceListener, 0, len(tt.modes))
+				for _, input := range tt.modes {
+					ln := must.Get(serviceHost.ListenService(serviceName.String(), input))
+					defer ln.Close()
+					listeners = append(listeners, ln)
+				}
+
+				tt.run(t, listeners, serviceClient)
 			}
 
-			tt.run(t, listeners, serviceClient)
+			t.Run("TUN", func(t *testing.T) { doTest(t, true) })
+			t.Run("netstack", func(t *testing.T) { doTest(t, false) })
 		})
 	}
 }
@@ -1928,20 +1936,21 @@ func (t *chanTUN) BatchSize() int           { return 1 }
 
 // listenTest provides common setup for listener and TUN tests.
 type listenTest struct {
+	control      *testcontrol.Server
 	s1, s2       *Server
 	s1ip4, s1ip6 netip.Addr
 	s2ip4, s2ip6 netip.Addr
 	tun          *chanTUN // nil for netstack mode
 }
 
-// setupListenTest creates two tsnet servers for testing.
+// setupTwoClientTest creates two tsnet servers for testing.
 // If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only.
-func setupListenTest(t *testing.T, useTUN bool) *listenTest {
+func setupTwoClientTest(t *testing.T, useTUN bool) *listenTest {
 	t.Helper()
 	tstest.Shard(t)
 	tstest.ResourceCheck(t)
 	ctx := t.Context()
-	controlURL, _ := startControl(t)
+	controlURL, control := startControl(t)
 	s1, _, _ := startServer(t, ctx, controlURL, "s1")
 
 	tmp := filepath.Join(t.TempDir(), "s2")
@@ -1969,6 +1978,7 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest {
 	if err != nil {
 		t.Fatal(err)
 	}
+	s2.lb.ConfigureCertsForTest(testCertRoot.getCert)
 
 	s1ip4, s1ip6 := s1.TailscaleIPs()
 	s2ip4 := s2status.TailscaleIPs[0]
@@ -1981,13 +1991,14 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest {
 	must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP))
 
 	return &listenTest{
-		s1:    s1,
-		s2:    s2,
-		s1ip4: s1ip4,
-		s1ip6: s1ip6,
-		s2ip4: s2ip4,
-		s2ip6: s2ip6,
-		tun:   tun,
+		control: control,
+		s1:      s1,
+		s2:      s2,
+		s1ip4:   s1ip4,
+		s1ip6:   s1ip6,
+		s2ip4:   s2ip4,
+		s2ip6:   s2ip6,
+		tun:     tun,
 	}
 }
 
@@ -2016,7 +2027,7 @@ func echoUDP(pkt []byte) []byte {
 }
 
 func TestTUN(t *testing.T) {
-	tt := setupListenTest(t, true)
+	tt := setupTwoClientTest(t, true)
 
 	go func() {
 		for pkt := range tt.tun.Inbound {
@@ -2059,7 +2070,7 @@ func TestTUN(t *testing.T) {
 // responses. This verifies that handleLocalPackets intercepts outbound traffic
 // to the service IP.
 func TestTUNDNS(t *testing.T) {
-	tt := setupListenTest(t, true)
+	tt := setupTwoClientTest(t, true)
 
 	test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) {
 		tt.tun.Outbound <- buildDNSQuery("s2", srcIP)
@@ -2149,13 +2160,13 @@ func TestListenPacket(t *testing.T) {
 	}
 
 	t.Run("Netstack", func(t *testing.T) {
-		lt := setupListenTest(t, false)
+		lt := setupTwoClientTest(t, false)
 		t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
 		t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
 	})
 
 	t.Run("TUN", func(t *testing.T) {
-		lt := setupListenTest(t, true)
+		lt := setupTwoClientTest(t, true)
 		t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
 		t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
 	})
@@ -2221,13 +2232,13 @@ func TestListenTCP(t *testing.T) {
 	}
 
 	t.Run("Netstack", func(t *testing.T) {
-		lt := setupListenTest(t, false)
+		lt := setupTwoClientTest(t, false)
 		t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
 		t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
 	})
 
 	t.Run("TUN", func(t *testing.T) {
-		lt := setupListenTest(t, true)
+		lt := setupTwoClientTest(t, true)
 		t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
 		t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
 	})
@@ -2299,13 +2310,13 @@ func TestListenTCPDualStack(t *testing.T) {
 	}
 
 	t.Run("Netstack", func(t *testing.T) {
-		lt := setupListenTest(t, false)
+		lt := setupTwoClientTest(t, false)
 		t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
 		t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
 	})
 
 	t.Run("TUN", func(t *testing.T) {
-		lt := setupListenTest(t, true)
+		lt := setupTwoClientTest(t, true)
 		t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
 		t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
 	})
@@ -2372,13 +2383,13 @@ func TestDialTCP(t *testing.T) {
 	}
 
 	t.Run("Netstack", func(t *testing.T) {
-		lt := setupListenTest(t, false)
+		lt := setupTwoClientTest(t, false)
 		t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
 		t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
 	})
 
 	t.Run("TUN", func(t *testing.T) {
-		lt := setupListenTest(t, true)
+		lt := setupTwoClientTest(t, true)
 
 		var escapedTCPPackets atomic.Int32
 		var wg sync.WaitGroup
@@ -2460,13 +2471,13 @@ func TestDialUDP(t *testing.T) {
 	}
 
 	t.Run("Netstack", func(t *testing.T) {
-		lt := setupListenTest(t, false)
+		lt := setupTwoClientTest(t, false)
 		t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
 		t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
 	})
 
 	t.Run("TUN", func(t *testing.T) {
-		lt := setupListenTest(t, true)
+		lt := setupTwoClientTest(t, true)
 
 		var escapedUDPPackets atomic.Int32
 		var wg sync.WaitGroup