Browse Source

Refactor: VLESS & VMess & Mux UDP FullCone NAT

https://t.me/projectXray/242770
RPRX 4 years ago
parent
commit
1174ff3090

+ 1 - 1
common/mux/client.go

@@ -330,7 +330,7 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
 		return buf.Copy(NewStreamReader(reader), buf.Discard)
 	}
 
-	rr := s.NewReader(reader)
+	rr := s.NewReader(reader, &meta.Target)
 	err := buf.Copy(rr, s.output)
 	if err != nil && buf.IsWriteError(err) {
 		newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()

+ 4 - 1
common/mux/frame.go

@@ -81,6 +81,9 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
 		if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil {
 			return err
 		}
+	} else if b.UDP != nil {
+		b.WriteByte(byte(TargetNetworkUDP))
+		addrParser.WriteAddressPort(b, b.UDP.Address, b.UDP.Port)
 	}
 
 	len1 := b.Len()
@@ -119,7 +122,7 @@ func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error {
 	f.Option = bitmask.Byte(b.Byte(3))
 	f.Target.Network = net.Network_Unknown
 
-	if f.SessionStatus == SessionStatusNew {
+	if f.SessionStatus == SessionStatusNew || (f.SessionStatus == SessionStatusKeep && b.Len() != 4) {
 		if b.Len() < 8 {
 			return newError("insufficient buffer: ", b.Len())
 		}

+ 7 - 1
common/mux/reader.go

@@ -5,6 +5,7 @@ import (
 
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/crypto"
+	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/serial"
 )
 
@@ -12,13 +13,15 @@ import (
 type PacketReader struct {
 	reader io.Reader
 	eof    bool
+	dest   *net.Destination
 }
 
 // NewPacketReader creates a new PacketReader.
-func NewPacketReader(reader io.Reader) *PacketReader {
+func NewPacketReader(reader io.Reader, dest *net.Destination) *PacketReader {
 	return &PacketReader{
 		reader: reader,
 		eof:    false,
+		dest:   dest,
 	}
 }
 
@@ -43,6 +46,9 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		return nil, err
 	}
 	r.eof = true
+	if r.dest != nil && r.dest.Network == net.Network_UDP {
+		b.UDP = r.dest
+	}
 	return buf.MultiBuffer{b}, nil
 }
 

+ 2 - 2
common/mux/server.go

@@ -145,7 +145,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
 		return nil
 	}
 
-	rr := s.NewReader(reader)
+	rr := s.NewReader(reader, &meta.Target)
 	if err := buf.Copy(rr, s.output); err != nil {
 		buf.Copy(rr, buf.Discard)
 		common.Interrupt(s.input)
@@ -168,7 +168,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
 		return buf.Copy(NewStreamReader(reader), buf.Discard)
 	}
 
-	rr := s.NewReader(reader)
+	rr := s.NewReader(reader, &meta.Target)
 	err := buf.Copy(rr, s.output)
 
 	if err != nil && buf.IsWriteError(err) {

+ 3 - 2
common/mux/session.go

@@ -5,6 +5,7 @@ import (
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
+	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 )
 
@@ -152,9 +153,9 @@ func (s *Session) Close() error {
 }
 
 // NewReader creates a buf.Reader based on the transfer type of this Session.
-func (s *Session) NewReader(reader *buf.BufferedReader) buf.Reader {
+func (s *Session) NewReader(reader *buf.BufferedReader, dest *net.Destination) buf.Reader {
 	if s.transferType == protocol.TransferTypeStream {
 		return NewStreamReader(reader)
 	}
-	return NewPacketReader(reader)
+	return NewPacketReader(reader, dest)
 }

+ 3 - 0
common/mux/writer.go

@@ -63,6 +63,9 @@ func (w *Writer) writeMetaOnly() error {
 
 func writeMetaWithFrame(writer buf.Writer, meta FrameMetadata, data buf.MultiBuffer) error {
 	frame := buf.New()
+	if len(data) == 1 {
+		frame.UDP = data[0].UDP
+	}
 	if err := meta.WriteTo(frame); err != nil {
 		return err
 	}

+ 137 - 0
common/vudp/vudp.go

@@ -0,0 +1,137 @@
+package vudp
+
+import (
+	"io"
+
+	"github.com/xtls/xray-core/common/buf"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
+)
+
+var addrParser = protocol.NewAddressParser(
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
+	protocol.PortThenAddress(),
+)
+
+func NewPacketWriter(writer buf.Writer, dest net.Destination) *PacketWriter {
+	return &PacketWriter{
+		Writer: writer,
+		Dest:   dest,
+	}
+}
+
+type PacketWriter struct {
+	Writer buf.Writer
+	Dest   net.Destination
+}
+
+func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	defer buf.ReleaseMulti(mb)
+	mb2Write := make(buf.MultiBuffer, 0, len(mb))
+	for _, b := range mb {
+		length := b.Len()
+		if length == 0 || length+666 > buf.Size {
+			continue
+		}
+
+		eb := buf.New()
+		eb.Write([]byte{0, 0, 0, 0})
+		if w.Dest.Network == net.Network_UDP {
+			eb.WriteByte(1) // New
+			eb.WriteByte(1) // Opt
+			eb.WriteByte(2) // UDP
+			addrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
+			w.Dest.Network = net.Network_Unknown
+		} else {
+			eb.WriteByte(2) // Keep
+			eb.WriteByte(1)
+			if b.UDP != nil {
+				eb.WriteByte(2)
+				addrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
+			}
+		}
+		l := eb.Len() - 2
+		eb.SetByte(0, byte(l>>8))
+		eb.SetByte(1, byte(l))
+		eb.WriteByte(byte(length >> 8))
+		eb.WriteByte(byte(length))
+		eb.Write(b.Bytes())
+
+		mb2Write = append(mb2Write, eb)
+	}
+	if mb2Write.IsEmpty() {
+		return nil
+	}
+	return w.Writer.WriteMultiBuffer(mb2Write)
+}
+
+func NewPacketReader(reader io.Reader) *PacketReader {
+	return &PacketReader{
+		Reader: reader,
+		cache:  make([]byte, 2),
+	}
+}
+
+type PacketReader struct {
+	Reader io.Reader
+	cache  []byte
+}
+
+func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	for {
+		if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
+			return nil, err
+		}
+		l := int32(r.cache[0])<<8 | int32(r.cache[1])
+		if l < 4 {
+			return nil, io.EOF
+		}
+		b := buf.New()
+		if _, err := b.ReadFullFrom(r.Reader, l); err != nil {
+			b.Release()
+			return nil, err
+		}
+		discard := false
+		switch b.Byte(2) {
+		case 2:
+			if l != 4 {
+				b.Advance(5)
+				addr, port, err := addrParser.ReadAddressPort(nil, b)
+				if err != nil {
+					b.Release()
+					return nil, err
+				}
+				b.UDP = &net.Destination{
+					Network: net.Network_UDP,
+					Address: addr,
+					Port:    port,
+				}
+			}
+		case 4:
+			discard = true
+		default:
+			b.Release()
+			return nil, io.EOF
+		}
+		if b.Byte(3) == 1 {
+			if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
+				b.Release()
+				return nil, err
+			}
+			length := int32(r.cache[0])<<8 | int32(r.cache[1])
+			if length > 0 {
+				b.Clear()
+				if _, err := b.ReadFullFrom(r.Reader, length); err != nil {
+					b.Release()
+					return nil, err
+				}
+				if !discard {
+					return buf.MultiBuffer{b}, nil
+				}
+			}
+		}
+		b.Release()
+	}
+}

+ 1 - 31
core/xray.go

@@ -3,13 +3,8 @@ package core
 import (
 	"context"
 	"reflect"
-	"runtime/debug"
-	"strings"
 	"sync"
 
-	"github.com/golang/protobuf/proto"
-
-	"github.com/xtls/xray-core/app/proxyman"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/serial"
 	"github.com/xtls/xray-core/features"
@@ -184,32 +179,7 @@ func NewWithContext(ctx context.Context, config *Config) (*Instance, error) {
 }
 
 func initInstanceWithConfig(config *Config, server *Instance) (bool, error) {
-	cone := true
-	v, t := false, false
-	for _, outbound := range config.Outbound {
-		s := strings.ToLower(outbound.ProxySettings.Type)
-		l := len(s)
-		if l >= 16 && s[11:16] == "vless" || l >= 16 && s[11:16] == "vmess" {
-			v = true
-			continue
-		}
-		if l >= 17 && s[11:17] == "trojan" || l >= 22 && s[11:22] == "shadowsocks" {
-			t = true
-			if outbound.SenderSettings != nil {
-				var m proxyman.SenderConfig
-				proto.Unmarshal(outbound.SenderSettings.Value, &m)
-				if m.MultiplexSettings != nil && m.MultiplexSettings.Enabled {
-					cone = false
-					break
-				}
-			}
-		}
-	}
-	if v && !t {
-		cone = false
-	}
-	server.ctx = context.WithValue(server.ctx, "cone", cone)
-	defer debug.FreeOSMemory()
+	server.ctx = context.WithValue(server.ctx, "cone", true)
 
 	if config.Transport != nil {
 		features.PrintDeprecatedFeatureWarning("global transport settings")

+ 13 - 0
proxy/vless/outbound/outbound.go

@@ -16,6 +16,7 @@ import (
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/signal"
 	"github.com/xtls/xray-core/common/task"
+	"github.com/xtls/xray-core/common/vudp"
 	core "github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/stats"
@@ -175,6 +176,12 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 	clientReader := link.Reader // .(*pipe.Reader)
 	clientWriter := link.Writer // .(*pipe.Writer)
 
+	if request.Command == protocol.RequestCommandUDP {
+		request.Command = protocol.RequestCommandMux
+		request.Address = net.DomainAddress("v1.mux.cool")
+		request.Port = net.Port(666)
+	}
+
 	postRequest := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 
@@ -185,6 +192,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 
 		// default: serverWriter := bufferWriter
 		serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons)
+		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
+			serverWriter = vudp.NewPacketWriter(serverWriter, target)
+		}
 		if err := buf.CopyOnceTimeout(clientReader, serverWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
 			return err // ...
 		}
@@ -216,6 +226,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 
 		// default: serverReader := buf.NewReader(conn)
 		serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons)
+		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
+			serverReader = vudp.NewPacketReader(conn)
+		}
 
 		if rawConn != nil {
 			var counter stats.Counter

+ 15 - 1
proxy/vmess/outbound/outbound.go

@@ -15,6 +15,7 @@ import (
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/signal"
 	"github.com/xtls/xray-core/common/task"
+	"github.com/xtls/xray-core/common/vudp"
 	core "github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/proxy/vmess"
@@ -122,6 +123,12 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 
+	if request.Command == protocol.RequestCommandUDP {
+		request.Command = protocol.RequestCommandMux
+		request.Address = net.DomainAddress("v1.mux.cool")
+		request.Port = net.Port(666)
+	}
+
 	requestDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 
@@ -131,6 +138,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 
 		bodyWriter := session.EncodeRequestBody(request, writer)
+		bodyWriter2 := bodyWriter
+		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
+			bodyWriter = vudp.NewPacketWriter(bodyWriter, target)
+		}
 		if err := buf.CopyOnceTimeout(input, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
 			return newError("failed to write first payload").Base(err)
 		}
@@ -144,7 +155,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 
 		if request.Option.Has(protocol.RequestOptionChunkStream) {
-			if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
+			if err := bodyWriter2.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
 				return err
 			}
 		}
@@ -163,6 +174,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		h.handleCommand(rec.Destination(), header.Command)
 
 		bodyReader := session.DecodeResponseBody(request, reader)
+		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
+			bodyReader = vudp.NewPacketReader(&buf.BufferedReader{Reader: bodyReader})
+		}
 
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
 	}