瀏覽代碼

Wireguard inbound: Fix context sharing problem (#4988)

* Try fix Wireguard inbound context sharing problem

* Shallow copy inbound and content

* Fix context passing

* Add notes for source address
yuhan6665 2 月之前
父節點
當前提交
337b4b814e
共有 4 個文件被更改,包括 47 次插入61 次删除
  1. 1 3
      common/mux/server.go
  2. 11 21
      common/session/context.go
  3. 9 7
      common/session/session.go
  4. 26 30
      proxy/wireguard/server.go

+ 1 - 3
common/mux/server.go

@@ -118,9 +118,7 @@ func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.Bu
 }
 
 func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
-	// deep-clone outbounds because it is going to be mutated concurrently
-	// (Target and OriginalTarget)
-	ctx = session.ContextCloneOutboundsAndContent(ctx)
+	ctx = session.SubContextFromMuxInbound(ctx)
 	errors.LogInfo(ctx, "received request for ", meta.Target)
 	{
 		msg := &log.AccessMessage{

+ 11 - 21
common/session/context.go

@@ -16,15 +16,15 @@ const (
 	inboundSessionKey         ctx.SessionKey = 1
 	outboundSessionKey        ctx.SessionKey = 2
 	contentSessionKey         ctx.SessionKey = 3
-	muxPreferredSessionKey    ctx.SessionKey = 4
-	sockoptSessionKey         ctx.SessionKey = 5
-	trackedConnectionErrorKey ctx.SessionKey = 6
-	dispatcherKey             ctx.SessionKey = 7
-	timeoutOnlyKey            ctx.SessionKey = 8
-	allowedNetworkKey         ctx.SessionKey = 9
-	handlerSessionKey         ctx.SessionKey = 10
-	mitmAlpn11Key             ctx.SessionKey = 11
-	mitmServerNameKey         ctx.SessionKey = 12
+	muxPreferredSessionKey    ctx.SessionKey = 4 // unused
+	sockoptSessionKey         ctx.SessionKey = 5 // used by dokodemo to only receive sockopt.Mark
+	trackedConnectionErrorKey ctx.SessionKey = 6 // used by observer to get outbound error
+	dispatcherKey             ctx.SessionKey = 7 // used by ss2022 inbounds to get dispatcher
+	timeoutOnlyKey            ctx.SessionKey = 8 // mux context's child contexts to only cancel when its own traffic times out
+	allowedNetworkKey         ctx.SessionKey = 9 // muxcool server control incoming request tcp/udp
+	handlerSessionKey         ctx.SessionKey = 10 // unused
+	mitmAlpn11Key             ctx.SessionKey = 11 // used by TLS dialer
+	mitmServerNameKey         ctx.SessionKey = 12 // used by TLS dialer
 )
 
 func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context {
@@ -42,18 +42,8 @@ func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Co
 	return context.WithValue(ctx, outboundSessionKey, outbounds)
 }
 
-func ContextCloneOutboundsAndContent(ctx context.Context) context.Context {
-	outbounds := OutboundsFromContext(ctx)
-	newOutbounds := make([]*Outbound, len(outbounds))
-	for i, ob := range outbounds {
-		if ob == nil {
-			continue
-		}
-
-		// copy outbound by value
-		v := *ob
-		newOutbounds[i] = &v
-	}
+func SubContextFromMuxInbound(ctx context.Context) context.Context {
+	newOutbounds := []*Outbound{{}}
 
 	content := ContentFromContext(ctx)
 	newContent := Content{}

+ 9 - 7
common/session/session.go

@@ -48,9 +48,9 @@ type Inbound struct {
 	User *protocol.MemoryUser
 	// VlessRoute is the user-sent VLESS UUID's last byte.
 	VlessRoute net.Port
-	// Conn is actually internet.Connection. May be nil.
+	// Used by splice copy. Conn is actually internet.Connection. May be nil.
 	Conn net.Conn
-	// Timer of the inbound buf copier. May be nil.
+	// Used by splice copy. Timer of the inbound buf copier. May be nil.
 	Timer *signal.ActivityTimer
 	// CanSpliceCopy is a property for this connection
 	// 1 = can, 2 = after processing protocol info should be able to, 3 = cannot
@@ -69,31 +69,33 @@ type Outbound struct {
 	Tag string
 	// Name of the outbound proxy that handles the connection.
 	Name string
-	// Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings
+	// Unused. Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings
 	Conn net.Conn
 	// CanSpliceCopy is a property for this connection
 	// 1 = can, 2 = after processing protocol info should be able to, 3 = cannot
 	CanSpliceCopy int
 }
 
-// SniffingRequest controls the behavior of content sniffing.
+// SniffingRequest controls the behavior of content sniffing. They are from inbound config. Read-only
 type SniffingRequest struct {
-	ExcludeForDomain               []string // read-only once set
-	OverrideDestinationForProtocol []string // read-only once set
+	ExcludeForDomain               []string
+	OverrideDestinationForProtocol []string
 	Enabled                        bool
 	MetadataOnly                   bool
 	RouteOnly                      bool
 }
 
-// Content is the metadata of the connection content.
+// Content is the metadata of the connection content. Mainly used for routing.
 type Content struct {
 	// Protocol of current content.
 	Protocol string
 
 	SniffingRequest SniffingRequest
 
+	// HTTP traffic sniffed headers
 	Attributes map[string]string
 
+	// SkipDNSResolve is set from DNS module. the DOH remote server maybe a domain name, this prevents cycle resolving dead loop
 	SkipDNSResolve bool
 }
 

+ 26 - 30
proxy/wireguard/server.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
+	c "github.com/xtls/xray-core/common/ctx"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/log"
 	"github.com/xtls/xray-core/common/net"
@@ -33,7 +34,6 @@ type routingInfo struct {
 	ctx         context.Context
 	dispatcher  routing.Dispatcher
 	inboundTag  *session.Inbound
-	outboundTag *session.Outbound
 	contentTag  *session.Content
 }
 
@@ -78,18 +78,11 @@ func (*Server) Network() []net.Network {
 
 // Process implements proxy.Inbound.
 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
-	inbound := session.InboundFromContext(ctx)
-	inbound.Name = "wireguard"
-	inbound.CanSpliceCopy = 3
-	outbounds := session.OutboundsFromContext(ctx)
-	ob := outbounds[len(outbounds)-1]
-
 	s.info = routingInfo{
-		ctx:         core.ToBackgroundDetachedContext(ctx),
-		dispatcher:  dispatcher,
-		inboundTag:  session.InboundFromContext(ctx),
-		outboundTag: ob,
-		contentTag:  session.ContentFromContext(ctx),
+		ctx:        ctx,
+		dispatcher: dispatcher,
+		inboundTag: session.InboundFromContext(ctx),
+		contentTag: session.ContentFromContext(ctx),
 	}
 
 	ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
@@ -134,6 +127,25 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 	defer conn.Close()
 
 	ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
+	sid := session.NewID()
+	ctx = c.ContextWithID(ctx, sid)
+	inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
+	if s.info.inboundTag != nil {
+		inbound = *s.info.inboundTag
+	}
+	inbound.Name = "wireguard"
+	inbound.CanSpliceCopy = 3
+
+	// overwrite the source to use the tun address for each sub context.
+	// Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
+	// Currently we have no way to link to the original source address
+	inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
+	ctx = session.ContextWithInbound(ctx, &inbound)
+	if s.info.contentTag != nil {
+		ctx = session.ContextWithContent(ctx, s.info.contentTag)
+	}
+	ctx = session.SubContextFromMuxInbound(ctx)
+
 	plcy := s.policyManager.ForLevel(0)
 	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
 
@@ -144,25 +156,9 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 		Reason: "",
 	})
 
-	if s.info.inboundTag != nil {
-		ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
-	}
-
-	// what's this?
-	// Session information should not be shared between different connections
-	// why reuse them in server level? This will cause incorrect destoverride and unexpected routing behavior.
-	// Disable it temporarily. Maybe s.info should be removed.
-
-	//	if s.info.outboundTag != nil {
-	//		ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag})
-	//	}
-	//  if s.info.contentTag != nil {
-	//	    ctx = session.ContextWithContent(ctx, s.info.contentTag)
-	//  }
-
 	link, err := s.info.dispatcher.Dispatch(ctx, dest)
 	if err != nil {
-		errors.LogErrorInner(s.info.ctx, err, "dispatch connection")
+		errors.LogErrorInner(ctx, err, "dispatch connection")
 	}
 	defer cancel()
 
@@ -188,7 +184,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
 		common.Interrupt(link.Reader)
 		common.Interrupt(link.Writer)
-		errors.LogDebugInner(s.info.ctx, err, "connection ends")
+		errors.LogDebugInner(ctx, err, "connection ends")
 		return
 	}
 }