Browse Source

net/socks5: optimize UDP relay

Key changes:
- No mutex for every udp package: replace syncs.Map with regular map for udpTargetConns
- Use socksAddr as map key for better type safety
- Add test for multi udp target

Updates #7581

Change-Id: Ic3d384a9eab62dcbf267d7d6d268bf242cc8ed3c
Signed-off-by: VimT <[email protected]>
VimT 1 year ago
parent
commit
43138c7a5c
2 changed files with 119 additions and 99 deletions
  1. 25 27
      net/socks5/socks5.go
  2. 94 72
      net/socks5/socks5_test.go

+ 25 - 27
net/socks5/socks5.go

@@ -22,7 +22,6 @@ import (
 	"log"
 	"net"
 	"strconv"
-	"tailscale.com/syncs"
 	"time"
 
 	"tailscale.com/types/logger"
@@ -151,7 +150,7 @@ type Conn struct {
 	request    *request
 
 	udpClientAddr  net.Addr
-	udpTargetConns syncs.Map[string, net.Conn]
+	udpTargetConns map[socksAddr]net.Conn
 }
 
 // Run starts the new connection.
@@ -311,17 +310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 
-	// close all target udp connections when the client connection is closed
-	defer func() {
-		c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
-			_ = conn.Close()
-			return true
-		})
-	}()
-
 	// client -> target
 	go func() {
 		defer cancel()
+
+		c.udpTargetConns = make(map[socksAddr]net.Conn)
+		// close all target udp connections when the client connection is closed
+		defer func() {
+			for _, conn := range c.udpTargetConns {
+				_ = conn.Close()
+			}
+		}()
+
 		buf := make([]byte, bufferSize)
 		for {
 			select {
@@ -354,33 +354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
 func (c *Conn) getOrDialTargetConn(
 	ctx context.Context,
 	clientConn net.PacketConn,
-	targetAddr string,
+	targetAddr socksAddr,
 ) (net.Conn, error) {
-	host, port, err := splitHostPort(targetAddr)
-	if err != nil {
-		return nil, err
-	}
-
-	conn, loaded := c.udpTargetConns.Load(targetAddr)
-	if loaded {
+	conn, exist := c.udpTargetConns[targetAddr]
+	if exist {
 		return conn, nil
 	}
-	conn, err = c.srv.dial(ctx, "udp", targetAddr)
+	conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
 	if err != nil {
 		return nil, err
 	}
-	c.udpTargetConns.Store(targetAddr, conn)
+	c.udpTargetConns[targetAddr] = conn
 
 	// target -> client
 	go func() {
 		buf := make([]byte, bufferSize)
-		addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
 		for {
 			select {
 			case <-ctx.Done():
 				return
 			default:
-				err := c.handleUDPResponse(clientConn, addr, conn, buf)
+				err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
 				if err != nil {
 					if isTimeout(err) {
 						continue
@@ -414,18 +408,17 @@ func (c *Conn) handleUDPRequest(
 		return fmt.Errorf("parse udp request: %w", err)
 	}
 
-	targetAddr := req.addr.hostPort()
-	targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
+	targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
 	if err != nil {
-		return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
+		return fmt.Errorf("dial target %s fail: %w", req.addr, err)
 	}
 
 	nn, err := targetConn.Write(data)
 	if err != nil {
-		return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
+		return fmt.Errorf("write to target %s fail: %w", req.addr, err)
 	}
 	if nn != len(data) {
-		return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
+		return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
 	}
 	return nil
 }
@@ -652,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) {
 	pkt = binary.BigEndian.AppendUint16(pkt, s.port)
 	return pkt, nil
 }
+
 func (s socksAddr) hostPort() string {
 	return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
 }
 
+func (s socksAddr) String() string {
+	return s.hostPort()
+}
+
 // response contains the contents of
 // a response packet sent from the proxy
 // to the client.

+ 94 - 72
net/socks5/socks5_test.go

@@ -169,12 +169,25 @@ func TestReadPassword(t *testing.T) {
 
 func TestUDP(t *testing.T) {
 	// backend UDP server which we'll use SOCKS5 to connect to
-	listener, err := net.ListenPacket("udp", ":0")
-	if err != nil {
-		t.Fatal(err)
+	newUDPEchoServer := func() net.PacketConn {
+		listener, err := net.ListenPacket("udp", ":0")
+		if err != nil {
+			t.Fatal(err)
+		}
+		go udpEchoServer(listener)
+		return listener
 	}
-	backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
-	go udpEchoServer(listener)
+
+	const echoServerNumber = 3
+	echoServerListener := make([]net.PacketConn, echoServerNumber)
+	for i := 0; i < echoServerNumber; i++ {
+		echoServerListener[i] = newUDPEchoServer()
+	}
+	defer func() {
+		for i := 0; i < echoServerNumber; i++ {
+			_ = echoServerListener[i].Close()
+		}
+	}()
 
 	// SOCKS5 server
 	socks5, err := net.Listen("tcp", ":0")
@@ -184,84 +197,93 @@ func TestUDP(t *testing.T) {
 	socks5Port := socks5.Addr().(*net.TCPAddr).Port
 	go socks5Server(socks5)
 
-	// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
-	conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
-	if err != nil {
-		t.Fatal(err)
-	}
-	_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
-	if err != nil {
-		t.Fatal(err)
-	}
-	buf := make([]byte, 1024)
-	n, err := conn.Read(buf) // server hello
-	if err != nil {
-		t.Fatal(err)
-	}
-	if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
-		t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
-	}
+	// make a socks5 udpAssociate conn
+	newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) {
+		// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
+		conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
+		if err != nil {
+			t.Fatal(err)
+		}
+		_, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth
+		if err != nil {
+			t.Fatal(err)
+		}
+		buf := make([]byte, 1024)
+		n, err := conn.Read(buf) // server hello
+		if err != nil {
+			t.Fatal(err)
+		}
+		if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired {
+			t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
+		}
 
-	targetAddr := socksAddr{
-		addrType: domainName,
-		addr:     "localhost",
-		port:     uint16(backendServerPort),
-	}
-	targetAddrPkt, err := targetAddr.marshal()
-	if err != nil {
-		t.Fatal(err)
-	}
-	_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
-	if err != nil {
-		t.Fatal(err)
-	}
+		targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
+		targetAddrPkt, err := targetAddr.marshal()
+		if err != nil {
+			t.Fatal(err)
+		}
+		_, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust
+		if err != nil {
+			t.Fatal(err)
+		}
 
-	n, err = conn.Read(buf) // server response
-	if err != nil {
-		t.Fatal(err)
-	}
-	if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
-		t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
+		n, err = conn.Read(buf) // server response
+		if err != nil {
+			t.Fatal(err)
+		}
+		if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) {
+			t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
+		}
+		udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		return conn, udpProxySocksAddr
 	}
-	udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
-	if err != nil {
-		t.Fatal(err)
+
+	conn, udpProxySocksAddr := newUdpAssociateConn()
+	defer conn.Close()
+
+	sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) {
+		udpPayload, err := (&udpRequest{addr: addr}).marshal()
+		if err != nil {
+			t.Fatal(err)
+		}
+		udpPayload = append(udpPayload, body...)
+		_, err = socks5UDPConn.Write(udpPayload)
+		if err != nil {
+			t.Fatal(err)
+		}
+		buf := make([]byte, 1024)
+		n, err := socks5UDPConn.Read(buf)
+		if err != nil {
+			t.Fatal(err)
+		}
+		_, responseBody, err = parseUDPRequest(buf[:n])
+		if err != nil {
+			t.Fatal(err)
+		}
+		return responseBody
 	}
 
 	udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
 	if err != nil {
 		t.Fatal(err)
 	}
-	udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
-	if err != nil {
-		t.Fatal(err)
-	}
-	udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
-	if err != nil {
-		t.Fatal(err)
-	}
-	udpPayload = append(udpPayload, []byte("Test")...)
-	_, err = udpConn.Write(udpPayload) // send udp package
-	if err != nil {
-		t.Fatal(err)
-	}
-	n, _, err = udpConn.ReadFrom(buf)
-	if err != nil {
-		t.Fatal(err)
-	}
-	_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
-	if err != nil {
-		t.Fatal(err)
-	}
-	if string(responseBody) != "Test" {
-		t.Fatalf("got: %q want: Test", responseBody)
-	}
-	err = udpConn.Close()
+	socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = conn.Close()
-	if err != nil {
-		t.Fatal(err)
+	defer socks5UDPConn.Close()
+
+	for i := 0; i < echoServerNumber; i++ {
+		port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port
+		addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)}
+		requestBody := []byte(fmt.Sprintf("Test %d", i))
+		responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody)
+		if !bytes.Equal(requestBody, responseBody) {
+			t.Fatalf("got: %q want: %q", responseBody, requestBody)
+		}
 	}
 }