Browse Source

Add DispatchLink

世界 4 years ago
parent
commit
50e576081e

+ 61 - 0
app/dispatcher/default.go

@@ -271,6 +271,67 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 	return inbound, nil
 }
 
+// DispatchLink implements routing.Dispatcher.
+func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
+	if !destination.IsValid() {
+		return newError("Dispatcher: Invalid destination.")
+	}
+	ob := &session.Outbound{
+		Target: destination,
+	}
+	ctx = session.ContextWithOutbound(ctx, ob)
+	content := session.ContentFromContext(ctx)
+	if content == nil {
+		content = new(session.Content)
+		ctx = session.ContextWithContent(ctx, content)
+	}
+	sniffingRequest := content.SniffingRequest
+	switch {
+	case !sniffingRequest.Enabled:
+		go d.routedDispatch(ctx, outbound, destination)
+	case destination.Network != net.Network_TCP:
+		// Only metadata sniff will be used for non tcp connection
+		result, err := sniffer(ctx, nil, true)
+		if err == nil {
+			content.Protocol = result.Protocol()
+			if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
+				domain := result.Domain()
+				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
+				destination.Address = net.ParseAddress(domain)
+				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
+					ob.RouteTarget = destination
+				} else {
+					ob.Target = destination
+				}
+			}
+		}
+		go d.routedDispatch(ctx, outbound, destination)
+	default:
+		go func() {
+			cReader := &cachedReader{
+				reader: outbound.Reader.(*pipe.Reader),
+			}
+			outbound.Reader = cReader
+			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly)
+			if err == nil {
+				content.Protocol = result.Protocol()
+			}
+			if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
+				domain := result.Domain()
+				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
+				destination.Address = net.ParseAddress(domain)
+				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
+					ob.RouteTarget = destination
+				} else {
+					ob.Target = destination
+				}
+			}
+			d.routedDispatch(ctx, outbound, destination)
+		}()
+	}
+	return nil
+}
+
 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) {
 	payload := buf.New()
 	defer payload.Release()

+ 15 - 2
app/reverse/bridge.go

@@ -147,7 +147,7 @@ func (w *BridgeWorker) Connections() uint32 {
 	return w.worker.ActiveConnections()
 }
 
-func (w *BridgeWorker) handleInternalConn(link transport.Link) {
+func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
 	go func() {
 		reader := link.Reader
 		for {
@@ -181,7 +181,7 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra
 	uplinkReader, uplinkWriter := pipe.New(opt...)
 	downlinkReader, downlinkWriter := pipe.New(opt...)
 
-	w.handleInternalConn(transport.Link{
+	w.handleInternalConn(&transport.Link{
 		Reader: downlinkReader,
 		Writer: uplinkWriter,
 	})
@@ -191,3 +191,16 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra
 		Writer: downlinkWriter,
 	}, nil
 }
+
+func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
+	if !isInternalDomain(dest) {
+		ctx = session.ContextWithInbound(ctx, &session.Inbound{
+			Tag: w.tag,
+		})
+		return w.dispatcher.DispatchLink(ctx, dest, link)
+	}
+
+	w.handleInternalConn(link)
+
+	return nil
+}

+ 9 - 0
common/mux/server.go

@@ -56,6 +56,15 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*transport
 	return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil
 }
 
+// DispatchLink implements routing.Dispatcher
+func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
+	if dest.Address != muxCoolAddress {
+		return s.dispatcher.DispatchLink(ctx, dest, link)
+	}
+	_, err := NewServerWorker(ctx, s.dispatcher, link)
+	return err
+}
+
 // Start implements common.Runnable.
 func (s *Server) Start() error {
 	return nil

+ 1 - 0
features/routing/dispatcher.go

@@ -17,6 +17,7 @@ type Dispatcher interface {
 
 	// Dispatch returns a Ray for transporting data for the given request.
 	Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error)
+	DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error
 }
 
 // DispatcherType returns the type of Dispatcher interface. Can be used to implement common.HasType.

+ 4 - 0
transport/internet/udp/dispatcher_test.go

@@ -24,6 +24,10 @@ func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*t
 	return d.OnDispatch(ctx, dest)
 }
 
+func (d *TestDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
+	return nil
+}
+
 func (d *TestDispatcher) Start() error {
 	return nil
 }