Browse Source

DNS: Fix some bugs; Refactors; Optimizations (#4659)

patterniha 6 months ago
parent
commit
aa4134f4a6

+ 188 - 0
app/dns/cache_controller.go

@@ -0,0 +1,188 @@
+package dns
+
+import (
+	"context"
+	go_errors "errors"
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/signal/pubsub"
+	"github.com/xtls/xray-core/common/task"
+	dns_feature "github.com/xtls/xray-core/features/dns"
+	"golang.org/x/net/dns/dnsmessage"
+	"sync"
+	"time"
+)
+
+type CacheController struct {
+	sync.RWMutex
+	ips          map[string]*record
+	pub          *pubsub.Service
+	cacheCleanup *task.Periodic
+	name         string
+	disableCache bool
+}
+
+func NewCacheController(name string, disableCache bool) *CacheController {
+	c := &CacheController{
+		name:         name,
+		disableCache: disableCache,
+		ips:          make(map[string]*record),
+		pub:          pubsub.NewService(),
+	}
+
+	c.cacheCleanup = &task.Periodic{
+		Interval: time.Minute,
+		Execute:  c.CacheCleanup,
+	}
+	return c
+}
+
+// CacheCleanup clears expired items from cache
+func (c *CacheController) CacheCleanup() error {
+	now := time.Now()
+	c.Lock()
+	defer c.Unlock()
+
+	if len(c.ips) == 0 {
+		return errors.New("nothing to do. stopping...")
+	}
+
+	for domain, record := range c.ips {
+		if record.A != nil && record.A.Expire.Before(now) {
+			record.A = nil
+		}
+		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
+			record.AAAA = nil
+		}
+
+		if record.A == nil && record.AAAA == nil {
+			errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
+			delete(c.ips, domain)
+		} else {
+			c.ips[domain] = record
+		}
+	}
+
+	if len(c.ips) == 0 {
+		c.ips = make(map[string]*record)
+	}
+
+	return nil
+}
+
+func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
+	elapsed := time.Since(req.start)
+
+	c.Lock()
+	rec, found := c.ips[req.domain]
+	if !found {
+		rec = &record{}
+	}
+
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		rec.A = ipRec
+	case dnsmessage.TypeAAAA:
+		rec.AAAA = ipRec
+	}
+
+	errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
+	c.ips[req.domain] = rec
+
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		c.pub.Publish(req.domain+"4", nil)
+		if !c.disableCache {
+			_, _, err := rec.AAAA.getIPs()
+			if !go_errors.Is(err, errRecordNotFound) {
+				c.pub.Publish(req.domain+"6", nil)
+			}
+		}
+	case dnsmessage.TypeAAAA:
+		c.pub.Publish(req.domain+"6", nil)
+		if !c.disableCache {
+			_, _, err := rec.A.getIPs()
+			if !go_errors.Is(err, errRecordNotFound) {
+				c.pub.Publish(req.domain+"4", nil)
+			}
+		}
+	}
+
+	c.Unlock()
+	common.Must(c.cacheCleanup.Start())
+}
+
+func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
+	c.RLock()
+	record, found := c.ips[domain]
+	c.RUnlock()
+
+	if !found {
+		return nil, 0, errRecordNotFound
+	}
+
+	var errs []error
+	var allIPs []net.IP
+	var rTTL uint32 = dns_feature.DefaultTTL
+
+	mergeReq := option.IPv4Enable && option.IPv6Enable
+
+	if option.IPv4Enable {
+		ips, ttl, err := record.A.getIPs()
+		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
+			return ips, ttl, err
+		}
+		if ttl < rTTL {
+			rTTL = ttl
+		}
+		if len(ips) > 0 {
+			allIPs = append(allIPs, ips...)
+		} else {
+			errs = append(errs, err)
+		}
+	}
+
+	if option.IPv6Enable {
+		ips, ttl, err := record.AAAA.getIPs()
+		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
+			return ips, ttl, err
+		}
+		if ttl < rTTL {
+			rTTL = ttl
+		}
+		if len(ips) > 0 {
+			allIPs = append(allIPs, ips...)
+		} else {
+			errs = append(errs, err)
+		}
+	}
+
+	if len(allIPs) > 0 {
+		return allIPs, rTTL, nil
+	}
+	if go_errors.Is(errs[0], errs[1]) {
+		return nil, rTTL, errs[0]
+	}
+	return nil, rTTL, errors.Combine(errs...)
+}
+
+func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
+	// ipv4 and ipv6 belong to different subscription groups
+	if option.IPv4Enable {
+		sub4 = c.pub.Subscribe(domain + "4")
+	}
+	if option.IPv6Enable {
+		sub6 = c.pub.Subscribe(domain + "6")
+	}
+	return
+}
+
+func closeSubscribers(sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
+	if sub4 != nil {
+		sub4.Close()
+	}
+	if sub6 != nil {
+		sub6.Close()
+	}
+}

+ 68 - 53
app/dns/dns.go

@@ -3,12 +3,12 @@ package dns
 
 import (
 	"context"
+	go_errors "errors"
 	"fmt"
 	"sort"
 	"strings"
 	"sync"
 
-	"github.com/xtls/xray-core/app/router"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
@@ -20,8 +20,6 @@ import (
 // DNS is a DNS rely server.
 type DNS struct {
 	sync.Mutex
-	tag                    string
-	disableCache           bool
 	disableFallback        bool
 	disableFallbackIfMatch bool
 	ipOption               *dns.IPOption
@@ -40,13 +38,6 @@ type DomainMatcherInfo struct {
 
 // New creates a new DNS server with given configuration.
 func New(ctx context.Context, config *Config) (*DNS, error) {
-	var tag string
-	if len(config.Tag) > 0 {
-		tag = config.Tag
-	} else {
-		tag = generateRandomTag()
-	}
-
 	var clientIP net.IP
 	switch len(config.ClientIp) {
 	case 0, net.IPv4len, net.IPv6len:
@@ -55,26 +46,28 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 		return nil, errors.New("unexpected client IP length ", len(config.ClientIp))
 	}
 
-	var ipOption *dns.IPOption
+	var ipOption dns.IPOption
 	switch config.QueryStrategy {
 	case QueryStrategy_USE_IP:
-		ipOption = &dns.IPOption{
+		ipOption = dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: true,
 			FakeEnable: false,
 		}
 	case QueryStrategy_USE_IP4:
-		ipOption = &dns.IPOption{
+		ipOption = dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: false,
 			FakeEnable: false,
 		}
 	case QueryStrategy_USE_IP6:
-		ipOption = &dns.IPOption{
+		ipOption = dns.IPOption{
 			IPv4Enable: false,
 			IPv6Enable: true,
 			FakeEnable: false,
 		}
+	default:
+		return nil, errors.New("unexpected query strategy ", config.QueryStrategy)
 	}
 
 	hosts, err := NewStaticHosts(config.StaticHosts)
@@ -82,8 +75,14 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 		return nil, errors.New("failed to create hosts").Base(err)
 	}
 
-	clients := []*Client{}
+	var clients []*Client
 	domainRuleCount := 0
+
+	var defaultTag = config.Tag
+	if len(config.Tag) == 0 {
+		defaultTag = generateRandomTag()
+	}
+
 	for _, ns := range config.NameServer {
 		domainRuleCount += len(ns.PrioritizedDomain)
 	}
@@ -91,7 +90,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 	// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
 	matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1)
 	domainMatcher := &strmatcher.MatcherGroup{}
-	geoipContainer := router.GeoIPMatcherContainer{}
 
 	for _, ns := range config.NameServer {
 		clientIdx := len(clients)
@@ -109,7 +107,18 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 		case net.IPv4len, net.IPv6len:
 			myClientIP = net.IP(ns.ClientIp)
 		}
-		client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain)
+
+		disableCache := config.DisableCache
+
+		var tag = defaultTag
+		if len(ns.Tag) > 0 {
+			tag = ns.Tag
+		}
+		clientIPOption := ResolveIpOptionOverride(ns.QueryStrategy, ipOption)
+		if !clientIPOption.IPv4Enable && !clientIPOption.IPv6Enable {
+			return nil, errors.New("no QueryStrategy available for ", ns.Address)
+		}
+		client, err := NewClient(ctx, ns, myClientIP, disableCache, tag, clientIPOption, &matcherInfos, updateDomain)
 		if err != nil {
 			return nil, errors.New("failed to create client").Base(err)
 		}
@@ -118,18 +127,16 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 
 	// If there is no DNS client in config, add a `localhost` DNS client
 	if len(clients) == 0 {
-		clients = append(clients, NewLocalDNSClient())
+		clients = append(clients, NewLocalDNSClient(ipOption))
 	}
 
 	return &DNS{
-		tag:                    tag,
 		hosts:                  hosts,
-		ipOption:               ipOption,
+		ipOption:               &ipOption,
 		clients:                clients,
 		ctx:                    ctx,
 		domainMatcher:          domainMatcher,
 		matcherInfos:           matcherInfos,
-		disableCache:           config.DisableCache,
 		disableFallback:        config.DisableFallback,
 		disableFallbackIfMatch: config.DisableFallbackIfMatch,
 	}, nil
@@ -153,11 +160,21 @@ func (s *DNS) Close() error {
 // IsOwnLink implements proxy.dns.ownLinkVerifier
 func (s *DNS) IsOwnLink(ctx context.Context) bool {
 	inbound := session.InboundFromContext(ctx)
-	return inbound != nil && inbound.Tag == s.tag
+	if inbound == nil {
+		return false
+	}
+	for _, client := range s.clients {
+		if client.tag == inbound.Tag {
+			return true
+		}
+	}
+	return false
 }
 
 // LookupIP implements dns.Client.
 func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, error) {
+	// Normalize the FQDN form query
+	domain = strings.TrimSuffix(domain, ".")
 	if domain == "" {
 		return nil, 0, errors.New("empty domain name")
 	}
@@ -169,9 +186,6 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
 		return nil, 0, dns.ErrEmptyResponse
 	}
 
-	// Normalize the FQDN form query
-	domain = strings.TrimSuffix(domain, ".")
-
 	// Static host lookup
 	switch addrs := s.hosts.Lookup(domain, option); {
 	case addrs == nil: // Domain not recorded in static host
@@ -184,32 +198,49 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
 	default: // Successfully found ip records in static host
 		errors.LogInfo(s.ctx, "returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs)
 		ips, err := toNetIP(addrs)
-		return ips, 10, err // Hosts ttl is 10
+		if err != nil {
+			return nil, 0, err
+		}
+		return ips, 10, nil // Hosts ttl is 10
 	}
 
 	// Name servers lookup
-	errs := []error{}
-	ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag})
+	var errs []error
 	for _, client := range s.sortClients(domain) {
 		if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
 			errors.LogDebug(s.ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
 			continue
 		}
-		ips, ttl, err := client.QueryIP(ctx, domain, option, s.disableCache)
+
+		ips, ttl, err := client.QueryIP(s.ctx, domain, option)
+
 		if len(ips) > 0 {
+			if ttl == 0 {
+				ttl = 1
+			}
 			return ips, ttl, nil
 		}
-		if err != nil {
-			errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name())
-			errs = append(errs, err)
-		}
-		// 5 for RcodeRefused in miekg/dns, hardcode to reduce binary size
-		if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch && err != dns.ErrEmptyResponse && dns.RCodeFromError(err) != 5 {
-			return nil, 0, err
+
+		errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name())
+		if err == nil {
+			err = dns.ErrEmptyResponse
 		}
+		errs = append(errs, err)
+
 	}
 
-	return nil, 0, errors.New("returning nil for domain ", domain).Base(errors.Combine(errs...))
+	if len(errs) > 0 {
+		allErrs := errors.Combine(errs...)
+		err0 := errs[0]
+		if errors.AllEqual(err0, allErrs) {
+			if go_errors.Is(err0, dns.ErrEmptyResponse) {
+				return nil, 0, dns.ErrEmptyResponse
+			}
+			return nil, 0, errors.New("returning nil for domain ", domain).Base(err0)
+		}
+		return nil, 0, errors.New("returning nil for domain ", domain).Base(allErrs)
+	}
+	return nil, 0, dns.ErrEmptyResponse
 }
 
 // LookupHosts implements dns.HostsLookup.
