Bläddra i källkod

net/socks5: fix UDP relay in userspace-networking mode

This commit addresses an issue with the SOCKS5 UDP relay functionality
when using the --tun=userspace-networking option. Previously, UDP packets
were not being correctly routed into the Tailscale network in this mode.

Key changes:
- Replace single UDP connection with a map of connections per target
- Use c.srv.dial for creating connections to ensure proper routing

Updates #7581

Change-Id: Iaaa66f9de6a3713218014cf3f498003a7cac9832
Signed-off-by: VimT <[email protected]>
VimT 1 år sedan
förälder
incheckning
b0626ff84c
1 ändrade filer med 63 tillägg och 38 borttagningar
  1. 63 38
      net/socks5/socks5.go

+ 63 - 38
net/socks5/socks5.go

@@ -22,6 +22,7 @@ import (
 	"log"
 	"net"
 	"strconv"
+	"tailscale.com/syncs"
 	"time"
 
 	"tailscale.com/types/logger"
@@ -81,6 +82,12 @@ const (
 	addrTypeNotSupported replyCode = 8
 )
 
+// UDP conn default buffer size and read timeout.
+const (
+	bufferSize  = 8 * 1024
+	readTimeout = 5 * time.Second
+)
+
 // Server is a SOCKS5 proxy server.
 type Server struct {
 	// Logf optionally specifies the logger to use.
@@ -143,7 +150,8 @@ type Conn struct {
 	clientConn net.Conn
 	request    *request
 
-	udpClientAddr net.Addr
+	udpClientAddr  net.Addr
+	udpTargetConns syncs.Map[string, net.Conn]
 }
 
 // Run starts the new connection.
@@ -276,15 +284,6 @@ func (c *Conn) handleUDP() error {
 	}
 	defer clientUDPConn.Close()
 
-	serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
-	if err != nil {
-		res := errorResponse(generalFailure)
-		buf, _ := res.marshal()
-		c.clientConn.Write(buf)
-		return err
-	}
-	defer serverUDPConn.Close()
-
 	bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
 	if err != nil {
 		return err
@@ -305,14 +304,20 @@ func (c *Conn) handleUDP() error {
 	}
 	c.clientConn.Write(buf)
 
-	return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
+	return c.transferUDP(c.clientConn, clientUDPConn)
 }
 
-func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
+func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error {
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
-	const bufferSize = 8 * 1024
-	const readTimeout = 5 * time.Second
+
+	// close all target udp connections when the client connection is closed
+	defer func() {
+		c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
+			_ = conn.Close()
+			return true
+		})
+	}()
 
 	// client -> target
 	go func() {
@@ -323,7 +328,7 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
 			case <-ctx.Done():
 				return
 			default:
-				err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
+				err := c.handleUDPRequest(ctx, clientConn, buf)
 				if err != nil {
 					if isTimeout(err) {
 						continue
@@ -337,21 +342,50 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
 		}
 	}()
 
+	// A UDP association terminates when the TCP connection that the UDP
+	// ASSOCIATE request arrived on terminates. RFC1928
+	_, err := io.Copy(io.Discard, associatedTCP)
+	if err != nil {
+		err = fmt.Errorf("udp associated tcp conn: %w", err)
+	}
+	return err
+}
+
+func (c *Conn) getOrDialTargetConn(
+	ctx context.Context,
+	clientConn net.PacketConn,
+	targetAddr string,
+) (net.Conn, error) {
+	host, port, err := splitHostPort(targetAddr)
+	if err != nil {
+		return nil, err
+	}
+
+	conn, loaded := c.udpTargetConns.Load(targetAddr)
+	if loaded {
+		return conn, nil
+	}
+	conn, err = c.srv.dial(ctx, "udp", targetAddr)
+	if err != nil {
+		return nil, err
+	}
+	c.udpTargetConns.Store(targetAddr, conn)
+
 	// target -> client
 	go func() {
-		defer cancel()
 		buf := make([]byte, bufferSize)
+		addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
 		for {
 			select {
 			case <-ctx.Done():
 				return
 			default:
-				err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
+				err := c.handleUDPResponse(clientConn, addr, conn, buf)
 				if err != nil {
 					if isTimeout(err) {
 						continue
 					}
-					if errors.Is(err, net.ErrClosed) {
+					if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
 						return
 					}
 					c.logf("udp transfer: handle udp response fail: %v", err)
@@ -360,20 +394,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
 		}
 	}()
 
-	// A UDP association terminates when the TCP connection that the UDP
-	// ASSOCIATE request arrived on terminates. RFC1928
-	_, err := io.Copy(io.Discard, associatedTCP)
-	if err != nil {
-		err = fmt.Errorf("udp associated tcp conn: %w", err)
-	}
-	return err
+	return conn, nil
 }
 
 func (c *Conn) handleUDPRequest(
+	ctx context.Context,
 	clientConn net.PacketConn,
-	targetConn net.PacketConn,
 	buf []byte,
-	readTimeout time.Duration,
 ) error {
 	// add a deadline for the read to avoid blocking forever
 	_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
@@ -386,12 +413,14 @@ func (c *Conn) handleUDPRequest(
 	if err != nil {
 		return fmt.Errorf("parse udp request: %w", err)
 	}
-	targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
+
+	targetAddr := req.addr.hostPort()
+	targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
 	if err != nil {
-		c.logf("resolve target addr fail: %v", err)
+		return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
 	}
 
-	nn, err := targetConn.WriteTo(data, targetAddr)
+	nn, err := targetConn.Write(data)
 	if err != nil {
 		return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
 	}
@@ -402,22 +431,18 @@ func (c *Conn) handleUDPRequest(
 }
 
 func (c *Conn) handleUDPResponse(
-	targetConn net.PacketConn,
 	clientConn net.PacketConn,
+	targetAddr socksAddr,
+	targetConn net.Conn,
 	buf []byte,
-	readTimeout time.Duration,
 ) error {
 	// add a deadline for the read to avoid blocking forever
 	_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
-	n, addr, err := targetConn.ReadFrom(buf)
+	n, err := targetConn.Read(buf)
 	if err != nil {
 		return fmt.Errorf("read from target: %w", err)
 	}
-	host, port, err := splitHostPort(addr.String())
-	if err != nil {
-		return fmt.Errorf("split host port: %w", err)
-	}
-	hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
+	hdr := udpRequest{addr: targetAddr}
 	pkt, err := hdr.marshal()
 	if err != nil {
 		return fmt.Errorf("marshal udp request: %w", err)