Browse Source

Fix ping domain

世界 4 months ago
parent
commit
49056b5060
2 changed files with 63 additions and 9 deletions
  1. 61 7
      route/route.go
  2. 2 2
      transport/wireguard/device_nat.go

+ 61 - 7
route/route.go

@@ -17,6 +17,7 @@ import (
 	R "github.com/sagernet/sing-box/route/rule"
 	"github.com/sagernet/sing-mux"
 	"github.com/sagernet/sing-tun"
+	"github.com/sagernet/sing-tun/ping"
 	"github.com/sagernet/sing-vmess"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
@@ -271,6 +272,7 @@ func (r *Router) PreMatch(metadata adapter.InboundContext, routeContext tun.Dire
 	if err != nil {
 		return nil, err
 	}
+	var directRouteOutbound adapter.DirectRouteOutbound
 	if selectedRule != nil {
 		switch action := selectedRule.Action().(type) {
 		case *R.RuleActionReject:
@@ -296,17 +298,69 @@ func (r *Router) PreMatch(metadata adapter.InboundContext, routeContext tun.Dire
 			if !common.Contains(outbound.Network(), metadata.Network) {
 				return nil, E.New(metadata.Network, " is not supported by outbound: ", action.Outbound)
 			}
-			return outbound.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext, timeout)
+			directRouteOutbound = outbound.(adapter.DirectRouteOutbound)
 		}
 	}
-	if selectedRule != nil || metadata.Network != N.NetworkICMP {
-		return nil, nil
+	if directRouteOutbound == nil {
+		if selectedRule != nil || metadata.Network != N.NetworkICMP {
+			return nil, nil
+		}
+		defaultOutbound := r.outbound.Default()
+		if !common.Contains(defaultOutbound.Network(), metadata.Network) {
+			return nil, E.New(metadata.Network, " is not supported by default outbound: ", defaultOutbound.Tag())
+		}
+		directRouteOutbound = defaultOutbound.(adapter.DirectRouteOutbound)
 	}
-	defaultOutbound := r.outbound.Default()
-	if !common.Contains(defaultOutbound.Network(), metadata.Network) {
-		return nil, E.New(metadata.Network, " is not supported by default outbound: ", defaultOutbound.Tag())
+	if metadata.Destination.IsFqdn() {
+		if len(metadata.DestinationAddresses) == 0 {
+			var strategy C.DomainStrategy
+			if metadata.Source.IsIPv4() {
+				strategy = C.DomainStrategyIPv4Only
+			} else {
+				strategy = C.DomainStrategyIPv6Only
+			}
+			err = r.actionResolve(r.ctx, &metadata, &R.RuleActionResolve{
+				Strategy: strategy,
+			})
+			if err != nil {
+				return nil, err
+			}
+		}
+		var newDestination netip.Addr
+		if metadata.Source.IsIPv4() {
+			for _, address := range metadata.DestinationAddresses {
+				if address.Is4() {
+					newDestination = address
+					break
+				}
+			}
+		} else {
+			for _, address := range metadata.DestinationAddresses {
+				if address.Is6() {
+					newDestination = address
+					break
+				}
+			}
+		}
+		if !newDestination.IsValid() {
+			if metadata.Source.IsIPv4() {
+				return nil, E.New("no IPv4 address found for domain: ", metadata.Destination.Fqdn)
+			} else {
+				return nil, E.New("no IPv6 address found for domain: ", metadata.Destination.Fqdn)
+			}
+		}
+		metadata.Destination = M.Socksaddr{
+			Addr: newDestination,
+		}
+		routeContext = ping.NewContextDestinationWriter(routeContext, metadata.OriginDestination.Addr)
+		var routeDestination tun.DirectRouteDestination
+		routeDestination, err = directRouteOutbound.NewDirectRouteConnection(metadata, routeContext, timeout)
+		if err != nil {
+			return nil, err
+		}
+		return ping.NewDestinationWriter(routeDestination, newDestination), nil
 	}
-	return defaultOutbound.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext, timeout)
+	return directRouteOutbound.NewDirectRouteConnection(metadata, routeContext, timeout)
 }
 
 func (r *Router) matchRule(

+ 2 - 2
transport/wireguard/device_nat.go

@@ -20,7 +20,7 @@ type natDeviceWrapper struct {
 	ctx            context.Context
 	logger         logger.ContextLogger
 	packetOutbound chan *buf.Buffer
-	rewriter       *ping.Rewriter
+	rewriter       *ping.SourceRewriter
 	buffer         [][]byte
 }
 
@@ -30,7 +30,7 @@ func NewNATDevice(ctx context.Context, logger logger.ContextLogger, upstream Dev
 		ctx:            ctx,
 		logger:         logger,
 		packetOutbound: make(chan *buf.Buffer, 256),
-		rewriter:       ping.NewRewriter(ctx, logger, upstream.Inet4Address(), upstream.Inet6Address()),
+		rewriter:       ping.NewSourceRewriter(ctx, logger, upstream.Inet4Address(), upstream.Inet6Address()),
 	}
 	return wrapper
 }