Browse Source

XUDP protocol: Add Global ID & UoT Migration

The first UoT protocol that supports UoT Migration
Thank @yuhan6665 for testing
RPRX 2 years ago
parent
commit
be23d5d3b7

+ 21 - 11
app/proxyman/config.pb.go

@@ -595,6 +595,8 @@ type MultiplexingConfig struct {
 	Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
 	// Max number of concurrent connections that one Mux connection can handle.
 	Concurrency uint32 `protobuf:"varint,2,opt,name=concurrency,proto3" json:"concurrency,omitempty"`
+	// Both(0), TCP(1), UDP(2).
+	Only uint32 `protobuf:"varint,3,opt,name=only,proto3" json:"only,omitempty"`
 }
 
 func (x *MultiplexingConfig) Reset() {
@@ -643,6 +645,13 @@ func (x *MultiplexingConfig) GetConcurrency() uint32 {
 	return 0
 }
 
+func (x *MultiplexingConfig) GetOnly() uint32 {
+	if x != nil {
+		return x.Only
+	}
+	return 0
+}
+
 type AllocationStrategy_AllocationStrategyConcurrency struct {
 	state         protoimpl.MessageState
 	sizeCache     protoimpl.SizeCache
@@ -856,21 +865,22 @@ var file_app_proxyman_config_proto_rawDesc = []byte{
 	0x28, 0x0b, 0x32, 0x25, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x70, 0x72,
 	0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x2e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78,
 	0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x11, 0x6d, 0x75, 0x6c, 0x74, 0x69,
-	0x70, 0x6c, 0x65, 0x78, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x50, 0x0a, 0x12,
+	0x70, 0x6c, 0x65, 0x78, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x64, 0x0a, 0x12,
 	0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x66,
 	0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20,
 	0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x20, 0x0a, 0x0b,
 	0x63, 0x6f, 0x6e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28,
-	0x0d, 0x52, 0x0b, 0x63, 0x6f, 0x6e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x63, 0x79, 0x2a, 0x23,
-	0x0a, 0x0e, 0x4b, 0x6e, 0x6f, 0x77, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x73,
-	0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x4c,
-	0x53, 0x10, 0x01, 0x42, 0x55, 0x0a, 0x15, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e,
-	0x61, 0x70, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x50, 0x01, 0x5a, 0x26,
-	0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f,
-	0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, 0x70, 0x72,
-	0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0xaa, 0x02, 0x11, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x41, 0x70,
-	0x70, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
-	0x6f, 0x33,
+	0x0d, 0x52, 0x0b, 0x63, 0x6f, 0x6e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x12,
+	0x0a, 0x04, 0x6f, 0x6e, 0x6c, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x6f, 0x6e,
+	0x6c, 0x79, 0x2a, 0x23, 0x0a, 0x0e, 0x4b, 0x6e, 0x6f, 0x77, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f,
+	0x63, 0x6f, 0x6c, 0x73, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x07,
+	0x0a, 0x03, 0x54, 0x4c, 0x53, 0x10, 0x01, 0x42, 0x55, 0x0a, 0x15, 0x63, 0x6f, 0x6d, 0x2e, 0x78,
+	0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e,
+	0x50, 0x01, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78,
+	0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70,
+	0x70, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0xaa, 0x02, 0x11, 0x58, 0x72, 0x61,
+	0x79, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x62, 0x06,
+	0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (

+ 2 - 0
app/proxyman/config.proto

@@ -98,4 +98,6 @@ message MultiplexingConfig {
   bool enabled = 1;
   // Max number of concurrent connections that one Mux connection can handle.
   uint32 concurrency = 2;
+  // Both(0), TCP(1), UDP(2).
+  uint32 only = 3;
 }

+ 5 - 2
app/proxyman/outbound/handler.go

@@ -111,7 +111,7 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
 			return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning()
 		}
 		h.mux = &mux.ClientManager{
-			Enabled: h.senderSettings.MultiplexSettings.Enabled,
+			Enabled: config.Enabled,
 			Picker: &mux.IncrementalWorkerPicker{
 				Factory: &mux.DialingWorkerFactory{
 					Proxy:  proxyHandler,
@@ -122,6 +122,7 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
 					},
 				},
 			},
+			Only: config.Only,
 		}
 	}
 
@@ -136,7 +137,9 @@ func (h *Handler) Tag() string {
 
 // Dispatch implements proxy.Outbound.Dispatch.
 func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
-	if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) {
+	outbound := session.OutboundFromContext(ctx)
+	if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) &&
+		(h.mux.Only == 0 || (outbound != nil && h.mux.Only == uint32(outbound.Target.Network))) {
 		if err := h.mux.Dispatch(ctx, link); err != nil {
 			err := newError("failed to process mux outbound traffic").Base(err)
 			session.SubmitOutboundErrorToOriginator(ctx, err)

+ 7 - 19
common/mux/client.go

@@ -14,6 +14,7 @@ import (
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/signal/done"
 	"github.com/xtls/xray-core/common/task"
+	"github.com/xtls/xray-core/common/xudp"
 	"github.com/xtls/xray-core/proxy"
 	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet"
@@ -23,6 +24,7 @@ import (
 type ClientManager struct {
 	Enabled bool // wheather mux is enabled from user config
 	Picker  WorkerPicker
+	Only    uint32
 }
 
 func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
@@ -247,22 +249,20 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 		transferType = protocol.TransferTypePacket
 	}
 	s.transferType = transferType
-	writer := NewWriter(s.ID, dest, output, transferType)
-	defer s.Close()
+	writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx))
+	defer s.Close(false)
 	defer writer.Close()
 
 	newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
 	if err := writeFirstPayload(s.input, writer); err != nil {
 		newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
 		writer.hasError = true
-		common.Interrupt(s.input)
 		return
 	}
 
 	if err := buf.Copy(s.input, writer); err != nil {
 		newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
 		writer.hasError = true
-		common.Interrupt(s.input)
 		return
 	}
 }
@@ -335,15 +335,8 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
 	err := buf.Copy(rr, s.output)
 	if err != nil && buf.IsWriteError(err) {
 		newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()
-
-		// Notify remote peer to close this session.
-		closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
-		closingWriter.Close()
-
-		drainErr := buf.Copy(rr, buf.Discard)
-		common.Interrupt(s.input)
-		s.Close()
-		return drainErr
+		s.Close(false)
+		return buf.Copy(rr, buf.Discard)
 	}
 
 	return err
@@ -351,12 +344,7 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
 
 func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
 	if s, found := m.sessionManager.Get(meta.SessionID); found {
-		if meta.Option.Has(OptionError) {
-			common.Interrupt(s.input)
-			common.Interrupt(s.output)
-		}
-		common.Interrupt(s.input)
-		s.Close()
+		s.Close(false)
 	}
 	if meta.Option.Has(OptionData) {
 		return buf.Copy(NewStreamReader(reader), buf.Discard)

+ 9 - 0
common/mux/frame.go

@@ -58,6 +58,7 @@ type FrameMetadata struct {
 	SessionID     uint16
 	Option        bitmask.Byte
 	SessionStatus SessionStatus
+	GlobalID      [8]byte
 }
 
 func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
@@ -81,6 +82,9 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
 		if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil {
 			return err
 		}
+		if b.UDP != nil {
+			b.Write(f.GlobalID[:])
+		}
 	} else if b.UDP != nil {
 		b.WriteByte(byte(TargetNetworkUDP))
 		addrParser.WriteAddressPort(b, b.UDP.Address, b.UDP.Port)
@@ -144,5 +148,10 @@ func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error {
 		}
 	}
 
+	if f.SessionStatus == SessionStatusNew && f.Option.Has(OptionData) &&
+		f.Target.Network == net.Network_UDP && b.Len() >= 8 {
+		copy(f.GlobalID[:], b.Bytes())
+	}
+
 	return nil
 }