@@ -228,22 +259,6 @@ func (s *DNS) LookupHosts(domain string) *net.Address {
 	return nil
 }
 
-// GetIPOption implements ClientWithIPOption.
-func (s *DNS) GetIPOption() *dns.IPOption {
-	return s.ipOption
-}
-
-// SetQueryOption implements ClientWithIPOption.
-func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) {
-	s.ipOption.IPv4Enable = isIPv4Enable
-	s.ipOption.IPv6Enable = isIPv6Enable
-}
-
-// SetFakeDNSOption implements ClientWithIPOption.
-func (s *DNS) SetFakeDNSOption(isFakeEnable bool) {
-	s.ipOption.FakeEnable = isFakeEnable
-}
-
 func (s *DNS) sortClients(domain string) []*Client {
 	clients := make([]*Client, 0, len(s.clients))
 	clientUsed := make([]bool, len(s.clients))

+ 3 - 1
app/dns/dns_test.go

@@ -76,6 +76,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 		case q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA:
 			ans.MsgHdr.Rcode = dns.RcodeNameError
 
+		case q.Name == "notexist.google.com." && q.Qtype == dns.TypeA:
+			ans.MsgHdr.Rcode = dns.RcodeNameError
+
 		case q.Name == "hostname." && q.Qtype == dns.TypeA:
 			rr, _ := dns.NewRR("hostname. IN A 127.0.0.1")
 			ans.Answer = append(ans.Answer, rr)
@@ -117,7 +120,6 @@ func TestUDPServerSubnet(t *testing.T) {
 		Handler: &staticHandler{},
 		UDPSize: 1200,
 	}
-
 	go dnsServer.ListenAndServe()
 	time.Sleep(time.Second)
 

+ 20 - 18
app/dns/dnscommon.go

@@ -32,31 +32,30 @@ type record struct {
 // IPRecord is a cacheable item for a resolved domain
 type IPRecord struct {
 	ReqID     uint16
-	IP        []net.Address
+	IP        []net.IP
 	Expire    time.Time
 	RCode     dnsmessage.RCode
 	RawHeader *dnsmessage.Header
 }
 
-func (r *IPRecord) getIPs() ([]net.Address, uint32, error) {
-	if r == nil || r.Expire.Before(time.Now()) {
+func (r *IPRecord) getIPs() ([]net.IP, uint32, error) {
+	if r == nil {
 		return nil, 0, errRecordNotFound
 	}
-	if r.RCode != dnsmessage.RCodeSuccess {
-		return nil, 0, dns_feature.RCodeError(r.RCode)
+	untilExpire := time.Until(r.Expire)
+	if untilExpire <= 0 {
+		return nil, 0, errRecordNotFound
 	}
-	ttl := uint32(time.Until(r.Expire) / time.Second)
-	return r.IP, ttl, nil
-}
 
-func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
-	if newRec == nil {
-		return false
+	ttl := uint32(untilExpire/time.Second) + uint32(1)
+	if r.RCode != dnsmessage.RCodeSuccess {
+		return nil, ttl, dns_feature.RCodeError(r.RCode)
 	}
-	if baseRec == nil {
-		return true
+	if len(r.IP) == 0 {
+		return nil, ttl, dns_feature.ErrEmptyResponse
 	}
-	return baseRec.Expire.Before(newRec.Expire)
+
+	return r.IP, ttl, nil
 }
 
 var errRecordNotFound = errors.New("record not found")
@@ -193,7 +192,7 @@ func parseResponse(payload []byte) (*IPRecord, error) {
 	ipRecord := &IPRecord{
 		ReqID:     h.ID,
 		RCode:     h.RCode,
-		Expire:    now.Add(time.Second * 600),
+		Expire:    now.Add(time.Second * dns_feature.DefaultTTL),
 		RawHeader: &h,
 	}
 
@@ -209,7 +208,7 @@ L:
 
 		ttl := ah.TTL
 		if ttl == 0 {
-			ttl = 600
+			ttl = 1
 		}
 		expire := now.Add(time.Duration(ttl) * time.Second)
 		if ipRecord.Expire.After(expire) {
@@ -223,14 +222,17 @@ L:
 				errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name)
 				break L
 			}
-			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
+			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]).IP())
 		case dnsmessage.TypeAAAA:
 			ans, err := parser.AAAAResource()
 			if err != nil {
 				errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name)
 				break L
 			}
-			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
+			newIP := net.IPAddress(ans.AAAA[:]).IP()
+			if len(newIP) == net.IPv6len {
+				ipRecord.IP = append(ipRecord.IP, newIP)
+			}
 		default:
 			if err := parser.SkipAnswer(); err != nil {
 				errors.LogInfoInner(context.Background(), err, "failed to skip answer")

+ 3 - 3
app/dns/dnscommon_test.go

@@ -51,7 +51,7 @@ func Test_parseResponse(t *testing.T) {
 	}{
 		{
 			"empty",
-			&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess, nil},
+			&IPRecord{0, []net.IP(nil), time.Time{}, dnsmessage.RCodeSuccess, nil},
 			false,
 		},
 		{
@@ -63,7 +63,7 @@ func Test_parseResponse(t *testing.T) {
 			"a record",
 			&IPRecord{
 				1,
-				[]net.Address{net.ParseAddress("8.8.8.8"), net.ParseAddress("8.8.4.4")},
+				[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
 				time.Time{},
 				dnsmessage.RCodeSuccess,
 				nil,
@@ -72,7 +72,7 @@ func Test_parseResponse(t *testing.T) {
 		},
 		{
 			"aaaa record",
-			&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil},
+			&IPRecord{2, []net.IP{net.ParseIP("2001::123:8888"), net.ParseIP("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil},
 			false,
 		},
 	}

+ 54 - 45
app/dns/nameserver.go

@@ -21,25 +21,23 @@ type Server interface {
 	// Name of the Client.
 	Name() string
 	// QueryIP sends IP queries to its configured server.
-	QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns.IPOption, disableCache bool) ([]net.IP, uint32, error)
+	QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error)
 }
 
 // Client is the interface for DNS client.
 type Client struct {
 	server             Server
-	clientIP           net.IP
 	skipFallback       bool
 	domains            []string
 	expectedIPs        []*router.GeoIPMatcher
 	allowUnexpectedIPs bool
 	tag                string
 	timeoutMs          time.Duration
+	ipOption           *dns.IPOption
 }
 
-var errExpectedIPNonMatch = errors.New("expectedIPs not match")
-
 // NewServer creates a name server object according to the network destination url.
-func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) {
+func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) (Server, error) {
 	if address := dest.Address; address.Family().IsDomain() {
 		u, err := url.Parse(address.Domain())
 		if err != nil {
@@ -47,26 +45,29 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis
 		}
 		switch {
 		case strings.EqualFold(u.String(), "localhost"):
-			return NewLocalNameServer(queryStrategy), nil
+			return NewLocalNameServer(), nil
 		case strings.EqualFold(u.Scheme, "https"): // DNS-over-HTTPS Remote mode
-			return NewDoHNameServer(u, queryStrategy, dispatcher, false), nil
+			return NewDoHNameServer(u, dispatcher, false, disableCache, clientIP), nil
 		case strings.EqualFold(u.Scheme, "h2c"): // DNS-over-HTTPS h2c Remote mode
-			return NewDoHNameServer(u, queryStrategy, dispatcher, true), nil
+			return NewDoHNameServer(u, dispatcher, true, disableCache, clientIP), nil
 		case strings.EqualFold(u.Scheme, "https+local"): // DNS-over-HTTPS Local mode
-			return NewDoHNameServer(u, queryStrategy, nil, false), nil
+			return NewDoHNameServer(u, nil, false, disableCache, clientIP), nil
 		case strings.EqualFold(u.Scheme, "h2c+local"): // DNS-over-HTTPS h2c Local mode
-			return NewDoHNameServer(u, queryStrategy, nil, true), nil
+			return NewDoHNameServer(u, nil, true, disableCache, clientIP), nil
 		case strings.EqualFold(u.Scheme, "quic+local"): // DNS-over-QUIC Local mode
-			return NewQUICNameServer(u, queryStrategy)
+			return NewQUICNameServer(u, disableCache, clientIP)
 		case strings.EqualFold(u.Scheme, "tcp"): // DNS-over-TCP Remote mode
-			return NewTCPNameServer(u, dispatcher, queryStrategy)
+			return NewTCPNameServer(u, dispatcher, disableCache, clientIP)
 		case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
-			return NewTCPLocalNameServer(u, queryStrategy)
+			return NewTCPLocalNameServer(u, disableCache, clientIP)
 		case strings.EqualFold(u.String(), "fakedns"):
 			var fd dns.FakeDNSEngine
-			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
+			err = core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
 				fd = fdns
 			})
+			if err != nil {
+				return nil, err
+			}
 			return NewFakeDNSServer(fd), nil
 		}
 	}
@@ -74,7 +75,7 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis
 		dest.Network = net.Network_UDP
 	}
 	if dest.Network == net.Network_UDP { // UDP classic DNS mode
-		return NewClassicNameServer(dest, dispatcher, queryStrategy), nil
+		return NewClassicNameServer(dest, dispatcher, disableCache, clientIP), nil
 	}
 	return nil, errors.New("No available name server could be created from ", dest).AtWarning()
 }
@@ -84,7 +85,9 @@ func NewClient(
 	ctx context.Context,
 	ns *NameServer,
 	clientIP net.IP,
-	container router.GeoIPMatcherContainer,
+	disableCache bool,
+	tag string,
+	ipOption dns.IPOption,
 	matcherInfos *[]*DomainMatcherInfo,
 	updateDomainRule func(strmatcher.Matcher, int, []*DomainMatcherInfo) error,
 ) (*Client, error) {
@@ -92,7 +95,7 @@ func NewClient(
 
 	err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
 		// Create a new server for each client for now
-		server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy())
+		server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, disableCache, clientIP)
 		if err != nil {
 			return errors.New("failed to create nameserver").Base(err).AtWarning()
 		}
@@ -149,7 +152,7 @@ func NewClient(
 		// Establish expected IPs
 		var matchers []*router.GeoIPMatcher
 		for _, geoip := range ns.Geoip {
-			matcher, err := container.Add(geoip)
+			matcher, err := router.GlobalGeoIPContainer.Add(geoip)
 			if err != nil {
 				return errors.New("failed to create ip matcher").Base(err).AtWarning()
 			}
@@ -169,15 +172,15 @@ func NewClient(
 		if ns.TimeoutMs > 0 {
 			timeoutMs = time.Duration(ns.TimeoutMs) * time.Millisecond
 		}
-		
+
 		client.server = server
-		client.clientIP = clientIP
 		client.skipFallback = ns.SkipFallback
 		client.domains = rules
 		client.expectedIPs = matchers
 		client.allowUnexpectedIPs = ns.AllowUnexpectedIPs
-		client.tag = ns.Tag
+		client.tag = tag
 		client.timeoutMs = timeoutMs
+		client.ipOption = &ipOption
 		return nil
 	})
 	return client, err
@@ -189,31 +192,43 @@ func (c *Client) Name() string {
 }
 
 // QueryIP sends DNS query to the name server with the client's IP.
-func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption, disableCache bool) ([]net.IP, uint32, error) {
-	ctx, cancel := context.WithTimeout(ctx, c.timeoutMs)
-	if len(c.tag) != 0 {
-		content := session.InboundFromContext(ctx)
-		errors.LogDebug(ctx, "DNS: client override tag from ", content.Tag, " to ", c.tag)
-		// create a new context to override the tag
-		// do not direct set *content.Tag, it might be used by other clients
-		ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: c.tag})
+func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error) {
+	option.IPv4Enable = option.IPv4Enable && c.ipOption.IPv4Enable
+	option.IPv6Enable = option.IPv6Enable && c.ipOption.IPv6Enable
+	if !option.IPv4Enable && !option.IPv6Enable {
+		return nil, 0, dns.ErrEmptyResponse
 	}
-	ips, ttl, err := c.server.QueryIP(ctx, domain, c.clientIP, option, disableCache)
+
+	ctx, cancel := context.WithTimeout(ctx, c.timeoutMs)
+	ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: c.tag})
+	ips, ttl, err := c.server.QueryIP(ctx, domain, option)
 	cancel()
 
 	if err != nil {
-		return ips, ttl, err
+		return nil, 0, err
+	}
+
+	if len(ips) == 0 {
+		return nil, 0, dns.ErrEmptyResponse
 	}
