cache_controller.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package dns
  2. import (
  3. "context"
  4. go_errors "errors"
  5. "github.com/xtls/xray-core/common"
  6. "github.com/xtls/xray-core/common/errors"
  7. "github.com/xtls/xray-core/common/net"
  8. "github.com/xtls/xray-core/common/signal/pubsub"
  9. "github.com/xtls/xray-core/common/task"
  10. dns_feature "github.com/xtls/xray-core/features/dns"
  11. "golang.org/x/net/dns/dnsmessage"
  12. "sync"
  13. "time"
  14. )
  15. type CacheController struct {
  16. sync.RWMutex
  17. ips map[string]*record
  18. pub *pubsub.Service
  19. cacheCleanup *task.Periodic
  20. name string
  21. disableCache bool
  22. }
  23. func NewCacheController(name string, disableCache bool) *CacheController {
  24. c := &CacheController{
  25. name: name,
  26. disableCache: disableCache,
  27. ips: make(map[string]*record),
  28. pub: pubsub.NewService(),
  29. }
  30. c.cacheCleanup = &task.Periodic{
  31. Interval: time.Minute,
  32. Execute: c.CacheCleanup,
  33. }
  34. return c
  35. }
  36. // CacheCleanup clears expired items from cache
  37. func (c *CacheController) CacheCleanup() error {
  38. now := time.Now()
  39. c.Lock()
  40. defer c.Unlock()
  41. if len(c.ips) == 0 {
  42. return errors.New("nothing to do. stopping...")
  43. }
  44. for domain, record := range c.ips {
  45. if record.A != nil && record.A.Expire.Before(now) {
  46. record.A = nil
  47. }
  48. if record.AAAA != nil && record.AAAA.Expire.Before(now) {
  49. record.AAAA = nil
  50. }
  51. if record.A == nil && record.AAAA == nil {
  52. errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
  53. delete(c.ips, domain)
  54. } else {
  55. c.ips[domain] = record
  56. }
  57. }
  58. if len(c.ips) == 0 {
  59. c.ips = make(map[string]*record)
  60. }
  61. return nil
  62. }
  63. func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
  64. elapsed := time.Since(req.start)
  65. c.Lock()
  66. rec, found := c.ips[req.domain]
  67. if !found {
  68. rec = &record{}
  69. }
  70. switch req.reqType {
  71. case dnsmessage.TypeA:
  72. rec.A = ipRec
  73. case dnsmessage.TypeAAAA:
  74. rec.AAAA = ipRec
  75. }
  76. errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
  77. c.ips[req.domain] = rec
  78. switch req.reqType {
  79. case dnsmessage.TypeA:
  80. c.pub.Publish(req.domain+"4", nil)
  81. if !c.disableCache {
  82. _, _, err := rec.AAAA.getIPs()
  83. if !go_errors.Is(err, errRecordNotFound) {
  84. c.pub.Publish(req.domain+"6", nil)
  85. }
  86. }
  87. case dnsmessage.TypeAAAA:
  88. c.pub.Publish(req.domain+"6", nil)
  89. if !c.disableCache {
  90. _, _, err := rec.A.getIPs()
  91. if !go_errors.Is(err, errRecordNotFound) {
  92. c.pub.Publish(req.domain+"4", nil)
  93. }
  94. }
  95. }
  96. c.Unlock()
  97. common.Must(c.cacheCleanup.Start())
  98. }
  99. func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
  100. c.RLock()
  101. record, found := c.ips[domain]
  102. c.RUnlock()
  103. if !found {
  104. return nil, 0, errRecordNotFound
  105. }
  106. var errs []error
  107. var allIPs []net.IP
  108. var rTTL uint32 = dns_feature.DefaultTTL
  109. mergeReq := option.IPv4Enable && option.IPv6Enable
  110. if option.IPv4Enable {
  111. ips, ttl, err := record.A.getIPs()
  112. if !mergeReq || go_errors.Is(err, errRecordNotFound) {
  113. return ips, ttl, err
  114. }
  115. if ttl < rTTL {
  116. rTTL = ttl
  117. }
  118. if len(ips) > 0 {
  119. allIPs = append(allIPs, ips...)
  120. } else {
  121. errs = append(errs, err)
  122. }
  123. }
  124. if option.IPv6Enable {
  125. ips, ttl, err := record.AAAA.getIPs()
  126. if !mergeReq || go_errors.Is(err, errRecordNotFound) {
  127. return ips, ttl, err
  128. }
  129. if ttl < rTTL {
  130. rTTL = ttl
  131. }
  132. if len(ips) > 0 {
  133. allIPs = append(allIPs, ips...)
  134. } else {
  135. errs = append(errs, err)
  136. }
  137. }
  138. if len(allIPs) > 0 {
  139. return allIPs, rTTL, nil
  140. }
  141. if go_errors.Is(errs[0], errs[1]) {
  142. return nil, rTTL, errs[0]
  143. }
  144. return nil, rTTL, errors.Combine(errs...)
  145. }
  146. func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
  147. // ipv4 and ipv6 belong to different subscription groups
  148. if option.IPv4Enable {
  149. sub4 = c.pub.Subscribe(domain + "4")
  150. }
  151. if option.IPv6Enable {
  152. sub6 = c.pub.Subscribe(domain + "6")
  153. }
  154. return
  155. }
  156. func closeSubscribers(sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
  157. if sub4 != nil {
  158. sub4.Close()
  159. }
  160. if sub6 != nil {
  161. sub6.Close()
  162. }
  163. }