浏览代码

Implement resolve(server)

世界 1 年之前
父节点
当前提交
b456aff4ac
共有 3 个文件被更改,包括 62 次插入41 次删除
  1. 8 6
      adapter/inbound.go
  2. 1 1
      route/route.go
  3. 53 34
      route/route_dns.go

+ 8 - 6
adapter/inbound.go

@@ -50,12 +50,14 @@ type InboundContext struct {
 	// Deprecated
 	InboundOptions            option.InboundOptions
 	UDPDisableDomainUnmapping bool
-	DestinationAddresses      []netip.Addr
-	SourceGeoIPCode           string
-	GeoIPCode                 string
-	ProcessInfo               *process.Info
-	QueryType                 uint16
-	FakeIP                    bool
+	DNSServer                 string
+
+	DestinationAddresses []netip.Addr
+	SourceGeoIPCode      string
+	GeoIPCode            string
+	ProcessInfo          *process.Info
+	QueryType            uint16
+	FakeIP               bool
 
 	// rule cache
 

+ 1 - 1
route/route.go

@@ -584,7 +584,7 @@ func (r *Router) actionSniff(
 
 func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionResolve) error {
 	if metadata.Destination.IsFqdn() {
-		// TODO: check if WithContext is necessary
+		metadata.DNSServer = action.Server
 		addresses, err := r.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, action.Strategy)
 		if err != nil {
 			return err

+ 53 - 34
route/route_dns.go

@@ -185,6 +185,20 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
 		cached        bool
 		err           error
 	)
+	printResult := func() {
+		if err != nil {
+			if errors.Is(err, dns.ErrResponseRejectedCached) {
+				r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
+			} else if errors.Is(err, dns.ErrResponseRejected) {
+				r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
+			} else {
+				r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
+			}
+		} else if len(responseAddrs) == 0 {
+			r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
+			err = dns.RCodeNameError
+		}
+	}
 	responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy)
 	if cached {
 		if len(responseAddrs) == 0 {
@@ -196,46 +210,51 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
 	ctx, metadata := adapter.ExtendContext(ctx)
 	metadata.Destination = M.Socksaddr{}
 	metadata.Domain = domain
-	var (
-		transport dns.Transport
-		options   dns.QueryOptions
-		rule      adapter.DNSRule
-		ruleIndex int
-	)
-	ruleIndex = -1
-	for {
-		dnsCtx := adapter.OverrideContext(ctx)
-		var addressLimit bool
-		transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
-		if strategy != dns.DomainStrategyAsIS {
-			options.Strategy = strategy
-		}
-		if rule != nil && rule.WithAddressLimit() {
-			addressLimit = true
-			responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
-				metadata.DestinationAddresses = responseAddrs
-				return rule.MatchAddressLimit(metadata)
-			})
-		} else {
-			addressLimit = false
-			responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
+	if metadata.DNSServer != "" {
+		transport, loaded := r.transportMap[metadata.DNSServer]
+		if !loaded {
+			return nil, E.New("transport not found: ", metadata.DNSServer)
 		}
-		if err != nil {
-			if errors.Is(err, dns.ErrResponseRejectedCached) {
-				r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
-			} else if errors.Is(err, dns.ErrResponseRejected) {
-				r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
+		if strategy == dns.DomainStrategyAsIS {
+			if transportDomainStrategy, loaded := r.transportDomainStrategy[transport]; loaded {
+				strategy = transportDomainStrategy
 			} else {
-				r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
+				strategy = r.defaultDomainStrategy
 			}
-		} else if len(responseAddrs) == 0 {
-			r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
-			err = dns.RCodeNameError
 		}
-		if !addressLimit || err == nil {
-			break
+		responseAddrs, err = r.dnsClient.Lookup(ctx, transport, domain, dns.QueryOptions{Strategy: strategy})
+	} else {
+		var (
+			transport dns.Transport
+			options   dns.QueryOptions
+			rule      adapter.DNSRule
+			ruleIndex int
+		)
+		ruleIndex = -1
+		for {
+			dnsCtx := adapter.OverrideContext(ctx)
+			var addressLimit bool
+			transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
+			if strategy != dns.DomainStrategyAsIS {
+				options.Strategy = strategy
+			}
+			if rule != nil && rule.WithAddressLimit() {
+				addressLimit = true
+				responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
+					metadata.DestinationAddresses = responseAddrs
+					return rule.MatchAddressLimit(metadata)
+				})
+			} else {
+				addressLimit = false
+				responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
+			}
+			if !addressLimit || err == nil {
+				break
+			}
+			printResult()
 		}
 	}
+	printResult()
 	if len(responseAddrs) > 0 {
 		r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
 	}