Browse Source

Fix port rule match logic

世界 3 years ago
parent
commit
d9aa0a67d6

+ 3 - 1
docs/configuration/dns/rule.md

@@ -104,8 +104,10 @@
 
     The default rule uses the following matching logic:  
     (`domain` || `domain_suffix` || `domain_keyword` || `domain_regex` || `geosite`) &&  
+    (`port` || `port_range`) &&  
     (`source_geoip` || `source_ip_cidr`) &&  
-    `other fields`  
+    (`source_port` || `source_port_range`) &&  
+    `other fields`
 
 #### inbound
 

+ 3 - 1
docs/configuration/dns/rule.zh.md

@@ -103,8 +103,10 @@
 
     默认规则使用以下匹配逻辑:  
     (`domain` || `domain_suffix` || `domain_keyword` || `domain_regex` || `geosite`) &&  
+    (`port` || `port_range`) &&  
     (`source_geoip` || `source_ip_cidr`) &&  
-    `other fields`  
+    (`source_port` || `source_port_range`) &&  
+    `other fields`
 
 #### inbound
 

+ 3 - 1
docs/configuration/route/rule.md

@@ -107,8 +107,10 @@
 
     The default rule uses the following matching logic:  
     (`domain` || `domain_suffix` || `domain_keyword` || `domain_regex` || `geosite` || `geoip` || `ip_cidr`) &&  
+    (`port` || `port_range`) &&  
     (`source_geoip` || `source_ip_cidr`) &&  
-    `other fields`  
+    (`source_port` || `source_port_range`) &&  
+    `other fields`
 
 #### inbound
 

+ 3 - 1
docs/configuration/route/rule.zh.md

@@ -105,8 +105,10 @@
 
     默认规则使用以下匹配逻辑:  
     (`domain` || `domain_suffix` || `domain_keyword` || `domain_regex` || `geosite` || `geoip` || `ip_cidr`) &&  
+    (`port` || `port_range`) &&  
     (`source_geoip` || `source_ip_cidr`) &&  
-    `other fields`  
+    (`source_port` || `source_port_range`) &&  
+    `other fields`
 
 #### inbound
 

+ 32 - 4
route/rule.go

