瀏覽代碼

wgengine/netstack: implement UDP relaying to advertised subnets

TCP was done in 662fbd4a09664e849f0b898d1e8df13325d36efa.

This does the same for UDP.

Tested by hand. Integration tests will have to come later. I'd wanted
to do it in this commit, but the SOCKS5 server needed for interop
testing between two userspace nodes doesn't yet support UDP and I
didn't want to invent some whole new userspace packet injection
interface at this point, as SOCKS seems like a better route, but
that's its own bug.

Fixes #2302

RELNOTE=netstack mode can now UDP relay to subnets

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 年之前
父節點
當前提交
95a9adbb97
共有 1 個文件被更改,包括 96 次插入46 次删除
  1. 96 46
      wgengine/netstack/netstack.go

+ 96 - 46
wgengine/netstack/netstack.go

@@ -136,6 +136,24 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
 	return ns, nil
 }
 
+// wrapProtoHandler returns protocol handler h wrapped in a version
+// that dynamically reconfigures ns's subnet addresses as needed for
+// outbound traffic.
+func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, *stack.PacketBuffer) bool) func(stack.TransportEndpointID, *stack.PacketBuffer) bool {
+	return func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool {
+		addr := tei.LocalAddress
+		ip, ok := netaddr.FromStdIP(net.IP(addr))
+		if !ok {
+			ns.logf("netstack: could not parse local address for incoming connection")
+			return false
+		}
+		if !ns.isLocalIP(ip) {
+			ns.addSubnetAddress(ip)
+		}
+		return h(tei, pb)
+	}
+}
+
 // Start sets up all the handlers so netstack can start working. Implements
 // wgengine.FakeImpl.
 func (ns *Impl) Start() error {
@@ -145,25 +163,8 @@ func (ns *Impl) Start() error {
 	const maxInFlightConnectionAttempts = 16
 	tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
 	udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
-	ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool {
-		addr := tei.LocalAddress
-		var pn tcpip.NetworkProtocolNumber
-		if addr.To4() != "" {
-			pn = ipv4.ProtocolNumber
-		} else {
-			pn = ipv6.ProtocolNumber
-		}
-		ip, ok := netaddr.FromStdIP(net.IP(addr))
-		if !ok {
-			ns.logf("netstack: could not parse local address %s for incoming TCP connection", ip)
-			return false
-		}
-		if !ns.isLocalIP(ip) {
-			ns.addSubnetAddress(pn, ip)
-		}
-		return tcpFwd.HandlePacket(tei, pb)
-	})
-	ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
+	ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapProtoHandler(tcpFwd.HandlePacket))
+	ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapProtoHandler(udpFwd.HandlePacket))
 	go ns.injectOutbound()
 	ns.tundev.PostFilterIn = ns.injectInbound
 	return nil
@@ -214,13 +215,19 @@ func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
 	ns.dns = DNSMapFromNetworkMap(nm)
 }
 
-func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, ip netaddr.IP) {
+func (ns *Impl) addSubnetAddress(ip netaddr.IP) {
 	ns.mu.Lock()
 	ns.connsOpenBySubnetIP[ip]++
 	needAdd := ns.connsOpenBySubnetIP[ip] == 1
 	ns.mu.Unlock()
 	// Only register address into netstack for first concurrent connection.
 	if needAdd {
+		var pn tcpip.NetworkProtocolNumber
+		if ip.Is4() {
+			pn = ipv4.ProtocolNumber
+		} else if ip.Is6() {
+			pn = ipv6.ProtocolNumber
+		}
 		ns.ipstack.AddAddress(nicID, pn, tcpip.Address(ip.IPAddr().IP))
 	}
 }
@@ -543,9 +550,9 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcp
 }
 
 func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
-	reqDetails := r.ID()
+	sess := r.ID()
 	if debugNetstack {
-		ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(reqDetails))
+		ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(sess))
 	}
 	var wq waiter.Queue
 	ep, err := r.CreateEndpoint(&wq)
@@ -553,30 +560,50 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
 		ns.logf("acceptUDP: could not create endpoint: %v", err)
 		return
 	}
