Browse Source

Fix DNS negative caching to comply with RFC 2308

世界 1 week ago
parent
commit
36dc883c7c
1 changed files with 26 additions and 5 deletions
  1. 26 5
      dns/client.go

+ 26 - 5
dns/client.go

@@ -95,6 +95,20 @@ func (c *Client) Start() {
 	}
 	}
 }
 }
 
 
+func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
+	for _, record := range response.Ns {
+		if soa, isSOA := record.(*dns.SOA); isSOA {
+			soaTTL := soa.Header().Ttl
+			soaMinimum := soa.Minttl
+			if soaTTL < soaMinimum {
+				return soaTTL, true
+			}
+			return soaMinimum, true
+		}
+	}
+	return 0, false
+}
+
 func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
 func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
 	if len(message.Question) == 0 {
 	if len(message.Question) == 0 {
 		if c.logger != nil {
 		if c.logger != nil {
@@ -214,7 +228,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 			response.Answer = append(response.Answer, validResponse.Answer...)
 			response.Answer = append(response.Answer, validResponse.Answer...)
 		}
 		}
 	}*/
 	}*/
-	disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError) || len(response.Answer) == 0
+	disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError)
 	if responseChecker != nil {
 	if responseChecker != nil {
 		var rejected bool
 		var rejected bool
 		// TODO: add accept_any rule and support to check response instead of addresses
 		// TODO: add accept_any rule and support to check response instead of addresses
@@ -251,10 +265,17 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 		}
 		}
 	}
 	}
 	var timeToLive uint32
 	var timeToLive uint32
-	for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
-		for _, record := range recordList {
-			if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
-				timeToLive = record.Header().Ttl
+	if len(response.Answer) == 0 {
+		if soaTTL, hasSOA := extractNegativeTTL(response); hasSOA {
+			timeToLive = soaTTL
+		}
+	}
+	if timeToLive == 0 {
+		for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
+			for _, record := range recordList {
+				if timeToLive == 0 || record.Header().Ttl > 0 && record.Header().Ttl < timeToLive {
+					timeToLive = record.Header().Ttl
+				}
 			}
 			}
 		}
 		}
 	}
 	}