Explorar o código

Fix buffer.UDP destination override (#2356)

dyhkwong %!s(int64=2) %!d(string=hai) anos
pai
achega
b8bd243df5

+ 28 - 79
app/dispatcher/default.go

@@ -4,7 +4,6 @@ package dispatcher
 
 import (
 	"context"
-	"fmt"
 	"strings"
 	"sync"
 	"time"
@@ -135,77 +134,10 @@ func (*DefaultDispatcher) Start() error {
 // Close implements common.Closable.
 func (*DefaultDispatcher) Close() error { return nil }
 
-func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
-	downOpt := pipe.OptionsFromContext(ctx)
-	upOpt := downOpt
-
-	if network == net.Network_UDP {
-		var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
-		// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
-		// When target replies, server will restore the domain and send back to client.
-		// Note: this map is not global but per connection context
-		upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
-			for i, buffer := range mb {
-				if buffer.UDP == nil {
-					continue
-				}
-				addr := buffer.UDP.Address
-				if addr.Family().IsIP() {
-					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
-						domain := fkr0.GetDomainFromFakeDNS(addr)
-						if len(domain) > 0 {
-							buffer.UDP.Address = net.DomainAddress(domain)
-							newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
-						} else {
-							newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
-						}
-					}
-				} else {
-					if ip2domain == nil {
-						ip2domain = new(sync.Map)
-						newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
-					}
-					domain := addr.Domain()
-					ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
-					if err == nil {
-						for _, ip := range ips {
-							ip2domain.Store(ip.String(), domain)
-						}
-						newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
-					} else {
-						newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
-					}
-				}
-			}
-			return mb
-		}))
-		downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
-			for i, buffer := range mb {
-				if buffer.UDP == nil {
-					continue
-				}
-				addr := buffer.UDP.Address
-				if addr.Family().IsIP() {
-					if ip2domain == nil {
-						continue
-					}
-					if domain, found := ip2domain.Load(addr.IP().String()); found {
-						buffer.UDP.Address = net.DomainAddress(domain.(string))
-						newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
-					}
-				} else {
-					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
-						fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
-						buffer.UDP.Address = fakeIp[0]
-						newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
-					}
-				}
-			}
-			return mb
-		}))
-	}
-	uplinkReader, uplinkWriter := pipe.New(upOpt...)
-	downlinkReader, downlinkWriter := pipe.New(downOpt...)
+func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
+	opt := pipe.OptionsFromContext(ctx)
+	uplinkReader, uplinkWriter := pipe.New(opt...)
+	downlinkReader, downlinkWriter := pipe.New(opt...)
 
 	inboundLink := &transport.Link{
 		Reader: downlinkReader,
@@ -263,7 +195,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu
 		protocolString = resComp.ProtocolForDomainResult()
 	}
 	for _, p := range request.OverrideDestinationForProtocol {
-		if strings.HasPrefix(protocolString, p) {
+		if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) {
 			return true
 		}
 		if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
@@ -287,7 +219,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 		panic("Dispatcher: Invalid destination.")
 	}
 	ob := &session.Outbound{
-		Target: destination,
+		OriginalTarget: destination,
+		Target:         destination,
 	}
 	ctx = session.ContextWithOutbound(ctx, ob)
 	content := session.ContentFromContext(ctx)
@@ -295,9 +228,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 		content = new(session.Content)
 		ctx = session.ContextWithContent(ctx, content)
 	}
-
 	sniffingRequest := content.SniffingRequest
-	inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
+	inbound, outbound := d.getLink(ctx)
 	if !sniffingRequest.Enabled {
 		go d.routedDispatch(ctx, outbound, destination)
 	} else {
@@ -314,7 +246,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 				domain := result.Domain()
 				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
 				destination.Address = net.ParseAddress(domain)
-				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
+				protocol := result.Protocol()
+				if resComp, ok := result.(SnifferResultComposite); ok {
+					protocol = resComp.ProtocolForDomainResult()
+				}
+				isFakeIP := false
+				if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
+					isFakeIP = true
+				}
+				if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
 					ob.RouteTarget = destination
 				} else {
 					ob.Target = destination
@@ -332,7 +272,8 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 		return newError("Dispatcher: Invalid destination.")
 	}
 	ob := &session.Outbound{
-		Target: destination,
+		OriginalTarget: destination,
+		Target:         destination,
 	}
 	ctx = session.ContextWithOutbound(ctx, ob)
 	content := session.ContentFromContext(ctx)