-	localAddr, err := ep.GetLocalAddress()
-	if err != nil {
+	dstAddr, ok := ipPortOfNetstackAddr(sess.LocalAddress, sess.LocalPort)
+	if !ok {
 		return
 	}
-	remoteAddr, err := ep.GetRemoteAddress()
-	if err != nil {
+	srcAddr, ok := ipPortOfNetstackAddr(sess.RemoteAddress, sess.RemotePort)
+	if !ok {
 		return
 	}
+
 	c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
-	go ns.forwardUDP(c, &wq, localAddr, remoteAddr)
+	go ns.forwardUDP(c, &wq, srcAddr, dstAddr)
 }
 
-func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) {
-	port := clientLocalAddr.Port
+// forwardUDP proxies between client (with addr clientAddr) and dstAddr.
+//
+// dstAddr may be either a local Tailscale IP, in which we case we proxy to
+// 127.0.0.1, or any other IP (from an advertised subnet), in which case we
+// proxy to it directly.
+func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientAddr, dstAddr netaddr.IPPort) {
+	port, srcPort := dstAddr.Port(), clientAddr.Port()
 	ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port)
-	backendListenAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(clientRemoteAddr.Port)}
-	backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
-	backendConn, err := net.ListenUDP("udp4", backendListenAddr)
+
+	var backendListenAddr *net.UDPAddr
+	var backendRemoteAddr *net.UDPAddr
+	isLocal := ns.isLocalIP(dstAddr.IP())
+	if isLocal {
+		backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
+		backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)}
+	} else {
+		backendRemoteAddr = dstAddr.UDPAddr()
+		if dstAddr.IP().Is4() {
+			backendListenAddr = &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: int(srcPort)}
+		} else {
+			backendListenAddr = &net.UDPAddr{IP: net.ParseIP("::"), Port: int(srcPort)}
+		}
+	}
+
+	backendConn, err := net.ListenUDP("udp", backendListenAddr)
 	if err != nil {
-		ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err)
+		ns.logf("netstack: could not bind local port %v: %v, trying again with random port", backendListenAddr.Port, err)
 		backendListenAddr.Port = 0
-		backendConn, err = net.ListenUDP("udp4", backendListenAddr)
+		backendConn, err = net.ListenUDP("udp", backendListenAddr)
 		if err != nil {
-			ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err)
+			ns.logf("netstack: could not create UDP socket, preventing forwarding to %v: %v", dstAddr, err)
 			return
 		}
 	}
@@ -585,28 +612,47 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA
 	if !ok {
 		ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port)
 	}
-	clientRemoteIP, _ := netaddr.FromStdIP(net.ParseIP(clientRemoteAddr.Addr.String()))
-	ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
+	if isLocal {
+		ns.e.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.IP())
+	}
 	ctx, cancel := context.WithCancel(context.Background())
-	timer := time.AfterFunc(2*time.Minute, func() {
-		ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
-		ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr)
+
+	idleTimeout := 2 * time.Minute
+	if port == 53 {
+		// Make DNS packet copies time out much sooner.
+		//
+		// TODO(bradfitz): make DNS queries over UDP forwarding even
+		// cheaper by adding an additional idleTimeout post-DNS-reply.
+		// For instance, after the DNS response goes back out, then only
+		// wait a few seconds (or zero, really)
+		idleTimeout = 30 * time.Second
+	}
+	timer := time.AfterFunc(idleTimeout, func() {
+		if isLocal {
+			ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
+		}
+		ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr)
 		cancel()
 		client.Close()
 		backendConn.Close()
 	})
 	extend := func() {
-		timer.Reset(2 * time.Minute)
+		timer.Reset(idleTimeout)
 	}
-	startPacketCopy(ctx, cancel, client, &net.UDPAddr{
-		IP:   net.ParseIP(clientRemoteAddr.Addr.String()),
-		Port: int(clientRemoteAddr.Port),
-	}, backendConn, ns.logf, extend)
+	startPacketCopy(ctx, cancel, client, clientAddr.UDPAddr(), backendConn, ns.logf, extend)
 	startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend)
-
+	if isLocal {
+		// Wait for the copies to be done before decrementing the
+		// subnet address count to potentially remove the route.
+		<-ctx.Done()
+		ns.removeSubnetAddress(dstAddr.IP())
+	}
 }
 
 func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) {
+	if debugNetstack {
+		logf("[v2] netstack: startPacketCopy to %v (%T) from %T", dstAddr, dst, src)
+	}
 	go func() {
 		defer cancel() // tear down the other direction's copy
 		pkt := make([]byte, mtu)
@@ -643,3 +689,7 @@ func stringifyTEI(tei stack.TransportEndpointID) string {
 	remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort)))
 	return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort)
 }
+
+func ipPortOfNetstackAddr(a tcpip.Address, port uint16) (ipp netaddr.IPPort, ok bool) {
+	return netaddr.FromStdAddr(net.IP(a), int(port), "") // TODO(bradfitz): can do without allocs
+}