-	netips, err := c.MatchExpectedIPs(domain, ips)
-	return netips, ttl, err
+
+	if len(c.expectedIPs) > 0 {
+		newIps := c.MatchExpectedIPs(domain, ips)
+		if len(newIps) == 0 {
+			if !c.allowUnexpectedIPs {
+				return nil, 0, dns.ErrEmptyResponse
+			}
+		} else {
+			ips = newIps
+		}
+	}
+
+	return ips, ttl, nil
 }
 
 // MatchExpectedIPs matches queried domain IPs with expected IPs and returns matched ones.
-func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) ([]net.IP, error) {
-	if len(c.expectedIPs) == 0 {
-		return ips, nil
-	}
-	newIps := []net.IP{}
+func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) []net.IP {
+	var newIps []net.IP
 	for _, ip := range ips {
 		for _, matcher := range c.expectedIPs {
 			if matcher.Match(ip) {
@@ -222,14 +237,8 @@ func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) ([]net.IP, error)
 			}
 		}
 	}
-	if len(newIps) == 0 {
-		if c.allowUnexpectedIPs {
-			return ips, nil
-		}
-		return nil, errExpectedIPNonMatch
-	}
 	errors.LogDebug(context.Background(), "domain ", domain, " expectedIPs ", newIps, " matched at server ", c.Name())
-	return newIps, nil
+	return newIps
 }
 
 func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption) dns.IPOption {

+ 50 - 195
app/dns/nameserver_doh.go

@@ -4,12 +4,12 @@ import (
 	"bytes"
 	"context"
 	"crypto/tls"
+	go_errors "errors"
 	"fmt"
 	"io"
 	"net/http"
 	"net/url"
 	"strings"
-	"sync"
 	"time"
 
 	utls "github.com/refraction-networking/utls"
@@ -21,12 +21,9 @@ import (
 	"github.com/xtls/xray-core/common/net/cnc"
 	"github.com/xtls/xray-core/common/protocol/dns"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal/pubsub"
-	"github.com/xtls/xray-core/common/task"
 	dns_feature "github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport/internet"
-	"golang.org/x/net/dns/dnsmessage"
 	"golang.org/x/net/http2"
 )
 
@@ -34,18 +31,14 @@ import (
 // which is compatible with traditional dns over udp(RFC1035),
 // thus most of the DOH implementation is copied from udpns.go
 type DoHNameServer struct {
-	sync.RWMutex
-	ips           map[string]*record
-	pub           *pubsub.Service
-	cleanup       *task.Periodic
-	httpClient    *http.Client
-	dohURL        string
-	name          string
-	queryStrategy QueryStrategy
+	cacheController *CacheController
+	httpClient      *http.Client
+	dohURL          string
+	clientIP        net.IP
 }
 
 // NewDoHNameServer creates DOH/DOHL client object for remote/local resolving.
-func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher routing.Dispatcher, h2c bool) *DoHNameServer {
+func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, h2c bool, disableCache bool, clientIP net.IP) *DoHNameServer {
 	url.Scheme = "https"
 	mode := "DOH"
 	if dispatcher == nil {
@@ -53,15 +46,9 @@ func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher rout
 	}
 	errors.LogInfo(context.Background(), "DNS: created ", mode, " client for ", url.String(), ", with h2c ", h2c)
 	s := &DoHNameServer{
-		ips:           make(map[string]*record),
-		pub:           pubsub.NewService(),
-		name:          mode + "//" + url.Host,
-		dohURL:        url.String(),
-		queryStrategy: queryStrategy,
-	}
-	s.cleanup = &task.Periodic{
-		Interval: time.Minute,
-		Execute:  s.Cleanup,
+		cacheController: NewCacheController(mode+"//"+url.Host, disableCache),
+		dohURL:          url.String(),
+		clientIP:        clientIP,
 	}
 	s.httpClient = &http.Client{
 		Transport: &http2.Transport{
@@ -127,101 +114,25 @@ func NewDoHNameServer(url *url.URL, queryStrategy QueryStrategy, dispatcher rout
 
 // Name implements Server.
 func (s *DoHNameServer) Name() string {
-	return s.name
-}
-
-// Cleanup clears expired items from cache
-func (s *DoHNameServer) Cleanup() error {
-	now := time.Now()
-	s.Lock()
-	defer s.Unlock()
-
-	if len(s.ips) == 0 {
-		return errors.New("nothing to do. stopping...")
-	}
-
-	for domain, record := range s.ips {
-		if record.A != nil && record.A.Expire.Before(now) {
-			record.A = nil
-		}
-		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
-			record.AAAA = nil
-		}
-
-		if record.A == nil && record.AAAA == nil {
-			errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
-			delete(s.ips, domain)
-		} else {
-			s.ips[domain] = record
-		}
-	}
-
-	if len(s.ips) == 0 {
-		s.ips = make(map[string]*record)
-	}
-
-	return nil
-}
-
-func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
-	elapsed := time.Since(req.start)
-
-	s.Lock()
-	rec, found := s.ips[req.domain]
-	if !found {
-		rec = &record{}
-	}
-	updated := false
-
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		if isNewer(rec.A, ipRec) {
-			rec.A = ipRec
-			updated = true
-		}
-	case dnsmessage.TypeAAAA:
-		addr := make([]net.Address, 0, len(ipRec.IP))
-		for _, ip := range ipRec.IP {
-			if len(ip.IP()) == net.IPv6len {
-				addr = append(addr, ip)
-			}
-		}
-		ipRec.IP = addr
-		if isNewer(rec.AAAA, ipRec) {
-			rec.AAAA = ipRec
-			updated = true
-		}
-	}
-	errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
-
-	if updated {
-		s.ips[req.domain] = rec
-	}
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		s.pub.Publish(req.domain+"4", nil)
-	case dnsmessage.TypeAAAA:
-		s.pub.Publish(req.domain+"6", nil)
-	}
-	s.Unlock()
-	common.Must(s.cleanup.Start())
+	return s.cacheController.name
 }
 
 func (s *DoHNameServer) newReqID() uint16 {
 	return 0
 }
 
-func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
-	errors.LogInfo(ctx, s.name, " querying: ", domain)
+func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
+	errors.LogInfo(ctx, s.Name(), " querying: ", domain)
 
-	if s.name+"." == "DOH//"+domain {
-		errors.LogError(ctx, s.name, " tries to resolve itself! Use IP or set \"hosts\" instead.")
+	if s.Name()+"." == "DOH//"+domain {
+		errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.")
+		noResponseErrCh <- errors.New("tries to resolve itself!", s.Name())
 		return
 	}
 
 	// As we don't want our traffic pattern looks like DoH, we use Random-Length Padding instead of Block-Length Padding recommended in RFC 8467
 	// Although DoH server like 1.1.1.1 will pad the response to Block-Length 468, at least it is better than no padding for response at all
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, int(crypto.RandBetween(100, 300))))
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
 
 	var deadline time.Time
 	if d, ok := ctx.Deadline(); ok {
@@ -256,19 +167,22 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP n
 			b, err := dns.PackMessage(r.msg)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to pack dns query for ", domain)
+				noResponseErrCh <- err
 				return
 			}
 			resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to retrieve response for ", domain)
+				noResponseErrCh <- err
 				return
 			}
 			rec, err := parseResponse(resp)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", domain)
