世界 4 месяцев назад
Родитель
Сommit
2594745ef8
3 измененных файлов с 69 добавлено и 51 удалено
  1. 0 0
      common/compatible/map.go
  2. 68 50
      dns/client.go
  3. 1 1
      experimental/clashapi/trafficontrol/manager.go

+ 0 - 0
experimental/clashapi/compatible/map.go → common/compatible/map.go


+ 68 - 50
dns/client.go

@@ -2,12 +2,14 @@ package dns
 
 import (
 	"context"
+	"errors"
 	"net"
 	"net/netip"
 	"strings"
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/compatible"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing/common"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -30,16 +32,18 @@ var (
 var _ adapter.DNSClient = (*Client)(nil)
 
 type Client struct {
-	timeout          time.Duration
-	disableCache     bool
-	disableExpire    bool
-	independentCache bool
-	clientSubnet     netip.Prefix
-	rdrc             adapter.RDRCStore
-	initRDRCFunc     func() adapter.RDRCStore
-	logger           logger.ContextLogger
-	cache            freelru.Cache[dns.Question, *dns.Msg]
-	transportCache   freelru.Cache[transportCacheKey, *dns.Msg]
+	timeout            time.Duration
+	disableCache       bool
+	disableExpire      bool
+	independentCache   bool
+	clientSubnet       netip.Prefix
+	rdrc               adapter.RDRCStore
+	initRDRCFunc       func() adapter.RDRCStore
+	logger             logger.ContextLogger
+	cache              freelru.Cache[dns.Question, *dns.Msg]
+	cacheLock          compatible.Map[dns.Question, chan struct{}]
+	transportCache     freelru.Cache[transportCacheKey, *dns.Msg]
+	transportCacheLock compatible.Map[dns.Question, chan struct{}]
 }
 
 type ClientOptions struct {
@@ -96,17 +100,15 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 		if c.logger != nil {
 			c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
 		}
-		responseMessage := dns.Msg{
-			MsgHdr: dns.MsgHdr{
-				Id:       message.Id,
-				Response: true,
-				Rcode:    dns.RcodeFormatError,
-			},
-			Question: message.Question,
-		}
-		return &responseMessage, nil
+		return FixedResponseStatus(message, dns.RcodeFormatError), nil
 	}
 	question := message.Question[0]
+	if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only {
+		if c.logger != nil {
+			c.logger.DebugContext(ctx, "strategy rejected")
+		}
+		return FixedResponseStatus(message, dns.RcodeSuccess), nil
+	}
 	clientSubnet := options.ClientSubnet
 	if !clientSubnet.IsValid() {
 		clientSubnet = c.clientSubnet
@@ -120,6 +122,27 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 		!options.ClientSubnet.IsValid()
 	disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
 	if !disableCache {
+		if c.cache != nil {
+			cond, loaded := c.cacheLock.LoadOrStore(question, make(chan struct{}))
+			if loaded {
+				<-cond
+			} else {
+				defer func() {
+					c.cacheLock.Delete(question)
+					close(cond)
+				}()
+			}
+		} else if c.transportCache != nil {
+			cond, loaded := c.transportCacheLock.LoadOrStore(question, make(chan struct{}))
+			if loaded {
+				<-cond
+			} else {
+				defer func() {
+					c.transportCacheLock.Delete(question)
+					close(cond)
+				}()
+			}
+		}
 		response, ttl := c.loadResponse(question, transport)
 		if response != nil {
 			logCachedResponse(c.logger, ctx, response, ttl)
@@ -127,27 +150,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 			return response, nil
 		}
 	}
-	if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only {
-		responseMessage := dns.Msg{
-			MsgHdr: dns.MsgHdr{
-				Id:       message.Id,
-				Response: true,
-				Rcode:    dns.RcodeSuccess,
-			},
-			Question: []dns.Question{question},
-		}
-		if c.logger != nil {
-			c.logger.DebugContext(ctx, "strategy rejected")
-		}
-		return &responseMessage, nil
-	}
+
 	messageId := message.Id
 	contextTransport, clientSubnetLoaded := transportTagFromContext(ctx)
 	if clientSubnetLoaded && transport.Tag() == contextTransport {
 		return nil, E.New("DNS query loopback in transport[", contextTransport, "]")
 	}
 	ctx = contextWithTransportTag(ctx, transport.Tag())
-	if responseChecker != nil && c.rdrc != nil {
+	if !disableCache && responseChecker != nil && c.rdrc != nil {
 		rejected := c.rdrc.LoadRDRC(transport.Tag(), question.Name, question.Qtype)
 		if rejected {
 			return nil, ErrResponseRejectedCached
@@ -157,7 +167,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 	response, err := transport.Exchange(ctx, message)
 	cancel()
 	if err != nil {
-		return nil, err
+		var rcodeError RcodeError
+		if errors.As(err, &rcodeError) {
+			response = FixedResponseStatus(message, int(rcodeError))
+		} else {
+			return nil, err
+		}
 	}
 	/*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
 		validResponse := response
@@ -196,13 +211,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
 	}*/
 	if responseChecker != nil {
 		var rejected bool
-		if !(response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError) {
+		// TODO: add accept_any rule and support to check response instead of addresses
+		if response.Rcode != dns.RcodeSuccess || len(response.Answer) == 0 {
 			rejected = true
 		} else {
 			rejected = !responseChecker(MessageToAddresses(response))
 		}
 		if rejected {
-			if c.rdrc != nil {
+			if !disableCache && c.rdrc != nil {
 				c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
 			}
 			logRejectedResponse(c.logger, ctx, response)
@@ -305,8 +321,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
 func (c *Client) ClearCache() {
 	if c.cache != nil {
 		c.cache.Purge()
-	}
-	if c.transportCache != nil {
+	} else if c.transportCache != nil {
 		c.transportCache.Purge()
 	}
 }
@@ -390,15 +405,15 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio
 				transportTag: transport.Tag(),
 			}, message)
 		}
-		return
-	}
-	if !c.independentCache {
-		c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
 	} else {
-		c.transportCache.AddWithLifetime(transportCacheKey{
-			Question:     question,
-			transportTag: transport.Tag(),
-		}, message, time.Second*time.Duration(timeToLive))
+		if !c.independentCache {
+			c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
+		} else {
+			c.transportCache.AddWithLifetime(transportCacheKey{
+				Question:     question,
+				transportTag: transport.Tag(),
+			}, message, time.Second*time.Duration(timeToLive))
+		}
 	}
 }
 
@@ -564,9 +579,12 @@ func transportTagFromContext(ctx context.Context) (string, bool) {
 func FixedResponseStatus(message *dns.Msg, rcode int) *dns.Msg {
 	return &dns.Msg{
 		MsgHdr: dns.MsgHdr{
-			Id:       message.Id,
-			Rcode:    rcode,
-			Response: true,
+			Id:                 message.Id,
+			Response:           true,
+			Authoritative:      true,
+			RecursionDesired:   true,
+			RecursionAvailable: true,
+			Rcode:              rcode,
 		},
 		Question: message.Question,
 	}

+ 1 - 1
experimental/clashapi/trafficontrol/manager.go

@@ -6,8 +6,8 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/sagernet/sing-box/common/compatible"
 	C "github.com/sagernet/sing-box/constant"
-	"github.com/sagernet/sing-box/experimental/clashapi/compatible"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/json"
 	"github.com/sagernet/sing/common/x/list"