Browse Source

XTLS Vision: Use separate uplink/downlink flag for direct copy (#4329)

Fixes https://github.com/XTLS/Xray-core/issues/4033
yuhan6665 8 months ago
parent
commit
03131c72db

+ 24 - 8
proxy/proxy.go

@@ -110,7 +110,8 @@ type TrafficState struct {
 
 	// reader link state
 	WithinPaddingBuffers     bool
-	ReaderSwitchToDirectCopy bool
+	DownlinkReaderDirectCopy bool
+	UplinkReaderDirectCopy   bool
 	RemainingCommand         int32
 	RemainingContent         int32
 	RemainingPadding         int32
@@ -118,7 +119,8 @@ type TrafficState struct {
 
 	// write link state
 	IsPadding                bool
-	WriterSwitchToDirectCopy bool
+	DownlinkWriterDirectCopy bool
+	UplinkWriterDirectCopy   bool
 }
 
 func NewTrafficState(userUUID []byte) *TrafficState {
@@ -131,13 +133,15 @@ func NewTrafficState(userUUID []byte) *TrafficState {
 		Cipher:                   0,
 		RemainingServerHello:     -1,
 		WithinPaddingBuffers:     true,
-		ReaderSwitchToDirectCopy: false,
+		DownlinkReaderDirectCopy: false,
+		UplinkReaderDirectCopy:   false,
 		RemainingCommand:         -1,
 		RemainingContent:         -1,
 		RemainingPadding:         -1,
 		CurrentCommand:           0,
 		IsPadding:                true,
-		WriterSwitchToDirectCopy: false,
+		DownlinkWriterDirectCopy: false,
+		UplinkWriterDirectCopy:   false,
 	}
 }
 
@@ -147,13 +151,15 @@ type VisionReader struct {
 	buf.Reader
 	trafficState *TrafficState
 	ctx          context.Context
+	isUplink     bool
 }
 
-func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader {
+func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, context context.Context) *VisionReader {
 	return &VisionReader{
 		Reader:       reader,
 		trafficState: state,
 		ctx:          context,
+		isUplink:     isUplink,
 	}
 }
 
@@ -175,7 +181,11 @@ func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 				w.trafficState.WithinPaddingBuffers = false
 			} else if w.trafficState.CurrentCommand == 2 {
 				w.trafficState.WithinPaddingBuffers = false
-				w.trafficState.ReaderSwitchToDirectCopy = true
+				if w.isUplink {
+					w.trafficState.UplinkReaderDirectCopy = true
+				} else {
+					w.trafficState.DownlinkReaderDirectCopy = true
+				}
 			} else {
 				errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len())
 			}
@@ -194,9 +204,10 @@ type VisionWriter struct {
 	trafficState      *TrafficState
 	ctx               context.Context
 	writeOnceUserUUID []byte
+	isUplink          bool
 }
 
-func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter {
+func NewVisionWriter(writer buf.Writer, state *TrafficState, isUplink bool, context context.Context) *VisionWriter {
 	w := make([]byte, len(state.UserUUID))
 	copy(w, state.UserUUID)
 	return &VisionWriter{
@@ -204,6 +215,7 @@ func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Con
 		trafficState:      state,
 		ctx:               context,
 		writeOnceUserUUID: w,
+		isUplink:          isUplink,
 	}
 }
 
@@ -221,7 +233,11 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		for i, b := range mb {
 			if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
 				if w.trafficState.EnableXtls {
-					w.trafficState.WriterSwitchToDirectCopy = true
+					if w.isUplink {
+						w.trafficState.UplinkWriterDirectCopy = true
+					} else {
+						w.trafficState.DownlinkWriterDirectCopy = true
+					}
 				}
 				var command byte = CommandPaddingContinue
 				if i == len(mb)-1 {

+ 2 - 2
proxy/vless/encoding/addons.go

@@ -61,13 +61,13 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
 }
 
 // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller.
-func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, context context.Context) buf.Writer {
+func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, isUplink bool, context context.Context) buf.Writer {
 	if request.Command == protocol.RequestCommandUDP {
 		return NewMultiLengthPacketWriter(writer.(buf.Writer))
 	}
 	w := buf.NewWriter(writer)
 	if requestAddons.Flow == vless.XRV {
-		w = proxy.NewVisionWriter(w, state, context)
+		w = proxy.NewVisionWriter(w, state, isUplink, context)
 	}
 	return w
 }

+ 10 - 6
proxy/vless/encoding/encoding.go

@@ -172,10 +172,10 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
 }
 
 // XtlsRead filter and read xtls protocol