+				noResponseErrCh <- err
 				return
 			}
-			s.updateIP(r, rec)
+			s.cacheController.updateIP(r, rec)
 		}(req)
 	}
 }
@@ -301,109 +215,50 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte,
 	return io.ReadAll(resp.Body)
 }
 
-func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	s.RLock()
-	record, found := s.ips[domain]
-	s.RUnlock()
-
-	if !found {
-		return nil, 0, errRecordNotFound
-	}
-
-	var err4 error
-	var err6 error
-	var ips []net.Address
-	var ip6 []net.Address
-	var ttl uint32
-
-	if option.IPv4Enable {
-		ips, ttl, err4 = record.A.getIPs()
-	}
-
-	if option.IPv6Enable {
-		ip6, ttl, err6 = record.AAAA.getIPs()
-		ips = append(ips, ip6...)
-	}
-
-	if len(ips) > 0 {
-		netips, err := toNetIP(ips)
-		return netips, ttl, err
-	}
-
-	if err4 != nil {
-		return nil, 0, err4
-	}
-
-	if err6 != nil {
-		return nil, 0, err6
-	}
-
-	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
-		return nil, 0, dns_feature.ErrEmptyResponse
-	}
-
-	return nil, 0, errRecordNotFound
-}
-
 // QueryIP implements Server.
