Browse Source

Refactor WrapLink logic (#5288)

https://github.com/XTLS/Xray-core/pull/5133
https://github.com/XTLS/Xray-core/pull/5286
风扇滑翔翼 2 weeks ago
parent
commit
f9dd3aef72
4 changed files with 22 additions and 8 deletions
  1. 3 2
      app/reverse/bridge.go
  2. 3 2
      common/mux/server.go
  3. 6 0
      features/routing/dispatcher.go
  4. 10 4
      proxy/vless/inbound/inbound.go

+ 3 - 2
app/reverse/bridge.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"time"
 
-	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/mux"
 	"github.com/xtls/xray-core/common/net"
@@ -231,7 +230,9 @@ func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, l
 		return w.Dispatcher.DispatchLink(ctx, dest, link)
 	}
 
-	link = w.Dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
+	if d, ok := w.Dispatcher.(routing.WrapLinkDispatcher); ok {
+		link = d.WrapLink(ctx, link)
+	}
 	w.handleInternalConn(link)
 
 	return nil

+ 3 - 2
common/mux/server.go

@@ -5,7 +5,6 @@ import (
 	"io"
 	"time"
 
-	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
@@ -64,7 +63,9 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
 	if dest.Address != muxCoolAddress {
 		return s.dispatcher.DispatchLink(ctx, dest, link)
 	}
-	link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
+	if d, ok := s.dispatcher.(routing.WrapLinkDispatcher); ok {
+		link = d.WrapLink(ctx, link)
+	}
 	worker, err := NewServerWorker(ctx, s.dispatcher, link)
 	if err != nil {
 		return err

+ 6 - 0
features/routing/dispatcher.go

@@ -26,3 +26,9 @@ type Dispatcher interface {
 func DispatcherType() interface{} {
 	return (*Dispatcher)(nil)
 }
+
+// Just for type assertion
+type WrapLinkDispatcher interface {
+	Dispatcher
+	WrapLink(ctx context.Context, link *transport.Link) *transport.Link
+}

+ 10 - 4
proxy/vless/inbound/inbound.go

@@ -12,7 +12,6 @@ import (
 	"time"
 	"unsafe"
 
-	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/app/reverse"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
@@ -76,7 +75,7 @@ type Handler struct {
 	validator              vless.Validator
 	decryption             *encryption.ServerInstance
 	outboundHandlerManager outbound.Manager
-	defaultDispatcher      *dispatcher.DefaultDispatcher
+	wrapLink               func(ctx context.Context, link *transport.Link) *transport.Link
 	ctx                    context.Context
 	fallbacks              map[string]map[string]map[string]*Fallback // or nil
 	// regexps               map[string]*regexp.Regexp       // or nil
@@ -85,12 +84,16 @@ type Handler struct {
 // New creates a new VLess inbound handler.
 func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
 	v := core.MustFromContext(ctx)
+	var wrapLinkFunc func(ctx context.Context, link *transport.Link) *transport.Link
+	if dispatcher, ok := v.GetFeature(routing.DispatcherType()).(routing.WrapLinkDispatcher); ok {
+		wrapLinkFunc = dispatcher.WrapLink
+	}
 	handler := &Handler{
 		inboundHandlerManager:  v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
 		policyManager:          v.GetFeature(policy.ManagerType()).(policy.Manager),
 		validator:              validator,
 		outboundHandlerManager: v.GetFeature(outbound.ManagerType()).(outbound.Manager),
-		defaultDispatcher:      v.GetFeature(routing.DispatcherType()).(*dispatcher.DefaultDispatcher),
+		wrapLink:               wrapLinkFunc,
 		ctx:                    ctx,
 	}
 
@@ -619,7 +622,10 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 		if err != nil {
 			return err
 		}
-		return r.NewMux(ctx, h.defaultDispatcher.WrapLink(ctx, &transport.Link{Reader: clientReader, Writer: clientWriter}))
+		if h.wrapLink == nil {
+			return errors.New("VLESS reverse must have a dispatcher that implemented routing.WrapLinkDispatcher")
+		}
+		return r.NewMux(ctx, h.wrapLink(ctx, &transport.Link{Reader: clientReader, Writer: clientWriter}))
 	}
 
 	if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{