Explorar el Código

Add disableCache/disableExpire option for dns client

世界 hace 3 años
padre
commit
ecac383477
Se han modificado 5 ficheros con 137 adiciones y 104 borrados
  1. 127 100
      dns/client.go
  2. 1 1
      dns/transport_tcp.go
  3. 1 1
      dns/transport_tls.go
  4. 1 1
      dns/transport_udp.go
  5. 7 1
      option/dns.go

+ 127 - 100
dns/client.go

@@ -13,6 +13,7 @@ import (
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/option"
 
 	"golang.org/x/net/dns/dnsmessage"
 )
@@ -27,12 +28,21 @@ var (
 var _ adapter.DNSClient = (*Client)(nil)
 
 type Client struct {
-	cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message]
+	cache         *cache.LruCache[dnsmessage.Question, *dnsmessage.Message]
+	disableCache  bool
+	disableExpire bool
 }
 
-func NewClient() *Client {
-	return &Client{
-		cache: cache.New[dnsmessage.Question, dnsmessage.Message](),
+func NewClient(options option.DNSClientOptions) *Client {
+	if options.DisableCache {
+		return &Client{
+			disableCache: true,
+		}
+	} else {
+		return &Client{
+			cache:         cache.New[dnsmessage.Question, *dnsmessage.Message](),
+			disableExpire: options.DisableExpire,
+		}
 	}
 }
 
@@ -41,10 +51,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 		return nil, E.New("empty query")
 	}
 	question := message.Questions[0]
-	cachedAnswer, cached := c.cache.Load(question)
-	if cached {
-		cachedAnswer.ID = message.ID
-		return &cachedAnswer, nil
+	if !c.disableCache {
+		cachedAnswer, cached := c.cache.Load(question)
+		if cached {
+			cachedAnswer.ID = message.ID
+			return cachedAnswer, nil
+		}
 	}
 	if !transport.Raw() {
 		if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
@@ -56,7 +68,9 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 	if err != nil {
 		return nil, err
 	}
-	c.cache.StoreWithExpire(question, *response, calculateExpire(message))
+	if !c.disableCache {
+		c.storeCache(question, response)
+	}
 	return message, err
 }
 
@@ -93,37 +107,39 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
 		}
 		return sortAddresses(response4, response6, strategy), nil
 	}
-	if strategy == C.DomainStrategyUseIPv4 {
-		response, err := c.questionCache(dnsmessage.Question{
-			Name:  dnsName,
-			Type:  dnsmessage.TypeA,
-			Class: dnsmessage.ClassINET,
-		})
-		if err != ErrNotCached {
-			return response, err
-		}
-	} else if strategy == C.DomainStrategyUseIPv6 {
-		response, err := c.questionCache(dnsmessage.Question{
-			Name:  dnsName,
-			Type:  dnsmessage.TypeAAAA,
-			Class: dnsmessage.ClassINET,
-		})
-		if err != ErrNotCached {
-			return response, err
-		}
-	} else {
-		response4, _ := c.questionCache(dnsmessage.Question{
-			Name:  dnsName,
-			Type:  dnsmessage.TypeA,
-			Class: dnsmessage.ClassINET,
-		})
-		response6, _ := c.questionCache(dnsmessage.Question{
-			Name:  dnsName,
-			Type:  dnsmessage.TypeAAAA,
-			Class: dnsmessage.ClassINET,
-		})
-		if len(response4) > 0 || len(response6) > 0 {
-			return sortAddresses(response4, response6, strategy), nil
+	if !c.disableCache {
+		if strategy == C.DomainStrategyUseIPv4 {
+			response, err := c.questionCache(dnsmessage.Question{
+				Name:  dnsName,
+				Type:  dnsmessage.TypeA,
+				Class: dnsmessage.ClassINET,
+			})
+			if err != ErrNotCached {
+				return response, err
+			}
+		} else if strategy == C.DomainStrategyUseIPv6 {
+			response, err := c.questionCache(dnsmessage.Question{
+				Name:  dnsName,
+				Type:  dnsmessage.TypeAAAA,
+				Class: dnsmessage.ClassINET,
+			})
+			if err != ErrNotCached {
+				return response, err
+			}
+		} else {
+			response4, _ := c.questionCache(dnsmessage.Question{
+				Name:  dnsName,
+				Type:  dnsmessage.TypeA,
+				Class: dnsmessage.ClassINET,
+			})
+			response6, _ := c.questionCache(dnsmessage.Question{
+				Name:  dnsName,
+				Type:  dnsmessage.TypeAAAA,
+				Class: dnsmessage.ClassINET,
+			})
+			if len(response4) > 0 || len(response6) > 0 {
+				return sortAddresses(response4, response6, strategy), nil
+			}
 		}
 	}
 	var rCode dnsmessage.RCode
@@ -135,70 +151,74 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
 		} else {
 			rCode = dnsmessage.RCode(rCodeError)
 		}
+		if c.disableCache {
+			return nil, err
+		}
 	}
 	header := dnsmessage.Header{
 		Response:      true,
 		Authoritative: true,
 		RCode:         rCode,
 	}
