Przeglądaj źródła

refactor(dns): enhance cache safety, optimize performance, and refactor query logic (#5248)

Meow 2 miesięcy temu
rodzic
commit
b40bf56e4e

+ 228 - 93
app/dns/cache_controller.go

@@ -3,24 +3,37 @@ package dns
 import (
 	"context"
 	go_errors "errors"
+	"runtime"
+	"sync"
+	"time"
+
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
-	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/signal/pubsub"
 	"github.com/xtls/xray-core/common/task"
 	dns_feature "github.com/xtls/xray-core/features/dns"
+
 	"golang.org/x/net/dns/dnsmessage"
-	"sync"
-	"time"
+	"golang.org/x/sync/singleflight"
+)
+
+const (
+	minSizeForEmptyRebuild  = 512
+	shrinkAbsoluteThreshold = 10240
+	shrinkRatioThreshold    = 0.65
+	migrationBatchSize      = 4096
 )
 
 type CacheController struct {
 	sync.RWMutex
-	ips          map[string]*record
-	pub          *pubsub.Service
-	cacheCleanup *task.Periodic
-	name         string
-	disableCache bool
+	ips           map[string]*record
+	dirtyips      map[string]*record
+	pub           *pubsub.Service
+	cacheCleanup  *task.Periodic
+	name          string
+	disableCache  bool
+	highWatermark int
+	requestGroup  singleflight.Group
 }
 
 func NewCacheController(name string, disableCache bool) *CacheController {
@@ -32,7 +45,7 @@ func NewCacheController(name string, disableCache bool) *CacheController {
 	}
 
 	c.cacheCleanup = &task.Periodic{
-		Interval: time.Minute,
+		Interval: 300 * time.Second,
 		Execute:  c.CacheCleanup,
 	}
 	return c
@@ -40,131 +53,253 @@ func NewCacheController(name string, disableCache bool) *CacheController {
 
 // CacheCleanup clears expired items from cache
 func (c *CacheController) CacheCleanup() error {
+	expiredKeys, err := c.collectExpiredKeys()
+	if err != nil {
+		return err
+	}
+	if len(expiredKeys) == 0 {
+		return nil
+	}
+	c.writeAndShrink(expiredKeys)
+	return nil
+}
+
+func (c *CacheController) collectExpiredKeys() ([]string, error) {
+	c.RLock()
+	defer c.RUnlock()
+
+	if len(c.ips) == 0 {
+		return nil, errors.New("nothing to do. stopping...")
+	}
+
+	// skip collection if a migration is in progress
+	if c.dirtyips != nil {
+		return nil, nil
+	}
+
 	now := time.Now()
+	expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate
+
+	for domain, rec := range c.ips {
+		if (rec.A != nil && rec.A.Expire.Before(now)) ||
+			(rec.AAAA != nil && rec.AAAA.Expire.Before(now)) {
+			expiredKeys = append(expiredKeys, domain)
+		}
+	}
+
+	return expiredKeys, nil
+}
+
+func (c *CacheController) writeAndShrink(expiredKeys []string) {
 	c.Lock()
 	defer c.Unlock()
 
-	if len(c.ips) == 0 {
-		return errors.New("nothing to do. stopping...")
+	// double check to prevent upper call multiple cleanup tasks
+	if c.dirtyips != nil {
+		return
+	}
+
+	lenBefore := len(c.ips)
+	if lenBefore > c.highWatermark {
+		c.highWatermark = lenBefore
 	}
 
-	for domain, record := range c.ips {
-		if record.A != nil && record.A.Expire.Before(now) {
-			record.A = nil
+	now := time.Now()
+	for _, domain := range expiredKeys {
+		rec := c.ips[domain]
+		if rec == nil {
+			continue
 		}
-		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
-			record.AAAA = nil
+		if rec.A != nil && rec.A.Expire.Before(now) {
+			rec.A = nil
 		}
-
-		if record.A == nil && record.AAAA == nil {
-			errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
+		if rec.AAAA != nil && rec.AAAA.Expire.Before(now) {
+			rec.AAAA = nil
+		}
+		if rec.A == nil && rec.AAAA == nil {
 			delete(c.ips, domain)
-		} else {
-			c.ips[domain] = record
 		}
 	}
 
-	if len(c.ips) == 0 {
-		c.ips = make(map[string]*record)
+	lenAfter := len(c.ips)
+
+	if lenAfter == 0 {
+		if c.highWatermark >= minSizeForEmptyRebuild {
+			errors.LogDebug(context.Background(), c.name,
+				" rebuilding empty cache map to reclaim memory.",
+				" size_before_cleanup=", lenBefore,
+				" peak_size_before_rebuild=", c.highWatermark,
+			)
+
+			c.ips = make(map[string]*record)
+			c.highWatermark = 0
+		}
+		return
+	}
+
+	if reductionFromPeak := c.highWatermark - lenAfter; reductionFromPeak > shrinkAbsoluteThreshold &&
+		float64(reductionFromPeak) > float64(c.highWatermark)*shrinkRatioThreshold {
+		errors.LogDebug(context.Background(), c.name,
+			" shrinking cache map to reclaim memory.",
+			" new_size=", lenAfter,
+			" peak_size_before_shrink=", c.highWatermark,
+			" reduction_since_peak=", reductionFromPeak,
+		)
+
+		c.dirtyips = c.ips
+		c.ips = make(map[string]*record, int(float64(lenAfter)*1.1))
+		c.highWatermark = lenAfter
+		go c.migrate()
 	}
 
-	return nil
 }
 
-func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
-	elapsed := time.Since(req.start)
+type migrationEntry struct {
+	key   string
+	value *record
+}
 
-	c.Lock()
-	rec, found := c.ips[req.domain]
-	if !found {
-		rec = &record{}
+func (c *CacheController) migrate() {
+	defer func() {
+		if r := recover(); r != nil {
+			errors.LogError(context.Background(), c.name, " panic during cache migration: ", r)
+			c.Lock()
+			c.dirtyips = nil
+			// c.ips = make(map[string]*record)
+			// c.highWatermark = 0
+			c.Unlock()
+		}
+	}()
+
+	c.RLock()
+	dirtyips := c.dirtyips
+	c.RUnlock()
+
+	// double check to prevent upper call multiple cleanup tasks
+	if dirtyips == nil {
+		return
 	}
 
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		rec.A = ipRec
-	case dnsmessage.TypeAAAA:
-		rec.AAAA = ipRec
+	errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items.")
+
+	batch := make([]migrationEntry, 0, migrationBatchSize)
+	for domain, recD := range dirtyips {
+		batch = append(batch, migrationEntry{domain, recD})
+
+		if len(batch) >= migrationBatchSize {
+			c.flush(batch)
+			batch = batch[:0]
+			runtime.Gosched()
+		}
+	}
+	if len(batch) > 0 {
+		c.flush(batch)
 	}
 
-	errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
-	c.ips[req.domain] = rec
+	c.Lock()
+	c.dirtyips = nil
+	c.Unlock()
 
-	switch req.reqType {
-	case dnsmessage.TypeA:
-		c.pub.Publish(req.domain+"4", nil)
-		if !c.disableCache {
-			_, _, err := rec.AAAA.getIPs()
-			if !go_errors.Is(err, errRecordNotFound) {
-				c.pub.Publish(req.domain+"6", nil)
+	errors.LogDebug(context.Background(), c.name, " cache migration completed.")
+}
+
+func (c *CacheController) flush(batch []migrationEntry) {
+	c.Lock()
+	defer c.Unlock()
+
+	for _, dirty := range batch {
+		if cur := c.ips[dirty.key]; cur != nil {
+			merge := &record{}
+			if cur.A == nil {
+				merge.A = dirty.value.A
+			} else {
+				merge.A = cur.A
 			}
-		}
-	case dnsmessage.TypeAAAA:
-		c.pub.Publish(req.domain+"6", nil)
-		if !c.disableCache {
-			_, _, err := rec.A.getIPs()
-			if !go_errors.Is(err, errRecordNotFound) {
-				c.pub.Publish(req.domain+"4", nil)
+			if cur.AAAA == nil {
+				merge.AAAA = dirty.value.AAAA
+			} else {
+				merge.AAAA = cur.AAAA
 			}
+			c.ips[dirty.key] = merge
+		} else {
+			c.ips[dirty.key] = dirty.value
 		}
 	}
-
-	c.Unlock()
-	common.Must(c.cacheCleanup.Start())
 }
 
-func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	c.RLock()
-	record, found := c.ips[domain]
-	c.RUnlock()
+func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) {
+	rtt := time.Since(req.start)
 
-	if !found {
-		return nil, 0, errRecordNotFound
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		c.pub.Publish(req.domain+"4", rep)
+	case dnsmessage.TypeAAAA:
+		c.pub.Publish(req.domain+"6", rep)
 	}
 
-	var errs []error
-	var allIPs []net.IP
-	var rTTL uint32 = dns_feature.DefaultTTL
+	if c.disableCache {
+		errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt)
+		return
+	}
 
-	mergeReq := option.IPv4Enable && option.IPv6Enable
+	c.Lock()
+	lockWait := time.Since(req.start) - rtt
 
-	if option.IPv4Enable {
-		ips, ttl, err := record.A.getIPs()
-		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
-			return ips, ttl, err
-		}
-		if ttl < rTTL {
-			rTTL = ttl
-		}
-		if len(ips) > 0 {
-			allIPs = append(allIPs, ips...)
-		} else {
-			errs = append(errs, err)
-		}
+	newRec := &record{}
+	oldRec := c.ips[req.domain]
+	var dirtyRec *record
+	if c.dirtyips != nil {
+		dirtyRec = c.dirtyips[req.domain]
 	}
 
-	if option.IPv6Enable {
-		ips, ttl, err := record.AAAA.getIPs()
-		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
-			return ips, ttl, err
-		}
-		if ttl < rTTL {
-			rTTL = ttl
+	var pubRecord *IPRecord
+	var pubSuffix string
+
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		newRec.A = rep
+		if oldRec != nil && oldRec.AAAA != nil {
+			newRec.AAAA = oldRec.AAAA
+			pubRecord = oldRec.AAAA
+		} else if dirtyRec != nil && dirtyRec.AAAA != nil {
+			pubRecord = dirtyRec.AAAA
 		}
-		if len(ips) > 0 {
-			allIPs = append(allIPs, ips...)
-		} else {
-			errs = append(errs, err)
+		pubSuffix = "6"
+	case dnsmessage.TypeAAAA:
+		newRec.AAAA = rep
+		if oldRec != nil && oldRec.A != nil {
+			newRec.A = oldRec.A
+			pubRecord = oldRec.A
+		} else if dirtyRec != nil && dirtyRec.A != nil {
+			pubRecord = dirtyRec.A
 		}
+		pubSuffix = "4"
 	}
 
-	if len(allIPs) > 0 {
-		return allIPs, rTTL, nil
+	c.ips[req.domain] = newRec
+	c.Unlock()
+
+	if pubRecord != nil {
+		_, _ /*ttl*/, err := pubRecord.getIPs()
+		if /*ttl >= 0 &&*/ !go_errors.Is(err, errRecordNotFound) {
+			c.pub.Publish(req.domain+pubSuffix, pubRecord)
+		}
 	}
-	if go_errors.Is(errs[0], errs[1]) {
-		return nil, rTTL, errs[0]
+
+	errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait)
+
+	common.Must(c.cacheCleanup.Start())
+}
+
+func (c *CacheController) findRecords(domain string) *record {
+	c.RLock()
+	defer c.RUnlock()
+
+	rec := c.ips[domain]
+	if rec == nil && c.dirtyips != nil {
+		rec = c.dirtyips[domain]
 	}
-	return nil, rTTL, errors.Combine(errs...)
+	return rec
 }
 
 func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {

+ 1 - 0
app/dns/dnscommon.go

@@ -17,6 +17,7 @@ import (
 )
 
 // Fqdn normalizes domain make sure it ends with '.'
+// case-sensitive
 func Fqdn(domain string) string {
 	if len(domain) > 0 && strings.HasSuffix(domain, ".") {
 		return domain

+ 149 - 0
app/dns/nameserver_cached.go

@@ -0,0 +1,149 @@
+package dns
+
+import (
+	"context"
+	go_errors "errors"
+	"time"
+
+	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/log"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/signal/pubsub"
+	"github.com/xtls/xray-core/features/dns"
+)
+
+type CachedNameserver interface {
+	getCacheController() *CacheController
+
+	sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns.IPOption)
+}
+
+// queryIP is called from dns.Server->queryIPTimeout
+func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.IPOption) ([]net.IP, uint32, error) {
+	fqdn := Fqdn(domain)
+
+	cache := s.getCacheController()
+	if !cache.disableCache {
+		if rec := cache.findRecords(fqdn); rec != nil {
+			ips, ttl, err := merge(option, rec.A, rec.AAAA)
+			if !go_errors.Is(err, errRecordNotFound) {
+				// errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
+				log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
+				return ips, ttl, err
+			}
+		}
+	} else {
+		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", fqdn, " at ", cache.name)
+	}
+
+	return fetch(ctx, s, fqdn, option)
+}
+
+func fetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) ([]net.IP, uint32, error) {
+	key := fqdn + "f"
+	switch {
+	case option.IPv4Enable && option.IPv6Enable:
+		key = key + "46"
+	case option.IPv4Enable:
+		key = key + "4"
+	case option.IPv6Enable:
+		key = key + "6"
+	}
+
+	v, _, _ := s.getCacheController().requestGroup.Do(key, func() (any, error) {
+		return doFetch(ctx, s, fqdn, option), nil
+	})
+	ret := v.(result)
+
+	return ret.ips, ret.ttl, ret.error
+}
+
+type result struct {
+	ips []net.IP
+	ttl uint32
+	error
+}
+
+func doFetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) result {
+	sub4, sub6 := s.getCacheController().registerSubscribers(fqdn, option)
+	defer closeSubscribers(sub4, sub6)
+
+	noResponseErrCh := make(chan error, 2)
+	onEvent := func(sub *pubsub.Subscriber) (*IPRecord, error) {
+		if sub == nil {
+			return nil, nil
+		}
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case err := <-noResponseErrCh:
+			return nil, err
+		case msg := <-sub.Wait():
+			sub.Close()
+			return msg.(*IPRecord), nil // should panic
+		}
+	}
+
+	start := time.Now()
+	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
+
+	rec4, err4 := onEvent(sub4)
+	rec6, err6 := onEvent(sub6)
+
+	var errs []error
+	if err4 != nil {
+		errs = append(errs, err4)
+	}
+	if err6 != nil {
+		errs = append(errs, err6)
+	}
+
+	ips, ttl, err := merge(option, rec4, rec6, errs...)
+	log.Record(&log.DNSLog{Server: s.getCacheController().name, Domain: fqdn, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
+	return result{ips, ttl, err}
+}
+
+func merge(option dns.IPOption, rec4 *IPRecord, rec6 *IPRecord, errs ...error) ([]net.IP, uint32, error) {
+	var allIPs []net.IP
+	var rTTL uint32 = dns.DefaultTTL
+
+	mergeReq := option.IPv4Enable && option.IPv6Enable
+
+	if option.IPv4Enable {
+		ips, ttl, err := rec4.getIPs() // it's safe
+		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
+			return ips, ttl, err
+		}
+		if ttl < rTTL {
+			rTTL = ttl
+		}
+		if len(ips) > 0 {
+			allIPs = append(allIPs, ips...)
+		} else {
+			errs = append(errs, err)
+		}
+	}
+
+	if option.IPv6Enable {
+		ips, ttl, err := rec6.getIPs() // it's safe
+		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
+			return ips, ttl, err
+		}
+		if ttl < rTTL {
+			rTTL = ttl
+		}
+		if len(ips) > 0 {
+			allIPs = append(allIPs, ips...)
+		} else {
+			errs = append(errs, err)
+		}
+	}
+
+	if len(allIPs) > 0 {
+		return allIPs, rTTL, nil
+	}
+	if len(errs) == 2 && go_errors.Is(errs[0], errs[1]) {
+		return nil, rTTL, errs[0]
+	}
+	return nil, rTTL, errors.Combine(errs...)
+}

+ 16 - 54
app/dns/nameserver_doh.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"context"
 	"crypto/tls"
-	go_errors "errors"
 	"fmt"
 	"io"
 	"net/http"
@@ -121,10 +120,16 @@ func (s *DoHNameServer) newReqID() uint16 {
 	return 0
 }
 
-func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
-	errors.LogInfo(ctx, s.Name(), " querying: ", domain)
+// getCacheController implements CachedNameserver.
+func (s *DoHNameServer) getCacheController() *CacheController {
+	return s.cacheController
+}
+
+// sendQuery implements CachedNameserver.
+func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
+	errors.LogInfo(ctx, s.Name(), " querying: ", fqdn)
 
-	if s.Name()+"." == "DOH//"+domain {
+	if s.Name()+"." == "DOH//"+fqdn {
 		errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.")
 		noResponseErrCh <- errors.New("tries to resolve itself!", s.Name())
 		return
@@ -132,7 +137,7 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
 
 	// As we don't want our traffic pattern looks like DoH, we use Random-Length Padding instead of Block-Length Padding recommended in RFC 8467
 	// Although DoH server like 1.1.1.1 will pad the response to Block-Length 468, at least it is better than no padding for response at all
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
+	reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
 
 	var deadline time.Time
 	if d, ok := ctx.Deadline(); ok {
@@ -166,23 +171,23 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
 
 			b, err := dns.PackMessage(r.msg)
 			if err != nil {
-				errors.LogErrorInner(ctx, err, "failed to pack dns query for ", domain)
+				errors.LogErrorInner(ctx, err, "failed to pack dns query for ", fqdn)
 				noResponseErrCh <- err
 				return
 			}
 			resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
 			if err != nil {
-				errors.LogErrorInner(ctx, err, "failed to retrieve response for ", domain)
+				errors.LogErrorInner(ctx, err, "failed to retrieve response for ", fqdn)
 				noResponseErrCh <- err
 				return
 			}
 			rec, err := parseResponse(resp)
 			if err != nil {
-				errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", domain)
+				errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", fqdn)
 				noResponseErrCh <- err
 				return
 			}
-			s.cacheController.updateIP(r, rec)
+			s.cacheController.updateRecord(r, rec)
 		}(req)
 	}
 }
@@ -216,49 +221,6 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte,
 }
 
 // QueryIP implements Server.
-func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { // nolint: dupl
-	fqdn := Fqdn(domain)
-	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
-	defer closeSubscribers(sub4, sub6)
-
-	if s.cacheController.disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
-	} else {
-		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-		if !go_errors.Is(err, errRecordNotFound) {
-			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
-			return ips, ttl, err
-		}
-	}
-
-	noResponseErrCh := make(chan error, 2)
-	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
-	start := time.Now()
-
-	if sub4 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub4.Wait():
-			sub4.Close()
-		}
-	}
-	if sub6 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub6.Wait():
-			sub6.Close()
-		}
-	}
-
-	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-	return ips, ttl, err
-
+func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
+	return queryIP(ctx, s, domain, option)
 }

+ 11 - 51
app/dns/nameserver_quic.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"context"
 	"encoding/binary"
-	go_errors "errors"
 	"net/url"
 	"sync"
 	"time"
@@ -59,7 +58,7 @@ func NewQUICNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*QUICN
 	return s, nil
 }
 
-// Name returns client name
+// Name implements Server.
 func (s *QUICNameServer) Name() string {
 	return s.cacheController.name
 }
@@ -68,10 +67,14 @@ func (s *QUICNameServer) newReqID() uint16 {
 	return 0
 }
 
-func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
-	errors.LogInfo(ctx, s.Name(), " querying: ", domain)
+// getCacheController implements CachedNameServer.
+func (s *QUICNameServer) getCacheController() *CacheController { return s.cacheController }
 
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
+// sendQuery implements CachedNameServer.
+func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
+	errors.LogInfo(ctx, s.Name(), " querying: ", fqdn)
+
+	reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
 
 	var deadline time.Time
 	if d, ok := ctx.Deadline(); ok {
@@ -167,57 +170,14 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- e
 				noResponseErrCh <- err
 				return
 			}
-			s.cacheController.updateIP(r, rec)
+			s.cacheController.updateRecord(r, rec)
 		}(req)
 	}
 }
 
-// QueryIP is called from dns.Server->queryIPTimeout
+// QueryIP implements Server.
 func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	fqdn := Fqdn(domain)
-	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
-	defer closeSubscribers(sub4, sub6)
-
-	if s.cacheController.disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
-	} else {
-		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-		if !go_errors.Is(err, errRecordNotFound) {
-			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
-			return ips, ttl, err
-		}
-	}
-
-	noResponseErrCh := make(chan error, 2)
-	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
-	start := time.Now()
-
-	if sub4 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub4.Wait():
-			sub4.Close()
-		}
-	}
-	if sub6 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub6.Wait():
-			sub6.Close()
-		}
-	}
-
-	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-	return ips, ttl, err
-
+	return queryIP(ctx, s, domain, option)
 }
 
 func isActive(s *quic.Conn) bool {

+ 11 - 50
app/dns/nameserver_tcp.go

@@ -4,14 +4,12 @@ import (
 	"bytes"
 	"context"
 	"encoding/binary"
-	go_errors "errors"
 	"net/url"
 	"sync/atomic"
 	"time"
 
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
-	"github.com/xtls/xray-core/common/log"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net/cnc"
 	"github.com/xtls/xray-core/common/protocol/dns"
@@ -99,10 +97,16 @@ func (s *TCPNameServer) newReqID() uint16 {
 	return uint16(atomic.AddUint32(&s.reqID, 1))
 }
 
-func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
-	errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
+// getCacheController implements CachedNameserver.
+func (s *TCPNameServer) getCacheController() *CacheController {
+	return s.cacheController
+}
+
+// sendQuery implements CachedNameserver.
+func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
+	errors.LogDebug(ctx, s.Name(), " querying DNS for: ", fqdn)
 
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
+	reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
 
 	var deadline time.Time
 	if d, ok := ctx.Deadline(); ok {
@@ -195,55 +199,12 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
 				return
 			}
 
-			s.cacheController.updateIP(r, rec)
+			s.cacheController.updateRecord(r, rec)
 		}(req)
 	}
 }
 
 // QueryIP implements Server.
 func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	fqdn := Fqdn(domain)
