Forráskód Böngészése

Refactor: Shadowsocks & Trojan UDP FullCone NAT

https://t.me/projectXray/95704
RPRX 4 éve
szülő
commit
8f8f7dd66f

+ 2 - 0
common/buf/buffer.go

@@ -2,6 +2,7 @@ package buf
 
 import (
 	"io"
+	"net"
 
 	"github.com/xtls/xray-core/common/bytespool"
 )
@@ -20,6 +21,7 @@ type Buffer struct {
 	v     []byte
 	start int32
 	end   int32
+	UDP   *net.UDPAddr
 }
 
 // New creates a Buffer with 0 length and 2K capacity.

+ 93 - 2
proxy/freedom/freedom.go

@@ -17,6 +17,7 @@ import (
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/features/policy"
+	"github.com/xtls/xray-core/features/stats"
 	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet"
 )
@@ -148,7 +149,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		if destination.Network == net.Network_TCP {
 			writer = buf.NewWriter(conn)
 		} else {
-			writer = &buf.SequentialWriter{Writer: conn}
+			writer = NewPacketWriter(conn)
 		}
 
 		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
@@ -165,7 +166,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		if destination.Network == net.Network_TCP {
 			reader = buf.NewReader(conn)
 		} else {
-			reader = buf.NewPacketReader(conn)
+			reader = NewPacketReader(conn)
 		}
 		if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
 			return newError("failed to process response").Base(err)
@@ -180,3 +181,93 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 
 	return nil
 }
+
+func NewPacketReader(conn net.Conn) buf.Reader {
+	iConn := conn
+	statConn, ok := iConn.(*internet.StatCouterConnection)
+	if ok {
+		iConn = statConn.Connection
+	}
+	var counter stats.Counter
+	if statConn != nil {
+		counter = statConn.ReadCounter
+	}
+	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
+		return &PacketReader{
+			PacketConnWrapper: c,
+			Counter:           counter,
+		}
+	}
+	return &buf.PacketReader{Reader: conn}
+}
+
+type PacketReader struct {
+	*internet.PacketConnWrapper
+	stats.Counter
+}
+
+func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	b := buf.New()
+	b.Resize(0, buf.Size)
+	n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
+	if err != nil {
+		b.Release()
+		return nil, err
+	}
+	b.Resize(0, int32(n))
+	b.UDP = d.(*net.UDPAddr)
+	if r.Counter != nil {
+		r.Counter.Add(int64(n))
+	}
+	return buf.MultiBuffer{b}, nil
+}
+
+func NewPacketWriter(conn net.Conn) buf.Writer {
+	iConn := conn
+	statConn, ok := iConn.(*internet.StatCouterConnection)
+	if ok {
+		iConn = statConn.Connection
+	}
+	var counter stats.Counter
+	if statConn != nil {
+		counter = statConn.WriteCounter
+	}
+	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
+		return &PacketWriter{
+			PacketConnWrapper: c,
+			Counter:           counter,
+		}
+	}
+	return &buf.SequentialWriter{Writer: conn}
+}
+
+type PacketWriter struct {
+	*internet.PacketConnWrapper
+	stats.Counter
+}
+
+func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	for {
+		mb2, b := buf.SplitFirst(mb)
+		mb = mb2
+		if b == nil {
+			break
+		}
+		var n int
+		var err error
+		if b.UDP != nil {
+			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), b.UDP)
+		} else {
+			n, err = w.PacketConnWrapper.Write(b.Bytes())
+		}
+		b.Release()
+		if err != nil {
+			buf.ReleaseMulti(mb)
+			return err
+		}
+		if w.Counter != nil {
+			w.Counter.Add(int64(n))
+		}
+	}
+	return nil
+}

+ 5 - 4
proxy/shadowsocks/client.go

@@ -134,14 +134,15 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 	}
 
 	if request.Command == protocol.RequestCommandUDP {
-		writer := &buf.SequentialWriter{Writer: &UDPWriter{
-			Writer:  conn,
-			Request: request,
-		}}
 
 		requestDone := func() error {
 			defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 
+			writer := &UDPWriter{
+				Writer:  conn,
+				Request: request,
+			}
+
 			if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
 				return newError("failed to transport all UDP request").Base(err)
 			}

+ 36 - 9
proxy/shadowsocks/protocol.go

@@ -230,11 +230,15 @@ func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		buffer.Release()
 		return nil, err
 	}
-	_, payload, err := DecodeUDPPacket(v.User, buffer)
+	u, payload, err := DecodeUDPPacket(v.User, buffer)
 	if err != nil {
 		buffer.Release()
 		return nil, err
 	}
+	payload.UDP = &net.UDPAddr{
+		IP:   u.Address.IP(),
+		Port: int(u.Port),
+	}
 	return buf.MultiBuffer{payload}, nil
 }
 
@@ -243,13 +247,36 @@ type UDPWriter struct {
 	Request *protocol.RequestHeader
 }
 
