Ver código fonte

Add disable_cache option to dns rule

世界 3 anos atrás
pai
commit
af19ba6119
6 arquivos alterados com 127 adições e 82 exclusões
  1. 5 0
      adapter/router.go
  2. 13 5
      option/dns.go
  3. 2 0
      option/route.go
  4. 15 9
      route/router.go
  5. 36 30
      route/rule.go
  6. 56 38
      route/rule_dns.go

+ 5 - 0
adapter/router.go

@@ -50,3 +50,8 @@ type Rule interface {
 	Outbound() string
 	Outbound() string
 	String() string
 	String() string
 }
 }
+
+type DNSRule interface {
+	Rule
+	DisableCache() bool
+}

+ 13 - 5
option/dns.go

@@ -55,6 +55,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) {
 	var v any
 	var v any
 	switch r.Type {
 	switch r.Type {
 	case C.RuleTypeDefault:
 	case C.RuleTypeDefault:
+		r.Type = ""
 		v = r.DefaultOptions
 		v = r.DefaultOptions
 	case C.RuleTypeLogical:
 	case C.RuleTypeLogical:
 		v = r.LogicalOptions
 		v = r.LogicalOptions
@@ -109,6 +110,7 @@ type DefaultDNSRule struct {
 	Outbound      Listable[string] `json:"outbound,omitempty"`
 	Outbound      Listable[string] `json:"outbound,omitempty"`
 	Invert        bool             `json:"invert,omitempty"`
 	Invert        bool             `json:"invert,omitempty"`
 	Server        string           `json:"server,omitempty"`
 	Server        string           `json:"server,omitempty"`
+	DisableCache  bool             `json:"disable_cache,omitempty"`
 }
 }
 
 
 func (r DefaultDNSRule) IsValid() bool {
 func (r DefaultDNSRule) IsValid() bool {
@@ -135,13 +137,17 @@ func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool {
 		common.ComparableSliceEquals(r.UserID, other.UserID) &&
 		common.ComparableSliceEquals(r.UserID, other.UserID) &&
 		common.ComparableSliceEquals(r.PackageName, other.PackageName) &&
 		common.ComparableSliceEquals(r.PackageName, other.PackageName) &&
 		common.ComparableSliceEquals(r.Outbound, other.Outbound) &&
 		common.ComparableSliceEquals(r.Outbound, other.Outbound) &&
-		r.Server == other.Server
+		r.Invert == other.Invert &&
+		r.Server == other.Server &&
+		r.DisableCache == other.DisableCache
 }
 }
 
 
 type LogicalDNSRule struct {
 type LogicalDNSRule struct {
-	Mode   string           `json:"mode"`
-	Rules  []DefaultDNSRule `json:"rules,omitempty"`
-	Server string           `json:"server,omitempty"`
+	Mode         string           `json:"mode"`
+	Rules        []DefaultDNSRule `json:"rules,omitempty"`
+	Invert       bool             `json:"invert,omitempty"`
+	Server       string           `json:"server,omitempty"`
+	DisableCache bool             `json:"disable_cache,omitempty"`
 }
 }
 
 
 func (r LogicalDNSRule) IsValid() bool {
 func (r LogicalDNSRule) IsValid() bool {
@@ -151,5 +157,7 @@ func (r LogicalDNSRule) IsValid() bool {
 func (r LogicalDNSRule) Equals(other LogicalDNSRule) bool {
 func (r LogicalDNSRule) Equals(other LogicalDNSRule) bool {
 	return r.Mode == other.Mode &&
 	return r.Mode == other.Mode &&
 		common.SliceEquals(r.Rules, other.Rules) &&
 		common.SliceEquals(r.Rules, other.Rules) &&
-		r.Server == other.Server
+		r.Invert == other.Invert &&
+		r.Server == other.Server &&
+		r.DisableCache == other.DisableCache
 }
 }

+ 2 - 0
option/route.go

@@ -145,6 +145,7 @@ func (r DefaultRule) Equals(other DefaultRule) bool {
 type LogicalRule struct {
 type LogicalRule struct {
 	Mode     string        `json:"mode"`
 	Mode     string        `json:"mode"`
 	Rules    []DefaultRule `json:"rules,omitempty"`
 	Rules    []DefaultRule `json:"rules,omitempty"`
+	Invert   bool          `json:"invert,omitempty"`
 	Outbound string        `json:"outbound,omitempty"`
 	Outbound string        `json:"outbound,omitempty"`
 }
 }
 
 
@@ -155,5 +156,6 @@ func (r LogicalRule) IsValid() bool {
 func (r LogicalRule) Equals(other LogicalRule) bool {
 func (r LogicalRule) Equals(other LogicalRule) bool {
 	return r.Mode == other.Mode &&
 	return r.Mode == other.Mode &&
 		common.SliceEquals(r.Rules, other.Rules) &&
 		common.SliceEquals(r.Rules, other.Rules) &&
+		r.Invert == other.Invert &&
 		r.Outbound == other.Outbound
 		r.Outbound == other.Outbound
 }
 }

+ 15 - 9
route/router.go

@@ -59,7 +59,7 @@ type Router struct {
 	geositeCache                       map[string]adapter.Rule
 	geositeCache                       map[string]adapter.Rule
 	dnsClient                          *dns.Client
 	dnsClient                          *dns.Client
 	defaultDomainStrategy              dns.DomainStrategy
 	defaultDomainStrategy              dns.DomainStrategy
-	dnsRules                           []adapter.Rule
+	dnsRules                           []adapter.DNSRule
 	defaultTransport                   dns.Transport
 	defaultTransport                   dns.Transport
 	transports                         []dns.Transport
 	transports                         []dns.Transport
 	transportMap                       map[string]dns.Transport
 	transportMap                       map[string]dns.Transport
@@ -80,7 +80,7 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont
 		dnsLogger:             dnsLogger,
 		dnsLogger:             dnsLogger,
 		outboundByTag:         make(map[string]adapter.Outbound),
 		outboundByTag:         make(map[string]adapter.Outbound),
 		rules:                 make([]adapter.Rule, 0, len(options.Rules)),
 		rules:                 make([]adapter.Rule, 0, len(options.Rules)),
-		dnsRules:              make([]adapter.Rule, 0, len(dnsOptions.Rules)),
+		dnsRules:              make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
 		needGeoIPDatabase:     hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
 		needGeoIPDatabase:     hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
 		needGeositeDatabase:   hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule),
 		needGeositeDatabase:   hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule),
 		geoIPOptions:          common.PtrValueOrDefault(options.GeoIP),
 		geoIPOptions:          common.PtrValueOrDefault(options.GeoIP),
@@ -536,15 +536,18 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 }
 }
 
 
 func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
 func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
-	return r.dnsClient.Exchange(ctx, r.matchDNS(ctx), message)
+	ctx, transport := r.matchDNS(ctx)
+	return r.dnsClient.Exchange(ctx, transport, message)
 }
 }
 
 
 func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
 func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
-	return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, strategy)
+	ctx, transport := r.matchDNS(ctx)
+	return r.dnsClient.Lookup(ctx, transport, domain, strategy)
 }
 }
 
 
 func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
 func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