-func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) { // nolint: dupl
+func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { // nolint: dupl
 	fqdn := Fqdn(domain)
-	option = ResolveIpOptionOverride(s.queryStrategy, option)
-	if !option.IPv4Enable && !option.IPv6Enable {
-		return nil, 0, dns_feature.ErrEmptyResponse
-	}
+	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
+	defer closeSubscribers(sub4, sub6)
 
-	if disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
+	if s.cacheController.disableCache {
+		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
 	} else {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
-			errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
+		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+		if !go_errors.Is(err, errRecordNotFound) {
+			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
+			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
 			return ips, ttl, err
 		}
 	}
 
-	// ipv4 and ipv6 belong to different subscription groups
-	var sub4, sub6 *pubsub.Subscriber
-	if option.IPv4Enable {
-		sub4 = s.pub.Subscribe(fqdn + "4")
-		defer sub4.Close()
-	}
-	if option.IPv6Enable {
-		sub6 = s.pub.Subscribe(fqdn + "6")
-		defer sub6.Close()
-	}
-	done := make(chan interface{})
-	go func() {
-		if sub4 != nil {
-			select {
-			case <-sub4.Wait():
-			case <-ctx.Done():
-			}
-		}
-		if sub6 != nil {
-			select {
-			case <-sub6.Wait():
-			case <-ctx.Done():
-			}
-		}
-		close(done)
-	}()
-	s.sendQuery(ctx, fqdn, clientIP, option)
+	noResponseErrCh := make(chan error, 2)
+	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
 	start := time.Now()
 
-	for {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err != errRecordNotFound {
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-			return ips, ttl, err
+	if sub4 != nil {
+		select {
+		case <-ctx.Done():
+			return nil, 0, ctx.Err()
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub4.Wait():
+			sub4.Close()
 		}
-
+	}
+	if sub6 != nil {
 		select {
 		case <-ctx.Done():
 			return nil, 0, ctx.Err()
-		case <-done:
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub6.Wait():
+			sub6.Close()
 		}
 	}
+
+	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
+	return ips, ttl, err
+
 }

+ 16 - 16
app/dns/nameserver_doh_test.go

@@ -17,12 +17,12 @@ func TestDOHNameServer(t *testing.T) {
 	url, err := url.Parse("https+local://1.1.1.1/dns-query")
 	common.Must(err)
 
-	s := NewDoHNameServer(url, QueryStrategy_USE_IP, nil, false)
+	s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -34,12 +34,12 @@ func TestDOHNameServerWithCache(t *testing.T) {
 	url, err := url.Parse("https+local://1.1.1.1/dns-query")
 	common.Must(err)
 
-	s := NewDoHNameServer(url, QueryStrategy_USE_IP, nil, false)
+	s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -47,10 +47,10 @@ func TestDOHNameServerWithCache(t *testing.T) {
 	}
 
 	ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips2, _, err := s.QueryIP(ctx2, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, true)
+	})
 	cancel()
 	common.Must(err)
 	if r := cmp.Diff(ips2, ips); r != "" {
@@ -62,12 +62,12 @@ func TestDOHNameServerWithIPv4Override(t *testing.T) {
 	url, err := url.Parse("https+local://1.1.1.1/dns-query")
 	common.Must(err)
 
-	s := NewDoHNameServer(url, QueryStrategy_USE_IP4, nil, false)
+	s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
-		IPv6Enable: true,
-	}, false)
+		IPv6Enable: false,
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -85,12 +85,12 @@ func TestDOHNameServerWithIPv6Override(t *testing.T) {
 	url, err := url.Parse("https+local://1.1.1.1/dns-query")
 	common.Must(err)
 
-	s := NewDoHNameServer(url, QueryStrategy_USE_IP6, nil, false)
+	s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
-		IPv4Enable: true,
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
+		IPv4Enable: false,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {

+ 1 - 1
app/dns/nameserver_fakedns.go

@@ -20,7 +20,7 @@ func (FakeDNSServer) Name() string {
 	return "FakeDNS"
 }
 
-func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, opt dns.IPOption, _ bool) ([]net.IP, uint32, error) {
+func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, opt dns.IPOption) ([]net.IP, uint32, error) {
 	if f.fakeDNSEngine == nil {
 		return nil, 0, errors.New("Unable to locate a fake DNS Engine").AtError()
 	}

+ 6 - 19
app/dns/nameserver_local.go

@@ -2,7 +2,6 @@ package dns
 
 import (
 	"context"
-	"strings"
 	"time"
 
 	"github.com/xtls/xray-core/common/errors"
@@ -14,26 +13,15 @@ import (
 
 // LocalNameServer is an wrapper over local DNS feature.
 type LocalNameServer struct {
-	client        *localdns.Client
-	queryStrategy QueryStrategy
+	client *localdns.Client
 }
 
-const errEmptyResponse = "No address associated with hostname"
-
 // QueryIP implements Server.
-func (s *LocalNameServer) QueryIP(ctx context.Context, domain string, _ net.IP, option dns.IPOption, _ bool) (ips []net.IP, ttl uint32, err error) {
-	option = ResolveIpOptionOverride(s.queryStrategy, option)
-	if !option.IPv4Enable && !option.IPv6Enable {
-		return nil, 0, dns.ErrEmptyResponse
-	}
+func (s *LocalNameServer) QueryIP(ctx context.Context, domain string, option dns.IPOption) (ips []net.IP, ttl uint32, err error) {
 
 	start := time.Now()
 	ips, ttl, err = s.client.LookupIP(domain, option)
 
-	if err != nil && strings.HasSuffix(err.Error(), errEmptyResponse) {
-		err = dns.ErrEmptyResponse
-	}
-
 	if len(ips) > 0 {
 		errors.LogInfo(ctx, "Localhost got answer: ", domain, " -> ", ips)
 		log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
@@ -48,15 +36,14 @@ func (s *LocalNameServer) Name() string {
 }
 
 // NewLocalNameServer creates localdns server object for directly lookup in system DNS.
-func NewLocalNameServer(queryStrategy QueryStrategy) *LocalNameServer {
+func NewLocalNameServer() *LocalNameServer {
 	errors.LogInfo(context.Background(), "DNS: created localhost client")
 	return &LocalNameServer{
-		queryStrategy: queryStrategy,
-		client:        localdns.New(),
+		client: localdns.New(),
 	}
 }
 
 // NewLocalDNSClient creates localdns client object for directly lookup in system DNS.
-func NewLocalDNSClient() *Client {
-	return &Client{server: NewLocalNameServer(QueryStrategy_USE_IP)}
+func NewLocalDNSClient(ipOption dns.IPOption) *Client {
+	return &Client{server: NewLocalNameServer(), ipOption: &ipOption}
 }

+ 3 - 4
app/dns/nameserver_local_test.go

@@ -7,18 +7,17 @@ import (
 
 	. "github.com/xtls/xray-core/app/dns"
 	"github.com/xtls/xray-core/common"
-	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/features/dns"
 )
 
 func TestLocalNameServer(t *testing.T) {
-	s := NewLocalNameServer(QueryStrategy_USE_IP)
+	s := NewLocalNameServer()
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP{}, dns.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
 		FakeEnable: false,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {

+ 63 - 194
app/dns/nameserver_quic.go

@@ -4,23 +4,20 @@ import (
 	"bytes"
 	"context"
 	"encoding/binary"
+	go_errors "errors"
 	"net/url"
 	"sync"
 	"time"
 
 	"github.com/quic-go/quic-go"
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/log"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol/dns"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal/pubsub"
-	"github.com/xtls/xray-core/common/task"
 	dns_feature "github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/transport/internet/tls"
-	"golang.org/x/net/dns/dnsmessage"
 	"golang.org/x/net/http2"
 )
 
@@ -33,17 +30,14 @@ const handshakeTimeout = time.Second * 8
 // QUICNameServer implemented DNS over QUIC
 type QUICNameServer struct {
 	sync.RWMutex
-	ips           map[string]*record
-	pub           *pubsub.Service
-	cleanup       *task.Periodic
-	name          string
-	destination   *net.Destination
-	connection    quic.Connection
-	queryStrategy QueryStrategy
+	cacheController *CacheController
+	destination     *net.Destination
+	connection      quic.Connection
+	clientIP        net.IP
 }
 
 // NewQUICNameServer creates DNS-over-QUIC client object for local resolving
-func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServer, error) {
+func NewQUICNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*QUICNameServer, error) {
 	errors.LogInfo(context.Background(), "DNS: created Local DNS-over-QUIC client for ", url.String())
 
 	var err error
@@ -57,15 +51,9 @@ func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServ
 	dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)
 
 	s := &QUICNameServer{
-		ips:           make(map[string]*record),
-		pub:           pubsub.NewService(),
-		name:          url.String(),
-		destination:   &dest,
-		queryStrategy: queryStrategy,
-	}
-	s.cleanup = &task.Periodic{
-		Interval: time.Minute,
-		Execute:  s.Cleanup,
+		cacheController: NewCacheController(url.String(), disableCache),
+		destination:     &dest,
+		clientIP:        clientIP,
 	}
 
 	return s, nil
@@ -73,94 +61,17 @@ func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServ
 
 // Name returns client name
 func (s *QUICNameServer) Name() string {
-	return s.name
-}
-
-// Cleanup clears expired items from cache
-func (s *QUICNameServer) Cleanup() error {
-	now := time.Now()
-	s.Lock()
-	defer s.Unlock()
-
-	if len(s.ips) == 0 {
-		return errors.New("nothing to do. stopping...")
-	}
-
-	for domain, record := range s.ips {
-		if record.A != nil && record.A.Expire.Before(now) {
-			record.A = nil
-		}
-		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
-			record.AAAA = nil
-		}
-
-		if record.A == nil && record.AAAA == nil {
-			errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
-			delete(s.ips, domain)
-		} else {
-			s.ips[domain] = record
-		}
-	}
-
-	if len(s.ips) == 0 {
-		s.ips = make(map[string]*record)
-	}
-
-	return nil
-}
-
-func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
-	elapsed := time.Since(req.start)
-
-	s.Lock()
-	rec, found := s.ips[req.domain]
-	if !found {
-		rec = &record{}
-	}
-	updated := false
-
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		if isNewer(rec.A, ipRec) {
-			rec.A = ipRec
-			updated = true
-		}
-	case dnsmessage.TypeAAAA:
-		addr := make([]net.Address, 0)
-		for _, ip := range ipRec.IP {
-			if len(ip.IP()) == net.IPv6len {
-				addr = append(addr, ip)
-			}
-		}
-		ipRec.IP = addr
-		if isNewer(rec.AAAA, ipRec) {
-			rec.AAAA = ipRec
-			updated = true
-		}
-	}
-	errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
-
-	if updated {
-		s.ips[req.domain] = rec
-	}
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		s.pub.Publish(req.domain+"4", nil)
-	case dnsmessage.TypeAAAA:
-		s.pub.Publish(req.domain+"6", nil)
-	}
-	s.Unlock()
-	common.Must(s.cleanup.Start())
+	return s.cacheController.name
 }
 
 func (s *QUICNameServer) newReqID() uint16 {
 	return 0
 }
 
-func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
-	errors.LogInfo(ctx, s.name, " querying: ", domain)
+func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
+	errors.LogInfo(ctx, s.Name(), " querying: ", domain)
 
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0))
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
 
 	var deadline time.Time
 	if d, ok := ctx.Deadline(); ok {
@@ -192,23 +103,36 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
 			b, err := dns.PackMessage(r.msg)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to pack dns query")
+				noResponseErrCh <- err
 				return
 			}
 
 			dnsReqBuf := buf.New()
-			binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
-			dnsReqBuf.Write(b.Bytes())
+			err = binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
+			if err != nil {
+				errors.LogErrorInner(ctx, err, "binary write failed")
+				noResponseErrCh <- err
+				return
+			}
+			_, err = dnsReqBuf.Write(b.Bytes())
+			if err != nil {
+				errors.LogErrorInner(ctx, err, "buffer write failed")
+				noResponseErrCh <- err
+				return
+			}
 			b.Release()
 
 			conn, err := s.openStream(dnsCtx)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to open quic connection")
+				noResponseErrCh <- err
 				return
 			}
 
 			_, err = conn.Write(dnsReqBuf.Bytes())
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to send query")
+				noResponseErrCh <- err
 				return
 			}
 
@@ -219,136 +143,81 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
 			n, err := respBuf.ReadFullFrom(conn, 2)
 			if err != nil && n == 0 {
 				errors.LogErrorInner(ctx, err, "failed to read response length")
+				noResponseErrCh <- err
 				return
 			}
 			var length int16
 			err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to parse response length")
+				noResponseErrCh <- err
 				return
 			}
 			respBuf.Clear()
 			n, err = respBuf.ReadFullFrom(conn, int32(length))
 			if err != nil && n == 0 {
 				errors.LogErrorInner(ctx, err, "failed to read response length")
+				noResponseErrCh <- err
 				return
 			}
 
 			rec, err := parseResponse(respBuf.Bytes())
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to handle response")
+				noResponseErrCh <- err
 				return
 			}
-			s.updateIP(r, rec)
+			s.cacheController.updateIP(r, rec)
 		}(req)
 	}
 }
 
-func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	s.RLock()
-	record, found := s.ips[domain]
-	s.RUnlock()
-
-	if !found {
-		return nil, 0, errRecordNotFound
-	}
-
-	var err4 error
-	var err6 error
-	var ips []net.Address
-	var ip6 []net.Address
-	var ttl uint32
-
-	if option.IPv4Enable {
-		ips, ttl, err4 = record.A.getIPs()
-	}
-
-	if option.IPv6Enable {
-		ip6, ttl, err6 = record.AAAA.getIPs()
-		ips = append(ips, ip6...)
-	}
-
-	if len(ips) > 0 {
-		netips, err := toNetIP(ips)
-		return netips, ttl, err
-	}
-
-	if err4 != nil {
-		return nil, 0, err4
-	}
-
-	if err6 != nil {
-		return nil, 0, err6
-	}
-
-	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
-		return nil, 0, dns_feature.ErrEmptyResponse
-	}
-
-	return nil, 0, errRecordNotFound
-}
-
 // QueryIP is called from dns.Server->queryIPTimeout