-// Write implements io.Writer.
-func (w *UDPWriter) Write(payload []byte) (int, error) {
-	packet, err := EncodeUDPPacket(w.Request, payload)
-	if err != nil {
-		return 0, err
+func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	for {
+		mb2, b := buf.SplitFirst(mb)
+		mb = mb2
+		if b == nil {
+			break
+		}
+		var packet *buf.Buffer
+		var err error
+		if b.UDP != nil {
+			request := &protocol.RequestHeader{
+				User:    w.Request.User,
+				Address: net.IPAddress(b.UDP.IP),
+				Port:    net.Port(b.UDP.Port),
+			}
+			packet, err = EncodeUDPPacket(request, b.Bytes())
+		} else {
+			packet, err = EncodeUDPPacket(w.Request, b.Bytes())
+		}
+		b.Release()
+		if err != nil {
+			buf.ReleaseMulti(mb)
+			return err
+		}
+		_, err = w.Writer.Write(packet.Bytes())
+		packet.Release()
+		if err != nil {
+			buf.ReleaseMulti(mb)
+			return err
+		}
 	}
-	_, err = w.Writer.Write(packet.Bytes())
-	packet.Release()
-	return len(payload), err
+	return nil
 }

+ 2 - 2
proxy/shadowsocks/protocol_test.go

@@ -145,7 +145,7 @@ func TestUDPReaderWriter(t *testing.T) {
 	cache := buf.New()
 	defer cache.Release()
 
-	writer := &buf.SequentialWriter{Writer: &UDPWriter{
+	writer := &UDPWriter{
 		Writer: cache,
 		Request: &protocol.RequestHeader{
 			Version: Version,
@@ -153,7 +153,7 @@ func TestUDPReaderWriter(t *testing.T) {
 			Port:    123,
 			User:    user,
 		},
-	}}
+	}
 
 	reader := &UDPReader{
 		Reader: cache,

+ 22 - 3
proxy/shadowsocks/server.go

@@ -77,6 +77,15 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 		}
 
 		payload := packet.Payload
+
+		if payload.UDP != nil {
+			request = &protocol.RequestHeader{
+				User:    request.User,
+				Address: net.IPAddress(payload.UDP.IP),
+				Port:    net.Port(payload.UDP.Port),
+			}
+		}
+
 		data, err := EncodeUDPPacket(request, payload.Bytes())
 		payload.Release()
 		if err != nil {
@@ -94,6 +103,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 	}
 	inbound.User = s.user
 
+	var dest net.Destination
+
 	reader := buf.NewPacketReader(conn)
 	for {
 		mpayload, err := reader.ReadMultiBuffer()
@@ -118,17 +129,25 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 			}
 
 			currentPacketCtx := ctx
-			dest := request.Destination()
 			if inbound.Source.IsValid() {
 				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 					From:   inbound.Source,
-					To:     dest,
+					To:     request.Destination(),
 					Status: log.AccessAccepted,
 					Reason: "",
 					Email:  request.User.Email,
 				})
 			}
-			newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(currentPacketCtx))
+			newError("tunnelling request to ", request.Destination()).WriteToLog(session.ExportIDToError(currentPacketCtx))
+
+			data.UDP = &net.UDPAddr{
+				IP:   request.Address.IP(),
+				Port: int(request.Port),
+			}
+
+			if dest.Network == 0 {
+				dest = request.Destination() // JUST FOLLOW THE FIREST PACKET
+			}
 
 			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
 			udpServer.Dispatch(currentPacketCtx, dest, data)

+ 21 - 1
proxy/socks/server.go

@@ -196,6 +196,15 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 		if request == nil {
 			return
 		}
+
+		if payload.UDP != nil {
+			request = &protocol.RequestHeader{
+				User:    request.User,
+				Address: net.IPAddress(payload.UDP.IP),
+				Port:    net.Port(payload.UDP.Port),
+			}
+		}
+
 		udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
 		payload.Release()
 
@@ -211,6 +220,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
 	}
 
+	var dest net.Destination
+
 	reader := buf.NewPacketReader(conn)
 	for {
 		mpayload, err := reader.ReadMultiBuffer()
@@ -242,8 +253,17 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 				})
 			}
 
+			payload.UDP = &net.UDPAddr{
+				IP:   request.Address.IP(),
+				Port: int(request.Port),
+			}
+
+			if dest.Network == 0 {
+				dest = request.Destination() // JUST FOLLOW THE FIREST PACKET
+			}
+
 			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
-			udpServer.Dispatch(currentPacketCtx, request.Destination(), payload)
+			udpServer.Dispatch(currentPacketCtx, dest, payload)
 		}
 	}
 }

+ 28 - 12
proxy/trojan/protocol.go