-func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error {
+func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
 	err := func() error {
 		for {
-			if trafficState.ReaderSwitchToDirectCopy {
+			if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
 				var writerConn net.Conn
 				var inTimer *signal.ActivityTimer
 				if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
@@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
 			buffer, err := reader.ReadMultiBuffer()
 			if !buffer.IsEmpty() {
 				timer.Update()
-				if trafficState.ReaderSwitchToDirectCopy {
+				if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
 					// XTLS Vision processes struct TLS Conn's input and rawInput
 					if inputBuffer, err := buf.ReadFrom(input); err == nil {
 						if !inputBuffer.IsEmpty() {
@@ -222,12 +222,12 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
 }
 
 // XtlsWrite filter and write xtls protocol
-func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, ctx context.Context) error {
+func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
 	err := func() error {
 		var ct stats.Counter
 		for {
 			buffer, err := reader.ReadMultiBuffer()
-			if trafficState.WriterSwitchToDirectCopy {
+			if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy {
 				if inbound := session.InboundFromContext(ctx); inbound != nil {
 					if inbound.CanSpliceCopy == 2 {
 						inbound.CanSpliceCopy = 1
@@ -239,7 +239,11 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
 				rawConn, _, writerCounter := proxy.UnwrapRawConn(conn)
 				writer = buf.NewWriter(rawConn)
 				ct = writerCounter
-				trafficState.WriterSwitchToDirectCopy = false
+				if isUplink {
+					trafficState.UplinkWriterDirectCopy = false
+				} else {
+					trafficState.DownlinkWriterDirectCopy = false
+				}
 			}
 			if !buffer.IsEmpty() {
 				if ct != nil {

+ 4 - 4
proxy/vless/inbound/inbound.go

@@ -538,8 +538,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 
 		if requestAddons.Flow == vless.XRV {
 			ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
-			clientReader = proxy.NewVisionReader(clientReader, trafficState, ctx1)
-			err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, ctx1)
+			clientReader = proxy.NewVisionReader(clientReader, trafficState, true, ctx1)
+			err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, trafficState, nil, true, ctx1)
 		} else {
 			// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
 			err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@@ -561,7 +561,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 		}
 
 		// default: clientWriter := bufferWriter
-		clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx)
+		clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, false, ctx)
 		multiBuffer, err1 := serverReader.ReadMultiBuffer()
 		if err1 != nil {
 			return err1 // ...
@@ -576,7 +576,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 
 		var err error
 		if requestAddons.Flow == vless.XRV {
-			err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, ctx)
+			err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, trafficState, nil, false, ctx)
 		} else {
 			// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
 			err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))

+ 4 - 4
proxy/vless/outbound/outbound.go

@@ -194,7 +194,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 
 		// default: serverWriter := bufferWriter
-		serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, ctx)
+		serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons, trafficState, true, ctx)
 		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
 			serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx))
 		}
@@ -234,7 +234,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 				}
 			}
 			ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice
-			err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, ctx1)
+			err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, trafficState, ob, true, ctx1)
 		} else {
 			// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
 			err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
@@ -261,7 +261,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		// default: serverReader := buf.NewReader(conn)
 		serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons)
 		if requestAddons.Flow == vless.XRV {
-			serverReader = proxy.NewVisionReader(serverReader, trafficState, ctx)
+			serverReader = proxy.NewVisionReader(serverReader, trafficState, false, ctx)
 		}
 		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
 			if requestAddons.Flow == vless.XRV {
@@ -272,7 +272,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 
 		if requestAddons.Flow == vless.XRV {
-			err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, ctx)
+			err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, trafficState, ob, false, ctx)
 		} else {
 			// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
 			err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))