|
|
@@ -3,24 +3,37 @@ package dns
|
|
|
import (
|
|
|
"context"
|
|
|
go_errors "errors"
|
|
|
+ "runtime"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+
|
|
|
"github.com/xtls/xray-core/common"
|
|
|
"github.com/xtls/xray-core/common/errors"
|
|
|
- "github.com/xtls/xray-core/common/net"
|
|
|
"github.com/xtls/xray-core/common/signal/pubsub"
|
|
|
"github.com/xtls/xray-core/common/task"
|
|
|
dns_feature "github.com/xtls/xray-core/features/dns"
|
|
|
+
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
|
- "sync"
|
|
|
- "time"
|
|
|
+ "golang.org/x/sync/singleflight"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ minSizeForEmptyRebuild = 512
|
|
|
+ shrinkAbsoluteThreshold = 10240
|
|
|
+ shrinkRatioThreshold = 0.65
|
|
|
+ migrationBatchSize = 4096
|
|
|
)
|
|
|
|
|
|
type CacheController struct {
|
|
|
sync.RWMutex
|
|
|
- ips map[string]*record
|
|
|
- pub *pubsub.Service
|
|
|
- cacheCleanup *task.Periodic
|
|
|
- name string
|
|
|
- disableCache bool
|
|
|
+ ips map[string]*record
|
|
|
+ dirtyips map[string]*record
|
|
|
+ pub *pubsub.Service
|
|
|
+ cacheCleanup *task.Periodic
|
|
|
+ name string
|
|
|
+ disableCache bool
|
|
|
+ highWatermark int
|
|
|
+ requestGroup singleflight.Group
|
|
|
}
|
|
|
|
|
|
func NewCacheController(name string, disableCache bool) *CacheController {
|
|
|
@@ -32,7 +45,7 @@ func NewCacheController(name string, disableCache bool) *CacheController {
|
|
|
}
|
|
|
|
|
|
c.cacheCleanup = &task.Periodic{
|
|
|
- Interval: time.Minute,
|
|
|
+ Interval: 300 * time.Second,
|
|
|
Execute: c.CacheCleanup,
|
|
|
}
|
|
|
return c
|
|
|
@@ -40,131 +53,253 @@ func NewCacheController(name string, disableCache bool) *CacheController {
|
|
|
|
|
|
// CacheCleanup clears expired items from cache
|
|
|
func (c *CacheController) CacheCleanup() error {
|
|
|
+ expiredKeys, err := c.collectExpiredKeys()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if len(expiredKeys) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ c.writeAndShrink(expiredKeys)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *CacheController) collectExpiredKeys() ([]string, error) {
|
|
|
+ c.RLock()
|
|
|
+ defer c.RUnlock()
|
|
|
+
|
|
|
+ if len(c.ips) == 0 {
|
|
|
+ return nil, errors.New("nothing to do. stopping...")
|
|
|
+ }
|
|
|
+
|
|
|
+ // skip collection if a migration is in progress
|
|
|
+ if c.dirtyips != nil {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+
|
|
|
now := time.Now()
|
|
|
+ expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate
|
|
|
+
|
|
|
+ for domain, rec := range c.ips {
|
|
|
+ if (rec.A != nil && rec.A.Expire.Before(now)) ||
|
|
|
+ (rec.AAAA != nil && rec.AAAA.Expire.Before(now)) {
|
|
|
+ expiredKeys = append(expiredKeys, domain)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return expiredKeys, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *CacheController) writeAndShrink(expiredKeys []string) {
|
|
|
c.Lock()
|
|
|
defer c.Unlock()
|
|
|
|
|
|
- if len(c.ips) == 0 {
|
|
|
- return errors.New("nothing to do. stopping...")
|
|
|
+ // double check to prevent upper call multiple cleanup tasks
|
|
|
+ if c.dirtyips != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ lenBefore := len(c.ips)
|
|
|
+ if lenBefore > c.highWatermark {
|
|
|
+ c.highWatermark = lenBefore
|
|
|
}
|
|
|
|
|
|
- for domain, record := range c.ips {
|
|
|
- if record.A != nil && record.A.Expire.Before(now) {
|
|
|
- record.A = nil
|
|
|
+ now := time.Now()
|
|
|
+ for _, domain := range expiredKeys {
|
|
|
+ rec := c.ips[domain]
|
|
|
+ if rec == nil {
|
|
|
+ continue
|
|
|
}
|
|
|
- if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
|
|
- record.AAAA = nil
|
|
|
+ if rec.A != nil && rec.A.Expire.Before(now) {
|
|
|
+ rec.A = nil
|
|
|
}
|
|
|
-
|
|
|
- if record.A == nil && record.AAAA == nil {
|
|
|
- errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
|
|
|
+ if rec.AAAA != nil && rec.AAAA.Expire.Before(now) {
|
|
|
+ rec.AAAA = nil
|
|
|
+ }
|
|
|
+ if rec.A == nil && rec.AAAA == nil {
|
|
|
delete(c.ips, domain)
|
|
|
- } else {
|
|
|
- c.ips[domain] = record
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if len(c.ips) == 0 {
|
|
|
- c.ips = make(map[string]*record)
|
|
|
+ lenAfter := len(c.ips)
|
|
|
+
|
|
|
+ if lenAfter == 0 {
|
|
|
+ if c.highWatermark >= minSizeForEmptyRebuild {
|
|
|
+ errors.LogDebug(context.Background(), c.name,
|
|
|
+ " rebuilding empty cache map to reclaim memory.",
|
|
|
+ " size_before_cleanup=", lenBefore,
|
|
|
+ " peak_size_before_rebuild=", c.highWatermark,
|
|
|
+ )
|
|
|
+
|
|
|
+ c.ips = make(map[string]*record)
|
|
|
+ c.highWatermark = 0
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if reductionFromPeak := c.highWatermark - lenAfter; reductionFromPeak > shrinkAbsoluteThreshold &&
|
|
|
+ float64(reductionFromPeak) > float64(c.highWatermark)*shrinkRatioThreshold {
|
|
|
+ errors.LogDebug(context.Background(), c.name,
|
|
|
+ " shrinking cache map to reclaim memory.",
|
|
|
+ " new_size=", lenAfter,
|
|
|
+ " peak_size_before_shrink=", c.highWatermark,
|
|
|
+ " reduction_since_peak=", reductionFromPeak,
|
|
|
+ )
|
|
|
+
|
|
|
+ c.dirtyips = c.ips
|
|
|
+ c.ips = make(map[string]*record, int(float64(lenAfter)*1.1))
|
|
|
+ c.highWatermark = lenAfter
|
|
|
+ go c.migrate()
|
|
|
}
|
|
|
|
|
|
- return nil
|
|
|
}
|
|
|
|
|
|
-func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
|
|
- elapsed := time.Since(req.start)
|
|
|
+type migrationEntry struct {
|
|
|
+ key string
|
|
|
+ value *record
|
|
|
+}
|
|
|
|
|
|
- c.Lock()
|
|
|
- rec, found := c.ips[req.domain]
|
|
|
- if !found {
|
|
|
- rec = &record{}
|
|
|
+func (c *CacheController) migrate() {
|
|
|
+ defer func() {
|
|
|
+ if r := recover(); r != nil {
|
|
|
+ errors.LogError(context.Background(), c.name, " panic during cache migration: ", r)
|
|
|
+ c.Lock()
|
|
|
+ c.dirtyips = nil
|
|
|
+ // c.ips = make(map[string]*record)
|
|
|
+ // c.highWatermark = 0
|
|
|
+ c.Unlock()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ c.RLock()
|
|
|
+ dirtyips := c.dirtyips
|
|
|
+ c.RUnlock()
|
|
|
+
|
|
|
+ // double check to prevent upper call multiple cleanup tasks
|
|
|
+ if dirtyips == nil {
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- switch req.reqType {
|
|
|
- case dnsmessage.TypeA:
|
|
|
- rec.A = ipRec
|
|
|
- case dnsmessage.TypeAAAA:
|
|
|
- rec.AAAA = ipRec
|
|
|
+ errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items.")
|
|
|
+
|
|
|
+ batch := make([]migrationEntry, 0, migrationBatchSize)
|
|
|
+ for domain, recD := range dirtyips {
|
|
|
+ batch = append(batch, migrationEntry{domain, recD})
|
|
|
+
|
|
|
+ if len(batch) >= migrationBatchSize {
|
|
|
+ c.flush(batch)
|
|
|
+ batch = batch[:0]
|
|
|
+ runtime.Gosched()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if len(batch) > 0 {
|
|
|
+ c.flush(batch)
|
|
|
}
|
|
|
|
|
|
- errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
|
|
|
- c.ips[req.domain] = rec
|
|
|
+ c.Lock()
|
|
|
+ c.dirtyips = nil
|
|
|
+ c.Unlock()
|
|
|
|
|
|
- switch req.reqType {
|
|
|
- case dnsmessage.TypeA:
|
|
|
- c.pub.Publish(req.domain+"4", nil)
|
|
|
- if !c.disableCache {
|
|
|
- _, _, err := rec.AAAA.getIPs()
|
|
|
- if !go_errors.Is(err, errRecordNotFound) {
|
|
|
- c.pub.Publish(req.domain+"6", nil)
|
|
|
+ errors.LogDebug(context.Background(), c.name, " cache migration completed.")
|
|
|
+}
|
|
|
+
|
|
|
+func (c *CacheController) flush(batch []migrationEntry) {
|
|
|
+ c.Lock()
|
|
|
+ defer c.Unlock()
|
|
|
+
|
|
|
+ for _, dirty := range batch {
|
|
|
+ if cur := c.ips[dirty.key]; cur != nil {
|
|
|
+ merge := &record{}
|
|
|
+ if cur.A == nil {
|
|
|
+ merge.A = dirty.value.A
|
|
|
+ } else {
|
|
|
+ merge.A = cur.A
|
|
|
}
|
|
|
- }
|
|
|
- case dnsmessage.TypeAAAA:
|
|
|
- c.pub.Publish(req.domain+"6", nil)
|
|
|
- if !c.disableCache {
|
|
|
- _, _, err := rec.A.getIPs()
|
|
|
- if !go_errors.Is(err, errRecordNotFound) {
|
|
|
- c.pub.Publish(req.domain+"4", nil)
|
|
|
+ if cur.AAAA == nil {
|
|
|
+ merge.AAAA = dirty.value.AAAA
|
|
|
+ } else {
|
|
|
+ merge.AAAA = cur.AAAA
|
|
|
}
|
|
|
+ c.ips[dirty.key] = merge
|
|
|
+ } else {
|
|
|
+ c.ips[dirty.key] = dirty.value
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- c.Unlock()
|
|
|
- common.Must(c.cacheCleanup.Start())
|
|
|
}
|
|
|
|
|
|
-func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
|
|
- c.RLock()
|
|
|
- record, found := c.ips[domain]
|
|
|
- c.RUnlock()
|
|
|
+func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) {
|
|
|
+ rtt := time.Since(req.start)
|
|
|
|
|
|
- if !found {
|
|
|
- return nil, 0, errRecordNotFound
|
|
|
+ switch req.reqType {
|
|
|
+ case dnsmessage.TypeA:
|
|
|
+ c.pub.Publish(req.domain+"4", rep)
|
|
|
+ case dnsmessage.TypeAAAA:
|
|
|
+ c.pub.Publish(req.domain+"6", rep)
|
|
|
}
|
|
|
|
|
|
- var errs []error
|
|
|
- var allIPs []net.IP
|
|
|
- var rTTL uint32 = dns_feature.DefaultTTL
|
|
|
+ if c.disableCache {
|
|
|
+ errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt)
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- mergeReq := option.IPv4Enable && option.IPv6Enable
|
|
|
+ c.Lock()
|
|
|
+ lockWait := time.Since(req.start) - rtt
|
|
|
|
|
|
- if option.IPv4Enable {
|
|
|
- ips, ttl, err := record.A.getIPs()
|
|
|
- if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
|
|
- return ips, ttl, err
|
|
|
- }
|
|
|
- if ttl < rTTL {
|
|
|
- rTTL = ttl
|
|
|
- }
|
|
|
- if len(ips) > 0 {
|
|
|
- allIPs = append(allIPs, ips...)
|
|
|
- } else {
|
|
|
- errs = append(errs, err)
|
|
|
- }
|
|
|
+ newRec := &record{}
|
|
|
+ oldRec := c.ips[req.domain]
|
|
|
+ var dirtyRec *record
|
|
|
+ if c.dirtyips != nil {
|
|
|
+ dirtyRec = c.dirtyips[req.domain]
|
|
|
}
|
|
|
|
|
|
- if option.IPv6Enable {
|
|
|
- ips, ttl, err := record.AAAA.getIPs()
|
|
|
- if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
|
|
- return ips, ttl, err
|
|
|
- }
|
|
|
- if ttl < rTTL {
|
|
|
- rTTL = ttl
|
|
|
+ var pubRecord *IPRecord
|
|
|
+ var pubSuffix string
|
|
|
+
|
|
|
+ switch req.reqType {
|
|
|
+ case dnsmessage.TypeA:
|
|
|
+ newRec.A = rep
|
|
|
+ if oldRec != nil && oldRec.AAAA != nil {
|
|
|
+ newRec.AAAA = oldRec.AAAA
|
|
|
+ pubRecord = oldRec.AAAA
|
|
|
+ } else if dirtyRec != nil && dirtyRec.AAAA != nil {
|
|
|
+ pubRecord = dirtyRec.AAAA
|
|
|
}
|
|
|
- if len(ips) > 0 {
|
|
|
- allIPs = append(allIPs, ips...)
|
|
|
- } else {
|
|
|
- errs = append(errs, err)
|
|
|
+ pubSuffix = "6"
|
|
|
+ case dnsmessage.TypeAAAA:
|
|
|
+ newRec.AAAA = rep
|
|
|
+ if oldRec != nil && oldRec.A != nil {
|
|
|
+ newRec.A = oldRec.A
|
|
|
+ pubRecord = oldRec.A
|
|
|
+ } else if dirtyRec != nil && dirtyRec.A != nil {
|
|
|
+ pubRecord = dirtyRec.A
|
|
|
}
|
|
|
+ pubSuffix = "4"
|
|
|
}
|
|
|
|
|
|
- if len(allIPs) > 0 {
|
|
|
- return allIPs, rTTL, nil
|
|
|
+ c.ips[req.domain] = newRec
|
|
|
+ c.Unlock()
|
|
|
+
|
|
|
+ if pubRecord != nil {
|
|
|
+ _, _ /*ttl*/, err := pubRecord.getIPs()
|
|
|
+ if /*ttl >= 0 &&*/ !go_errors.Is(err, errRecordNotFound) {
|
|
|
+ c.pub.Publish(req.domain+pubSuffix, pubRecord)
|
|
|
+ }
|
|
|
}
|
|
|
- if go_errors.Is(errs[0], errs[1]) {
|
|
|
- return nil, rTTL, errs[0]
|
|
|
+
|
|
|
+ errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait)
|
|
|
+
|
|
|
+ common.Must(c.cacheCleanup.Start())
|
|
|
+}
|
|
|
+
|
|
|
+func (c *CacheController) findRecords(domain string) *record {
|
|
|
+ c.RLock()
|
|
|
+ defer c.RUnlock()
|
|
|
+
|
|
|
+ rec := c.ips[domain]
|
|
|
+ if rec == nil && c.dirtyips != nil {
|
|
|
+ rec = c.dirtyips[domain]
|
|
|
}
|
|
|
- return nil, rTTL, errors.Combine(errs...)
|
|
|
+ return rec
|
|
|
}
|
|
|
|
|
|
func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
|