-	return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, r.defaultDomainStrategy)
+	ctx, transport := r.matchDNS(ctx)
+	return r.dnsClient.Lookup(ctx, transport, domain, r.defaultDomainStrategy)
 }
 }
 
 
 func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
 func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
@@ -586,23 +589,26 @@ func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, de
 	return nil, defaultOutbound
 	return nil, defaultOutbound
 }
 }
 
 
-func (r *Router) matchDNS(ctx context.Context) dns.Transport {
+func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport) {
 	metadata := adapter.ContextFrom(ctx)
 	metadata := adapter.ContextFrom(ctx)
 	if metadata == nil {
 	if metadata == nil {
 		r.dnsLogger.WarnContext(ctx, "no context: ", reflect.TypeOf(ctx))
 		r.dnsLogger.WarnContext(ctx, "no context: ", reflect.TypeOf(ctx))
-		return r.defaultTransport
+		return ctx, r.defaultTransport
 	}
 	}
 	for i, rule := range r.dnsRules {
 	for i, rule := range r.dnsRules {
 		if rule.Match(metadata) {
 		if rule.Match(metadata) {
+			if rule.DisableCache() {
+				ctx = dns.ContextWithDisableCache(ctx, true)
+			}
 			detour := rule.Outbound()
 			detour := rule.Outbound()
 			r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
 			r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
 			if transport, loaded := r.transportMap[detour]; loaded {
 			if transport, loaded := r.transportMap[detour]; loaded {
-				return transport
+				return ctx, transport
 			}
 			}
 			r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour)
 			r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour)
 		}
 		}
 	}
 	}