-	expire := time.Now().Add(time.Second * time.Duration(DefaultTTL))
-	if strategy != C.DomainStrategyUseIPv6 {
-		question4 := dnsmessage.Question{
-			Name:  dnsName,
-			Type:  dnsmessage.TypeA,
-			Class: dnsmessage.ClassINET,
-		}
-		response4 := common.Filter(response, func(addr netip.Addr) bool {
-			return addr.Is4() || addr.Is4In6()
-		})
-		message4 := dnsmessage.Message{
-			Header:    header,
-			Questions: []dnsmessage.Question{question4},
-		}
-		if len(response4) > 0 {
-			for _, address := range response4 {
-				message4.Answers = append(message4.Answers, dnsmessage.Resource{
-					Header: dnsmessage.ResourceHeader{
-						Name:  question4.Name,
-						Class: question4.Class,
-						TTL:   DefaultTTL,
-					},
-					Body: &dnsmessage.AResource{
-						A: address.As4(),
-					},
-				})
+	if !c.disableCache {
+		if strategy != C.DomainStrategyUseIPv6 {
+			question4 := dnsmessage.Question{
+				Name:  dnsName,
+				Type:  dnsmessage.TypeA,
+				Class: dnsmessage.ClassINET,
 			}
+			response4 := common.Filter(response, func(addr netip.Addr) bool {
+				return addr.Is4() || addr.Is4In6()
+			})
+			message4 := &dnsmessage.Message{
+				Header:    header,
+				Questions: []dnsmessage.Question{question4},
+			}
+			if len(response4) > 0 {
+				for _, address := range response4 {
+					message4.Answers = append(message4.Answers, dnsmessage.Resource{
+						Header: dnsmessage.ResourceHeader{
+							Name:  question4.Name,
+							Class: question4.Class,
+							TTL:   DefaultTTL,
+						},
+						Body: &dnsmessage.AResource{
+							A: address.As4(),
+						},
+					})
+				}
+			}
+			c.storeCache(question4, message4)
 		}
-		c.cache.StoreWithExpire(question4, message4, expire)
-	}
-	if strategy != C.DomainStrategyUseIPv4 {
-		question6 := dnsmessage.Question{
-			Name:  dnsName,
-			Type:  dnsmessage.TypeAAAA,
-			Class: dnsmessage.ClassINET,
-		}
-		response6 := common.Filter(response, func(addr netip.Addr) bool {
-			return addr.Is6() && !addr.Is4In6()
-		})
-		message6 := dnsmessage.Message{
-			Header:    header,
-			Questions: []dnsmessage.Question{question6},
-		}
-		if len(response6) > 0 {
-			for _, address := range response6 {
-				message6.Answers = append(message6.Answers, dnsmessage.Resource{
-					Header: dnsmessage.ResourceHeader{
-						Name:  question6.Name,
-						Class: question6.Class,
-						TTL:   DefaultTTL,
-					},
-					Body: &dnsmessage.AAAAResource{
-						AAAA: address.As16(),
-					},
-				})
+		if strategy != C.DomainStrategyUseIPv4 {
+			question6 := dnsmessage.Question{
+				Name:  dnsName,
+				Type:  dnsmessage.TypeAAAA,
+				Class: dnsmessage.ClassINET,
+			}
+			response6 := common.Filter(response, func(addr netip.Addr) bool {
+				return addr.Is6() && !addr.Is4In6()
+			})
+			message6 := &dnsmessage.Message{
+				Header:    header,
+				Questions: []dnsmessage.Question{question6},
+			}
+			if len(response6) > 0 {
+				for _, address := range response6 {
+					message6.Answers = append(message6.Answers, dnsmessage.Resource{
+						Header: dnsmessage.ResourceHeader{
+							Name:  question6.Name,
+							Class: question6.Class,
+							TTL:   DefaultTTL,
+						},
+						Body: &dnsmessage.AAAAResource{
+							AAAA: address.As16(),
+						},
+					})
+				}
 			}
+			c.storeCache(question6, message6)
 		}
-		c.cache.StoreWithExpire(question6, message6, expire)
 	}
 	return response, err
 }
@@ -211,14 +231,19 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.Do
 	}
 }
 
-func calculateExpire(message *dnsmessage.Message) time.Time {
+func (c *Client) storeCache(question dnsmessage.Question, message *dnsmessage.Message) {
+	if c.disableExpire {
+		c.cache.Store(question, message)
+		return
+	}
 	timeToLive := DefaultTTL
 	for _, answer := range message.Answers {
 		if int(answer.Header.TTL) < timeToLive {
 			timeToLive = int(answer.Header.TTL)
 		}
 	}
-	return time.Now().Add(time.Second * time.Duration(timeToLive))
+	expire := time.Now().Add(time.Second * time.Duration(timeToLive))
+	c.cache.StoreWithExpire(question, message, expire)
 }
 
 func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
@@ -275,9 +300,11 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
 		Type:  qType,
 		Class: dnsmessage.ClassINET,
 	}
-	cachedAddresses, err := c.questionCache(question)
-	if err != ErrNotCached {
-		return cachedAddresses, err
+	if !c.disableCache {
+		cachedAddresses, err := c.questionCache(question)
+		if err != ErrNotCached {
+			return cachedAddresses, err
+		}
 	}
 	message := dnsmessage.Message{
 		Header: dnsmessage.Header{
@@ -298,7 +325,7 @@ func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, erro
 	if !cached {
 		return nil, ErrNotCached
 	}
-	return messageToAddresses(&response)
+	return messageToAddresses(response)
 }
 
 func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {

+ 1 - 1
dns/transport_tcp.go

@@ -91,7 +91,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
 	cancel()
 	conn.err = err
 	if err != nil {
-		t.logger.Warn("connection closed: ", err)
+		t.logger.Debug("connection closed: ", err)
 	}
 }
 

+ 1 - 1
dns/transport_tls.go

@@ -99,7 +99,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
 	cancel()
 	conn.err = err
 	if err != nil {
-		t.logger.Warn("connection closed: ", err)
+		t.logger.Debug("connection closed: ", err)
 	}
 }
 

+ 1 - 1
dns/transport_udp.go

@@ -87,7 +87,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
 	cancel()
 	conn.err = err
 	if err != nil {
-		t.logger.Warn("connection closed: ", err)
+		t.logger.Debug("connection closed: ", err)
 	}
 }
 

+ 7 - 1
option/dns.go

@@ -2,11 +2,17 @@ package option
 
 type DNSOptions struct {
 	Servers []DNSServerOptions `json:"servers,omitempty"`
+	DNSClientOptions
+}
+
+type DNSClientOptions struct {
+	DisableCache  bool `json:"disable_cache,omitempty"`
+	DisableExpire bool `json:"disable_expire,omitempty"`
 }
 
 type DNSServerOptions struct {
 	Tag             string `json:"tag,omitempty"`
 	Address         string `json:"address"`
-	Detour          string `json:"detour,omitempty"`
 	AddressResolver string `json:"address_resolver,omitempty"`
+	DialerOptions
 }