+ 3 - 3
common/mux/mux_test.go

@@ -32,13 +32,13 @@ func TestReaderWriter(t *testing.T) {
 	pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
 
 	dest := net.TCPDestination(net.DomainAddress("example.com"), 80)
-	writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream)
+	writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream, [8]byte{})
 
 	dest2 := net.TCPDestination(net.LocalHostIP, 443)
-	writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream)
+	writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream, [8]byte{})
 
 	dest3 := net.TCPDestination(net.LocalHostIPv6, 18374)
-	writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream)
+	writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream, [8]byte{})
 
 	writePayload := func(writer *Writer, payload ...byte) error {
 		b := buf.New()

+ 82 - 18
common/mux/server.go

@@ -2,6 +2,7 @@ package mux
 
 import (
 	"context"
+	"fmt"
 	"io"
 
 	"github.com/xtls/xray-core/common"
@@ -11,6 +12,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/xudp"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport"
@@ -99,7 +101,7 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
 	}
 
 	writer.Close()
-	s.Close()
+	s.Close(false)
 }
 
 func (w *ServerWorker) ActiveConnections() uint32 {
@@ -131,6 +133,81 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
 		}
 		ctx = log.ContextWithAccessMessage(ctx, msg)
 	}
+
+	if meta.GlobalID != [8]byte{} {
+		mb, err := NewPacketReader(reader, &meta.Target).ReadMultiBuffer()
+		if err != nil {
+			return err
+		}
+		XUDPManager.Lock()
+		x := XUDPManager.Map[meta.GlobalID]
+		if x == nil {
+			x = &XUDP{GlobalID: meta.GlobalID}
+			XUDPManager.Map[meta.GlobalID] = x
+			XUDPManager.Unlock()
+		} else {
+			if x.Status == Initializing { // nearly impossible
+				XUDPManager.Unlock()
+				if xudp.Show {
+					fmt.Printf("XUDP hit: %v err: conflict\n", meta.GlobalID)
+				}
+				// It's not a good idea to return an err here, so just let client wait.
+				// Client will receive an End frame after sending a Keep frame.
+				return nil
+			}
+			x.Status = Initializing
+			XUDPManager.Unlock()
+			x.Mux.Close(false) // detach from previous Mux
+			b := buf.New()
+			b.Write(mb[0].Bytes())
+			b.UDP = mb[0].UDP
+			if err = x.Mux.output.WriteMultiBuffer(mb); err != nil {
+				x.Interrupt()
+				mb = buf.MultiBuffer{b}
+			} else {
+				b.Release()
+				mb = nil
+			}
+			if xudp.Show {
+				fmt.Printf("XUDP hit: %v err: %v\n", meta.GlobalID, err)
+			}
+		}
+		if mb != nil {
+			ctx = session.ContextWithTimeoutOnly(ctx, true)
+			// Actually, it won't return an error in Xray-core's implementations.
+			link, err := w.dispatcher.Dispatch(ctx, meta.Target)
+			if err != nil {
+				err = newError("failed to dispatch request to ", meta.Target).Base(err)
+				if xudp.Show {
+					fmt.Printf("XUDP new: %v err: %v\n", meta.GlobalID, err)
+				}
+				return err // it will break the whole Mux connection
+			}
+			link.Writer.WriteMultiBuffer(mb) // it's meaningless to test a new pipe
+			x.Mux = &Session{
+				input:  link.Reader,
+				output: link.Writer,
+			}
+			if xudp.Show {
+				fmt.Printf("XUDP new: %v err: %v\n", meta.GlobalID, err)
+			}
+		}
+		x.Mux = &Session{
+			input:        x.Mux.input,
+			output:       x.Mux.output,
+			parent:       w.sessionManager,
+			ID:           meta.SessionID,
+			transferType: protocol.TransferTypePacket,
+			XUDP:         x,
+		}
+		go handle(ctx, x.Mux, w.link.Writer)
+		x.Status = Active
+		if !w.sessionManager.Add(x.Mux) {
+			x.Mux.Close(false)
+		}
+		return nil
+	}
+
 	link, err := w.dispatcher.Dispatch(ctx, meta.Target)
 	if err != nil {
 		if meta.Option.Has(OptionData) {
@@ -157,8 +234,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
 	rr := s.NewReader(reader, &meta.Target)
 	if err := buf.Copy(rr, s.output); err != nil {
 		buf.Copy(rr, buf.Discard)
-		common.Interrupt(s.input)
-		return s.Close()
+		return s.Close(false)
 	}
 	return nil
 }
@@ -182,15 +258,8 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
 
 	if err != nil && buf.IsWriteError(err) {
 		newError("failed to write to downstream writer. closing session ", s.ID).Base(err).WriteToLog()
-
-		// Notify remote peer to close this session.
-		closingWriter := NewResponseWriter(meta.SessionID, w.link.Writer, protocol.TransferTypeStream)
-		closingWriter.Close()
-
-		drainErr := buf.Copy(rr, buf.Discard)
-		common.Interrupt(s.input)
-		s.Close()
-		return drainErr
+		s.Close(false)
+		return buf.Copy(rr, buf.Discard)
 	}
 
 	return err
@@ -198,12 +267,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
 
 func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
 	if s, found := w.sessionManager.Get(meta.SessionID); found {
-		if meta.Option.Has(OptionError) {
-			common.Interrupt(s.input)
-			common.Interrupt(s.output)
-		}
-		common.Interrupt(s.input)
-		s.Close()
+		s.Close(false)
 	}
 	if meta.Option.Has(OptionData) {
 		return buf.Copy(NewStreamReader(reader), buf.Discard)

+ 98 - 14
common/mux/session.go

@@ -1,12 +1,18 @@
 package mux
 
 import (
+	"fmt"
+	"io"
+	"runtime"
 	"sync"
+	"time"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/xudp"
+	"github.com/xtls/xray-core/transport/pipe"
 )
 
 type SessionManager struct {
@@ -61,21 +67,25 @@ func (m *SessionManager) Allocate() *Session {
 	return s
 }
 
-func (m *SessionManager) Add(s *Session) {
+func (m *SessionManager) Add(s *Session) bool {
 	m.Lock()
 	defer m.Unlock()
 
 	if m.closed {
-		return
+		return false
 	}
 
 	m.count++
 	m.sessions[s.ID] = s
+	return true
 }
 
-func (m *SessionManager) Remove(id uint16) {
-	m.Lock()
-	defer m.Unlock()
+func (m *SessionManager) Remove(locked bool, id uint16) {
+	if !locked {
+		m.Lock()
+		defer m.Unlock()
+	}
+	locked = true
 
 	if m.closed {
 		return
@@ -83,9 +93,11 @@ func (m *SessionManager) Remove(id uint16) {
 
 	delete(m.sessions, id)
 
-	if len(m.sessions) == 0 {
-		m.sessions = make(map[uint16]*Session, 16)
-	}
+	/*
+		if len(m.sessions) == 0 {
+			m.sessions = make(map[uint16]*Session, 16)
+		}
+	*/
 }
 
 func (m *SessionManager) Get(id uint16) (*Session, bool) {
@@ -127,8 +139,7 @@ func (m *SessionManager) Close() error {
 	m.closed = true
 
 	for _, s := range m.sessions {
-		common.Close(s.input)
-		common.Close(s.output)
+		s.Close(true)
 	}
 
 	m.sessions = nil
@@ -142,13 +153,42 @@ type Session struct {
 	parent       *SessionManager
 	ID           uint16
 	transferType protocol.TransferType
+	closed       bool
+	XUDP         *XUDP
 }
 
 // Close closes all resources associated with this session.
-func (s *Session) Close() error {
-	common.Close(s.output)
-	common.Close(s.input)
-	s.parent.Remove(s.ID)
+func (s *Session) Close(locked bool) error {
+	if !locked {
+		s.parent.Lock()
+		defer s.parent.Unlock()
+	}
+	locked = true
+	if s.closed {
+		return nil
+	}
+	s.closed = true
+	if s.XUDP == nil {
+		common.Interrupt(s.input)
+		common.Close(s.output)
+	} else {
+		// Stop existing handle(), then trigger writer.Close().
+		// Note that s.output may be dispatcher.SizeStatWriter.
+		s.input.(*pipe.Reader).ReturnAnError(io.EOF)
+		runtime.Gosched()
+		// If the error set by ReturnAnError still exists, clear it.
+		s.input.(*pipe.Reader).Recover()
+		XUDPManager.Lock()
+		if s.XUDP.Status == Active {
+			s.XUDP.Expire = time.Now().Add(time.Minute)
+			s.XUDP.Status = Expiring
+			if xudp.Show {
+				fmt.Printf("XUDP put: %v\n", s.XUDP.GlobalID)
+			}
+		}
+		XUDPManager.Unlock()
+	}
+	s.parent.Remove(locked, s.ID)
 	return nil
 }
 
@@ -159,3 +199,47 @@ func (s *Session) NewReader(reader *buf.BufferedReader, dest *net.Destination) b
 	}
 	return NewPacketReader(reader, dest)
 }
+
+const (
+	Initializing = 0
+	Active       = 1
+	Expiring     = 2
+)
+
+type XUDP struct {
+	GlobalID [8]byte
+	Status   uint64
+	Expire   time.Time
+	Mux      *Session
+}
+
+func (x *XUDP) Interrupt() {
+	common.Interrupt(x.Mux.input)
+	common.Close(x.Mux.output)
+}
+
+var XUDPManager struct {
+	sync.Mutex
+	Map map[[8]byte]*XUDP
+}
+
+func init() {
+	XUDPManager.Map = make(map[[8]byte]*XUDP)
+	go func() {
+		for {
+			time.Sleep(time.Minute)
+			now := time.Now()
+			XUDPManager.Lock()
+			for id, x := range XUDPManager.Map {
+				if x.Status == Expiring && now.After(x.Expire) {
+					x.Interrupt()
+					delete(XUDPManager.Map, id)
+					if xudp.Show {
+						fmt.Printf("XUDP del: %v\n", id)
+					}
+				}
+			}
+			XUDPManager.Unlock()
+		}
+	}()
+}

+ 1 - 1
common/mux/session_test.go

@@ -44,7 +44,7 @@ func TestSessionManagerClose(t *testing.T) {
 	if m.CloseIfNoSession() {
 		t.Error("able to close")
 	}
-	m.Remove(s.ID)
+	m.Remove(false, s.ID)
 	if !m.CloseIfNoSession() {
 		t.Error("not able to close")
 	}

+ 4 - 1
common/mux/writer.go

@@ -15,15 +15,17 @@ type Writer struct {
 	followup     bool
 	hasError     bool
 	transferType protocol.TransferType
+	globalID     [8]byte
 }
 
-func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType) *Writer {
+func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType, globalID [8]byte) *Writer {
 	return &Writer{
 		id:           id,
 		dest:         dest,
 		writer:       writer,
 		followup:     false,
 		transferType: transferType,
+		globalID:     globalID,
 	}
 }
 
@@ -40,6 +42,7 @@ func (w *Writer) getNextFrameMeta() FrameMetadata {
 	meta := FrameMetadata{
 		SessionID: w.id,
 		Target:    w.dest,
+		GlobalID:  w.globalID,
 	}
 
 	if w.followup {

+ 16 - 0
common/session/context.go

@@ -2,10 +2,14 @@ package session
 
 import (
 	"context"
+	_ "unsafe"
 
 	"github.com/xtls/xray-core/features/routing"
 )
 
+//go:linkname IndependentCancelCtx context.newCancelCtx
+func IndependentCancelCtx(parent context.Context) context.Context
+
 type sessionKey int
 
 const (
@@ -17,6 +21,7 @@ const (
 	sockoptSessionKey
 	trackedConnectionErrorKey
 	dispatcherKey
+	timeoutOnlyKey
 )
 
 // ContextWithID returns a new context with the given ID.
@@ -131,3 +136,14 @@ func DispatcherFromContext(ctx context.Context) routing.Dispatcher {
 	}
 	return nil
 }
+
+func ContextWithTimeoutOnly(ctx context.Context, only bool) context.Context {
+	return context.WithValue(ctx, timeoutOnlyKey, only)
+}
+
+func TimeoutOnlyFromContext(ctx context.Context) bool {
+	if val, ok := ctx.Value(timeoutOnlyKey).(bool); ok {
+		return val
+	}
+	return false
+}

+ 2 - 0
common/session/session.go

@@ -42,6 +42,8 @@ type Inbound struct {
 	Gateway net.Destination
 	// Tag of the inbound proxy that handles the connection.
 	Tag string
+	// Name of the inbound proxy that handles the connection.
+	Name string
 	// User is the user that authencates for the inbound. May be nil if the protocol allows anounymous traffic.
 	User *protocol.MemoryUser
 	// Conn is actually internet.Connection. May be nil.

+ 12 - 0
common/task/task.go

@@ -38,6 +38,12 @@ func Run(ctx context.Context, tasks ...func() error) error {
 		}(task)
 	}
 
+	/*
+		if altctx := ctx.Value("altctx"); altctx != nil {
+			ctx = altctx.(context.Context)
+		}
+	*/
+
 	for i := 0; i < n; i++ {
 		select {
 		case err := <-done:
@@ -48,5 +54,11 @@ func Run(ctx context.Context, tasks ...func() error) error {
 		}
 	}
 
+	/*
+		if cancel := ctx.Value("cancel"); cancel != nil {
+			cancel.(context.CancelFunc)()
+		}
+	*/
+
 	return nil
 }

+ 58 - 9
common/xudp/xudp.go

@@ -1,30 +1,76 @@
 package xudp
 
 import (
+	"context"
+	"crypto/rand"
+	"encoding/base64"
+	"fmt"
 	"io"
+	"os"
+	"strings"
 
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/session"
+	"lukechampine.com/blake3"
 )
 
-var addrParser = protocol.NewAddressParser(
+var AddrParser = protocol.NewAddressParser(
 	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
 	protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
 	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
 	protocol.PortThenAddress(),
 )
 
-func NewPacketWriter(writer buf.Writer, dest net.Destination) *PacketWriter {
+var (
+	Show    bool
+	BaseKey [32]byte
+)
+
+const (
+	EnvShow    = "XRAY_XUDP_SHOW"
+	EnvBaseKey = "XRAY_XUDP_BASEKEY"
+)
+
+func init() {
+	if strings.ToLower(os.Getenv(EnvShow)) == "true" {
+		Show = true
+	}
+	if raw := os.Getenv(EnvBaseKey); raw != "" {
+		if key, _ := base64.RawURLEncoding.DecodeString(raw); len(key) == len(BaseKey) {
+			copy(BaseKey[:], key)
+			return
+		} else {
+			panic(EnvBaseKey + ": invalid value: " + raw)
+		}
+	}
+	rand.Read(BaseKey[:])
+}
+
+func GetGlobalID(ctx context.Context) (globalID [8]byte) {
+	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
+		(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
+		h := blake3.New(8, BaseKey[:])
+		h.Write([]byte(inbound.Source.String()))
+		copy(globalID[:], h.Sum(nil))
+		fmt.Printf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID)
+	}
+	return
+}
+
+func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter {
 	return &PacketWriter{
-		Writer: writer,
-		Dest:   dest,
+		Writer:   writer,
+		Dest:     dest,
+		GlobalID: globalID,
 	}
 }
 
 type PacketWriter struct {
-	Writer buf.Writer
-	Dest   net.Destination
+	Writer   buf.Writer
+	Dest     net.Destination
+	GlobalID [8]byte
 }
 
 func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
@@ -42,14 +88,17 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 			eb.WriteByte(1) // New
 			eb.WriteByte(1) // Opt
 			eb.WriteByte(2) // UDP
-			addrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
+			AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
+			if b.UDP != nil { // make sure it's user's proxy request
+				eb.Write(w.GlobalID[:])
+			}
 			w.Dest.Network = net.Network_Unknown
 		} else {
 			eb.WriteByte(2) // Keep
 			eb.WriteByte(1)
 			if b.UDP != nil {
 				eb.WriteByte(2)
-				addrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
+				AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
 			}
 		}
 		l := eb.Len() - 2
@@ -98,7 +147,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		case 2:
 			if l != 4 {
 				b.Advance(5)
-				addr, port, err := addrParser.ReadAddressPort(nil, b)
+				addr, port, err := AddrParser.ReadAddressPort(nil, b)
 				if err != nil {
 					b.Release()
 					return nil, err

+ 1 - 1
go.mod

@@ -29,6 +29,7 @@ require (
 	google.golang.org/protobuf v1.30.0
 	gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c
 	h12.io/socks v1.0.3
+	lukechampine.com/blake3 v1.1.7
 )
 
 require (
@@ -55,5 +56,4 @@ require (
 	google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 // indirect
 	gopkg.in/yaml.v2 v2.4.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
-	lukechampine.com/blake3 v1.1.7 // indirect
 )

+ 17 - 8
infra/conf/xray.go

@@ -10,6 +10,7 @@ import (
 	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/app/proxyman"
 	"github.com/xtls/xray-core/app/stats"
+	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/serial"
 	core "github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/transport/internet"
@@ -107,8 +108,9 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) {
 }
 
 type MuxConfig struct {
-	Enabled     bool  `json:"enabled"`
-	Concurrency int16 `json:"concurrency"`
+	Enabled     bool   `json:"enabled"`
+	Concurrency int16  `json:"concurrency"`
+	Only        string `json:"only"`
 }
 
 // Build creates MultiplexingConfig, Concurrency < 0 completely disables mux.
@@ -116,16 +118,23 @@ func (m *MuxConfig) Build() *proxyman.MultiplexingConfig {
 	if m.Concurrency < 0 {
 		return nil
 	}
-
-	var con uint32 = 8
-	if m.Concurrency > 0 {
-		con = uint32(m.Concurrency)
+	if m.Concurrency == 0 {
+		m.Concurrency = 8
 	}
 
-	return &proxyman.MultiplexingConfig{
+	config := &proxyman.MultiplexingConfig{
 		Enabled:     m.Enabled,
-		Concurrency: con,
+		Concurrency: uint32(m.Concurrency),
+	}
+
+	switch strings.ToLower(m.Only) {
+	case "tcp":
+		config.Only = uint32(net.Network_TCP)
+	case "udp":
+		config.Only = uint32(net.Network_UDP)
 	}
+
+	return config
 }
 
 type InboundDetourAllocationConfig struct {

+ 4 - 0
proxy/dns/dns.go

@@ -148,6 +148,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 		}
 	}
 
+	if session.TimeoutOnlyFromContext(ctx) {
+		ctx, _ = context.WithCancel(context.Background())
+	}
+
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
 

+ 1 - 0
proxy/dokodemo/dokodemo.go

@@ -103,6 +103,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 
 	inbound := session.InboundFromContext(ctx)
 	if inbound != nil {
+		inbound.Name = "dokodemo-door"
 		inbound.User = &protocol.MemoryUser{
 			Level: d.config.UserLevel,
 		}

+ 16 - 1
proxy/freedom/freedom.go

@@ -149,9 +149,20 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 	}
 	defer conn.Close()
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	plcy := h.policy()
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, plcy.Timeouts.ConnectionIdle)
 
 	requestDone := func() error {
 		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
@@ -186,6 +197,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		return nil
 	}
 
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
 	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
 		return newError("connection ends").Base(err)
 	}

+ 16 - 1
proxy/http/client.go

@@ -128,8 +128,19 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		p = c.policyManager.ForLevel(user.Level)
 	}
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, p.Timeouts.ConnectionIdle)
 
 	requestFunc := func() error {
 		defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
@@ -140,6 +151,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
 	}
 
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
 	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
 	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
 		return newError("connection ends").Base(err)

+ 1 - 0
proxy/http/server.go

@@ -85,6 +85,7 @@ type readerOnly struct {
 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
 	inbound := session.InboundFromContext(ctx)
 	if inbound != nil {
+		inbound.Name = "http"
 		inbound.User = &protocol.MemoryUser{
 			Level: s.config.UserLevel,
 		}

+ 16 - 1
proxy/shadowsocks/client.go

@@ -96,9 +96,24 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 	}
 	request.User = user
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	sessionPolicy := c.policyManager.ForLevel(user.Level)
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, sessionPolicy.Timeouts.ConnectionIdle)
+
+	if newCtx != nil {
+		ctx = newCtx
+	}
 
 	if request.Command == protocol.RequestCommandTCP {
 		requestDone := func() error {

+ 1 - 0
proxy/shadowsocks/server.go

@@ -113,6 +113,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
 	if inbound == nil {
 		panic("no inbound metadata")
 	}
+	inbound.Name = "shadowsocks"
 
 	var dest *net.Destination
 

+ 2 - 1
proxy/shadowsocks_2022/inbound.go

@@ -3,7 +3,7 @@ package shadowsocks_2022
 import (
 	"context"
 
-	"github.com/sagernet/sing-shadowsocks"
+	shadowsocks "github.com/sagernet/sing-shadowsocks"
 	"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
 	C "github.com/sagernet/sing/common"
 	B "github.com/sagernet/sing/common/buf"
@@ -64,6 +64,7 @@ func (i *Inbound) Network() []net.Network {
 
 func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
 	inbound := session.InboundFromContext(ctx)
+	inbound.Name = "shadowsocks-2022"
 
 	var metadata M.Metadata
 	if inbound.Source.IsValid() {

+ 1 - 0
proxy/shadowsocks_2022/inbound_multi.go

@@ -153,6 +153,7 @@ func (i *MultiUserInbound) Network() []net.Network {
 
 func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
 	inbound := session.InboundFromContext(ctx)
+	inbound.Name = "shadowsocks-2022-multi"
 
 	var metadata M.Metadata
 	if inbound.Source.IsValid() {

+ 1 - 0
proxy/shadowsocks_2022/inbound_relay.go

@@ -85,6 +85,7 @@ func (i *RelayInbound) Network() []net.Network {
 
 func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
 	inbound := session.InboundFromContext(ctx)
+	inbound.Name = "shadowsocks-2022-relay"
 
 	var metadata M.Metadata
 	if inbound.Source.IsValid() {

+ 5 - 1
proxy/shadowsocks_2022/outbound.go

@@ -6,7 +6,7 @@ import (
 	"runtime"
 	"time"
 
-	"github.com/sagernet/sing-shadowsocks"
+	shadowsocks "github.com/sagernet/sing-shadowsocks"
 	"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
 	C "github.com/sagernet/sing/common"
 	B "github.com/sagernet/sing/common/buf"
@@ -88,6 +88,10 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int
 		return newError("failed to connect to server").Base(err)
 	}
 
+	if session.TimeoutOnlyFromContext(ctx) {
+		ctx, _ = context.WithCancel(context.Background())
+	}
+
 	if network == net.Network_TCP {
 		serverConn := o.method.DialEarlyConn(connection, toSocksaddr(destination))
 		var handshake bool

+ 16 - 1
proxy/socks/client.go

@@ -151,8 +151,19 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		newError("failed to clear deadline after handshake").Base(err).WriteToLog(session.ExportIDToError(ctx))
 	}
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, p.Timeouts.ConnectionIdle)
 
 	var requestFunc func() error
 	var responseFunc func() error
@@ -183,6 +194,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		}
 	}
 
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
 	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
 	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
 		return newError("connection ends").Base(err)

+ 1 - 0
proxy/socks/server.go

@@ -64,6 +64,7 @@ func (s *Server) Network() []net.Network {
 // Process implements proxy.Inbound.
 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
 	if inbound := session.InboundFromContext(ctx); inbound != nil {
+		inbound.Name = "socks"
 		inbound.User = &protocol.MemoryUser{
 			Level: s.config.UserLevel,
 		}

+ 16 - 1
proxy/trojan/client.go

@@ -93,9 +93,20 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		Flow: account.Flow,
 	}
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	sessionPolicy := c.policyManager.ForLevel(user.Level)
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, sessionPolicy.Timeouts.ConnectionIdle)
 
 	postRequest := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
@@ -149,6 +160,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
 	}
 
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
 	responseDoneAndCloseWriter := task.OnSuccess(getResponse, task.Close(link.Writer))
 	if err := task.Run(ctx, postRequest, responseDoneAndCloseWriter); err != nil {
 		return newError("connection ends").Base(err)

+ 1 - 0
proxy/trojan/server.go

@@ -217,6 +217,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 	if inbound == nil {
 		panic("no inbound metadata")
 	}
+	inbound.Name = "trojan"
 	inbound.User = user
 	sessionPolicy = s.policyManager.ForLevel(user.Level)
 

+ 1 - 0
proxy/vless/inbound/inbound.go

@@ -438,6 +438,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 	if inbound == nil {
 		panic("no inbound metadata")
 	}
+	inbound.Name = "vless"
 	inbound.User = request.User
 
 	account := request.User.Account.(*vless.MemoryAccount)

+ 17 - 2
proxy/vless/outbound/outbound.go

@@ -170,9 +170,20 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 	}
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	sessionPolicy := h.policyManager.ForLevel(request.User.Level)
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, sessionPolicy.Timeouts.ConnectionIdle)
 
 	clientReader := link.Reader // .(*pipe.Reader)
 	clientWriter := link.Writer // .(*pipe.Writer)
@@ -200,7 +211,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		// default: serverWriter := bufferWriter
 		serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons)
 		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
-			serverWriter = xudp.NewPacketWriter(serverWriter, target)
+			serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx))
 		}
 		userUUID := account.ID.Bytes()
 		timeoutReader, ok := clientReader.(buf.TimeoutReader)
@@ -300,6 +311,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		return nil
 	}
 
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
 	if err := task.Run(ctx, postRequest, task.OnSuccess(getResponse, task.Close(clientWriter))); err != nil {
 		return newError("connection ends").Base(err).AtInfo()
 	}

+ 1 - 0
proxy/vmess/inbound/inbound.go

@@ -287,6 +287,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 	if inbound == nil {
 		panic("no inbound metadata")
 	}
+	inbound.Name = "vmess"
 	inbound.User = request.User
 
 	sessionPolicy = h.policyManager.ForLevel(request.User.Level)

+ 17 - 2
proxy/vmess/outbound/outbound.go

@@ -138,11 +138,22 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 
 	behaviorSeed := crc64.Checksum(hashkdf.Sum(nil), crc64.MakeTable(crc64.ISO))
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	session := encoding.NewClientSession(ctx, isAEAD, protocol.DefaultIDHash, int64(behaviorSeed))
 	sessionPolicy := h.policyManager.ForLevel(request.User.Level)
 
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, sessionPolicy.Timeouts.ConnectionIdle)
 
 	if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 {
 		request.Command = protocol.RequestCommandMux
@@ -164,7 +175,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 		bodyWriter2 := bodyWriter
 		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
-			bodyWriter = xudp.NewPacketWriter(bodyWriter, target)
+			bodyWriter = xudp.NewPacketWriter(bodyWriter, target, xudp.GetGlobalID(ctx))
 		}
 		if err := buf.CopyOnceTimeout(input, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
 			return newError("failed to write first payload").Base(err)
@@ -208,6 +219,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
 	}
 
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
 	responseDonePost := task.OnSuccess(responseDone, task.Close(output))
 	if err := task.Run(ctx, requestDone, responseDonePost); err != nil {
 		return newError("connection ends").Base(err)

+ 16 - 1
proxy/wireguard/wireguard.go

@@ -127,10 +127,21 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		addr = net.IPAddress(ips[0])
 	}
 
+	var newCtx context.Context
+	var newCancel context.CancelFunc
+	if session.TimeoutOnlyFromContext(ctx) {
+		newCtx, newCancel = context.WithCancel(context.Background())
+	}
+
 	p := h.policyManager.ForLevel(0)
 
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, func() {
+		cancel()
+		if newCancel != nil {
+			newCancel()
+		}
+	}, p.Timeouts.ConnectionIdle)
 	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
 
 	var requestFunc func() error
@@ -166,6 +177,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 	}
 
+	if newCtx != nil {
+		ctx = newCtx
+	}
+
 	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
 	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
 		return newError("connection ends").Base(err)

+ 3 - 0
transport/pipe/impl.go

@@ -37,6 +37,7 @@ type pipe struct {
 	readSignal  *signal.Notifier
 	writeSignal *signal.Notifier
 	done        *done.Instance
+	errChan     chan error
 	option      pipeOption
 	state       state
 }
@@ -92,6 +93,8 @@ func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		select {
 		case <-p.readSignal.Wait():
 		case <-p.done.Wait():
+		case err = <-p.errChan:
+			return nil, err
 		}
 	}
 }

+ 1 - 0
transport/pipe/pipe.go

@@ -59,6 +59,7 @@ func New(opts ...Option) (*Reader, *Writer) {
 		readSignal:  signal.NewNotifier(),
 		writeSignal: signal.NewNotifier(),
 		done:        done.New(),
+		errChan:     make(chan error, 1),
 		option: pipeOption{
 			limit: -1,
 		},

+ 14 - 0
transport/pipe/reader.go

@@ -25,3 +25,17 @@ func (r *Reader) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error
 func (r *Reader) Interrupt() {
 	r.pipe.Interrupt()
 }
+
+// ReturnAnError makes ReadMultiBuffer return an error, only once.
+func (r *Reader) ReturnAnError(err error) {
+	r.pipe.errChan <- err
+}
+
+// Recover catches an error set by ReturnAnError, if exists.
+func (r *Reader) Recover() (err error) {
+	select {
+	case err = <-r.pipe.errChan:
+	default:
+	}
+	return
+}