-func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) {
+func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
 	fqdn := Fqdn(domain)
-	option = ResolveIpOptionOverride(s.queryStrategy, option)
-	if !option.IPv4Enable && !option.IPv6Enable {
-		return nil, 0, dns_feature.ErrEmptyResponse
-	}
+	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
+	defer closeSubscribers(sub4, sub6)
 
-	if disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
+	if s.cacheController.disableCache {
+		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
 	} else {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
-			errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
+		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+		if !go_errors.Is(err, errRecordNotFound) {
+			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
+			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
 			return ips, ttl, err
 		}
 	}
 
-	// ipv4 and ipv6 belong to different subscription groups
-	var sub4, sub6 *pubsub.Subscriber
-	if option.IPv4Enable {
-		sub4 = s.pub.Subscribe(fqdn + "4")
-		defer sub4.Close()
-	}
-	if option.IPv6Enable {
-		sub6 = s.pub.Subscribe(fqdn + "6")
-		defer sub6.Close()
-	}
-	done := make(chan interface{})
-	go func() {
-		if sub4 != nil {
-			select {
-			case <-sub4.Wait():
-			case <-ctx.Done():
-			}
-		}
-		if sub6 != nil {
-			select {
-			case <-sub6.Wait():
-			case <-ctx.Done():
-			}
-		}
-		close(done)
-	}()
-	s.sendQuery(ctx, fqdn, clientIP, option)
+	noResponseErrCh := make(chan error, 2)
+	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
 	start := time.Now()
 
-	for {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err != errRecordNotFound {
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-			return ips, ttl, err
+	if sub4 != nil {
+		select {
+		case <-ctx.Done():
+			return nil, 0, ctx.Err()
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub4.Wait():
+			sub4.Close()
 		}
-
+	}
+	if sub6 != nil {
 		select {
 		case <-ctx.Done():
 			return nil, 0, ctx.Err()
-		case <-done:
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub6.Wait():
+			sub6.Close()
 		}
 	}
+
+	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
+	return ips, ttl, err
+
 }
 
 func isActive(s quic.Connection) bool {

+ 13 - 14
app/dns/nameserver_quic_test.go

@@ -16,24 +16,23 @@ import (
 func TestQUICNameServer(t *testing.T) {
 	url, err := url.Parse("quic://dns.adguard-dns.com")
 	common.Must(err)
-	s, err := NewQUICNameServer(url, QueryStrategy_USE_IP)
+	s, err := NewQUICNameServer(url, false, net.IP(nil))
 	common.Must(err)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
 		t.Error("expect some ips, but got 0")
 	}
-
 	ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns.IPOption{
+	ips2, _, err := s.QueryIP(ctx2, "google.com", dns.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, true)
+	})
 	cancel()
 	common.Must(err)
 	if r := cmp.Diff(ips2, ips); r != "" {
@@ -44,13 +43,13 @@ func TestQUICNameServer(t *testing.T) {
 func TestQUICNameServerWithIPv4Override(t *testing.T) {
 	url, err := url.Parse("quic://dns.adguard-dns.com")
 	common.Must(err)
-	s, err := NewQUICNameServer(url, QueryStrategy_USE_IP4)
+	s, err := NewQUICNameServer(url, false, net.IP(nil))
 	common.Must(err)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
 		IPv4Enable: true,
-		IPv6Enable: true,
-	}, false)
+		IPv6Enable: false,
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -67,13 +66,13 @@ func TestQUICNameServerWithIPv4Override(t *testing.T) {
 func TestQUICNameServerWithIPv6Override(t *testing.T) {
 	url, err := url.Parse("quic://dns.adguard-dns.com")
 	common.Must(err)
-	s, err := NewQUICNameServer(url, QueryStrategy_USE_IP6)
+	s, err := NewQUICNameServer(url, false, net.IP(nil))
 	common.Must(err)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
-		IPv4Enable: true,
+	ips, _, err := s.QueryIP(ctx, "google.com", dns.IPOption{
+		IPv4Enable: false,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {

+ 69 - 197
app/dns/nameserver_tcp.go

@@ -4,12 +4,11 @@ import (
 	"bytes"
 	"context"
 	"encoding/binary"
+	go_errors "errors"
 	"net/url"
-	"sync"
 	"sync/atomic"
 	"time"
 
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/log"
@@ -17,34 +16,28 @@ import (
 	"github.com/xtls/xray-core/common/net/cnc"
 	"github.com/xtls/xray-core/common/protocol/dns"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal/pubsub"
-	"github.com/xtls/xray-core/common/task"
 	dns_feature "github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport/internet"
-	"golang.org/x/net/dns/dnsmessage"
 )
 
 // TCPNameServer implemented DNS over TCP (RFC7766).
 type TCPNameServer struct {
-	sync.RWMutex
-	name          string
-	destination   *net.Destination
-	ips           map[string]*record
-	pub           *pubsub.Service
-	cleanup       *task.Periodic
-	reqID         uint32
-	dial          func(context.Context) (net.Conn, error)
-	queryStrategy QueryStrategy
+	cacheController *CacheController
+	destination     *net.Destination
+	reqID           uint32
+	dial            func(context.Context) (net.Conn, error)
+	clientIP        net.IP
 }
 
 // NewTCPNameServer creates DNS over TCP server object for remote resolving.
 func NewTCPNameServer(
 	url *url.URL,
 	dispatcher routing.Dispatcher,
-	queryStrategy QueryStrategy,
+	disableCache bool,
+	clientIP net.IP,
 ) (*TCPNameServer, error) {
-	s, err := baseTCPNameServer(url, "TCP", queryStrategy)
+	s, err := baseTCPNameServer(url, "TCP", disableCache, clientIP)
 	if err != nil {
 		return nil, err
 	}
@@ -65,8 +58,8 @@ func NewTCPNameServer(
 }
 
 // NewTCPLocalNameServer creates DNS over TCP client object for local resolving
-func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameServer, error) {
-	s, err := baseTCPNameServer(url, "TCPL", queryStrategy)
+func NewTCPLocalNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*TCPNameServer, error) {
+	s, err := baseTCPNameServer(url, "TCPL", disableCache, clientIP)
 	if err != nil {
 		return nil, err
 	}
@@ -78,7 +71,7 @@ func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameS
 	return s, nil
 }
 
-func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) (*TCPNameServer, error) {
+func baseTCPNameServer(url *url.URL, prefix string, disableCache bool, clientIP net.IP) (*TCPNameServer, error) {
 	port := net.Port(53)
 	if url.Port() != "" {
 		var err error
@@ -89,15 +82,9 @@ func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy)
 	dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)
 
 	s := &TCPNameServer{
-		destination:   &dest,
-		ips:           make(map[string]*record),
-		pub:           pubsub.NewService(),
-		name:          prefix + "//" + dest.NetAddr(),
-		queryStrategy: queryStrategy,
-	}
-	s.cleanup = &task.Periodic{
-		Interval: time.Minute,
-		Execute:  s.Cleanup,
+		cacheController: NewCacheController(prefix+"//"+dest.NetAddr(), disableCache),
+		destination:     &dest,
+		clientIP:        clientIP,
 	}
 
 	return s, nil
@@ -105,94 +92,17 @@ func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy)
 
 // Name implements Server.
 func (s *TCPNameServer) Name() string {
-	return s.name
-}
-
-// Cleanup clears expired items from cache
-func (s *TCPNameServer) Cleanup() error {
-	now := time.Now()
-	s.Lock()
-	defer s.Unlock()
-
-	if len(s.ips) == 0 {
-		return errors.New("nothing to do. stopping...")
-	}
-
-	for domain, record := range s.ips {
-		if record.A != nil && record.A.Expire.Before(now) {
-			record.A = nil
-		}
-		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
-			record.AAAA = nil
-		}
-
-		if record.A == nil && record.AAAA == nil {
-			errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
-			delete(s.ips, domain)
-		} else {
-			s.ips[domain] = record
-		}
-	}
-
-	if len(s.ips) == 0 {
-		s.ips = make(map[string]*record)
-	}
-
-	return nil
-}
-
-func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
-	elapsed := time.Since(req.start)
-
-	s.Lock()
-	rec, found := s.ips[req.domain]
-	if !found {
-		rec = &record{}
-	}
-	updated := false
-
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		if isNewer(rec.A, ipRec) {
-			rec.A = ipRec
-			updated = true
-		}
-	case dnsmessage.TypeAAAA:
-		addr := make([]net.Address, 0)
-		for _, ip := range ipRec.IP {
-			if len(ip.IP()) == net.IPv6len {
-				addr = append(addr, ip)
-			}
-		}
-		ipRec.IP = addr
-		if isNewer(rec.AAAA, ipRec) {
-			rec.AAAA = ipRec
-			updated = true
-		}
-	}
-	errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
-
-	if updated {
-		s.ips[req.domain] = rec
-	}
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		s.pub.Publish(req.domain+"4", nil)
-	case dnsmessage.TypeAAAA:
-		s.pub.Publish(req.domain+"6", nil)
-	}
-	s.Unlock()
-	common.Must(s.cleanup.Start())
+	return s.cacheController.name
 }
 
 func (s *TCPNameServer) newReqID() uint16 {
 	return uint16(atomic.AddUint32(&s.reqID, 1))
 }
 
-func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
-	errors.LogDebug(ctx, s.name, " querying DNS for: ", domain)
+func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
+	errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
 
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0))
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
 
 	var deadline time.Time
 	if d, ok := ctx.Deadline(); ok {
@@ -221,23 +131,36 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP n
 			b, err := dns.PackMessage(r.msg)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to pack dns query")
+				noResponseErrCh <- err
 				return
 			}
 
 			conn, err := s.dial(dnsCtx)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to dial namesever")
+				noResponseErrCh <- err
 				return
 			}
 			defer conn.Close()
 			dnsReqBuf := buf.New()
-			binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
-			dnsReqBuf.Write(b.Bytes())
+			err = binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
+			if err != nil {
+				errors.LogErrorInner(ctx, err, "binary write failed")
+				noResponseErrCh <- err
+				return
+			}
+			_, err = dnsReqBuf.Write(b.Bytes())
+			if err != nil {
+				errors.LogErrorInner(ctx, err, "buffer write failed")
+				noResponseErrCh <- err
+				return
+			}
 			b.Release()
 
 			_, err = conn.Write(dnsReqBuf.Bytes())
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to send query")
+				noResponseErrCh <- err
 				return
 			}
 			dnsReqBuf.Release()
@@ -247,131 +170,80 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP n
 			n, err := respBuf.ReadFullFrom(conn, 2)
 			if err != nil && n == 0 {
 				errors.LogErrorInner(ctx, err, "failed to read response length")
+				noResponseErrCh <- err
 				return
 			}
 			var length int16
 			err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to parse response length")
+				noResponseErrCh <- err
 				return
 			}
 			respBuf.Clear()
 			n, err = respBuf.ReadFullFrom(conn, int32(length))
 			if err != nil && n == 0 {
 				errors.LogErrorInner(ctx, err, "failed to read response length")
+				noResponseErrCh <- err
 				return
 			}
 
 			rec, err := parseResponse(respBuf.Bytes())
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to parse DNS over TCP response")
+				noResponseErrCh <- err
 				return
 			}
 