-	return r.defaultTransport
+	return ctx, r.defaultTransport
 }
 }
 
 
 func (r *Router) InterfaceBindManager() control.BindManager {
 func (r *Router) InterfaceBindManager() control.BindManager {

+ 36 - 30
route/rule.go

@@ -49,10 +49,6 @@ type DefaultRule struct {
 	outbound                string
 	outbound                string
 }
 }
 
 
-func (r *DefaultRule) Type() string {
-	return C.RuleTypeDefault
-}
-
 type RuleItem interface {
 type RuleItem interface {
 	Match(metadata *adapter.InboundContext) bool
 	Match(metadata *adapter.InboundContext) bool
 	String() string
 	String() string
@@ -180,6 +176,10 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
 	return rule, nil
 	return rule, nil
 }
 }
 
 
+func (r *DefaultRule) Type() string {
+	return C.RuleTypeDefault
+}
+
 func (r *DefaultRule) Start() error {
 func (r *DefaultRule) Start() error {
 	for _, item := range r.allItems {
 	for _, item := range r.allItems {
 		err := common.Start(item)
 		err := common.Start(item)
@@ -261,9 +261,34 @@ var _ adapter.Rule = (*LogicalRule)(nil)
 type LogicalRule struct {
 type LogicalRule struct {
 	mode     string
 	mode     string
 	rules    []*DefaultRule
 	rules    []*DefaultRule
+	invert   bool
 	outbound string
 	outbound string
 }
 }
 
 
+func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
+	r := &LogicalRule{
+		rules:    make([]*DefaultRule, len(options.Rules)),
+		invert:   options.Invert,
+		outbound: options.Outbound,
+	}
+	switch options.Mode {
+	case C.LogicalTypeAnd:
+		r.mode = C.LogicalTypeAnd
+	case C.LogicalTypeOr:
+		r.mode = C.LogicalTypeOr
+	default:
+		return nil, E.New("unknown logical mode: ", options.Mode)
+	}
+	for i, subRule := range options.Rules {
+		rule, err := NewDefaultRule(router, logger, subRule)
+		if err != nil {
+			return nil, E.Cause(err, "sub rule[", i, "]")
+		}
+		r.rules[i] = rule
+	}
+	return r, nil
+}
+
 func (r *LogicalRule) Type() string {
 func (r *LogicalRule) Type() string {
 	return C.RuleTypeLogical
 	return C.RuleTypeLogical
 }
 }
@@ -298,38 +323,15 @@ func (r *LogicalRule) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
-	r := &LogicalRule{
-		rules:    make([]*DefaultRule, len(options.Rules)),
-		outbound: options.Outbound,
-	}
-	switch options.Mode {
-	case C.LogicalTypeAnd:
-		r.mode = C.LogicalTypeAnd
-	case C.LogicalTypeOr:
-		r.mode = C.LogicalTypeOr
-	default:
-		return nil, E.New("unknown logical mode: ", options.Mode)
-	}
-	for i, subRule := range options.Rules {
-		rule, err := NewDefaultRule(router, logger, subRule)
-		if err != nil {
-			return nil, E.Cause(err, "sub rule[", i, "]")
-		}
-		r.rules[i] = rule
-	}
-	return r, nil
-}
-
 func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool {
 func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool {
 	if r.mode == C.LogicalTypeAnd {
 	if r.mode == C.LogicalTypeAnd {
 		return common.All(r.rules, func(it *DefaultRule) bool {
 		return common.All(r.rules, func(it *DefaultRule) bool {
 			return it.Match(metadata)
 			return it.Match(metadata)
-		})
+		}) != r.invert
 	} else {
 	} else {
 		return common.Any(r.rules, func(it *DefaultRule) bool {
 		return common.Any(r.rules, func(it *DefaultRule) bool {
 			return it.Match(metadata)
 			return it.Match(metadata)
-		})
+		}) != r.invert
 	}
 	}
 }
 }
 
 
@@ -345,5 +347,9 @@ func (r *LogicalRule) String() string {
 	case C.LogicalTypeOr:
 	case C.LogicalTypeOr:
 		op = "||"
 		op = "||"
 	}
 	}
-	return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
+	if !r.invert {
+		return strings.Join(F.MapToString(r.rules), " "+op+" ")
+	} else {
+		return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
+	}
 }
 }