@@ -356,7 +297,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 			domain := result.Domain()
 			newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
 			destination.Address = net.ParseAddress(domain)
-			if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
+			protocol := result.Protocol()
+			if resComp, ok := result.(SnifferResultComposite); ok {
+				protocol = resComp.ProtocolForDomainResult()
+			}
+			isFakeIP := false
+			if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
+				isFakeIP = true
+			}
+			if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
 				ob.RouteTarget = destination
 			} else {
 				ob.Target = destination

+ 6 - 1
app/proxyman/outbound/handler.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/xtls/xray-core/app/proxyman"
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/mux"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net/cnc"
@@ -166,6 +167,11 @@ func (h *Handler) Tag() string {
 
 // Dispatch implements proxy.Outbound.Dispatch.
 func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
+	outbound := session.OutboundFromContext(ctx)
+	if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
+		link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
+		link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
+	}
 	if h.mux != nil {
 		test := func(err error) {
 			if err != nil {
@@ -175,7 +181,6 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
 				common.Interrupt(link.Writer)
 			}
 		}
-		outbound := session.OutboundFromContext(ctx)
 		if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
 			switch h.udp443 {
 			case "reject":

+ 38 - 0
common/buf/override.go

@@ -0,0 +1,38 @@
+package buf
+
+import (
+	"github.com/xtls/xray-core/common/net"
+)
+
+type EndpointOverrideReader struct {
+	Reader
+	Dest         net.Address
+	OriginalDest net.Address
+}
+
+func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) {
+	mb, err := r.Reader.ReadMultiBuffer()
+	if err == nil {
+		for _, b := range mb {
+			if b.UDP != nil && b.UDP.Address == r.OriginalDest {
+				b.UDP.Address = r.Dest
+			}
+		}
+	}
+	return mb, err
+}
+
+type EndpointOverrideWriter struct {
+	Writer
+	Dest         net.Address
+	OriginalDest net.Address
+}
+
+func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error {
+	for _, b := range mb {
+		if b.UDP != nil && b.UDP.Address == w.Dest {
+			b.UDP.Address = w.OriginalDest
+		}
+	}
+	return w.Writer.WriteMultiBuffer(mb)
+}

+ 3 - 2
common/session/session.go

@@ -55,8 +55,9 @@ type Inbound struct {
 // Outbound is the metadata of an outbound connection.
 type Outbound struct {
 	// Target address of the outbound connection.
-	Target      net.Destination
-	RouteTarget net.Destination
+	OriginalTarget net.Destination
+	Target         net.Destination
+	RouteTarget    net.Destination
 	// Gateway address
 	Gateway net.Address
 }

+ 0 - 5
transport/pipe/impl.go

@@ -24,7 +24,6 @@ const (
 type pipeOption struct {
 	limit           int32 // maximum buffer size in bytes
 	discardOverflow bool
-	onTransmission  func(buffer buf.MultiBuffer) buf.MultiBuffer
 }
 
 func (o *pipeOption) isFull(curSize int32) bool {
@@ -141,10 +140,6 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		return nil
 	}
 
-	if p.option.onTransmission != nil {
-		mb = p.option.onTransmission(mb)
-	}
-
 	for {
 		err := p.writeMultiBufferInternal(mb)
 		if err == nil {

+ 0 - 7
transport/pipe/pipe.go

@@ -3,7 +3,6 @@ package pipe
 import (
 	"context"
 
-	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/signal"
 	"github.com/xtls/xray-core/common/signal/done"
 	"github.com/xtls/xray-core/features/policy"
@@ -26,12 +25,6 @@ func WithSizeLimit(limit int32) Option {
 	}
 }
 
-func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option {
-	return func(option *pipeOption) {
-		option.onTransmission = hook
-	}
-}
-
 // DiscardOverflow returns an Option for Pipe to discard writes if full.
 func DiscardOverflow() Option {
 	return func(opt *pipeOption) {