@@ -128,31 +128,43 @@ type PacketWriter struct {
 
 // WriteMultiBuffer implements buf.Writer
 func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
-	b := make([]byte, maxLength)
-	for !mb.IsEmpty() {
-		var length int
-		mb, length = buf.SplitBytes(mb, b)
-		if _, err := w.writePacket(b[:length], w.Target); err != nil {
+	for {
+		mb2, b := buf.SplitFirst(mb)
+		mb = mb2
+		if b == nil {
+			break
+		}
+		target := w.Target
+		if b.UDP != nil {
+			target.Address = net.IPAddress(b.UDP.IP)
+			target.Port = net.Port(b.UDP.Port)
+		}
+		if _, err := w.writePacket(b.Bytes(), target); err != nil {
 			buf.ReleaseMulti(mb)
 			return err
 		}
 	}
-
 	return nil
 }
 
 // WriteMultiBufferWithMetadata writes udp packet with destination specified
 func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
-	b := make([]byte, maxLength)
-	for !mb.IsEmpty() {
-		var length int
-		mb, length = buf.SplitBytes(mb, b)
-		if _, err := w.writePacket(b[:length], dest); err != nil {
+	for {
+		mb2, b := buf.SplitFirst(mb)
+		mb = mb2
+		if b == nil {
+			break
+		}
+		source := dest
+		if b.UDP != nil {
+			source.Address = net.IPAddress(b.UDP.IP)
+			source.Port = net.Port(b.UDP.Port)
+		}
+		if _, err := w.writePacket(b.Bytes(), source); err != nil {
 			buf.ReleaseMulti(mb)
 			return err
 		}
 	}
-
 	return nil
 }
 
@@ -300,6 +312,10 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
 		}
 
 		b := buf.New()
+		b.UDP = &net.UDPAddr{
+			IP:   addr.IP(),
+			Port: int(port.Value()),
+		}
 		mb = append(mb, b)
 		n, err := b.ReadFullFrom(r, int32(length))
 		if err != nil {

+ 7 - 1
proxy/trojan/server.go

@@ -256,6 +256,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 	inbound := session.InboundFromContext(ctx)
 	user := inbound.User
 
+	var dest net.Destination
+
 	for {
 		select {
 		case <-ctx.Done():
@@ -278,8 +280,12 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 			})
 			newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx))
 
+			if dest.Network == 0 {
+				dest = p.Target // JUST FOLLOW THE FIREST PACKET
+			}
+
 			for _, b := range p.Buffer {
-				udpServer.Dispatch(ctx, p.Target, b)
+				udpServer.Dispatch(ctx, dest, b)
 			}
 		}
 	}

+ 18 - 10
transport/internet/system_dialer.go

@@ -60,7 +60,7 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne
 		if err != nil {
 			return nil, err
 		}
-		return &packetConnWrapper{
+		return &PacketConnWrapper{
 			conn: packetConn,
 			dest: destAddr,
 		}, nil
@@ -98,41 +98,49 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne
 	return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr())
 }
 
-type packetConnWrapper struct {
+type PacketConnWrapper struct {
 	conn net.PacketConn
 	dest net.Addr
 }
 
-func (c *packetConnWrapper) Close() error {
+func (c *PacketConnWrapper) Close() error {
 	return c.conn.Close()
 }
 
-func (c *packetConnWrapper) LocalAddr() net.Addr {
+func (c *PacketConnWrapper) LocalAddr() net.Addr {
 	return c.conn.LocalAddr()
 }
 
-func (c *packetConnWrapper) RemoteAddr() net.Addr {
+func (c *PacketConnWrapper) RemoteAddr() net.Addr {
 	return c.dest
 }
 
-func (c *packetConnWrapper) Write(p []byte) (int, error) {
+func (c *PacketConnWrapper) Write(p []byte) (int, error) {
 	return c.conn.WriteTo(p, c.dest)
 }
 
-func (c *packetConnWrapper) Read(p []byte) (int, error) {
+func (c *PacketConnWrapper) Read(p []byte) (int, error) {
 	n, _, err := c.conn.ReadFrom(p)
 	return n, err
 }
 
-func (c *packetConnWrapper) SetDeadline(t time.Time) error {
+func (c *PacketConnWrapper) WriteTo(p []byte, d net.Addr) (int, error) {
+	return c.conn.WriteTo(p, d)
+}
+
+func (c *PacketConnWrapper) ReadFrom(p []byte) (int, net.Addr, error) {
+	return c.conn.ReadFrom(p)
+}
+
+func (c *PacketConnWrapper) SetDeadline(t time.Time) error {
 	return c.conn.SetDeadline(t)
 }
 
-func (c *packetConnWrapper) SetReadDeadline(t time.Time) error {
+func (c *PacketConnWrapper) SetReadDeadline(t time.Time) error {
 	return c.conn.SetReadDeadline(t)
 }
 
-func (c *packetConnWrapper) SetWriteDeadline(t time.Time) error {
+func (c *PacketConnWrapper) SetWriteDeadline(t time.Time) error {
 	return c.conn.SetWriteDeadline(t)
 }