-	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
-	defer closeSubscribers(sub4, sub6)
-
-	if s.cacheController.disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
-	} else {
-		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-		if !go_errors.Is(err, errRecordNotFound) {
-			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
-			return ips, ttl, err
-		}
-	}
-
-	noResponseErrCh := make(chan error, 2)
-	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
-	start := time.Now()
-
-	if sub4 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub4.Wait():
-			sub4.Close()
-		}
-	}
-	if sub6 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub6.Wait():
-			sub6.Close()
-		}
-	}
-
-	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-	return ips, ttl, err
-
+	return queryIP(ctx, s, domain, option)
 }

+ 11 - 50
app/dns/nameserver_udp.go

@@ -2,7 +2,6 @@ package dns
 
 import (
 	"context"
-	go_errors "errors"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -10,7 +9,6 @@ import (
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
-	"github.com/xtls/xray-core/common/log"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol/dns"
 	udp_proto "github.com/xtls/xray-core/common/protocol/udp"
@@ -134,7 +132,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 		}
 	}
 
-	s.cacheController.updateIP(&req.dnsRequest, ipRec)
+	s.cacheController.updateRecord(&req.dnsRequest, ipRec)
 }
 
 func (s *ClassicNameServer) newReqID() uint16 {
@@ -150,10 +148,16 @@ func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
 	common.Must(s.requestsCleanup.Start())
 }
 