+ 56 - 38
route/rule_dns.go

@@ -12,7 +12,7 @@ import (
 	F "github.com/sagernet/sing/common/format"
 	F "github.com/sagernet/sing/common/format"
 )
 )
 
 
-func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.Rule, error) {
+func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) {
 	if common.IsEmptyByEquals(options) {
 	if common.IsEmptyByEquals(options) {
 		return nil, E.New("empty rule config")
 		return nil, E.New("empty rule config")
 	}
 	}
@@ -38,7 +38,7 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.
 	}
 	}
 }
 }
 
 
-var _ adapter.Rule = (*DefaultDNSRule)(nil)
+var _ adapter.DNSRule = (*DefaultDNSRule)(nil)
 
 
 type DefaultDNSRule struct {
 type DefaultDNSRule struct {
 	items        []RuleItem
 	items        []RuleItem
@@ -46,16 +46,14 @@ type DefaultDNSRule struct {
 	allItems     []RuleItem
 	allItems     []RuleItem
 	invert       bool
 	invert       bool
 	outbound     string
 	outbound     string
-}
-
-func (r *DefaultDNSRule) Type() string {
-	return C.RuleTypeDefault
+	disableCache bool
 }
 }
 
 
 func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
 func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
 	rule := &DefaultDNSRule{
 	rule := &DefaultDNSRule{
-		invert:   true,
-		outbound: options.Server,
+		invert:       options.Invert,
+		outbound:     options.Server,
+		disableCache: options.DisableCache,
 	}
 	}
 	if len(options.Inbound) > 0 {
 	if len(options.Inbound) > 0 {
 		item := NewInboundRule(options.Inbound)
 		item := NewInboundRule(options.Inbound)
@@ -156,6 +154,10 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
 	return rule, nil
 	return rule, nil
 }
 }
 
 
+func (r *DefaultDNSRule) Type() string {
+	return C.RuleTypeDefault
+}
+
 func (r *DefaultDNSRule) Start() error {
 func (r *DefaultDNSRule) Start() error {
 	for _, item := range r.allItems {
 	for _, item := range r.allItems {
 		err := common.Start(item)
 		err := common.Start(item)
@@ -213,16 +215,47 @@ func (r *DefaultDNSRule) Outbound() string {
 	return r.outbound
 	return r.outbound
 }
 }
 
 
+func (r *DefaultDNSRule) DisableCache() bool {
+	return r.disableCache
+}
+
 func (r *DefaultDNSRule) String() string {
 func (r *DefaultDNSRule) String() string {
 	return strings.Join(F.MapToString(r.allItems), " ")
 	return strings.Join(F.MapToString(r.allItems), " ")
 }
 }
 
 
-var _ adapter.Rule = (*LogicalRule)(nil)
+var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
 
 
 type LogicalDNSRule struct {
 type LogicalDNSRule struct {
-	mode     string
-	rules    []*DefaultDNSRule
-	outbound string
+	mode         string
+	rules        []*DefaultDNSRule
+	invert       bool
+	outbound     string
+	disableCache bool
+}
+
+func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
+	r := &LogicalDNSRule{
+		rules:        make([]*DefaultDNSRule, len(options.Rules)),
+		invert:       options.Invert,
+		outbound:     options.Server,
+		disableCache: options.DisableCache,
+	}
+	switch options.Mode {
+	case C.LogicalTypeAnd:
+		r.mode = C.LogicalTypeAnd
+	case C.LogicalTypeOr:
+		r.mode = C.LogicalTypeOr
+	default:
+		return nil, E.New("unknown logical mode: ", options.Mode)
+	}
+	for i, subRule := range options.Rules {
+		rule, err := NewDefaultDNSRule(router, logger, subRule)
+		if err != nil {
+			return nil, E.Cause(err, "sub rule[", i, "]")
+		}
+		r.rules[i] = rule
+	}
+	return r, nil
 }
 }
 
 
 func (r *LogicalDNSRule) Type() string {
 func (r *LogicalDNSRule) Type() string {
@@ -259,38 +292,15 @@ func (r *LogicalDNSRule) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
-	r := &LogicalDNSRule{
-		rules:    make([]*DefaultDNSRule, len(options.Rules)),
-		outbound: options.Server,
-	}
-	switch options.Mode {
-	case C.LogicalTypeAnd:
-		r.mode = C.LogicalTypeAnd
-	case C.LogicalTypeOr:
-		r.mode = C.LogicalTypeOr
-	default:
-		return nil, E.New("unknown logical mode: ", options.Mode)
-	}
-	for i, subRule := range options.Rules {
-		rule, err := NewDefaultDNSRule(router, logger, subRule)
-		if err != nil {
-			return nil, E.Cause(err, "sub rule[", i, "]")
-		}
-		r.rules[i] = rule
-	}
-	return r, nil
-}
-
 func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
 func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
 	if r.mode == C.LogicalTypeAnd {
 	if r.mode == C.LogicalTypeAnd {
 		return common.All(r.rules, func(it *DefaultDNSRule) bool {
 		return common.All(r.rules, func(it *DefaultDNSRule) bool {
 			return it.Match(metadata)
 			return it.Match(metadata)
-		})
+		}) != r.invert
 	} else {
 	} else {
 		return common.Any(r.rules, func(it *DefaultDNSRule) bool {
 		return common.Any(r.rules, func(it *DefaultDNSRule) bool {
 			return it.Match(metadata)
 			return it.Match(metadata)
-		})
+		}) != r.invert
 	}
 	}
 }
 }
 
 
@@ -298,6 +308,10 @@ func (r *LogicalDNSRule) Outbound() string {
 	return r.outbound
 	return r.outbound
 }
 }
 
 
+func (r *LogicalDNSRule) DisableCache() bool {
+	return r.disableCache
+}
+
 func (r *LogicalDNSRule) String() string {
 func (r *LogicalDNSRule) String() string {
 	var op string
 	var op string
 	switch r.mode {
 	switch r.mode {
@@ -306,5 +320,9 @@ func (r *LogicalDNSRule) String() string {
 	case C.LogicalTypeOr:
 	case C.LogicalTypeOr:
 		op = "||"
 		op = "||"
 	}
 	}
-	return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
+	if !r.invert {
+		return strings.Join(F.MapToString(r.rules), " "+op+" ")
+	} else {
+		return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
+	}
 }
 }