dns_cache.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package cachefile
  2. import (
  3. "encoding/binary"
  4. "time"
  5. "github.com/sagernet/bbolt"
  6. "github.com/sagernet/sing/common/buf"
  7. "github.com/sagernet/sing/common/logger"
  8. )
  9. var bucketDNSCache = []byte("dns_cache")
  10. func (c *CacheFile) StoreDNS() bool {
  11. return c.storeDNS
  12. }
  13. func (c *CacheFile) LoadDNSCache(transportName string, qName string, qType uint16) (rawMessage []byte, expireAt time.Time, loaded bool) {
  14. c.saveDNSCacheAccess.RLock()
  15. entry, cached := c.saveDNSCache[saveCacheKey{transportName, qName, qType}]
  16. c.saveDNSCacheAccess.RUnlock()
  17. if cached {
  18. return entry.rawMessage, entry.expireAt, true
  19. }
  20. key := buf.Get(2 + len(qName))
  21. binary.BigEndian.PutUint16(key, qType)
  22. copy(key[2:], qName)
  23. defer buf.Put(key)
  24. err := c.view(func(tx *bbolt.Tx) error {
  25. bucket := c.bucket(tx, bucketDNSCache)
  26. if bucket == nil {
  27. return nil
  28. }
  29. bucket = bucket.Bucket([]byte(transportName))
  30. if bucket == nil {
  31. return nil
  32. }
  33. content := bucket.Get(key)
  34. if len(content) < 8 {
  35. return nil
  36. }
  37. expireAt = time.Unix(int64(binary.BigEndian.Uint64(content[:8])), 0)
  38. rawMessage = make([]byte, len(content)-8)
  39. copy(rawMessage, content[8:])
  40. loaded = true
  41. return nil
  42. })
  43. if err != nil {
  44. return nil, time.Time{}, false
  45. }
  46. return
  47. }
  48. func (c *CacheFile) SaveDNSCache(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time) error {
  49. return c.batch(func(tx *bbolt.Tx) error {
  50. bucket, err := c.createBucket(tx, bucketDNSCache)
  51. if err != nil {
  52. return err
  53. }
  54. bucket, err = bucket.CreateBucketIfNotExists([]byte(transportName))
  55. if err != nil {
  56. return err
  57. }
  58. key := buf.Get(2 + len(qName))
  59. binary.BigEndian.PutUint16(key, qType)
  60. copy(key[2:], qName)
  61. defer buf.Put(key)
  62. value := buf.Get(8 + len(rawMessage))
  63. defer buf.Put(value)
  64. binary.BigEndian.PutUint64(value[:8], uint64(expireAt.Unix()))
  65. copy(value[8:], rawMessage)
  66. return bucket.Put(key, value)
  67. })
  68. }
  69. func (c *CacheFile) SaveDNSCacheAsync(transportName string, qName string, qType uint16, rawMessage []byte, expireAt time.Time, logger logger.Logger) {
  70. saveKey := saveCacheKey{transportName, qName, qType}
  71. if !c.queueDNSCacheSave(saveKey, rawMessage, expireAt) {
  72. return
  73. }
  74. go c.flushPendingDNSCache(saveKey, logger)
  75. }
  76. func (c *CacheFile) queueDNSCacheSave(saveKey saveCacheKey, rawMessage []byte, expireAt time.Time) bool {
  77. c.saveDNSCacheAccess.Lock()
  78. defer c.saveDNSCacheAccess.Unlock()
  79. entry := c.saveDNSCache[saveKey]
  80. entry.rawMessage = append([]byte(nil), rawMessage...)
  81. entry.expireAt = expireAt
  82. entry.sequence++
  83. startFlush := !entry.saving
  84. entry.saving = true
  85. c.saveDNSCache[saveKey] = entry
  86. return startFlush
  87. }
  88. func (c *CacheFile) flushPendingDNSCache(saveKey saveCacheKey, logger logger.Logger) {
  89. c.flushPendingDNSCacheWith(saveKey, logger, func(entry saveDNSCacheEntry) error {
  90. return c.SaveDNSCache(saveKey.TransportName, saveKey.QuestionName, saveKey.QType, entry.rawMessage, entry.expireAt)
  91. })
  92. }
  93. func (c *CacheFile) flushPendingDNSCacheWith(saveKey saveCacheKey, logger logger.Logger, save func(saveDNSCacheEntry) error) {
  94. for {
  95. c.saveDNSCacheAccess.RLock()
  96. entry, loaded := c.saveDNSCache[saveKey]
  97. c.saveDNSCacheAccess.RUnlock()
  98. if !loaded {
  99. return
  100. }
  101. err := save(entry)
  102. if err != nil {
  103. logger.Warn("save DNS cache: ", err)
  104. }
  105. c.saveDNSCacheAccess.Lock()
  106. currentEntry, loaded := c.saveDNSCache[saveKey]
  107. if !loaded {
  108. c.saveDNSCacheAccess.Unlock()
  109. return
  110. }
  111. if currentEntry.sequence != entry.sequence {
  112. c.saveDNSCacheAccess.Unlock()
  113. continue
  114. }
  115. delete(c.saveDNSCache, saveKey)
  116. c.saveDNSCacheAccess.Unlock()
  117. return
  118. }
  119. }
  120. func (c *CacheFile) ClearDNSCache() error {
  121. c.saveDNSCacheAccess.Lock()
  122. clear(c.saveDNSCache)
  123. c.saveDNSCacheAccess.Unlock()
  124. return c.batch(func(tx *bbolt.Tx) error {
  125. if c.cacheID == nil {
  126. bucket := tx.Bucket(bucketDNSCache)
  127. if bucket == nil {
  128. return nil
  129. }
  130. return tx.DeleteBucket(bucketDNSCache)
  131. }
  132. bucket := tx.Bucket(c.cacheID)
  133. if bucket == nil || bucket.Bucket(bucketDNSCache) == nil {
  134. return nil
  135. }
  136. return bucket.DeleteBucket(bucketDNSCache)
  137. })
  138. }
  139. func (c *CacheFile) loopCacheCleanup(interval time.Duration, cleanupFunc func()) {
  140. ticker := time.NewTicker(interval)
  141. defer ticker.Stop()
  142. for {
  143. select {
  144. case <-c.ctx.Done():
  145. return
  146. case <-ticker.C:
  147. cleanupFunc()
  148. }
  149. }
  150. }
  151. func (c *CacheFile) cleanupDNSCache() {
  152. now := time.Now()
  153. err := c.batch(func(tx *bbolt.Tx) error {
  154. bucket := c.bucket(tx, bucketDNSCache)
  155. if bucket == nil {
  156. return nil
  157. }
  158. var emptyTransports [][]byte
  159. err := bucket.ForEachBucket(func(transportName []byte) error {
  160. transportBucket := bucket.Bucket(transportName)
  161. if transportBucket == nil {
  162. return nil
  163. }
  164. var expiredKeys [][]byte
  165. err := transportBucket.ForEach(func(key, value []byte) error {
  166. if len(value) < 8 {
  167. expiredKeys = append(expiredKeys, append([]byte(nil), key...))
  168. return nil
  169. }
  170. if c.disableExpire {
  171. return nil
  172. }
  173. expireAt := time.Unix(int64(binary.BigEndian.Uint64(value[:8])), 0)
  174. if now.After(expireAt.Add(c.optimisticTimeout)) {
  175. expiredKeys = append(expiredKeys, append([]byte(nil), key...))
  176. }
  177. return nil
  178. })
  179. if err != nil {
  180. return err
  181. }
  182. for _, key := range expiredKeys {
  183. err = transportBucket.Delete(key)
  184. if err != nil {
  185. return err
  186. }
  187. }
  188. first, _ := transportBucket.Cursor().First()
  189. if first == nil {
  190. emptyTransports = append(emptyTransports, append([]byte(nil), transportName...))
  191. }
  192. return nil
  193. })
  194. if err != nil {
  195. return err
  196. }
  197. for _, name := range emptyTransports {
  198. err = bucket.DeleteBucket(name)
  199. if err != nil {
  200. return err
  201. }
  202. }
  203. return nil
  204. })
  205. if err != nil {
  206. c.logger.Warn("cleanup DNS cache: ", err)
  207. }
  208. }
  209. func (c *CacheFile) clearRDRC() {
  210. c.saveRDRCAccess.Lock()
  211. clear(c.saveRDRC)
  212. c.saveRDRCAccess.Unlock()
  213. err := c.batch(func(tx *bbolt.Tx) error {
  214. if c.cacheID == nil {
  215. if tx.Bucket(bucketRDRC) == nil {
  216. return nil
  217. }
  218. return tx.DeleteBucket(bucketRDRC)
  219. }
  220. bucket := tx.Bucket(c.cacheID)
  221. if bucket == nil || bucket.Bucket(bucketRDRC) == nil {
  222. return nil
  223. }
  224. return bucket.DeleteBucket(bucketRDRC)
  225. })
  226. if err != nil {
  227. c.logger.Warn("clear RDRC: ", err)
  228. }
  229. }
  230. func (c *CacheFile) cleanupRDRC() {
  231. now := time.Now()
  232. err := c.batch(func(tx *bbolt.Tx) error {
  233. bucket := c.bucket(tx, bucketRDRC)
  234. if bucket == nil {
  235. return nil
  236. }
  237. var emptyTransports [][]byte
  238. err := bucket.ForEachBucket(func(transportName []byte) error {
  239. transportBucket := bucket.Bucket(transportName)
  240. if transportBucket == nil {
  241. return nil
  242. }
  243. var expiredKeys [][]byte
  244. err := transportBucket.ForEach(func(key, value []byte) error {
  245. if len(value) < 8 {
  246. expiredKeys = append(expiredKeys, append([]byte(nil), key...))
  247. return nil
  248. }
  249. expiresAt := time.Unix(int64(binary.BigEndian.Uint64(value)), 0)
  250. if now.After(expiresAt) {
  251. expiredKeys = append(expiredKeys, append([]byte(nil), key...))
  252. }
  253. return nil
  254. })
  255. if err != nil {
  256. return err
  257. }
  258. for _, key := range expiredKeys {
  259. err = transportBucket.Delete(key)
  260. if err != nil {
  261. return err
  262. }
  263. }
  264. first, _ := transportBucket.Cursor().First()
  265. if first == nil {
  266. emptyTransports = append(emptyTransports, append([]byte(nil), transportName...))
  267. }
  268. return nil
  269. })
  270. if err != nil {
  271. return err
  272. }
  273. for _, name := range emptyTransports {
  274. err = bucket.DeleteBucket(name)
  275. if err != nil {
  276. return err
  277. }
  278. }
  279. return nil
  280. })
  281. if err != nil {
  282. c.logger.Warn("cleanup RDRC: ", err)
  283. }
  284. }