nameserver_cached.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package dns
  2. import (
  3. "context"
  4. go_errors "errors"
  5. "time"
  6. "github.com/xtls/xray-core/common/errors"
  7. "github.com/xtls/xray-core/common/log"
  8. "github.com/xtls/xray-core/common/net"
  9. "github.com/xtls/xray-core/common/signal/pubsub"
  10. "github.com/xtls/xray-core/features/dns"
  11. )
  12. type CachedNameserver interface {
  13. getCacheController() *CacheController
  14. sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns.IPOption)
  15. }
  16. // queryIP is called from dns.Server->queryIPTimeout
  17. func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.IPOption) ([]net.IP, uint32, error) {
  18. fqdn := Fqdn(domain)
  19. cache := s.getCacheController()
  20. if !cache.disableCache {
  21. if rec := cache.findRecords(fqdn); rec != nil {
  22. ips, ttl, err := merge(option, rec.A, rec.AAAA)
  23. if !go_errors.Is(err, errRecordNotFound) {
  24. if ttl > 0 {
  25. errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
  26. log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
  27. return ips, uint32(ttl), err
  28. }
  29. if cache.serveStale && (cache.serveExpiredTTL == 0 || cache.serveExpiredTTL < ttl) {
  30. errors.LogDebugInner(ctx, err, cache.name, " cache OPTIMISTE ", fqdn, " -> ", ips)
  31. log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheOptimiste, Elapsed: 0, Error: err})
  32. go pull(ctx, s, fqdn, option)
  33. return ips, 1, err
  34. }
  35. }
  36. }
  37. } else {
  38. errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", fqdn, " at ", cache.name)
  39. }
  40. return fetch(ctx, s, fqdn, option)
  41. }
  42. func pull(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) {
  43. nctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 8*time.Second)
  44. defer cancel()
  45. fetch(nctx, s, fqdn, option)
  46. }
  47. func fetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) ([]net.IP, uint32, error) {
  48. key := fqdn
  49. switch {
  50. case option.IPv4Enable && option.IPv6Enable:
  51. key = key + "46"
  52. case option.IPv4Enable:
  53. key = key + "4"
  54. case option.IPv6Enable:
  55. key = key + "6"
  56. }
  57. v, _, _ := s.getCacheController().requestGroup.Do(key, func() (any, error) {
  58. return doFetch(ctx, s, fqdn, option), nil
  59. })
  60. ret := v.(result)
  61. return ret.ips, ret.ttl, ret.error
  62. }
  63. type result struct {
  64. ips []net.IP
  65. ttl uint32
  66. error
  67. }
  68. func doFetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) result {
  69. sub4, sub6 := s.getCacheController().registerSubscribers(fqdn, option)
  70. defer closeSubscribers(sub4, sub6)
  71. noResponseErrCh := make(chan error, 2)
  72. onEvent := func(sub *pubsub.Subscriber) (*IPRecord, error) {
  73. if sub == nil {
  74. return nil, nil
  75. }
  76. select {
  77. case <-ctx.Done():
  78. return nil, ctx.Err()
  79. case err := <-noResponseErrCh:
  80. return nil, err
  81. case msg := <-sub.Wait():
  82. sub.Close()
  83. return msg.(*IPRecord), nil // should panic
  84. }
  85. }
  86. start := time.Now()
  87. s.sendQuery(ctx, noResponseErrCh, fqdn, option)
  88. rec4, err4 := onEvent(sub4)
  89. rec6, err6 := onEvent(sub6)
  90. var errs []error
  91. if err4 != nil {
  92. errs = append(errs, err4)
  93. }
  94. if err6 != nil {
  95. errs = append(errs, err6)
  96. }
  97. ips, ttl, err := merge(option, rec4, rec6, errs...)
  98. var rTTL uint32
  99. if ttl > 0 {
  100. rTTL = uint32(ttl)
  101. } else if ttl == 0 && go_errors.Is(err, errRecordNotFound) {
  102. rTTL = 0
  103. } else { // edge case: where a fast rep's ttl expires during the rtt of a slower, parallel query
  104. rTTL = 1
  105. }
  106. log.Record(&log.DNSLog{Server: s.getCacheController().name, Domain: fqdn, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
  107. return result{ips, rTTL, err}
  108. }
  109. func merge(option dns.IPOption, rec4 *IPRecord, rec6 *IPRecord, errs ...error) ([]net.IP, int32, error) {
  110. var allIPs []net.IP
  111. var rTTL int32 = dns.DefaultTTL
  112. mergeReq := option.IPv4Enable && option.IPv6Enable
  113. if option.IPv4Enable {
  114. ips, ttl, err := rec4.getIPs() // it's safe
  115. if !mergeReq || go_errors.Is(err, errRecordNotFound) {
  116. return ips, ttl, err
  117. }
  118. if ttl < rTTL {
  119. rTTL = ttl
  120. }
  121. if len(ips) > 0 {
  122. allIPs = append(allIPs, ips...)
  123. } else {
  124. errs = append(errs, err)
  125. }
  126. }
  127. if option.IPv6Enable {
  128. ips, ttl, err := rec6.getIPs() // it's safe
  129. if !mergeReq || go_errors.Is(err, errRecordNotFound) {
  130. return ips, ttl, err
  131. }
  132. if ttl < rTTL {
  133. rTTL = ttl
  134. }
  135. if len(ips) > 0 {
  136. allIPs = append(allIPs, ips...)
  137. } else {
  138. errs = append(errs, err)
  139. }
  140. }
  141. if len(allIPs) > 0 {
  142. return allIPs, rTTL, nil
  143. }
  144. if len(errs) == 2 && go_errors.Is(errs[0], errs[1]) {
  145. return nil, rTTL, errs[0]
  146. }
  147. return nil, rTTL, errors.Combine(errs...)
  148. }