Browse Source

tsnet: add support for a user-supplied tun.Device

tsnet users can now provide a tun.Device, including any custom
implementation that conforms to the interface.

netstack has a new option CheckLocalTransportEndpoints that when used
alongside a TUN enables netstack listens and dials to correctly capture
traffic associated with those sockets. tsnet with a TUN sets this
option, while all other builds leave this at false to preserve existing
performance.

Updates #18423

Signed-off-by: James Tucker <[email protected]>
James Tucker 1 month ago
parent
commit
63d563e734
3 changed files with 842 additions and 5 deletions
  1. 84 4
      tsnet/tsnet.go
  2. 673 0
      tsnet/tsnet_test.go
  3. 85 1
      wgengine/netstack/netstack.go

+ 84 - 4
tsnet/tsnet.go

@@ -26,6 +26,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/tailscale/wireguard-go/tun"
 	"tailscale.com/client/local"
 	"tailscale.com/control/controlclient"
 	"tailscale.com/envknob"
@@ -167,6 +168,11 @@ type Server struct {
 	// that the control server will allow the node to adopt that tag.
 	AdvertiseTags []string
 
+	// Tun, if non-nil, specifies a custom tun.Device to use for packet I/O.
+	//
+	// This field must be set before calling Start.
+	Tun tun.Device
+
 	initOnce             sync.Once
 	initErr              error
 	lb                   *ipnlocal.LocalBackend
@@ -659,6 +665,7 @@ func (s *Server) start() (reterr error) {
 	s.dialer = &tsdial.Dialer{Logf: tsLogf} // mutated below (before used)
 	s.dialer.SetBus(sys.Bus.Get())
 	eng, err := wgengine.NewUserspaceEngine(tsLogf, wgengine.Config{
+		Tun:           s.Tun,
 		EventBus:      sys.Bus.Get(),
 		ListenPort:    s.Port,
 		NetMon:        s.netMon,
@@ -682,8 +689,16 @@ func (s *Server) start() (reterr error) {
 	}
 	sys.Tun.Get().Start()
 	sys.Set(ns)
-	ns.ProcessLocalIPs = true
-	ns.ProcessSubnets = true
+	if s.Tun == nil {
+		// Only process packets in netstack when using the default fake TUN.
+		// When a TUN is provided, let packets flow through it instead.
+		ns.ProcessLocalIPs = true
+		ns.ProcessSubnets = true
+	} else {
+		// When using a TUN, check gVisor for registered endpoints to handle
+		// packets for tsnet listeners and outbound connection replies.
+		ns.CheckLocalTransportEndpoints = true
+	}
 	ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow
 	ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow
 	s.netstack = ns
@@ -1072,10 +1087,34 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
 			network = "udp6"
 		}
 	}
-	if err := s.Start(); err != nil {
+
+	netLn, err := s.listen(network, addr, listenOnTailnet)
+	if err != nil {
 		return nil, err
 	}
-	return s.netstack.ListenPacket(network, ap.String())
+	ln := netLn.(*listener)
+
+	pc, err := s.netstack.ListenPacket(network, ap.String())
+	if err != nil {
+		ln.Close()
+		return nil, err
+	}
+
+	return &udpPacketConn{
+		PacketConn: pc,
+		ln:         ln,
+	}, nil
+}
+
+// udpPacketConn wraps a net.PacketConn to unregister from s.listeners on Close.
+type udpPacketConn struct {
+	net.PacketConn
+	ln *listener
+}
+
+func (c *udpPacketConn) Close() error {
+	c.ln.Close()
+	return c.PacketConn.Close()
 }
 
 // ListenTLS announces only on the Tailscale network.
@@ -1611,10 +1650,37 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro
 		closedc: make(chan struct{}),
 		conn:    make(chan net.Conn),
 	}
+
+	// When using a TUN with TCP, create a gVisor TCP listener.
+	if s.Tun != nil && (network == "" || network == "tcp" || network == "tcp4" || network == "tcp6") {
+		var nsNetwork string
+		nsAddr := host
+		switch {
+		case network == "tcp4" || network == "tcp6":
+			nsNetwork = network
+		case host.Addr().Is4():
+			nsNetwork = "tcp4"
+		case host.Addr().Is6():
+			nsNetwork = "tcp6"
+		default:
+			// Wildcard address: use tcp6 for dual-stack (accepts both v4 and v6).
+			nsNetwork = "tcp6"
+			nsAddr = netip.AddrPortFrom(netip.IPv6Unspecified(), host.Port())
+		}
+		gonetLn, err := s.netstack.ListenTCP(nsNetwork, nsAddr.String())
+		if err != nil {
+			return nil, fmt.Errorf("tsnet: %w", err)
+		}
+		ln.gonetLn = gonetLn
+	}
+
 	s.mu.Lock()
 	for _, key := range keys {
 		if _, ok := s.listeners[key]; ok {
 			s.mu.Unlock()
+			if ln.gonetLn != nil {
+				ln.gonetLn.Close()
+			}
 			return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr)
 		}
 	}