-			s.updateIP(r, rec)
+			s.cacheController.updateIP(r, rec)
 		}(req)
 	}
 }
 
-func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	s.RLock()
-	record, found := s.ips[domain]
-	s.RUnlock()
-
-	if !found {
-		return nil, 0, errRecordNotFound
-	}
-
-	var err4 error
-	var err6 error
-	var ips []net.Address
-	var ip6 []net.Address
-	var ttl uint32
-
-	if option.IPv4Enable {
-		ips, ttl, err4 = record.A.getIPs()
-	}
-
-	if option.IPv6Enable {
-		ip6, ttl, err6 = record.AAAA.getIPs()
-		ips = append(ips, ip6...)
-	}
-
-	if len(ips) > 0 {
-		netips, err := toNetIP(ips)
-		return netips, ttl, err
-	}
-
-	if err4 != nil {
-		return nil, 0, err4
-	}
-
-	if err6 != nil {
-		return nil, 0, err6
-	}
-
-	return nil, 0, dns_feature.ErrEmptyResponse
-}
-
 // QueryIP implements Server.
-func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) {
+func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
 	fqdn := Fqdn(domain)
-	option = ResolveIpOptionOverride(s.queryStrategy, option)
-	if !option.IPv4Enable && !option.IPv6Enable {
-		return nil, 0, dns_feature.ErrEmptyResponse
-	}
+	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
+	defer closeSubscribers(sub4, sub6)
 
-	if disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
+	if s.cacheController.disableCache {
+		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
 	} else {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
-			errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
+		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+		if !go_errors.Is(err, errRecordNotFound) {
+			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
+			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
 			return ips, ttl, err
 		}
 	}
 
-	// ipv4 and ipv6 belong to different subscription groups
-	var sub4, sub6 *pubsub.Subscriber
-	if option.IPv4Enable {
-		sub4 = s.pub.Subscribe(fqdn + "4")
-		defer sub4.Close()
-	}
-	if option.IPv6Enable {
-		sub6 = s.pub.Subscribe(fqdn + "6")
-		defer sub6.Close()
-	}
-	done := make(chan interface{})
-	go func() {
-		if sub4 != nil {
-			select {
-			case <-sub4.Wait():
-			case <-ctx.Done():
-			}
-		}
-		if sub6 != nil {
-			select {
-			case <-sub6.Wait():
-			case <-ctx.Done():
-			}
-		}
-		close(done)
-	}()
-	s.sendQuery(ctx, fqdn, clientIP, option)
+	noResponseErrCh := make(chan error, 2)
+	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
 	start := time.Now()
 
-	for {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err != errRecordNotFound {
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-			return ips, ttl, err
+	if sub4 != nil {
+		select {
+		case <-ctx.Done():
+			return nil, 0, ctx.Err()
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub4.Wait():
+			sub4.Close()
 		}
-
+	}
+	if sub6 != nil {
 		select {
 		case <-ctx.Done():
 			return nil, 0, ctx.Err()
-		case <-done:
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub6.Wait():
+			sub6.Close()
 		}
 	}
+
+	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
+	return ips, ttl, err
+
 }

+ 16 - 16
app/dns/nameserver_tcp_test.go

@@ -16,13 +16,13 @@ import (
 func TestTCPLocalNameServer(t *testing.T) {
 	url, err := url.Parse("tcp+local://8.8.8.8")
 	common.Must(err)
-	s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP)
+	s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
 	common.Must(err)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -33,13 +33,13 @@ func TestTCPLocalNameServer(t *testing.T) {
 func TestTCPLocalNameServerWithCache(t *testing.T) {
 	url, err := url.Parse("tcp+local://8.8.8.8")
 	common.Must(err)
-	s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP)
+	s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
 	common.Must(err)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 	if len(ips) == 0 {
@@ -47,10 +47,10 @@ func TestTCPLocalNameServerWithCache(t *testing.T) {
 	}
 
 	ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips2, _, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips2, _, err := s.QueryIP(ctx2, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
 		IPv6Enable: true,
-	}, true)
+	})
 	cancel()
 	common.Must(err)
 	if r := cmp.Diff(ips2, ips); r != "" {
@@ -61,13 +61,13 @@ func TestTCPLocalNameServerWithCache(t *testing.T) {
 func TestTCPLocalNameServerWithIPv4Override(t *testing.T) {
 	url, err := url.Parse("tcp+local://8.8.8.8")
 	common.Must(err)
-	s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP4)
+	s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
 	common.Must(err)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
 		IPv4Enable: true,
-		IPv6Enable: true,
-	}, false)
+		IPv6Enable: false,
+	})
 	cancel()
 	common.Must(err)
 
@@ -85,13 +85,13 @@ func TestTCPLocalNameServerWithIPv4Override(t *testing.T) {
 func TestTCPLocalNameServerWithIPv6Override(t *testing.T) {
 	url, err := url.Parse("tcp+local://8.8.8.8")
 	common.Must(err)
-	s, err := NewTCPLocalNameServer(url, QueryStrategy_USE_IP6)
+	s, err := NewTCPLocalNameServer(url, false, net.IP(nil))
 	common.Must(err)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	ips, _, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
-		IPv4Enable: true,
+	ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
+		IPv4Enable: false,
 		IPv6Enable: true,
-	}, false)
+	})
 	cancel()
 	common.Must(err)
 

+ 59 - 182
app/dns/nameserver_udp.go

