cache_controller.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. package dns
  2. import (
  3. "context"
  4. go_errors "errors"
  5. "runtime"
  6. "sync"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/errors"
  10. "github.com/xtls/xray-core/common/signal/pubsub"
  11. "github.com/xtls/xray-core/common/task"
  12. dns_feature "github.com/xtls/xray-core/features/dns"
  13. "golang.org/x/net/dns/dnsmessage"
  14. "golang.org/x/sync/singleflight"
  15. )
  16. const (
  17. minSizeForEmptyRebuild = 512
  18. shrinkAbsoluteThreshold = 10240
  19. shrinkRatioThreshold = 0.65
  20. migrationBatchSize = 4096
  21. )
  22. type CacheController struct {
  23. name string
  24. disableCache bool
  25. serveStale bool
  26. serveExpiredTTL int32
  27. ips map[string]*record
  28. dirtyips map[string]*record
  29. sync.RWMutex
  30. pub *pubsub.Service
  31. cacheCleanup *task.Periodic
  32. highWatermark int
  33. requestGroup singleflight.Group
  34. }
  35. func NewCacheController(name string, disableCache bool, serveStale bool, serveExpiredTTL uint32) *CacheController {
  36. c := &CacheController{
  37. name: name,
  38. disableCache: disableCache,
  39. serveStale: serveStale,
  40. serveExpiredTTL: -int32(serveExpiredTTL),
  41. ips: make(map[string]*record),
  42. pub: pubsub.NewService(),
  43. }
  44. c.cacheCleanup = &task.Periodic{
  45. Interval: 300 * time.Second,
  46. Execute: c.CacheCleanup,
  47. }
  48. return c
  49. }
  50. // CacheCleanup clears expired items from cache
  51. func (c *CacheController) CacheCleanup() error {
  52. expiredKeys, err := c.collectExpiredKeys()
  53. if err != nil {
  54. return err
  55. }
  56. if len(expiredKeys) == 0 {
  57. return nil
  58. }
  59. c.writeAndShrink(expiredKeys)
  60. return nil
  61. }
  62. func (c *CacheController) collectExpiredKeys() ([]string, error) {
  63. c.RLock()
  64. defer c.RUnlock()
  65. if len(c.ips) == 0 {
  66. return nil, errors.New("nothing to do. stopping...")
  67. }
  68. // skip collection if a migration is in progress
  69. if c.dirtyips != nil {
  70. return nil, nil
  71. }
  72. now := time.Now()
  73. if c.serveStale && c.serveExpiredTTL != 0 {
  74. now = now.Add(time.Duration(c.serveExpiredTTL) * time.Second)
  75. }
  76. expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate
  77. for domain, rec := range c.ips {
  78. if (rec.A != nil && rec.A.Expire.Before(now)) ||
  79. (rec.AAAA != nil && rec.AAAA.Expire.Before(now)) {
  80. expiredKeys = append(expiredKeys, domain)
  81. }
  82. }
  83. return expiredKeys, nil
  84. }
  85. func (c *CacheController) writeAndShrink(expiredKeys []string) {
  86. c.Lock()
  87. defer c.Unlock()
  88. // double check to prevent upper call multiple cleanup tasks
  89. if c.dirtyips != nil {
  90. return
  91. }
  92. lenBefore := len(c.ips)
  93. if lenBefore > c.highWatermark {
  94. c.highWatermark = lenBefore
  95. }
  96. now := time.Now()
  97. if c.serveStale && c.serveExpiredTTL != 0 {
  98. now = now.Add(time.Duration(c.serveExpiredTTL) * time.Second)
  99. }
  100. for _, domain := range expiredKeys {
  101. rec := c.ips[domain]
  102. if rec == nil {
  103. continue
  104. }
  105. if rec.A != nil && rec.A.Expire.Before(now) {
  106. rec.A = nil
  107. }
  108. if rec.AAAA != nil && rec.AAAA.Expire.Before(now) {
  109. rec.AAAA = nil
  110. }
  111. if rec.A == nil && rec.AAAA == nil {
  112. delete(c.ips, domain)
  113. }
  114. }
  115. lenAfter := len(c.ips)
  116. if lenAfter == 0 {
  117. if c.highWatermark >= minSizeForEmptyRebuild {
  118. errors.LogDebug(context.Background(), c.name,
  119. " rebuilding empty cache map to reclaim memory.",
  120. " size_before_cleanup=", lenBefore,
  121. " peak_size_before_rebuild=", c.highWatermark,
  122. )
  123. c.ips = make(map[string]*record)
  124. c.highWatermark = 0
  125. }
  126. return
  127. }
  128. if reductionFromPeak := c.highWatermark - lenAfter; reductionFromPeak > shrinkAbsoluteThreshold &&
  129. float64(reductionFromPeak) > float64(c.highWatermark)*shrinkRatioThreshold {
  130. errors.LogDebug(context.Background(), c.name,
  131. " shrinking cache map to reclaim memory.",
  132. " new_size=", lenAfter,
  133. " peak_size_before_shrink=", c.highWatermark,
  134. " reduction_since_peak=", reductionFromPeak,
  135. )
  136. c.dirtyips = c.ips
  137. c.ips = make(map[string]*record, int(float64(lenAfter)*1.1))
  138. c.highWatermark = lenAfter
  139. go c.migrate()
  140. }
  141. }
  142. type migrationEntry struct {
  143. key string
  144. value *record
  145. }
  146. func (c *CacheController) migrate() {
  147. defer func() {
  148. if r := recover(); r != nil {
  149. errors.LogError(context.Background(), c.name, " panic during cache migration: ", r)
  150. c.Lock()
  151. c.dirtyips = nil
  152. // c.ips = make(map[string]*record)
  153. // c.highWatermark = 0
  154. c.Unlock()
  155. }
  156. }()
  157. c.RLock()
  158. dirtyips := c.dirtyips
  159. c.RUnlock()
  160. // double check to prevent upper call multiple cleanup tasks
  161. if dirtyips == nil {
  162. return
  163. }
  164. errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items")
  165. batch := make([]migrationEntry, 0, migrationBatchSize)
  166. for domain, recD := range dirtyips {
  167. batch = append(batch, migrationEntry{domain, recD})
  168. if len(batch) >= migrationBatchSize {
  169. c.flush(batch)
  170. batch = batch[:0]
  171. runtime.Gosched()
  172. }
  173. }
  174. if len(batch) > 0 {
  175. c.flush(batch)
  176. }
  177. c.Lock()
  178. c.dirtyips = nil
  179. c.Unlock()
  180. errors.LogDebug(context.Background(), c.name, " cache migration completed")
  181. }
  182. func (c *CacheController) flush(batch []migrationEntry) {
  183. c.Lock()
  184. defer c.Unlock()
  185. for _, dirty := range batch {
  186. if cur := c.ips[dirty.key]; cur != nil {
  187. merge := &record{}
  188. if cur.A == nil {
  189. merge.A = dirty.value.A
  190. } else {
  191. merge.A = cur.A
  192. }
  193. if cur.AAAA == nil {
  194. merge.AAAA = dirty.value.AAAA
  195. } else {
  196. merge.AAAA = cur.AAAA
  197. }
  198. c.ips[dirty.key] = merge
  199. } else {
  200. c.ips[dirty.key] = dirty.value
  201. }
  202. }
  203. }
  204. func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) {
  205. rtt := time.Since(req.start)
  206. switch req.reqType {
  207. case dnsmessage.TypeA:
  208. c.pub.Publish(req.domain+"4", rep)
  209. case dnsmessage.TypeAAAA:
  210. c.pub.Publish(req.domain+"6", rep)
  211. }
  212. if c.disableCache {
  213. errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt)
  214. return
  215. }
  216. c.Lock()
  217. lockWait := time.Since(req.start) - rtt
  218. newRec := &record{}
  219. oldRec := c.ips[req.domain]
  220. var dirtyRec *record
  221. if c.dirtyips != nil {
  222. dirtyRec = c.dirtyips[req.domain]
  223. }
  224. var pubRecord *IPRecord
  225. var pubSuffix string
  226. switch req.reqType {
  227. case dnsmessage.TypeA:
  228. newRec.A = rep
  229. if oldRec != nil && oldRec.AAAA != nil {
  230. newRec.AAAA = oldRec.AAAA
  231. pubRecord = oldRec.AAAA
  232. } else if dirtyRec != nil && dirtyRec.AAAA != nil {
  233. pubRecord = dirtyRec.AAAA
  234. }
  235. pubSuffix = "6"
  236. case dnsmessage.TypeAAAA:
  237. newRec.AAAA = rep
  238. if oldRec != nil && oldRec.A != nil {
  239. newRec.A = oldRec.A
  240. pubRecord = oldRec.A
  241. } else if dirtyRec != nil && dirtyRec.A != nil {
  242. pubRecord = dirtyRec.A
  243. }
  244. pubSuffix = "4"
  245. }
  246. c.ips[req.domain] = newRec
  247. c.Unlock()
  248. if pubRecord != nil {
  249. _, ttl, err := pubRecord.getIPs()
  250. if ttl > 0 && !go_errors.Is(err, errRecordNotFound) {
  251. c.pub.Publish(req.domain+pubSuffix, pubRecord)
  252. }
  253. }
  254. errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait)
  255. if !c.serveStale || c.serveExpiredTTL != 0 {
  256. common.Must(c.cacheCleanup.Start())
  257. }
  258. }
  259. func (c *CacheController) findRecords(domain string) *record {
  260. c.RLock()
  261. defer c.RUnlock()
  262. rec := c.ips[domain]
  263. if rec == nil && c.dirtyips != nil {
  264. rec = c.dirtyips[domain]
  265. }
  266. return rec
  267. }
  268. func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
  269. // ipv4 and ipv6 belong to different subscription groups
  270. if option.IPv4Enable {
  271. sub4 = c.pub.Subscribe(domain + "4")
  272. }
  273. if option.IPv6Enable {
  274. sub6 = c.pub.Subscribe(domain + "6")
  275. }
  276. return
  277. }
  278. func closeSubscribers(sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
  279. if sub4 != nil {
  280. sub4.Close()
  281. }
  282. if sub6 != nil {
  283. sub6.Close()
  284. }
  285. }