@@ -1684,9 +1750,17 @@ type listener struct {
 	conn    chan net.Conn // unbuffered, never closed
 	closedc chan struct{} // closed on [listener.Close]
 	closed  bool          // guarded by s.mu
+
+	// gonetLn, if set, is the gonet.Listener that handles new connections.
+	// gonetLn is set by [listen] when a TUN is in use and terminates the listener.
+	// gonetLn is nil when TUN is nil.
+	gonetLn net.Listener
 }
 
 func (ln *listener) Accept() (net.Conn, error) {
+	if ln.gonetLn != nil {
+		return ln.gonetLn.Accept()
+	}
 	select {
 	case c := <-ln.conn:
 		return c, nil
@@ -1696,6 +1770,9 @@ func (ln *listener) Accept() (net.Conn, error) {
 }
 
 func (ln *listener) Addr() net.Addr {
+	if ln.gonetLn != nil {
+		return ln.gonetLn.Addr()
+	}
 	return addr{
 		network: ln.keys[0].network,
 		addr:    ln.addr,
@@ -1721,6 +1798,9 @@ func (ln *listener) closeLocked() error {
 	}
 	close(ln.closedc)
 	ln.closed = true
+	if ln.gonetLn != nil {
+		ln.gonetLn.Close()
+	}
 	return nil
 }
 

+ 673 - 0
tsnet/tsnet_test.go

@@ -39,6 +39,7 @@ import (
 	"github.com/google/go-cmp/cmp"
 	dto "github.com/prometheus/client_model/go"
 	"github.com/prometheus/common/expfmt"
+	"github.com/tailscale/wireguard-go/tun"
 	"golang.org/x/net/proxy"
 
 	"tailscale.com/client/local"
@@ -48,11 +49,13 @@ import (
 	"tailscale.com/ipn/ipnlocal"
 	"tailscale.com/ipn/store/mem"
 	"tailscale.com/net/netns"
+	"tailscale.com/net/packet"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tstest"
 	"tailscale.com/tstest/deptest"
 	"tailscale.com/tstest/integration"
 	"tailscale.com/tstest/integration/testcontrol"
+	"tailscale.com/types/ipproto"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/views"
@@ -1860,6 +1863,676 @@ func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) {
 	t.Error("magicsock did not find a direct path from lc1 to lc2")
 }
 
+// chanTUN is a tun.Device for testing that uses channels for packet I/O.
+// Inbound receives packets written to the TUN (from the perspective of the network stack).
+// Outbound is for injecting packets to be read from the TUN.
+type chanTUN struct {
+	Inbound  chan []byte // packets written to TUN
+	Outbound chan []byte // packets to read from TUN
+	closed   chan struct{}
+	events   chan tun.Event
+}
+
+func newChanTUN() *chanTUN {
+	t := &chanTUN{
+		Inbound:  make(chan []byte, 10),
+		Outbound: make(chan []byte, 10),
+		closed:   make(chan struct{}),
+		events:   make(chan tun.Event, 1),
+	}
+	t.events <- tun.EventUp
+	return t
+}
+
+func (t *chanTUN) File() *os.File { panic("not implemented") }
+
+func (t *chanTUN) Close() error {
+	select {
+	case <-t.closed:
+	default:
+		close(t.closed)
+		close(t.Inbound)
+	}
+	return nil
+}
+
+func (t *chanTUN) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+	select {
+	case <-t.closed:
+		return 0, io.EOF
+	case pkt := <-t.Outbound:
+		sizes[0] = copy(bufs[0][offset:], pkt)
+		return 1, nil
+	}
+}
+
+func (t *chanTUN) Write(bufs [][]byte, offset int) (int, error) {
+	for _, buf := range bufs {
+		pkt := buf[offset:]
+		if len(pkt) == 0 {
+			continue
+		}
+		select {
+		case <-t.closed:
+			return 0, errors.New("closed")
+		case t.Inbound <- slices.Clone(pkt):
+		}
+	}
+	return len(bufs), nil
+}
+
+func (t *chanTUN) MTU() (int, error)        { return 1280, nil }
+func (t *chanTUN) Name() (string, error)    { return "chantun", nil }
+func (t *chanTUN) Events() <-chan tun.Event { return t.events }
+func (t *chanTUN) BatchSize() int           { return 1 }
+
+// listenTest provides common setup for listener and TUN tests.
+type listenTest struct {
+	s1, s2       *Server
+	s1ip4, s1ip6 netip.Addr
+	s2ip4, s2ip6 netip.Addr
+	tun          *chanTUN // nil for netstack mode
+}
+
+// setupListenTest creates two tsnet servers for testing.
+// If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only.
+func setupListenTest(t *testing.T, useTUN bool) *listenTest {
+	t.Helper()
+	tstest.Shard(t)
+	tstest.ResourceCheck(t)
+	ctx := t.Context()
+	controlURL, _ := startControl(t)
+	s1, _, _ := startServer(t, ctx, controlURL, "s1")
+
+	tmp := filepath.Join(t.TempDir(), "s2")
+	must.Do(os.MkdirAll(tmp, 0755))
+	s2 := &Server{
+		Dir:        tmp,
+		ControlURL: controlURL,
+		Hostname:   "s2",
+		Store:      new(mem.Store),
+		Ephemeral:  true,
+	}
+
+	var tun *chanTUN
+	if useTUN {
+		tun = newChanTUN()
+		s2.Tun = tun
+	}
+
+	if *verboseNodes {
+		s2.Logf = t.Logf
+	}
+	t.Cleanup(func() { s2.Close() })
+
+	s2status, err := s2.Up(ctx)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	s1ip4, s1ip6 := s1.TailscaleIPs()
+	s2ip4 := s2status.TailscaleIPs[0]
+	var s2ip6 netip.Addr
+	if len(s2status.TailscaleIPs) > 1 {
+		s2ip6 = s2status.TailscaleIPs[1]
+	}
+
+	lc1 := must.Get(s1.LocalClient())
+	must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP))
+
+	return &listenTest{
+		s1:    s1,
+		s2:    s2,
+		s1ip4: s1ip4,
+		s1ip6: s1ip6,
+		s2ip4: s2ip4,
+		s2ip6: s2ip6,
+		tun:   tun,
+	}
+}
+
+// echoUDP returns an IP packet with src/dst and ports swapped, with checksums recomputed.
+func echoUDP(pkt []byte) []byte {
+	var p packet.Parsed
+	p.Decode(pkt)
+	if p.IPProto != ipproto.UDP {
+		return nil
+	}
+	switch p.IPVersion {
+	case 4:
+		h := p.UDP4Header()
+		h.ToResponse()
+		return packet.Generate(h, p.Payload())
+	case 6:
+		h := packet.UDP6Header{
+			IP6Header: p.IP6Header(),
+			SrcPort:   p.Src.Port(),
+			DstPort:   p.Dst.Port(),
+		}
+		h.ToResponse()
+		return packet.Generate(h, p.Payload())
+	}
+	return nil
+}
+
+func TestTUN(t *testing.T) {
+	tt := setupListenTest(t, true)
+
+	go func() {
+		for pkt := range tt.tun.Inbound {
+			var p packet.Parsed
+			p.Decode(pkt)
+			if p.Dst.Port() == 9999 {
+				tt.tun.Outbound <- echoUDP(pkt)
+			}
+		}
+	}()
+
+	test := func(t *testing.T, s2ip netip.Addr) {
+		conn, err := tt.s1.Dial(t.Context(), "udp", netip.AddrPortFrom(s2ip, 9999).String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer conn.Close()
+
+		want := "hello from s1"
+		if _, err := conn.Write([]byte(want)); err != nil {
+			t.Fatal(err)
+		}
+
+		conn.SetReadDeadline(time.Now().Add(5 * time.Second))
+		got := make([]byte, 1024)
+		n, err := conn.Read(got)
+		if err != nil {
+			t.Fatalf("reading echo response: %v", err)
+		}
+		if string(got[:n]) != want {
+			t.Errorf("got %q, want %q", got[:n], want)
+		}
+	}
+
+	t.Run("IPv4", func(t *testing.T) { test(t, tt.s2ip4) })
+	t.Run("IPv6", func(t *testing.T) { test(t, tt.s2ip6) })
+}
+
+// TestTUNDNS tests that a TUN can send DNS queries to quad-100 and receive
+// responses. This verifies that handleLocalPackets intercepts outbound traffic
+// to the service IP.
+func TestTUNDNS(t *testing.T) {
+	tt := setupListenTest(t, true)
+
+	test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) {
+		tt.tun.Outbound <- buildDNSQuery("s2", srcIP)
+
+		ipVersion := uint8(4)
+		if srcIP.Is6() {
+			ipVersion = 6
+		}
+		for {
+			select {
+			case pkt := <-tt.tun.Inbound:
+				var p packet.Parsed
+				p.Decode(pkt)
+				if p.IPVersion != ipVersion || p.IPProto != ipproto.UDP {
+					continue
+				}
+				if p.Src.Addr() == serviceIP && p.Src.Port() == 53 {
+					if len(p.Payload()) < 12 {
+						t.Fatalf("DNS response too short: %d bytes", len(p.Payload()))
+					}
+					return // success
+				}
+			case <-time.After(5 * time.Second):
+				t.Fatal("timeout waiting for DNS response")
+			}
+		}
+	}
+
+	t.Run("IPv4", func(t *testing.T) {
+		test(t, tt.s2ip4, netip.MustParseAddr("100.100.100.100"))
+	})
+	t.Run("IPv6", func(t *testing.T) {
+		test(t, tt.s2ip6, netip.MustParseAddr("fd7a:115c:a1e0::53"))
+	})
+}
+
+// TestListenPacket tests UDP listeners (ListenPacket) in both netstack and TUN modes.
+func TestListenPacket(t *testing.T) {
+	testListenPacket := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
+		pc, err := lt.s2.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer pc.Close()
+
+		echoErr := make(chan error, 1)
+		go func() {
+			buf := make([]byte, 1500)
+			n, addr, err := pc.ReadFrom(buf)
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			_, err = pc.WriteTo(buf[:n], addr)
+			if err != nil {
+				echoErr <- err
+				return
+			}
+		}()
+
+		conn, err := lt.s1.Dial(t.Context(), "udp", pc.LocalAddr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer conn.Close()
+
+		want := "hello udp"
+		if _, err := conn.Write([]byte(want)); err != nil {
+			t.Fatal(err)
+		}
+
+		conn.SetReadDeadline(time.Now().Add(5 * time.Second))
+		got := make([]byte, 1024)
+		n, err := conn.Read(got)
+		if err != nil {
+			select {
+			case e := <-echoErr:
+				t.Fatalf("echo error: %v; read error: %v", e, err)
+			default:
+				t.Fatalf("Read failed: %v", err)
+			}
+		}
+
+		if string(got[:n]) != want {
+			t.Errorf("got %q, want %q", got[:n], want)
+		}
+	}
+
+	t.Run("Netstack", func(t *testing.T) {
+		lt := setupListenTest(t, false)
+		t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
+		t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
+	})
+
+	t.Run("TUN", func(t *testing.T) {
+		lt := setupListenTest(t, true)
+		t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
+		t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
+	})
+}
+
+// TestListenTCP tests TCP listeners with concrete addresses in both netstack
+// and TUN modes.
+func TestListenTCP(t *testing.T) {
+	testListenTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
+		ln, err := lt.s2.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer ln.Close()
+
+		echoErr := make(chan error, 1)
+		go func() {
+			conn, err := ln.Accept()
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			defer conn.Close()
+			buf := make([]byte, 1024)
+			n, err := conn.Read(buf)
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			_, err = conn.Write(buf[:n])
+			if err != nil {
+				echoErr <- err
+				return
+			}
+		}()
+
+		conn, err := lt.s1.Dial(t.Context(), "tcp", ln.Addr().String())
+		if err != nil {
+			t.Fatalf("Dial failed: %v", err)
+		}
+		defer conn.Close()
+
+		want := "hello tcp"
+		if _, err := conn.Write([]byte(want)); err != nil {
+			t.Fatalf("Write failed: %v", err)
+		}
+
+		conn.SetReadDeadline(time.Now().Add(5 * time.Second))
+		got := make([]byte, 1024)
+		n, err := conn.Read(got)
+		if err != nil {
+			select {
+			case e := <-echoErr:
+				t.Fatalf("echo error: %v; read error: %v", e, err)
+			default:
+				t.Fatalf("Read failed: %v", err)
+			}
+		}
+
+		if string(got[:n]) != want {
+			t.Errorf("got %q, want %q", got[:n], want)
+		}
+	}
+
+	t.Run("Netstack", func(t *testing.T) {
+		lt := setupListenTest(t, false)
+		t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
+		t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
+	})
+
+	t.Run("TUN", func(t *testing.T) {
+		lt := setupListenTest(t, true)
+		t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
+		t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
+	})
+}
+
+// TestListenTCPDualStack tests TCP listeners with wildcard addresses (dual-stack)
+// in both netstack and TUN modes.
+func TestListenTCPDualStack(t *testing.T) {
+	testListenTCPDualStack := func(t *testing.T, lt *listenTest, dialIP netip.Addr) {
+		ln, err := lt.s2.Listen("tcp", ":0")
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer ln.Close()
+
+		_, portStr, err := net.SplitHostPort(ln.Addr().String())
+		if err != nil {
+			t.Fatalf("parsing listener address %q: %v", ln.Addr().String(), err)
+		}
+
+		echoErr := make(chan error, 1)
+		go func() {
+			conn, err := ln.Accept()
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			defer conn.Close()
+			buf := make([]byte, 1024)
+			n, err := conn.Read(buf)
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			_, err = conn.Write(buf[:n])
+			if err != nil {
+				echoErr <- err
+				return
+			}
+		}()
+
+		dialAddr := net.JoinHostPort(dialIP.String(), portStr)
+		conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr)
+		if err != nil {
+			t.Fatalf("Dial(%q) failed: %v", dialAddr, err)
+		}
+		defer conn.Close()
+
+		want := "hello tcp dualstack"
+		if _, err := conn.Write([]byte(want)); err != nil {
+			t.Fatalf("Write failed: %v", err)
+		}
+
+		conn.SetReadDeadline(time.Now().Add(5 * time.Second))
+		got := make([]byte, 1024)
+		n, err := conn.Read(got)
+		if err != nil {
+			select {
+			case e := <-echoErr:
+				t.Fatalf("echo error: %v; read error: %v", e, err)
+			default:
+				t.Fatalf("Read failed: %v", err)
+			}
+		}
+
+		if string(got[:n]) != want {
+			t.Errorf("got %q, want %q", got[:n], want)
+		}
+	}
+
+	t.Run("Netstack", func(t *testing.T) {
+		lt := setupListenTest(t, false)
+		t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
+		t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
+	})
+
+	t.Run("TUN", func(t *testing.T) {
+		lt := setupListenTest(t, true)
+		t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
+		t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
+	})
+}
+
+// TestDialTCP tests TCP dialing from s2 to s1 in both netstack and TUN modes.
+// In TUN mode, this verifies that outbound TCP connections and their replies
+// are handled by netstack without packets escaping to the TUN.
+func TestDialTCP(t *testing.T) {
+	testDialTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
+		ln, err := lt.s1.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer ln.Close()
+
+		echoErr := make(chan error, 1)
+		go func() {
+			conn, err := ln.Accept()
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			defer conn.Close()
+			buf := make([]byte, 1024)
+			n, err := conn.Read(buf)
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			_, err = conn.Write(buf[:n])
+			if err != nil {
+				echoErr <- err
+				return
+			}
+		}()
+
+		conn, err := lt.s2.Dial(t.Context(), "tcp", ln.Addr().String())
+		if err != nil {
+			t.Fatalf("Dial failed: %v", err)
+		}
+		defer conn.Close()
+
+		want := "hello tcp dial"
+		if _, err := conn.Write([]byte(want)); err != nil {
+			t.Fatalf("Write failed: %v", err)
+		}
+
+		conn.SetReadDeadline(time.Now().Add(5 * time.Second))
+		got := make([]byte, 1024)
+		n, err := conn.Read(got)
+		if err != nil {
+			select {
+			case e := <-echoErr:
+				t.Fatalf("echo error: %v; read error: %v", e, err)
+			default:
+				t.Fatalf("Read failed: %v", err)
+			}
+		}
+
+		if string(got[:n]) != want {
+			t.Errorf("got %q, want %q", got[:n], want)
+		}
+	}
+
+	t.Run("Netstack", func(t *testing.T) {
+		lt := setupListenTest(t, false)
+		t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
+		t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
+	})
+
+	t.Run("TUN", func(t *testing.T) {
+		lt := setupListenTest(t, true)
+
+		var escapedTCPPackets atomic.Int32
+		var wg sync.WaitGroup
+		wg.Go(func() {
+			for pkt := range lt.tun.Inbound {
+				var p packet.Parsed
+				p.Decode(pkt)
+				if p.IPProto == ipproto.TCP {
+					escapedTCPPackets.Add(1)
+					t.Logf("TCP packet escaped to TUN: %v -> %v", p.Src, p.Dst)
+				}
+			}
+		})
+
+		t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
+		t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
+
+		lt.tun.Close()
+		wg.Wait()
+		if escaped := escapedTCPPackets.Load(); escaped > 0 {
+			t.Errorf("%d TCP packets escaped to TUN", escaped)
+		}
+	})
+}
+
+// TestDialUDP tests UDP dialing from s2 to s1 in both netstack and TUN modes.
+// In TUN mode, this verifies that outbound UDP connections register endpoints
+// with gVisor, allowing reply packets to be routed through netstack instead of
+// escaping to the TUN.
+func TestDialUDP(t *testing.T) {
+	testDialUDP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
+		pc, err := lt.s1.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer pc.Close()
+
+		echoErr := make(chan error, 1)
+		go func() {
+			buf := make([]byte, 1500)
+			n, addr, err := pc.ReadFrom(buf)
+			if err != nil {
+				echoErr <- err
+				return
+			}
+			_, err = pc.WriteTo(buf[:n], addr)
+			if err != nil {
+				echoErr <- err
+				return
+			}
+		}()
+
+		conn, err := lt.s2.Dial(t.Context(), "udp", pc.LocalAddr().String())
+		if err != nil {
+			t.Fatalf("Dial failed: %v", err)
+		}
+		defer conn.Close()
+
+		want := "hello udp dial"
+		if _, err := conn.Write([]byte(want)); err != nil {
+			t.Fatalf("Write failed: %v", err)
+		}
+
+		conn.SetReadDeadline(time.Now().Add(5 * time.Second))
+		got := make([]byte, 1024)
+		n, err := conn.Read(got)
+		if err != nil {
+			select {
+			case e := <-echoErr:
+				t.Fatalf("echo error: %v; read error: %v", e, err)
+			default:
+				t.Fatalf("Read failed: %v", err)
+			}
+		}
+
+		if string(got[:n]) != want {
+			t.Errorf("got %q, want %q", got[:n], want)
+		}
+	}
+
+	t.Run("Netstack", func(t *testing.T) {
+		lt := setupListenTest(t, false)
+		t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
+		t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
+	})
+
+	t.Run("TUN", func(t *testing.T) {
+		lt := setupListenTest(t, true)
+
+		var escapedUDPPackets atomic.Int32
+		var wg sync.WaitGroup
+		wg.Go(func() {
+			for pkt := range lt.tun.Inbound {
+				var p packet.Parsed
+				p.Decode(pkt)
+				if p.IPProto == ipproto.UDP {
+					escapedUDPPackets.Add(1)
+					t.Logf("UDP packet escaped to TUN: %v -> %v", p.Src, p.Dst)
+				}
+			}
+		})
+
+		t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
+		t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
+
+		lt.tun.Close()
+		wg.Wait()
+		if escaped := escapedUDPPackets.Load(); escaped > 0 {
+			t.Errorf("%d UDP packets escaped to TUN", escaped)
+		}
+	})
+}
+
+// buildDNSQuery builds a UDP/IP packet containing a DNS query for name to the
+// Tailscale service IP (100.100.100.100 for IPv4, fd7a:115c:a1e0::53 for IPv6).
+func buildDNSQuery(name string, srcIP netip.Addr) []byte {
+	qtype := byte(0x01) // Type A for IPv4
+	if srcIP.Is6() {
+		qtype = 0x1c // Type AAAA for IPv6
+	}
+	dns := []byte{
+		0x12, 0x34, // ID
+		0x01, 0x00, // Flags: standard query, recursion desired
+		0x00, 0x01, // QDCOUNT: 1
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ANCOUNT, NSCOUNT, ARCOUNT
+	}
+	for _, label := range strings.Split(name, ".") {
+		dns = append(dns, byte(len(label)))
+		dns = append(dns, label...)
+	}
+	dns = append(dns, 0x00, 0x00, qtype, 0x00, 0x01) // null, Type A/AAAA, Class IN
+
+	if srcIP.Is4() {
+		h := packet.UDP4Header{
+			IP4Header: packet.IP4Header{
+				Src: srcIP,
+				Dst: netip.MustParseAddr("100.100.100.100"),
+			},
+			SrcPort: 12345,
+			DstPort: 53,
+		}
+		return packet.Generate(h, dns)
+	}
+	h := packet.UDP6Header{
+		IP6Header: packet.IP6Header{
+			Src: srcIP,
+			Dst: netip.MustParseAddr("fd7a:115c:a1e0::53"),
+		},
+		SrcPort: 12345,
+		DstPort: 53,
+	}
+	return packet.Generate(h, dns)
+}
+
 func TestDeps(t *testing.T) {
 	tstest.Shard(t)
 	deptest.DepChecker{

+ 85 - 1
wgengine/netstack/netstack.go

@@ -165,6 +165,17 @@ type Impl struct {
 	// over the UDP flow.
 	GetUDPHandlerForFlow func(src, dst netip.AddrPort) (handler func(nettype.ConnPacketConn), intercept bool)
 
+	// CheckLocalTransportEndpoints, if true, causes netstack to check if gVisor
+	// has a registered endpoint for incoming packets to local IPs. This is used
+	// by tsnet to intercept packets for registered listeners and outbound
+	// connections when ProcessLocalIPs is false (i.e., when using a TUN).
+	// It can only be set before calling Start.
+	// TODO(raggi): refactor the way we handle both CheckLocalTransportEndpoints
+	// and the earlier netstack registrations for serve, funnel, peerAPI and so
+	// on. Currently this optimizes away cost for tailscaled in TUN mode, while
+	// enabling extension support when using tsnet in TUN mode. See #18423.
+	CheckLocalTransportEndpoints bool
+
 	// ProcessLocalIPs is whether netstack should handle incoming
 	// traffic directed at the Node.Addresses (local IPs).
 	// It can only be set before calling Start.
@@ -1109,6 +1120,45 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool {
 	if ns.ProcessSubnets && !isLocal {
 		return true
 	}
+	if isLocal && ns.CheckLocalTransportEndpoints {
+		// Handle packets to registered listeners and replies to outbound
+		// connections by checking if gVisor has a registered endpoint.
+		// This covers TCP listeners, UDP listeners, and outbound TCP replies.
+		if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP {
+			var netProto tcpip.NetworkProtocolNumber
+			var id stack.TransportEndpointID
+			if p.Dst.Addr().Is4() {
+				netProto = ipv4.ProtocolNumber
+				id = stack.TransportEndpointID{
+					LocalAddress:  tcpip.AddrFrom4(p.Dst.Addr().As4()),
+					LocalPort:     p.Dst.Port(),
+					RemoteAddress: tcpip.AddrFrom4(p.Src.Addr().As4()),
+					RemotePort:    p.Src.Port(),
+				}
+			} else {
+				netProto = ipv6.ProtocolNumber
+				id = stack.TransportEndpointID{
+					LocalAddress:  tcpip.AddrFrom16(p.Dst.Addr().As16()),
+					LocalPort:     p.Dst.Port(),
+					RemoteAddress: tcpip.AddrFrom16(p.Src.Addr().As16()),
+					RemotePort:    p.Src.Port(),
+				}
+			}
+			var transProto tcpip.TransportProtocolNumber
+			if p.IPProto == ipproto.TCP {
+				transProto = tcp.ProtocolNumber
+			} else {
+				transProto = udp.ProtocolNumber
+			}
+			ep := ns.ipstack.FindTransportEndpoint(netProto, transProto, id, nicID)
+			if debugNetstack() {
+				ns.logf("[v2] FindTransportEndpoint: id=%+v found=%v", id, ep != nil)
+			}
+			if ep != nil {
+				return true
+			}
+		}
+	}
 	return false
 }
 
@@ -1575,7 +1625,7 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
 func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) {
 	ap, err := netip.ParseAddrPort(address)
 	if err != nil {
-		return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %v", address, err)
+		return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err)
 	}
 
 	var networkProto tcpip.NetworkProtocolNumber
@@ -1612,6 +1662,40 @@ func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) {
 	return gonet.NewUDPConn(&wq, ep), nil
 }
 
+// ListenTCP listens for TCP connections on the given address.
+func (ns *Impl) ListenTCP(network, address string) (*gonet.TCPListener, error) {
+	ap, err := netip.ParseAddrPort(address)
+	if err != nil {
+		return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err)
+	}
+
+	var networkProto tcpip.NetworkProtocolNumber
+	switch network {
+	case "tcp4":
+		networkProto = ipv4.ProtocolNumber
+		if ap.Addr().IsValid() && !ap.Addr().Is4() {
+			return nil, fmt.Errorf("netstack: tcp4 requires an IPv4 address")
+		}
+	case "tcp6":
+		networkProto = ipv6.ProtocolNumber
+		if ap.Addr().IsValid() && !ap.Addr().Is6() {
+			return nil, fmt.Errorf("netstack: tcp6 requires an IPv6 address")
+		}
+	default:
+		return nil, fmt.Errorf("netstack: unsupported network %q", network)
+	}
+
+	localAddress := tcpip.FullAddress{
+		NIC:  nicID,
+		Port: ap.Port(),
+	}
+	if ap.Addr().IsValid() && !ap.Addr().IsUnspecified() {
+		localAddress.Addr = tcpip.AddrFromSlice(ap.Addr().AsSlice())
+	}
+
+	return gonet.ListenTCP(ns.ipstack, localAddress, networkProto)
+}
+
 func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
 	sess := r.ID()
 	if debugNetstack() {