@@ -2,6 +2,7 @@ package dns
 
 import (
 	"context"
+	go_errors "errors"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -13,7 +14,6 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol/dns"
 	udp_proto "github.com/xtls/xray-core/common/protocol/udp"
-	"github.com/xtls/xray-core/common/signal/pubsub"
 	"github.com/xtls/xray-core/common/task"
 	dns_feature "github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/features/routing"
@@ -24,15 +24,13 @@ import (
 // ClassicNameServer implemented traditional UDP DNS.
 type ClassicNameServer struct {
 	sync.RWMutex
-	name          string
-	address       *net.Destination
-	ips           map[string]*record
-	requests      map[uint16]*udpDnsRequest
-	pub           *pubsub.Service
-	udpServer     *udp.Dispatcher
-	cleanup       *task.Periodic
-	reqID         uint32
-	queryStrategy QueryStrategy
+	cacheController *CacheController
+	address         *net.Destination
+	requests        map[uint16]*udpDnsRequest
+	udpServer       *udp.Dispatcher
+	requestsCleanup *task.Periodic
+	reqID           uint32
+	clientIP        net.IP
 }
 
 type udpDnsRequest struct {
@@ -41,23 +39,21 @@ type udpDnsRequest struct {
 }
 
 // NewClassicNameServer creates udp server object for remote resolving.
-func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) *ClassicNameServer {
+func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) *ClassicNameServer {
 	// default to 53 if unspecific
 	if address.Port == 0 {
 		address.Port = net.Port(53)
 	}
 
 	s := &ClassicNameServer{
-		address:       &address,
-		ips:           make(map[string]*record),
-		requests:      make(map[uint16]*udpDnsRequest),
-		pub:           pubsub.NewService(),
-		name:          strings.ToUpper(address.String()),
-		queryStrategy: queryStrategy,
+		cacheController: NewCacheController(strings.ToUpper(address.String()), disableCache),
+		address:         &address,
+		requests:        make(map[uint16]*udpDnsRequest),
+		clientIP:        clientIP,
 	}
-	s.cleanup = &task.Periodic{
+	s.requestsCleanup = &task.Periodic{
 		Interval: time.Minute,
-		Execute:  s.Cleanup,
+		Execute:  s.RequestsCleanup,
 	}
 	s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
 	errors.LogInfo(context.Background(), "DNS: created UDP client initialized for ", address.NetAddr())
@@ -66,37 +62,17 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 
 // Name implements Server.
 func (s *ClassicNameServer) Name() string {
-	return s.name
+	return s.cacheController.name
 }
 
-// Cleanup clears expired items from cache
-func (s *ClassicNameServer) Cleanup() error {
+// RequestsCleanup clears expired items from cache
+func (s *ClassicNameServer) RequestsCleanup() error {
 	now := time.Now()
 	s.Lock()
 	defer s.Unlock()
 
-	if len(s.ips) == 0 && len(s.requests) == 0 {
-		return errors.New(s.name, " nothing to do. stopping...")
-	}
-
-	for domain, record := range s.ips {
-		if record.A != nil && record.A.Expire.Before(now) {
-			record.A = nil
-		}
-		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
-			record.AAAA = nil
-		}
-
-		if record.A == nil && record.AAAA == nil {
-			errors.LogDebug(context.Background(), s.name, " cleanup ", domain)
-			delete(s.ips, domain)
-		} else {
-			s.ips[domain] = record
-		}
-	}
-
-	if len(s.ips) == 0 {
-		s.ips = make(map[string]*record)
+	if len(s.requests) == 0 {
+		return errors.New(s.Name(), " nothing to do. stopping...")
 	}
 
 	for id, req := range s.requests {
@@ -116,7 +92,7 @@ func (s *ClassicNameServer) Cleanup() error {
 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
 	ipRec, err := parseResponse(packet.Payload.Bytes())
 	if err != nil {
-		errors.LogError(ctx, s.name, " fail to parse responded DNS udp")
+		errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
 		return
 	}
 
@@ -129,14 +105,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 	}
 	s.Unlock()
 	if !ok {
-		errors.LogError(ctx, s.name, " cannot find the pending request")
+		errors.LogError(ctx, s.Name(), " cannot find the pending request")
 		return
 	}
 
 	// if truncated, retry with EDNS0 option(udp payload size: 1350)
 	if ipRec.RawHeader.Truncated {
 		// if already has EDNS0 option, no need to retry
-		if ok && len(req.msg.Additionals) == 0 {
+		if len(req.msg.Additionals) == 0 {
 			// copy necessary meta data from original request
 			// and add EDNS0 option
 			opt := new(dnsmessage.Resource)
@@ -154,51 +130,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 		}
 	}
 
-	var rec record
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		rec.A = ipRec
-	case dnsmessage.TypeAAAA:
-		rec.AAAA = ipRec
-	}
-
-	elapsed := time.Since(req.start)
-	errors.LogInfo(ctx, s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
-	if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
-		s.updateIP(req.domain, &rec)
-	}
-}
-
-func (s *ClassicNameServer) updateIP(domain string, newRec *record) {
-	s.Lock()
-
-	rec, found := s.ips[domain]
-	if !found {
-		rec = &record{}
-	}
-
-	updated := false
-	if isNewer(rec.A, newRec.A) {
-		rec.A = newRec.A
-		updated = true
-	}
-	if isNewer(rec.AAAA, newRec.AAAA) {
-		rec.AAAA = newRec.AAAA
-		updated = true
-	}
-
-	if updated {
-		errors.LogDebug(context.Background(), s.name, " updating IP records for domain:", domain)
-		s.ips[domain] = rec
-	}
-	if newRec.A != nil {
-		s.pub.Publish(domain+"4", nil)
-	}
-	if newRec.AAAA != nil {
-		s.pub.Publish(domain+"6", nil)
-	}
-	s.Unlock()
-	common.Must(s.cleanup.Start())
+	s.cacheController.updateIP(&req.dnsRequest, ipRec)
 }
 
 func (s *ClassicNameServer) newReqID() uint16 {
@@ -207,17 +139,17 @@ func (s *ClassicNameServer) newReqID() uint16 {
 
 func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
 	s.Lock()
-	defer s.Unlock()
-
 	id := req.msg.ID
 	req.expire = time.Now().Add(time.Second * 8)
 	s.requests[id] = req
+	s.Unlock()
+	common.Must(s.requestsCleanup.Start())
 }
 
-func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
-	errors.LogDebug(ctx, s.name, " querying DNS for: ", domain)
+func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) {
+	errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
 
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP, 0))
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
 
 	for _, req := range reqs {
 		udpReq := &udpDnsRequest{
@@ -230,105 +162,50 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
 	}
 }
 
-func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	s.RLock()
-	record, found := s.ips[domain]
-	s.RUnlock()
-
-	if !found {
-		return nil, 0, errRecordNotFound
-	}
-
-	var err4 error
-	var err6 error
-	var ips []net.Address
-	var ip6 []net.Address
-	var ttl uint32
-
-	if option.IPv4Enable {
-		ips, ttl, err4 = record.A.getIPs()
-	}
-
-	if option.IPv6Enable {
-		ip6, ttl, err6 = record.AAAA.getIPs()
-		ips = append(ips, ip6...)
-	}
-
-	if len(ips) > 0 {
-		netips, err := toNetIP(ips)
-		return netips, ttl, err
-	}
-
-	if err4 != nil {
-		return nil, 0, err4
-	}
-
-	if err6 != nil {
-		return nil, 0, err6
-	}
-
-	return nil, 0, dns_feature.ErrEmptyResponse
-}
-
 // QueryIP implements Server.
-func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, uint32, error) {
+func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
 	fqdn := Fqdn(domain)
-	option = ResolveIpOptionOverride(s.queryStrategy, option)
-	if !option.IPv4Enable && !option.IPv6Enable {
-		return nil, 0, dns_feature.ErrEmptyResponse
-	}
+	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
+	defer closeSubscribers(sub4, sub6)
 
-	if disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)
+	if s.cacheController.disableCache {
+		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
 	} else {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 {
-			errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
+		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+		if !go_errors.Is(err, errRecordNotFound) {
+			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
+			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
 			return ips, ttl, err
 		}
 	}
 
-	// ipv4 and ipv6 belong to different subscription groups
-	var sub4, sub6 *pubsub.Subscriber
-	if option.IPv4Enable {
-		sub4 = s.pub.Subscribe(fqdn + "4")
-		defer sub4.Close()
-	}
-	if option.IPv6Enable {
-		sub6 = s.pub.Subscribe(fqdn + "6")
-		defer sub6.Close()
-	}
-	done := make(chan interface{})
-	go func() {
-		if sub4 != nil {
-			select {
-			case <-sub4.Wait():
-			case <-ctx.Done():
-			}
-		}
-		if sub6 != nil {
-			select {
-			case <-sub6.Wait():
-			case <-ctx.Done():
-			}
-		}
-		close(done)
-	}()
-	s.sendQuery(ctx, fqdn, clientIP, option)
+	noResponseErrCh := make(chan error, 2)
+	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
 	start := time.Now()
 
-	for {
-		ips, ttl, err := s.findIPsForDomain(fqdn, option)
-		if err != errRecordNotFound {
-			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-			return ips, ttl, err
+	if sub4 != nil {
+		select {
+		case <-ctx.Done():
+			return nil, 0, ctx.Err()
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub4.Wait():
+			sub4.Close()
 		}
-
+	}
+	if sub6 != nil {
 		select {
 		case <-ctx.Done():
 			return nil, 0, ctx.Err()
-		case <-done:
+		case err := <-noResponseErrCh:
+			return nil, 0, err
+		case <-sub6.Wait():
+			sub6.Close()
 		}
 	}
+
+	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
+	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
+	return ips, ttl, err
+
 }

+ 1 - 1
app/router/condition.go

@@ -119,7 +119,7 @@ type MultiGeoIPMatcher struct {
 func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
 	var matchers []*GeoIPMatcher
 	for _, geoip := range geoips {
-		matcher, err := globalGeoIPContainer.Add(geoip)
+		matcher, err := GlobalGeoIPContainer.Add(geoip)
 		if err != nil {
 			return nil, err
 		}

+ 1 - 1
app/router/condition_geoip.go

@@ -115,4 +115,4 @@ func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) {
 	return m, nil
 }
 
-var globalGeoIPContainer GeoIPMatcherContainer
+var GlobalGeoIPContainer GeoIPMatcherContainer

+ 3 - 2
common/errors/multi_error.go

@@ -1,6 +1,7 @@
 package errors
 
 import (
+	"errors"
 	"strings"
 )
 
@@ -36,12 +37,12 @@ func AllEqual(expected error, actual error) bool {
 			return false
 		}
 		for _, err := range errs {
-			if err != expected {
+			if !errors.Is(err, expected) {
 				return false
 			}
 		}
 		return true
 	default:
-		return errs == expected
+		return errors.Is(errs, expected)
 	}
 }

+ 2 - 0
features/dns/client.go

@@ -38,6 +38,8 @@ func ClientType() interface{} {
 // ErrEmptyResponse indicates that DNS query succeeded but no answer was returned.
 var ErrEmptyResponse = errors.New("empty response")
 
+const DefaultTTL = 300
+
 type RCodeError uint16
 
 func (e RCodeError) Error() string {

+ 13 - 11
features/dns/localdns/client.go

@@ -30,29 +30,31 @@ func (*Client) LookupIP(host string, option dns.IPOption) ([]net.IP, uint32, err
 	ipv6 := make([]net.IP, 0, len(ips))
 	for _, ip := range ips {
 		parsed := net.IPAddress(ip)
-		if parsed != nil {
-			parsedIPs = append(parsedIPs, parsed.IP())
+		if parsed == nil {
+			continue
 		}
-		if len(ip) == net.IPv4len {
-			ipv4 = append(ipv4, ip)
-		}
-		if len(ip) == net.IPv6len {
-			ipv6 = append(ipv6, ip)
+		parsedIP := parsed.IP()
+		parsedIPs = append(parsedIPs, parsedIP)
+
+		if len(parsedIP) == net.IPv4len {
+			ipv4 = append(ipv4, parsedIP)
+		} else {
+			ipv6 = append(ipv6, parsedIP)
 		}
 	}
-	// Local DNS ttl is 600
+
 	switch {
 	case option.IPv4Enable && option.IPv6Enable:
 		if len(parsedIPs) > 0 {
-			return parsedIPs, 600, nil
+			return parsedIPs, dns.DefaultTTL, nil
 		}
 	case option.IPv4Enable:
 		if len(ipv4) > 0 {
-			return ipv4, 600, nil
+			return ipv4, dns.DefaultTTL, nil
 		}
 	case option.IPv6Enable:
 		if len(ipv6) > 0 {
-			return ipv6, 600, nil
+			return ipv6, dns.DefaultTTL, nil
 		}
 	}
 	return nil, 0, dns.ErrEmptyResponse

+ 2 - 1
proxy/dns/dns.go

@@ -2,6 +2,7 @@ package dns
 
 import (
 	"context"
+	go_errors "errors"
 	"io"
 	"sync"
 	"time"
@@ -255,7 +256,7 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 	}
 
 	rcode := dns.RCodeFromError(err)
-	if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) {
+	if rcode == 0 && len(ips) == 0 && !go_errors.Is(err, dns.ErrEmptyResponse) {
 		errors.LogInfoInner(context.Background(), err, "ip query")
 		return
 	}