|
|
@@ -13,8 +13,10 @@
|
|
|
package socks5
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"context"
|
|
|
"encoding/binary"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log"
|
|
|
@@ -121,7 +123,7 @@ func (s *Server) Serve(l net.Listener) error {
|
|
|
}
|
|
|
go func() {
|
|
|
defer c.Close()
|
|
|
- conn := &Conn{clientConn: c, srv: s}
|
|
|
+ conn := &Conn{logf: s.Logf, clientConn: c, srv: s}
|
|
|
err := conn.Run()
|
|
|
if err != nil {
|
|
|
s.logf("client connection failed: %v", err)
|
|
|
@@ -136,9 +138,12 @@ type Conn struct {
|
|
|
// The struct is filled by each of the internal
|
|
|
// methods in turn as the transaction progresses.
|
|
|
|
|
|
+ logf logger.Logf
|
|
|
srv *Server
|
|
|
clientConn net.Conn
|
|
|
request *request
|
|
|
+
|
|
|
+ udpClientAddr net.Addr
|
|
|
}
|
|
|
|
|
|
// Run starts the new connection.
|
|
|
@@ -172,58 +177,59 @@ func (c *Conn) Run() error {
|
|
|
func (c *Conn) handleRequest() error {
|
|
|
req, err := parseClientRequest(c.clientConn)
|
|
|
if err != nil {
|
|
|
- res := &response{reply: generalFailure}
|
|
|
+ res := errorResponse(generalFailure)
|
|
|
buf, _ := res.marshal()
|
|
|
c.clientConn.Write(buf)
|
|
|
return err
|
|
|
}
|
|
|
- if req.command != connect {
|
|
|
- res := &response{reply: commandNotSupported}
|
|
|
+
|
|
|
+ c.request = req
|
|
|
+ switch req.command {
|
|
|
+ case connect:
|
|
|
+ return c.handleTCP()
|
|
|
+ case udpAssociate:
|
|
|
+ return c.handleUDP()
|
|
|
+ default:
|
|
|
+ res := errorResponse(commandNotSupported)
|
|
|
buf, _ := res.marshal()
|
|
|
c.clientConn.Write(buf)
|
|
|
return fmt.Errorf("unsupported command %v", req.command)
|
|
|
}
|
|
|
- c.request = req
|
|
|
+}
|
|
|
|
|
|
+func (c *Conn) handleTCP() error {
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
defer cancel()
|
|
|
srv, err := c.srv.dial(
|
|
|
ctx,
|
|
|
"tcp",
|
|
|
- net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))),
|
|
|
+ c.request.destination.hostPort(),
|
|
|
)
|
|
|
if err != nil {
|
|
|
- res := &response{reply: generalFailure}
|
|
|
+ res := errorResponse(generalFailure)
|
|
|
buf, _ := res.marshal()
|
|
|
c.clientConn.Write(buf)
|
|
|
return err
|
|
|
}
|
|
|
defer srv.Close()
|
|
|
- serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
|
|
|
+
|
|
|
+ localAddr := srv.LocalAddr().String()
|
|
|
+ serverAddr, serverPort, err := splitHostPort(localAddr)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- serverPort, _ := strconv.Atoi(serverPortStr)
|
|
|
|
|
|
- var bindAddrType addrType
|
|
|
- if ip := net.ParseIP(serverAddr); ip != nil {
|
|
|
- if ip.To4() != nil {
|
|
|
- bindAddrType = ipv4
|
|
|
- } else {
|
|
|
- bindAddrType = ipv6
|
|
|
- }
|
|
|
- } else {
|
|
|
- bindAddrType = domainName
|
|
|
- }
|
|
|
res := &response{
|
|
|
- reply: success,
|
|
|
- bindAddrType: bindAddrType,
|
|
|
- bindAddr: serverAddr,
|
|
|
- bindPort: uint16(serverPort),
|
|
|
+ reply: success,
|
|
|
+ bindAddr: socksAddr{
|
|
|
+ addrType: getAddrType(serverAddr),
|
|
|
+ addr: serverAddr,
|
|
|
+ port: serverPort,
|
|
|
+ },
|
|
|
}
|
|
|
buf, err := res.marshal()
|
|
|
if err != nil {
|
|
|
- res = &response{reply: generalFailure}
|
|
|
+ res = errorResponse(generalFailure)
|
|
|
buf, _ = res.marshal()
|
|
|
}
|
|
|
c.clientConn.Write(buf)
|
|
|
@@ -246,6 +252,208 @@ func (c *Conn) handleRequest() error {
|
|
|
return <-errc
|
|
|
}
|
|
|
|
|
|
+func (c *Conn) handleUDP() error {
|
|
|
+ // The DST.ADDR and DST.PORT fields contain the address and port that
|
|
|
+ // the client expects to use to send UDP datagrams on for the
|
|
|
+ // association. The server MAY use this information to limit access
|
|
|
+ // to the association.
|
|
|
+ // @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
|
|
|
+ //
|
|
|
+ // We do NOT limit the access from the client currently in this implementation.
|
|
|
+ _ = c.request.destination
|
|
|
+
|
|
|
+ addr := c.clientConn.LocalAddr()
|
|
|
+ host, _, err := net.SplitHostPort(addr.String())
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0"))
|
|
|
+ if err != nil {
|
|
|
+ res := errorResponse(generalFailure)
|
|
|
+ buf, _ := res.marshal()
|
|
|
+ c.clientConn.Write(buf)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer clientUDPConn.Close()
|
|
|
+
|
|
|
+ serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
|
|
|
+ if err != nil {
|
|
|
+ res := errorResponse(generalFailure)
|
|
|
+ buf, _ := res.marshal()
|
|
|
+ c.clientConn.Write(buf)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer serverUDPConn.Close()
|
|
|
+
|
|
|
+ bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ res := &response{
|
|
|
+ reply: success,
|
|
|
+ bindAddr: socksAddr{
|
|
|
+ addrType: getAddrType(bindAddr),
|
|
|
+ addr: bindAddr,
|
|
|
+ port: bindPort,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ buf, err := res.marshal()
|
|
|
+ if err != nil {
|
|
|
+ res = errorResponse(generalFailure)
|
|
|
+ buf, _ = res.marshal()
|
|
|
+ }
|
|
|
+ c.clientConn.Write(buf)
|
|
|
+
|
|
|
+ return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
+ defer cancel()
|
|
|
+ const bufferSize = 8 * 1024
|
|
|
+ const readTimeout = 5 * time.Second
|
|
|
+
|
|
|
+ // client -> target
|
|
|
+ go func() {
|
|
|
+ defer cancel()
|
|
|
+ buf := make([]byte, bufferSize)
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
|
|
|
+ if err != nil {
|
|
|
+ if isTimeout(err) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if errors.Is(err, net.ErrClosed) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.logf("udp transfer: handle udp request fail: %v", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // target -> client
|
|
|
+ go func() {
|
|
|
+ defer cancel()
|
|
|
+ buf := make([]byte, bufferSize)
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
|
|
|
+ if err != nil {
|
|
|
+ if isTimeout(err) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if errors.Is(err, net.ErrClosed) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.logf("udp transfer: handle udp response fail: %v", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // A UDP association terminates when the TCP connection that the UDP
|
|
|
+ // ASSOCIATE request arrived on terminates. RFC1928
|
|
|
+ _, err := io.Copy(io.Discard, associatedTCP)
|
|
|
+ if err != nil {
|
|
|
+ err = fmt.Errorf("udp associated tcp conn: %w", err)
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Conn) handleUDPRequest(
|
|
|
+ clientConn net.PacketConn,
|
|
|
+ targetConn net.PacketConn,
|
|
|
+ buf []byte,
|
|
|
+ readTimeout time.Duration,
|
|
|
+) error {
|
|
|
+ // add a deadline for the read to avoid blocking forever
|
|
|
+ _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
|
+ n, addr, err := clientConn.ReadFrom(buf)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("read from client: %w", err)
|
|
|
+ }
|
|
|
+ c.udpClientAddr = addr
|
|
|
+ req, data, err := parseUDPRequest(buf[:n])
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("parse udp request: %w", err)
|
|
|
+ }
|
|
|
+ targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
|
|
|
+ if err != nil {
|
|
|
+ c.logf("resolve target addr fail: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ nn, err := targetConn.WriteTo(data, targetAddr)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
|
|
|
+ }
|
|
|
+ if nn != len(data) {
|
|
|
+ return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Conn) handleUDPResponse(
|
|
|
+ targetConn net.PacketConn,
|
|
|
+ clientConn net.PacketConn,
|
|
|
+ buf []byte,
|
|
|
+ readTimeout time.Duration,
|
|
|
+) error {
|
|
|
+ // add a deadline for the read to avoid blocking forever
|
|
|
+ _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
|
+ n, addr, err := targetConn.ReadFrom(buf)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("read from target: %w", err)
|
|
|
+ }
|
|
|
+ host, port, err := splitHostPort(addr.String())
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("split host port: %w", err)
|
|
|
+ }
|
|
|
+ hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
|
|
|
+ pkt, err := hdr.marshal()
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("marshal udp request: %w", err)
|
|
|
+ }
|
|
|
+ data := append(pkt, buf[:n]...)
|
|
|
+ // use addr from client to send back
|
|
|
+ nn, err := clientConn.WriteTo(data, c.udpClientAddr)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("write to client: %w", err)
|
|
|
+ }
|
|
|
+ if nn != len(data) {
|
|
|
+ return fmt.Errorf("write to client: %w", io.ErrShortWrite)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func isTimeout(err error) bool {
|
|
|
+ terr, ok := errors.Unwrap(err).(interface{ Timeout() bool })
|
|
|
+ return ok && terr.Timeout()
|
|
|
+}
|
|
|
+
|
|
|
+func splitHostPort(hostport string) (host string, port uint16, err error) {
|
|
|
+ host, portStr, err := net.SplitHostPort(hostport)
|
|
|
+ if err != nil {
|
|
|
+ return "", 0, err
|
|
|
+ }
|
|
|
+ portInt, err := strconv.Atoi(portStr)
|
|
|
+ if err != nil {
|
|
|
+ return "", 0, err
|
|
|
+ }
|
|
|
+ if portInt < 0 || portInt > 65535 {
|
|
|
+ return "", 0, fmt.Errorf("invalid port number %d", portInt)
|
|
|
+ }
|
|
|
+ return host, uint16(portInt), nil
|
|
|
+}
|
|
|
+
|
|
|
// parseClientGreeting parses a request initiation packet.
|
|
|
func parseClientGreeting(r io.Reader, authMethod byte) error {
|
|
|
var hdr [2]byte
|
|
|
@@ -295,123 +503,205 @@ func parseClientAuth(r io.Reader) (usr, pwd string, err error) {
|
|
|
return string(usrBytes), string(pwdBytes), nil
|
|
|
}
|
|
|
|
|
|
+func getAddrType(addr string) addrType {
|
|
|
+ if ip := net.ParseIP(addr); ip != nil {
|
|
|
+ if ip.To4() != nil {
|
|
|
+ return ipv4
|
|
|
+ }
|
|
|
+ return ipv6
|
|
|
+ }
|
|
|
+ return domainName
|
|
|
+}
|
|
|
+
|
|
|
// request represents data contained within a SOCKS5
|
|
|
// connection request packet.
|
|
|
type request struct {
|
|
|
- command commandType
|
|
|
- destination string
|
|
|
- port uint16
|
|
|
- destAddrType addrType
|
|
|
+ command commandType
|
|
|
+ destination socksAddr
|
|
|
}
|
|
|
|
|
|
// parseClientRequest converts raw packet bytes into a
|
|
|
// SOCKS5Request struct.
|
|
|
func parseClientRequest(r io.Reader) (*request, error) {
|
|
|
- var hdr [4]byte
|
|
|
+ var hdr [3]byte
|
|
|
_, err := io.ReadFull(r, hdr[:])
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("could not read packet header")
|
|
|
}
|
|
|
cmd := hdr[1]
|
|
|
- destAddrType := addrType(hdr[3])
|
|
|
|
|
|
- var destination string
|
|
|
- var port uint16
|
|
|
+ destination, err := parseSocksAddr(r)
|
|
|
+ return &request{
|
|
|
+ command: commandType(cmd),
|
|
|
+ destination: destination,
|
|
|
+ }, err
|
|
|
+}
|
|
|
+
|
|
|
+type socksAddr struct {
|
|
|
+ addrType addrType
|
|
|
+ addr string
|
|
|
+ port uint16
|
|
|
+}
|
|
|
+
|
|
|
+var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
|
|
|
+
|
|
|
+func parseSocksAddr(r io.Reader) (addr socksAddr, err error) {
|
|
|
+ var addrTypeData [1]byte
|
|
|
+ _, err = io.ReadFull(r, addrTypeData[:])
|
|
|
+ if err != nil {
|
|
|
+ return socksAddr{}, fmt.Errorf("could not read address type")
|
|
|
+ }
|
|
|
|
|
|
- if destAddrType == ipv4 {
|
|
|
+ dstAddrType := addrType(addrTypeData[0])
|
|
|
+ var destination string
|
|
|
+ switch dstAddrType {
|
|
|
+ case ipv4:
|
|
|
var ip [4]byte
|
|
|
_, err = io.ReadFull(r, ip[:])
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("could not read IPv4 address")
|
|
|
+ return socksAddr{}, fmt.Errorf("could not read IPv4 address")
|
|
|
}
|
|
|
destination = net.IP(ip[:]).String()
|
|
|
- } else if destAddrType == domainName {
|
|
|
+ case domainName:
|
|
|
var dstSizeByte [1]byte
|
|
|
_, err = io.ReadFull(r, dstSizeByte[:])
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("could not read domain name size")
|
|
|
+ return socksAddr{}, fmt.Errorf("could not read domain name size")
|
|
|
}
|
|
|
dstSize := int(dstSizeByte[0])
|
|
|
domainName := make([]byte, dstSize)
|
|
|
_, err = io.ReadFull(r, domainName)
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("could not read domain name")
|
|
|
+ return socksAddr{}, fmt.Errorf("could not read domain name")
|
|
|
}
|
|
|
destination = string(domainName)
|
|
|
- } else if destAddrType == ipv6 {
|
|
|
+ case ipv6:
|
|
|
var ip [16]byte
|
|
|
_, err = io.ReadFull(r, ip[:])
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("could not read IPv6 address")
|
|
|
+ return socksAddr{}, fmt.Errorf("could not read IPv6 address")
|
|
|
}
|
|
|
destination = net.IP(ip[:]).String()
|
|
|
- } else {
|
|
|
- return nil, fmt.Errorf("unsupported address type")
|
|
|
+ default:
|
|
|
+ return socksAddr{}, fmt.Errorf("unsupported address type")
|
|
|
}
|
|
|
var portBytes [2]byte
|
|
|
_, err = io.ReadFull(r, portBytes[:])
|
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("could not read port")
|
|
|
+ return socksAddr{}, fmt.Errorf("could not read port")
|
|
|
}
|
|
|
- port = binary.BigEndian.Uint16(portBytes[:])
|
|
|
-
|
|
|
- return &request{
|
|
|
- command: commandType(cmd),
|
|
|
- destination: destination,
|
|
|
- port: port,
|
|
|
- destAddrType: destAddrType,
|
|
|
+ port := binary.BigEndian.Uint16(portBytes[:])
|
|
|
+ return socksAddr{
|
|
|
+ addrType: dstAddrType,
|
|
|
+ addr: destination,
|
|
|
+ port: port,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
+func (s socksAddr) marshal() ([]byte, error) {
|
|
|
+ var addr []byte
|
|
|
+ switch s.addrType {
|
|
|
+ case ipv4:
|
|
|
+ addr = net.ParseIP(s.addr).To4()
|
|
|
+ if addr == nil {
|
|
|
+ return nil, fmt.Errorf("invalid IPv4 address for binding")
|
|
|
+ }
|
|
|
+ case domainName:
|
|
|
+ if len(s.addr) > 255 {
|
|
|
+ return nil, fmt.Errorf("invalid domain name for binding")
|
|
|
+ }
|
|
|
+ addr = make([]byte, 0, len(s.addr)+1)
|
|
|
+ addr = append(addr, byte(len(s.addr)))
|
|
|
+ addr = append(addr, []byte(s.addr)...)
|
|
|
+ case ipv6:
|
|
|
+ addr = net.ParseIP(s.addr).To16()
|
|
|
+ if addr == nil {
|
|
|
+ return nil, fmt.Errorf("invalid IPv6 address for binding")
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ return nil, fmt.Errorf("unsupported address type")
|
|
|
+ }
|
|
|
+
|
|
|
+ pkt := []byte{byte(s.addrType)}
|
|
|
+ pkt = append(pkt, addr...)
|
|
|
+ pkt = binary.BigEndian.AppendUint16(pkt, s.port)
|
|
|
+ return pkt, nil
|
|
|
+}
|
|
|
+func (s socksAddr) hostPort() string {
|
|
|
+ return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
|
|
|
+}
|
|
|
+
|
|
|
// response contains the contents of
|
|
|
// a response packet sent from the proxy
|
|
|
// to the client.
|
|
|
type response struct {
|
|
|
- reply replyCode
|
|
|
- bindAddrType addrType
|
|
|
- bindAddr string
|
|
|
- bindPort uint16
|
|
|
+ reply replyCode
|
|
|
+ bindAddr socksAddr
|
|
|
+}
|
|
|
+
|
|
|
+func errorResponse(code replyCode) *response {
|
|
|
+ return &response{reply: code, bindAddr: zeroSocksAddr}
|
|
|
}
|
|
|
|
|
|
// marshal converts a SOCKS5Response struct into
|
|
|
// a packet. If res.reply == Success, it may throw an error on
|
|
|
// receiving an invalid bind address. Otherwise, it will not throw.
|
|
|
func (res *response) marshal() ([]byte, error) {
|
|
|
- pkt := make([]byte, 4)
|
|
|
+ pkt := make([]byte, 3)
|
|
|
pkt[0] = socks5Version
|
|
|
pkt[1] = byte(res.reply)
|
|
|
pkt[2] = 0 // null reserved byte
|
|
|
- pkt[3] = byte(res.bindAddrType)
|
|
|
|
|
|
- if res.reply != success {
|
|
|
- return pkt, nil
|
|
|
+ addrPkt, err := res.bindAddr.marshal()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- var addr []byte
|
|
|
- switch res.bindAddrType {
|
|
|
- case ipv4:
|
|
|
- addr = net.ParseIP(res.bindAddr).To4()
|
|
|
- if addr == nil {
|
|
|
- return nil, fmt.Errorf("invalid IPv4 address for binding")
|
|
|
- }
|
|
|
- case domainName:
|
|
|
- if len(res.bindAddr) > 255 {
|
|
|
- return nil, fmt.Errorf("invalid domain name for binding")
|
|
|
- }
|
|
|
- addr = make([]byte, 0, len(res.bindAddr)+1)
|
|
|
- addr = append(addr, byte(len(res.bindAddr)))
|
|
|
- addr = append(addr, []byte(res.bindAddr)...)
|
|
|
- case ipv6:
|
|
|
- addr = net.ParseIP(res.bindAddr).To16()
|
|
|
- if addr == nil {
|
|
|
- return nil, fmt.Errorf("invalid IPv6 address for binding")
|
|
|
- }
|
|
|
- default:
|
|
|
- return nil, fmt.Errorf("unsupported address type")
|
|
|
+ return append(pkt, addrPkt...), nil
|
|
|
+}
|
|
|
+
|
|
|
+type udpRequest struct {
|
|
|
+ frag byte
|
|
|
+ addr socksAddr
|
|
|
+}
|
|
|
+
|
|
|
+// +----+------+------+----------+----------+----------+
|
|
|
+// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
|
|
+// +----+------+------+----------+----------+----------+
|
|
|
+// | 2 | 1 | 1 | Variable | 2 | Variable |
|
|
|
+// +----+------+------+----------+----------+----------+
|
|
|
+func parseUDPRequest(data []byte) (*udpRequest, []byte, error) {
|
|
|
+ if len(data) < 4 {
|
|
|
+ return nil, nil, fmt.Errorf("invalid packet length")
|
|
|
}
|
|
|
|
|
|
- pkt = append(pkt, addr...)
|
|
|
- pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort))
|
|
|
+ // reserved bytes
|
|
|
+ if !(data[0] == 0 && data[1] == 0) {
|
|
|
+ return nil, nil, fmt.Errorf("invalid udp request header")
|
|
|
+ }
|
|
|
|
|
|
- return pkt, nil
|
|
|
+ frag := data[2]
|
|
|
+
|
|
|
+ reader := bytes.NewReader(data[3:])
|
|
|
+ addr, err := parseSocksAddr(reader)
|
|
|
+ bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length
|
|
|
+ body := data[len(data)-bodyLen:]
|
|
|
+ return &udpRequest{
|
|
|
+ frag: frag,
|
|
|
+ addr: addr,
|
|
|
+ }, body, err
|
|
|
+}
|
|
|
+
|
|
|
+func (u *udpRequest) marshal() ([]byte, error) {
|
|
|
+ pkt := make([]byte, 3)
|
|
|
+ pkt[0] = 0
|
|
|
+ pkt[1] = 0
|
|
|
+ pkt[2] = u.frag
|
|
|
+
|
|
|
+ addrPkt, err := u.addr.marshal()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return append(pkt, addrPkt...), nil
|
|
|
}
|