浏览代码

Fix data leak between mux.cool connections (#3718)

Fix #116
mmmray 1 年之前
父节点
当前提交
3dd3bf94d4
共有 3 个文件被更改,包括 144 次插入1 次删除
  1. 4 1
      common/mux/server.go
  2. 124 0
      common/mux/server_test.go
  3. 16 0
      common/session/context.go

+ 4 - 1
common/mux/server.go

@@ -118,6 +118,9 @@ func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.Bu
 }
 }
 
 
 func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
 func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
+	// deep-clone outbounds because it is going to be mutated concurrently
+	// (Target and OriginalTarget)
+	ctx = session.ContextCloneOutbounds(ctx)
 	errors.LogInfo(ctx, "received request for ", meta.Target)
 	errors.LogInfo(ctx, "received request for ", meta.Target)
 	{
 	{
 		msg := &log.AccessMessage{
 		msg := &log.AccessMessage{
@@ -170,7 +173,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
 				b.Release()
 				b.Release()
 				mb = nil
 				mb = nil
 			}
 			}
-			errors.LogInfoInner(ctx, err,"XUDP hit ", meta.GlobalID)
+			errors.LogInfoInner(ctx, err, "XUDP hit ", meta.GlobalID)
 		}
 		}
 		if mb != nil {
 		if mb != nil {
 			ctx = session.ContextWithTimeoutOnly(ctx, true)
 			ctx = session.ContextWithTimeoutOnly(ctx, true)

+ 124 - 0
common/mux/server_test.go

@@ -0,0 +1,124 @@
+package mux_test
+
+import (
+	"context"
+	"testing"
+
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
+	"github.com/xtls/xray-core/common/mux"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/features/routing"
+	"github.com/xtls/xray-core/transport"
+	"github.com/xtls/xray-core/transport/pipe"
+)
+
+func newLinkPair() (*transport.Link, *transport.Link) {
+	opt := pipe.WithoutSizeLimit()
+	uplinkReader, uplinkWriter := pipe.New(opt)
+	downlinkReader, downlinkWriter := pipe.New(opt)
+
+	uplink := &transport.Link{
+		Reader: uplinkReader,
+		Writer: downlinkWriter,
+	}
+
+	downlink := &transport.Link{
+		Reader: downlinkReader,
+		Writer: uplinkWriter,
+	}
+
+	return uplink, downlink
+}
+
+type TestDispatcher struct {
+	OnDispatch func(ctx context.Context, dest net.Destination) (*transport.Link, error)
+}
+
+func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) {
+	return d.OnDispatch(ctx, dest)
+}
+
+func (d *TestDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
+	return nil
+}
+
+func (d *TestDispatcher) Start() error {
+	return nil
+}
+
+func (d *TestDispatcher) Close() error {
+	return nil
+}
+
+func (*TestDispatcher) Type() interface{} {
+	return routing.DispatcherType()
+}
+
+func TestRegressionOutboundLeak(t *testing.T) {
+	originalOutbounds := []*session.Outbound{{}}
+	serverCtx := session.ContextWithOutbounds(context.Background(), originalOutbounds)
+
+	websiteUplink, websiteDownlink := newLinkPair()
+
+	dispatcher := TestDispatcher{
+		OnDispatch: func(ctx context.Context, dest net.Destination) (*transport.Link, error) {
+			// emulate what DefaultRouter.Dispatch does, and mutate something on the context
+			ob := session.OutboundsFromContext(ctx)[0]
+			ob.Target = dest
+			return websiteDownlink, nil
+		},
+	}
+
+	muxServerUplink, muxServerDownlink := newLinkPair()
+	_, err := mux.NewServerWorker(serverCtx, &dispatcher, muxServerUplink)
+	common.Must(err)
+
+	client, err := mux.NewClientWorker(*muxServerDownlink, mux.ClientStrategy{})
+	common.Must(err)
+
+	clientCtx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
+		Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
+	}})
+
+	muxClientUplink, muxClientDownlink := newLinkPair()
+
+	ok := client.Dispatch(clientCtx, muxClientUplink)
+	if !ok {
+		t.Error("failed to dispatch")
+	}
+
+	{
+		b := buf.FromBytes([]byte("hello"))
+		common.Must(muxClientDownlink.Writer.WriteMultiBuffer(buf.MultiBuffer{b}))
+	}
+
+	resMb, err := websiteUplink.Reader.ReadMultiBuffer()
+	common.Must(err)
+	res := resMb.String()
+	if res != "hello" {
+		t.Error("upload: ", res)
+	}
+
+	{
+		b := buf.FromBytes([]byte("world"))
+		common.Must(websiteUplink.Writer.WriteMultiBuffer(buf.MultiBuffer{b}))
+	}
+
+	resMb, err = muxClientDownlink.Reader.ReadMultiBuffer()
+	common.Must(err)
+	res = resMb.String()
+	if res != "world" {
+		t.Error("download: ", res)
+	}
+
+	outbounds := session.OutboundsFromContext(serverCtx)
+	if outbounds[0] != originalOutbounds[0] {
+		t.Error("outbound got reassigned: ", outbounds[0])
+	}
+
+	if outbounds[0].Target.Address != nil {
+		t.Error("outbound target got leaked: ", outbounds[0].Target.String())
+	}
+}

+ 16 - 0
common/session/context.go

@@ -40,6 +40,22 @@ func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Co
 	return context.WithValue(ctx, outboundSessionKey, outbounds)
 	return context.WithValue(ctx, outboundSessionKey, outbounds)
 }
 }
 
 
+func ContextCloneOutbounds(ctx context.Context) context.Context {
+	outbounds := OutboundsFromContext(ctx)
+	newOutbounds := make([]*Outbound, len(outbounds))
+	for i, ob := range outbounds {
+		if ob == nil {
+			continue
+		}
+
+		// copy outbound by value
+		v := *ob
+		newOutbounds[i] = &v
+	}
+
+	return ContextWithOutbounds(ctx, newOutbounds)
+}
+
 func OutboundsFromContext(ctx context.Context) []*Outbound {
 func OutboundsFromContext(ctx context.Context) []*Outbound {
 	if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok {
 	if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok {
 		return outbounds
 		return outbounds