-func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) {
-	errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
+// getCacheController implements CachedNameserver.
+func (s *ClassicNameServer) getCacheController() *CacheController {
+	return s.cacheController
+}
+
+// sendQuery implements CachedNameserver.
+func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, fqdn string, option dns_feature.IPOption) {
+	errors.LogDebug(ctx, s.Name(), " querying DNS for: ", fqdn)
 
-	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
+	reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
 
 	for _, req := range reqs {
 		udpReq := &udpDnsRequest{
@@ -170,48 +174,5 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
 
 // QueryIP implements Server.
 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
-	fqdn := Fqdn(domain)
-	sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
-	defer closeSubscribers(sub4, sub6)
-
-	if s.cacheController.disableCache {
-		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
-	} else {
-		ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-		if !go_errors.Is(err, errRecordNotFound) {
-			errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
-			log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
-			return ips, ttl, err
-		}
-	}
-
-	noResponseErrCh := make(chan error, 2)
-	s.sendQuery(ctx, noResponseErrCh, fqdn, option)
-	start := time.Now()
-
-	if sub4 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub4.Wait():
-			sub4.Close()
-		}
-	}
-	if sub6 != nil {
-		select {
-		case <-ctx.Done():
-			return nil, 0, ctx.Err()
-		case err := <-noResponseErrCh:
-			return nil, 0, err
-		case <-sub6.Wait():
-			sub6.Close()
-		}
-	}
-
-	ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
-	log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
-	return ips, ttl, err
-
+	return queryIP(ctx, s, domain, option)
 }