Ver Fonte

Allow nested logical rules

世界 há 1 ano atrás
pai
commit
5948ffb965

+ 28 - 14
experimental/clashapi.go

@@ -5,6 +5,7 @@ import (
 	"os"
 
 	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing/common"
@@ -27,24 +28,37 @@ func NewClashServer(ctx context.Context, router adapter.Router, logFactory log.O
 
 func CalculateClashModeList(options option.Options) []string {
 	var clashMode []string
-	for _, dnsRule := range common.PtrValueOrDefault(options.DNS).Rules {
-		if dnsRule.DefaultOptions.ClashMode != "" && !common.Contains(clashMode, dnsRule.DefaultOptions.ClashMode) {
-			clashMode = append(clashMode, dnsRule.DefaultOptions.ClashMode)
-		}
-		for _, defaultRule := range dnsRule.LogicalOptions.Rules {
-			if defaultRule.ClashMode != "" && !common.Contains(clashMode, defaultRule.ClashMode) {
-				clashMode = append(clashMode, defaultRule.ClashMode)
+	clashMode = append(clashMode, extraClashModeFromRule(common.PtrValueOrDefault(options.Route).Rules)...)
+	clashMode = append(clashMode, extraClashModeFromDNSRule(common.PtrValueOrDefault(options.DNS).Rules)...)
+	clashMode = common.FilterNotDefault(common.Uniq(clashMode))
+	return clashMode
+}
+
+func extraClashModeFromRule(rules []option.Rule) []string {
+	var clashMode []string
+	for _, rule := range rules {
+		switch rule.Type {
+		case C.RuleTypeDefault:
+			if rule.DefaultOptions.ClashMode != "" {
+				clashMode = append(clashMode, rule.DefaultOptions.ClashMode)
 			}
+		case C.RuleTypeLogical:
+			clashMode = append(clashMode, extraClashModeFromRule(rule.LogicalOptions.Rules)...)
 		}
 	}
-	for _, rule := range common.PtrValueOrDefault(options.Route).Rules {
-		if rule.DefaultOptions.ClashMode != "" && !common.Contains(clashMode, rule.DefaultOptions.ClashMode) {
-			clashMode = append(clashMode, rule.DefaultOptions.ClashMode)
-		}
-		for _, defaultRule := range rule.LogicalOptions.Rules {
-			if defaultRule.ClashMode != "" && !common.Contains(clashMode, defaultRule.ClashMode) {
-				clashMode = append(clashMode, defaultRule.ClashMode)
+	return clashMode
+}
+
+func extraClashModeFromDNSRule(rules []option.DNSRule) []string {
+	var clashMode []string
+	for _, rule := range rules {
+		switch rule.Type {
+		case C.RuleTypeDefault:
+			if rule.DefaultOptions.ClashMode != "" {
+				clashMode = append(clashMode, rule.DefaultOptions.ClashMode)
 			}
+		case C.RuleTypeLogical:
+			clashMode = append(clashMode, extraClashModeFromDNSRule(rule.LogicalOptions.Rules)...)
 		}
 	}
 	return clashMode

+ 16 - 5
option/rule.go

@@ -53,6 +53,17 @@ func (r *Rule) UnmarshalJSON(bytes []byte) error {
 	return nil
 }
 
+func (r Rule) IsValid() bool {
+	switch r.Type {
+	case C.RuleTypeDefault:
+		return r.DefaultOptions.IsValid()
+	case C.RuleTypeLogical:
+		return r.LogicalOptions.IsValid()
+	default:
+		panic("unknown rule type: " + r.Type)
+	}
+}
+
 type DefaultRule struct {
 	Inbound         Listable[string] `json:"inbound,omitempty"`
 	IPVersion       int              `json:"ip_version,omitempty"`
@@ -92,12 +103,12 @@ func (r DefaultRule) IsValid() bool {
 }
 
 type LogicalRule struct {
-	Mode     string        `json:"mode"`
-	Rules    []DefaultRule `json:"rules,omitempty"`
-	Invert   bool          `json:"invert,omitempty"`
-	Outbound string        `json:"outbound,omitempty"`
+	Mode     string `json:"mode"`
+	Rules    []Rule `json:"rules,omitempty"`
+	Invert   bool   `json:"invert,omitempty"`
+	Outbound string `json:"outbound,omitempty"`
 }
 
 func (r LogicalRule) IsValid() bool {
-	return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid)
+	return len(r.Rules) > 0 && common.All(r.Rules, Rule.IsValid)
 }

+ 18 - 7
option/rule_dns.go

@@ -53,6 +53,17 @@ func (r *DNSRule) UnmarshalJSON(bytes []byte) error {
 	return nil
 }
 
+func (r DNSRule) IsValid() bool {
+	switch r.Type {
+	case C.RuleTypeDefault:
+		return r.DefaultOptions.IsValid()
+	case C.RuleTypeLogical:
+		return r.LogicalOptions.IsValid()
+	default:
+		panic("unknown DNS rule type: " + r.Type)
+	}
+}
+
 type DefaultDNSRule struct {
 	Inbound         Listable[string]       `json:"inbound,omitempty"`
 	IPVersion       int                    `json:"ip_version,omitempty"`
@@ -96,14 +107,14 @@ func (r DefaultDNSRule) IsValid() bool {
 }
 
 type LogicalDNSRule struct {
-	Mode         string           `json:"mode"`
-	Rules        []DefaultDNSRule `json:"rules,omitempty"`
-	Invert       bool             `json:"invert,omitempty"`
-	Server       string           `json:"server,omitempty"`
-	DisableCache bool             `json:"disable_cache,omitempty"`
-	RewriteTTL   *uint32          `json:"rewrite_ttl,omitempty"`
+	Mode         string    `json:"mode"`
+	Rules        []DNSRule `json:"rules,omitempty"`
+	Invert       bool      `json:"invert,omitempty"`
+	Server       string    `json:"server,omitempty"`
+	DisableCache bool      `json:"disable_cache,omitempty"`
+	RewriteTTL   *uint32   `json:"rewrite_ttl,omitempty"`
 }
 
 func (r LogicalDNSRule) IsValid() bool {
-	return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid)
+	return len(r.Rules) > 0 && common.All(r.Rules, DNSRule.IsValid)
 }

+ 2 - 2
route/router.go

@@ -128,14 +128,14 @@ func NewRouter(
 		Logger:           router.dnsLogger,
 	})
 	for i, ruleOptions := range options.Rules {
-		routeRule, err := NewRule(router, router.logger, ruleOptions)
+		routeRule, err := NewRule(router, router.logger, ruleOptions, true)
 		if err != nil {
 			return nil, E.Cause(err, "parse rule[", i, "]")
 		}
 		router.rules = append(router.rules, routeRule)
 	}
 	for i, dnsRuleOptions := range dnsOptions.Rules {
-		dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions)
+		dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions, true)
 		if err != nil {
 			return nil, E.Cause(err, "parse dns rule[", i, "]")
 		}

+ 4 - 8
route/router_geo_resources.go

@@ -252,10 +252,8 @@ func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool
 				return true
 			}
 		case C.RuleTypeLogical:
-			for _, subRule := range rule.LogicalOptions.Rules {
-				if cond(subRule) {
-					return true
-				}
+			if hasRule(rule.LogicalOptions.Rules, cond) {
+				return true
 			}
 		}
 	}
@@ -270,10 +268,8 @@ func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bo
 				return true
 			}
 		case C.RuleTypeLogical:
-			for _, subRule := range rule.LogicalOptions.Rules {
-				if cond(subRule) {
-					return true
-				}
+			if hasDNSRule(rule.LogicalOptions.Rules, cond) {
+				return true
 			}
 		}
 	}

+ 4 - 4
route/rule_default.go

@@ -8,13 +8,13 @@ import (
 	E "github.com/sagernet/sing/common/exceptions"
 )
 
-func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) {
+func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule, checkOutbound bool) (adapter.Rule, error) {
 	switch options.Type {
 	case "", C.RuleTypeDefault:
 		if !options.DefaultOptions.IsValid() {
 			return nil, E.New("missing conditions")
 		}
-		if options.DefaultOptions.Outbound == "" {
+		if options.DefaultOptions.Outbound == "" && checkOutbound {
 			return nil, E.New("missing outbound field")
 		}
 		return NewDefaultRule(router, logger, options.DefaultOptions)
@@ -22,7 +22,7 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul
 		if !options.LogicalOptions.IsValid() {
 			return nil, E.New("missing conditions")
 		}
-		if options.LogicalOptions.Outbound == "" {
+		if options.LogicalOptions.Outbound == "" && checkOutbound {
 			return nil, E.New("missing outbound field")
 		}
 		return NewLogicalRule(router, logger, options.LogicalOptions)
@@ -220,7 +220,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
 		return nil, E.New("unknown logical mode: ", options.Mode)
 	}
 	for i, subRule := range options.Rules {
-		rule, err := NewDefaultRule(router, logger, subRule)
+		rule, err := NewRule(router, logger, subRule, false)
 		if err != nil {
 			return nil, E.Cause(err, "sub rule[", i, "]")
 		}

+ 4 - 4
route/rule_dns.go

@@ -8,13 +8,13 @@ import (
 	E "github.com/sagernet/sing/common/exceptions"
 )
 
-func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) {
+func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule, checkServer bool) (adapter.DNSRule, error) {
 	switch options.Type {
 	case "", C.RuleTypeDefault:
 		if !options.DefaultOptions.IsValid() {
 			return nil, E.New("missing conditions")
 		}
-		if options.DefaultOptions.Server == "" {
+		if options.DefaultOptions.Server == "" && checkServer {
 			return nil, E.New("missing server field")
 		}
 		return NewDefaultDNSRule(router, logger, options.DefaultOptions)
@@ -22,7 +22,7 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.
 		if !options.LogicalOptions.IsValid() {
 			return nil, E.New("missing conditions")
 		}
-		if options.LogicalOptions.Server == "" {
+		if options.LogicalOptions.Server == "" && checkServer {
 			return nil, E.New("missing server field")
 		}
 		return NewLogicalDNSRule(router, logger, options.LogicalOptions)
@@ -228,7 +228,7 @@ func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options
 		return nil, E.New("unknown logical mode: ", options.Mode)
 	}
 	for i, subRule := range options.Rules {
-		rule, err := NewDefaultDNSRule(router, logger, subRule)
+		rule, err := NewDNSRule(router, logger, subRule, false)
 		if err != nil {
 			return nil, E.Cause(err, "sub rule[", i, "]")
 		}