@@ -41,7 +41,9 @@ var _ adapter.Rule = (*DefaultRule)(nil)
 type DefaultRule struct {
 	items                   []RuleItem
 	sourceAddressItems      []RuleItem
+	sourcePortItems         []RuleItem
 	destinationAddressItems []RuleItem
+	destinationPortItems    []RuleItem
 	allItems                []RuleItem
 	invert                  bool
 	outbound                string
@@ -143,7 +145,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
 	}
 	if len(options.SourcePort) > 0 {
 		item := NewPortItem(true, options.SourcePort)
-		rule.items = append(rule.items, item)
+		rule.sourcePortItems = append(rule.sourcePortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.SourcePortRange) > 0 {
@@ -151,12 +153,12 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
 		if err != nil {
 			return nil, E.Cause(err, "source_port_range")
 		}
-		rule.items = append(rule.items, item)
+		rule.sourcePortItems = append(rule.sourcePortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.Port) > 0 {
 		item := NewPortItem(false, options.Port)
-		rule.items = append(rule.items, item)
+		rule.destinationPortItems = append(rule.destinationPortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.PortRange) > 0 {
@@ -164,7 +166,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
 		if err != nil {
 			return nil, E.Cause(err, "port_range")
 		}
-		rule.items = append(rule.items, item)
+		rule.destinationPortItems = append(rule.destinationPortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.ProcessName) > 0 {
@@ -256,6 +258,19 @@ func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool {
 		}
 	}
 
+	if len(r.sourcePortItems) > 0 {
+		var sourcePortMatch bool
+		for _, item := range r.sourcePortItems {
+			if item.Match(metadata) {
+				sourcePortMatch = true
+				break
+			}
+		}
+		if !sourcePortMatch {
+			return r.invert
+		}
+	}
+
 	if len(r.destinationAddressItems) > 0 {
 		var destinationAddressMatch bool
 		for _, item := range r.destinationAddressItems {
@@ -269,6 +284,19 @@ func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool {
 		}
 	}
 
+	if len(r.destinationPortItems) > 0 {
+		var destinationPortMatch bool
+		for _, item := range r.destinationPortItems {
+			if item.Match(metadata) {
+				destinationPortMatch = true
+				break
+			}
+		}
+		if !destinationPortMatch {
+			return r.invert
+		}
+	}
+
 	return !r.invert
 }
 

+ 65 - 21
route/rule_dns.go

@@ -39,12 +39,15 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.
 var _ adapter.DNSRule = (*DefaultDNSRule)(nil)
 
 type DefaultDNSRule struct {
-	items        []RuleItem
-	addressItems []RuleItem
-	allItems     []RuleItem
-	invert       bool
-	outbound     string
-	disableCache bool
+	items                   []RuleItem
+	sourceAddressItems      []RuleItem
+	sourcePortItems         []RuleItem
+	destinationAddressItems []RuleItem
+	destinationPortItems    []RuleItem
+	allItems                []RuleItem
+	invert                  bool
+	outbound                string
+	disableCache            bool
 }
 
 func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
@@ -90,12 +93,12 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
 	}
 	if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 {
 		item := NewDomainItem(options.Domain, options.DomainSuffix)
-		rule.addressItems = append(rule.addressItems, item)
+		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.DomainKeyword) > 0 {
 		item := NewDomainKeywordItem(options.DomainKeyword)
-		rule.addressItems = append(rule.addressItems, item)
+		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.DomainRegex) > 0 {
@@ -103,17 +106,17 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
 		if err != nil {
 			return nil, E.Cause(err, "domain_regex")
 		}
-		rule.addressItems = append(rule.addressItems, item)
+		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.Geosite) > 0 {
 		item := NewGeositeItem(router, logger, options.Geosite)
-		rule.addressItems = append(rule.addressItems, item)
+		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.SourceGeoIP) > 0 {
 		item := NewGeoIPItem(router, logger, true, options.SourceGeoIP)
-		rule.items = append(rule.items, item)
+		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.SourceIPCIDR) > 0 {
@@ -121,12 +124,12 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
 		if err != nil {
 			return nil, E.Cause(err, "source_ipcidr")
 		}
-		rule.items = append(rule.items, item)
+		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.SourcePort) > 0 {
 		item := NewPortItem(true, options.SourcePort)
-		rule.items = append(rule.items, item)
+		rule.sourcePortItems = append(rule.sourcePortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.SourcePortRange) > 0 {
@@ -134,12 +137,12 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
 		if err != nil {
 			return nil, E.Cause(err, "source_port_range")
 		}
-		rule.items = append(rule.items, item)
+		rule.sourcePortItems = append(rule.sourcePortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.Port) > 0 {
 		item := NewPortItem(false, options.Port)
-		rule.items = append(rule.items, item)
+		rule.destinationPortItems = append(rule.destinationPortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.PortRange) > 0 {
@@ -147,7 +150,7 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
 		if err != nil {
 			return nil, E.Cause(err, "port_range")
 		}
-		rule.items = append(rule.items, item)
+		rule.destinationPortItems = append(rule.destinationPortItems, item)
 		rule.allItems = append(rule.allItems, item)
 	}
 	if len(options.ProcessName) > 0 {
@@ -230,18 +233,59 @@ func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool {
 			return r.invert
 		}
 	}
-	if len(r.addressItems) > 0 {
-		var addressMatch bool
-		for _, item := range r.addressItems {
+
+	if len(r.sourceAddressItems) > 0 {
+		var sourceAddressMatch bool
+		for _, item := range r.sourceAddressItems {
+			if item.Match(metadata) {
+				sourceAddressMatch = true
+				break
+			}
+		}
+		if !sourceAddressMatch {
+			return r.invert
+		}
+	}
+
+	if len(r.sourcePortItems) > 0 {
+		var sourcePortMatch bool
+		for _, item := range r.sourcePortItems {
 			if item.Match(metadata) {
-				addressMatch = true
+				sourcePortMatch = true
 				break
 			}
 		}
-		if !addressMatch {
+		if !sourcePortMatch {
 			return r.invert
 		}
 	}
+
+	if len(r.destinationAddressItems) > 0 {
+		var destinationAddressMatch bool
+		for _, item := range r.destinationAddressItems {
+			if item.Match(metadata) {
+				destinationAddressMatch = true
+				break
+			}
+		}
+		if !destinationAddressMatch {
+			return r.invert
+		}
+	}
+
+	if len(r.destinationPortItems) > 0 {
+		var destinationPortMatch bool
+		for _, item := range r.destinationPortItems {
+			if item.Match(metadata) {
+				destinationPortMatch = true
+				break
+			}
+		}
+		if !destinationPortMatch {
+			return r.invert
+		}
+	}
+
 	return !r.invert
 }