소스 검색

net/socks5: support UDP

Updates #7581

Signed-off-by: VimT <[email protected]>
VimT 1 년 전
부모
커밋
e3f047618b
2개의 변경된 파일484개의 추가작업 그리고 81개의 파일을 삭제
  1. 371 81
      net/socks5/socks5.go
  2. 113 0
      net/socks5/socks5_test.go

+ 371 - 81
net/socks5/socks5.go

@@ -13,8 +13,10 @@
 package socks5
 
 import (
+	"bytes"
 	"context"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"log"
@@ -121,7 +123,7 @@ func (s *Server) Serve(l net.Listener) error {
 		}
 		go func() {
 			defer c.Close()
-			conn := &Conn{clientConn: c, srv: s}
+			conn := &Conn{logf: s.Logf, clientConn: c, srv: s}
 			err := conn.Run()
 			if err != nil {
 				s.logf("client connection failed: %v", err)
@@ -136,9 +138,12 @@ type Conn struct {
 	// The struct is filled by each of the internal
 	// methods in turn as the transaction progresses.
 
+	logf       logger.Logf
 	srv        *Server
 	clientConn net.Conn
 	request    *request
+
+	udpClientAddr net.Addr
 }
 
 // Run starts the new connection.
@@ -172,58 +177,59 @@ func (c *Conn) Run() error {
 func (c *Conn) handleRequest() error {
 	req, err := parseClientRequest(c.clientConn)
 	if err != nil {
-		res := &response{reply: generalFailure}
+		res := errorResponse(generalFailure)
 		buf, _ := res.marshal()
 		c.clientConn.Write(buf)
 		return err
 	}
-	if req.command != connect {
-		res := &response{reply: commandNotSupported}
+
+	c.request = req
+	switch req.command {
+	case connect:
+		return c.handleTCP()
+	case udpAssociate:
+		return c.handleUDP()
+	default:
+		res := errorResponse(commandNotSupported)
 		buf, _ := res.marshal()
 		c.clientConn.Write(buf)
 		return fmt.Errorf("unsupported command %v", req.command)
 	}
-	c.request = req
+}
 
+func (c *Conn) handleTCP() error {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
 	srv, err := c.srv.dial(
 		ctx,
 		"tcp",
-		net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))),
+		c.request.destination.hostPort(),
 	)
 	if err != nil {
-		res := &response{reply: generalFailure}
+		res := errorResponse(generalFailure)
 		buf, _ := res.marshal()
 		c.clientConn.Write(buf)
 		return err
 	}
 	defer srv.Close()
-	serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
+
+	localAddr := srv.LocalAddr().String()
+	serverAddr, serverPort, err := splitHostPort(localAddr)
 	if err != nil {
 		return err
 	}
-	serverPort, _ := strconv.Atoi(serverPortStr)
 
-	var bindAddrType addrType
-	if ip := net.ParseIP(serverAddr); ip != nil {
-		if ip.To4() != nil {
-			bindAddrType = ipv4
-		} else {
-			bindAddrType = ipv6
-		}
-	} else {
-		bindAddrType = domainName
-	}
 	res := &response{
-		reply:        success,
-		bindAddrType: bindAddrType,
-		bindAddr:     serverAddr,
-		bindPort:     uint16(serverPort),
+		reply: success,
+		bindAddr: socksAddr{
+			addrType: getAddrType(serverAddr),
+			addr:     serverAddr,
+			port:     serverPort,
+		},
 	}
 	buf, err := res.marshal()
 	if err != nil {
-		res = &response{reply: generalFailure}
+		res = errorResponse(generalFailure)
 		buf, _ = res.marshal()
 	}
 	c.clientConn.Write(buf)
@@ -246,6 +252,208 @@ func (c *Conn) handleRequest() error {
 	return <-errc
 }
 
+func (c *Conn) handleUDP() error {
+	// The DST.ADDR and DST.PORT fields contain the address and port that
+	// the client expects to use to send UDP datagrams on for the
+	// association. The server MAY use this information to limit access
+	// to the association.
+	// @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
+	//
+	// We do NOT limit the access from the client currently in this implementation.
+	_ = c.request.destination
+
+	addr := c.clientConn.LocalAddr()
+	host, _, err := net.SplitHostPort(addr.String())
+	if err != nil {
+		return err
+	}
+	clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0"))
+	if err != nil {
+		res := errorResponse(generalFailure)
+		buf, _ := res.marshal()
+		c.clientConn.Write(buf)
+		return err
+	}
+	defer clientUDPConn.Close()
+
+	serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
+	if err != nil {
+		res := errorResponse(generalFailure)
+		buf, _ := res.marshal()
+		c.clientConn.Write(buf)
+		return err
+	}
+	defer serverUDPConn.Close()
+
+	bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
+	if err != nil {
+		return err
+	}
+
+	res := &response{
+		reply: success,
+		bindAddr: socksAddr{
+			addrType: getAddrType(bindAddr),
+			addr:     bindAddr,
+			port:     bindPort,
+		},
+	}
+	buf, err := res.marshal()
+	if err != nil {
+		res = errorResponse(generalFailure)
+		buf, _ = res.marshal()
+	}
+	c.clientConn.Write(buf)
+
+	return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
+}
+
+func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	const bufferSize = 8 * 1024
+	const readTimeout = 5 * time.Second
+
+	// client -> target
+	go func() {
+		defer cancel()
+		buf := make([]byte, bufferSize)
+		for {
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
+				if err != nil {
+					if isTimeout(err) {
+						continue
+					}
+					if errors.Is(err, net.ErrClosed) {
+						return
+					}
+					c.logf("udp transfer: handle udp request fail: %v", err)
+				}
+			}
+		}
+	}()
+
+	// target -> client
+	go func() {
+		defer cancel()
+		buf := make([]byte, bufferSize)
+		for {
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
+				if err != nil {
+					if isTimeout(err) {
+						continue
+					}
+					if errors.Is(err, net.ErrClosed) {
+						return
+					}
+					c.logf("udp transfer: handle udp response fail: %v", err)
+				}
+			}
+		}
+	}()
+
+	// A UDP association terminates when the TCP connection that the UDP
+	// ASSOCIATE request arrived on terminates. RFC1928
+	_, err := io.Copy(io.Discard, associatedTCP)
+	if err != nil {
+		err = fmt.Errorf("udp associated tcp conn: %w", err)
+	}
+	return err
+}
+
+func (c *Conn) handleUDPRequest(
+	clientConn net.PacketConn,
+	targetConn net.PacketConn,
+	buf []byte,
+	readTimeout time.Duration,
+) error {
+	// add a deadline for the read to avoid blocking forever
+	_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
+	n, addr, err := clientConn.ReadFrom(buf)
+	if err != nil {
+		return fmt.Errorf("read from client: %w", err)
+	}
+	c.udpClientAddr = addr
+	req, data, err := parseUDPRequest(buf[:n])
+	if err != nil {
+		return fmt.Errorf("parse udp request: %w", err)
+	}
+	targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
+	if err != nil {
+		c.logf("resolve target addr fail: %v", err)
+	}
+
+	nn, err := targetConn.WriteTo(data, targetAddr)
+	if err != nil {
+		return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
+	}
+	if nn != len(data) {
+		return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
+	}
+	return nil
+}
+
+func (c *Conn) handleUDPResponse(
+	targetConn net.PacketConn,
+	clientConn net.PacketConn,
+	buf []byte,
+	readTimeout time.Duration,
+) error {
+	// add a deadline for the read to avoid blocking forever
+	_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
+	n, addr, err := targetConn.ReadFrom(buf)
+	if err != nil {
+		return fmt.Errorf("read from target: %w", err)
+	}
+	host, port, err := splitHostPort(addr.String())
+	if err != nil {
+		return fmt.Errorf("split host port: %w", err)
+	}
+	hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
+	pkt, err := hdr.marshal()
+	if err != nil {
+		return fmt.Errorf("marshal udp request: %w", err)
+	}
+	data := append(pkt, buf[:n]...)
+	// use addr from client to send back
+	nn, err := clientConn.WriteTo(data, c.udpClientAddr)
+	if err != nil {
+		return fmt.Errorf("write to client: %w", err)
+	}
+	if nn != len(data) {
+		return fmt.Errorf("write to client: %w", io.ErrShortWrite)
+	}
+	return nil
+}
+
+func isTimeout(err error) bool {
+	terr, ok := errors.Unwrap(err).(interface{ Timeout() bool })
+	return ok && terr.Timeout()
+}
+
+func splitHostPort(hostport string) (host string, port uint16, err error) {
+	host, portStr, err := net.SplitHostPort(hostport)
+	if err != nil {
+		return "", 0, err
+	}
+	portInt, err := strconv.Atoi(portStr)
+	if err != nil {
+		return "", 0, err
+	}
+	if portInt < 0 || portInt > 65535 {
+		return "", 0, fmt.Errorf("invalid port number %d", portInt)
+	}
+	return host, uint16(portInt), nil
+}
+
 // parseClientGreeting parses a request initiation packet.
 func parseClientGreeting(r io.Reader, authMethod byte) error {
 	var hdr [2]byte
@@ -295,123 +503,205 @@ func parseClientAuth(r io.Reader) (usr, pwd string, err error) {
 	return string(usrBytes), string(pwdBytes), nil
 }
 
+func getAddrType(addr string) addrType {
+	if ip := net.ParseIP(addr); ip != nil {
+		if ip.To4() != nil {
+			return ipv4
+		}
+		return ipv6
+	}
+	return domainName
+}
+
 // request represents data contained within a SOCKS5
 // connection request packet.
 type request struct {
-	command      commandType
-	destination  string
-	port         uint16
-	destAddrType addrType
+	command     commandType
+	destination socksAddr
 }
 
 // parseClientRequest converts raw packet bytes into a
 // SOCKS5Request struct.
 func parseClientRequest(r io.Reader) (*request, error) {
-	var hdr [4]byte
+	var hdr [3]byte
 	_, err := io.ReadFull(r, hdr[:])
 	if err != nil {
 		return nil, fmt.Errorf("could not read packet header")
 	}
 	cmd := hdr[1]
-	destAddrType := addrType(hdr[3])
 
-	var destination string
-	var port uint16
+	destination, err := parseSocksAddr(r)
+	return &request{
+		command:     commandType(cmd),
+		destination: destination,
+	}, err
+}
+
+type socksAddr struct {
+	addrType addrType
+	addr     string
+	port     uint16
+}
+
+var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
+
+func parseSocksAddr(r io.Reader) (addr socksAddr, err error) {
+	var addrTypeData [1]byte
+	_, err = io.ReadFull(r, addrTypeData[:])
+	if err != nil {
+		return socksAddr{}, fmt.Errorf("could not read address type")
+	}
 
-	if destAddrType == ipv4 {
+	dstAddrType := addrType(addrTypeData[0])
+	var destination string
+	switch dstAddrType {
+	case ipv4:
 		var ip [4]byte
 		_, err = io.ReadFull(r, ip[:])
 		if err != nil {
-			return nil, fmt.Errorf("could not read IPv4 address")
+			return socksAddr{}, fmt.Errorf("could not read IPv4 address")
 		}
 		destination = net.IP(ip[:]).String()
-	} else if destAddrType == domainName {
+	case domainName:
 		var dstSizeByte [1]byte
 		_, err = io.ReadFull(r, dstSizeByte[:])
 		if err != nil {
-			return nil, fmt.Errorf("could not read domain name size")
+			return socksAddr{}, fmt.Errorf("could not read domain name size")
 		}
 		dstSize := int(dstSizeByte[0])
 		domainName := make([]byte, dstSize)
 		_, err = io.ReadFull(r, domainName)
 		if err != nil {
-			return nil, fmt.Errorf("could not read domain name")
+			return socksAddr{}, fmt.Errorf("could not read domain name")
 		}
 		destination = string(domainName)
-	} else if destAddrType == ipv6 {
+	case ipv6:
 		var ip [16]byte
 		_, err = io.ReadFull(r, ip[:])
 		if err != nil {
-			return nil, fmt.Errorf("could not read IPv6 address")
+			return socksAddr{}, fmt.Errorf("could not read IPv6 address")
 		}
 		destination = net.IP(ip[:]).String()
-	} else {
-		return nil, fmt.Errorf("unsupported address type")
+	default:
+		return socksAddr{}, fmt.Errorf("unsupported address type")
 	}
 	var portBytes [2]byte
 	_, err = io.ReadFull(r, portBytes[:])
 	if err != nil {
-		return nil, fmt.Errorf("could not read port")
+		return socksAddr{}, fmt.Errorf("could not read port")
 	}
-	port = binary.BigEndian.Uint16(portBytes[:])
-
-	return &request{
-		command:      commandType(cmd),
-		destination:  destination,
-		port:         port,
-		destAddrType: destAddrType,
+	port := binary.BigEndian.Uint16(portBytes[:])
+	return socksAddr{
+		addrType: dstAddrType,
+		addr:     destination,
+		port:     port,
 	}, nil
 }
 
+func (s socksAddr) marshal() ([]byte, error) {
+	var addr []byte
+	switch s.addrType {
+	case ipv4:
+		addr = net.ParseIP(s.addr).To4()
+		if addr == nil {
+			return nil, fmt.Errorf("invalid IPv4 address for binding")
+		}
+	case domainName:
+		if len(s.addr) > 255 {
+			return nil, fmt.Errorf("invalid domain name for binding")
+		}
+		addr = make([]byte, 0, len(s.addr)+1)
+		addr = append(addr, byte(len(s.addr)))
+		addr = append(addr, []byte(s.addr)...)
+	case ipv6:
+		addr = net.ParseIP(s.addr).To16()
+		if addr == nil {
+			return nil, fmt.Errorf("invalid IPv6 address for binding")
+		}
+	default:
+		return nil, fmt.Errorf("unsupported address type")
+	}
+
+	pkt := []byte{byte(s.addrType)}
+	pkt = append(pkt, addr...)
+	pkt = binary.BigEndian.AppendUint16(pkt, s.port)
+	return pkt, nil
+}
+func (s socksAddr) hostPort() string {
+	return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
+}
+
 // response contains the contents of
 // a response packet sent from the proxy
 // to the client.
 type response struct {
-	reply        replyCode
-	bindAddrType addrType
-	bindAddr     string
-	bindPort     uint16
+	reply    replyCode
+	bindAddr socksAddr
+}
+
+func errorResponse(code replyCode) *response {
+	return &response{reply: code, bindAddr: zeroSocksAddr}
 }
 
 // marshal converts a SOCKS5Response struct into
 // a packet. If res.reply == Success, it may throw an error on
 // receiving an invalid bind address. Otherwise, it will not throw.
 func (res *response) marshal() ([]byte, error) {
-	pkt := make([]byte, 4)
+	pkt := make([]byte, 3)
 	pkt[0] = socks5Version
 	pkt[1] = byte(res.reply)
 	pkt[2] = 0 // null reserved byte
-	pkt[3] = byte(res.bindAddrType)
 
-	if res.reply != success {
-		return pkt, nil
+	addrPkt, err := res.bindAddr.marshal()
+	if err != nil {
+		return nil, err
 	}
 
-	var addr []byte
-	switch res.bindAddrType {
-	case ipv4:
-		addr = net.ParseIP(res.bindAddr).To4()
-		if addr == nil {
-			return nil, fmt.Errorf("invalid IPv4 address for binding")
-		}
-	case domainName:
-		if len(res.bindAddr) > 255 {
-			return nil, fmt.Errorf("invalid domain name for binding")
-		}
-		addr = make([]byte, 0, len(res.bindAddr)+1)
-		addr = append(addr, byte(len(res.bindAddr)))
-		addr = append(addr, []byte(res.bindAddr)...)
-	case ipv6:
-		addr = net.ParseIP(res.bindAddr).To16()
-		if addr == nil {
-			return nil, fmt.Errorf("invalid IPv6 address for binding")
-		}
-	default:
-		return nil, fmt.Errorf("unsupported address type")
+	return append(pkt, addrPkt...), nil
+}
+
+type udpRequest struct {
+	frag byte
+	addr socksAddr
+}
+
+// +----+------+------+----------+----------+----------+
+// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
+// +----+------+------+----------+----------+----------+
+// | 2  |  1   |  1   | Variable |    2     | Variable |
+// +----+------+------+----------+----------+----------+
+func parseUDPRequest(data []byte) (*udpRequest, []byte, error) {
+	if len(data) < 4 {
+		return nil, nil, fmt.Errorf("invalid packet length")
 	}
 
-	pkt = append(pkt, addr...)
-	pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort))
+	// reserved bytes
+	if !(data[0] == 0 && data[1] == 0) {
+		return nil, nil, fmt.Errorf("invalid udp request header")
+	}
 
-	return pkt, nil
+	frag := data[2]
+
+	reader := bytes.NewReader(data[3:])
+	addr, err := parseSocksAddr(reader)
+	bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length
+	body := data[len(data)-bodyLen:]
+	return &udpRequest{
+		frag: frag,
+		addr: addr,
+	}, body, err
+}
+
+func (u *udpRequest) marshal() ([]byte, error) {
+	pkt := make([]byte, 3)
+	pkt[0] = 0
+	pkt[1] = 0
+	pkt[2] = u.frag
+
+	addrPkt, err := u.addr.marshal()
+	if err != nil {
+		return nil, err
+	}
+
+	return append(pkt, addrPkt...), nil
 }

+ 113 - 0
net/socks5/socks5_test.go

@@ -4,6 +4,7 @@
 package socks5
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"io"
@@ -32,6 +33,19 @@ func backendServer(listener net.Listener) {
 	listener.Close()
 }
 
+func udpEchoServer(conn net.PacketConn) {
+	var buf [1024]byte
+	n, addr, err := conn.ReadFrom(buf[:])
+	if err != nil {
+		panic(err)
+	}
+	_, err = conn.WriteTo(buf[:n], addr)
+	if err != nil {
+		panic(err)
+	}
+	conn.Close()
+}
+
 func TestRead(t *testing.T) {
 	// backend server which we'll use SOCKS5 to connect to
 	listener, err := net.Listen("tcp", ":0")
@@ -152,3 +166,102 @@ func TestReadPassword(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+func TestUDP(t *testing.T) {
+	// backend UDP server which we'll use SOCKS5 to connect to
+	listener, err := net.ListenPacket("udp", ":0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
+	go udpEchoServer(listener)
+
+	// SOCKS5 server
+	socks5, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	socks5Port := socks5.Addr().(*net.TCPAddr).Port
+	go socks5Server(socks5)
+
+	// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
+	conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
+	if err != nil {
+		t.Fatal(err)
+	}
+	buf := make([]byte, 1024)
+	n, err := conn.Read(buf) // server hello
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
+		t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
+	}
+
+	targetAddr := socksAddr{
+		addrType: domainName,
+		addr:     "localhost",
+		port:     uint16(backendServerPort),
+	}
+	targetAddrPkt, err := targetAddr.marshal()
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	n, err = conn.Read(buf) // server response
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
+		t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
+	}
+	udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
+	if err != nil {
+		t.Fatal(err)
+	}
+	udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
+	if err != nil {
+		t.Fatal(err)
+	}
+	udpPayload = append(udpPayload, []byte("Test")...)
+	_, err = udpConn.Write(udpPayload) // send udp package
+	if err != nil {
+		t.Fatal(err)
+	}
+	n, _, err = udpConn.ReadFrom(buf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
+	if err != nil {
+		t.Fatal(err)
+	}
+	if string(responseBody) != "Test" {
+		t.Fatalf("got: %q want: Test", responseBody)
+	}
+	err = udpConn.Close()
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = conn.Close()
+	if err != nil {
+		t.Fatal(err)
+	}
+}