浏览代码

perf(GeoIPMatcher): faster heuristic matching with reduced memory usage (#5289)

Meow 2 月之前
父节点
当前提交
fcfb0a302a
共有 6 个文件被更改,包括 985 次插入192 次删除
  1. 18 20
      app/dns/nameserver.go
  2. 30 39
      app/router/condition.go
  3. 900 83
      app/router/condition_geoip.go
  4. 29 42
      app/router/condition_geoip_test.go
  5. 1 1
      app/router/condition_test.go
  6. 7 7
      app/router/config.go

+ 18 - 20
app/dns/nameserver.go

@@ -29,8 +29,8 @@ type Client struct {
 	server        Server
 	skipFallback  bool
 	domains       []string
-	expectedIPs   []*router.GeoIPMatcher
-	unexpectedIPs []*router.GeoIPMatcher
+	expectedIPs   router.GeoIPMatcher
+	unexpectedIPs router.GeoIPMatcher
 	actPrior      bool
 	actUnprior    bool
 	tag           string
@@ -154,23 +154,21 @@ func NewClient(
 		}
 
 		// Establish expected IPs
-		var expectedMatchers []*router.GeoIPMatcher
-		for _, geoip := range ns.ExpectedGeoip {
-			matcher, err := router.GlobalGeoIPContainer.Add(geoip)
+		var expectedMatcher router.GeoIPMatcher
+		if len(ns.ExpectedGeoip) > 0 {
+			expectedMatcher, err = router.BuildOptimizedGeoIPMatcher(ns.ExpectedGeoip...)
 			if err != nil {
 				return errors.New("failed to create expected ip matcher").Base(err).AtWarning()
 			}
-			expectedMatchers = append(expectedMatchers, matcher)
 		}
 
 		// Establish unexpected IPs
-		var unexpectedMatchers []*router.GeoIPMatcher
-		for _, geoip := range ns.UnexpectedGeoip {
-			matcher, err := router.GlobalGeoIPContainer.Add(geoip)
+		var unexpectedMatcher router.GeoIPMatcher
+		if len(ns.UnexpectedGeoip) > 0 {
+			unexpectedMatcher, err = router.BuildOptimizedGeoIPMatcher(ns.UnexpectedGeoip...)
 			if err != nil {
 				return errors.New("failed to create unexpected ip matcher").Base(err).AtWarning()
 			}
-			unexpectedMatchers = append(unexpectedMatchers, matcher)
 		}
 
 		if len(clientIP) > 0 {
@@ -192,8 +190,8 @@ func NewClient(
 		client.server = server
 		client.skipFallback = ns.SkipFallback
 		client.domains = rules
-		client.expectedIPs = expectedMatchers
-		client.unexpectedIPs = unexpectedMatchers
+		client.expectedIPs = expectedMatcher
+		client.unexpectedIPs = unexpectedMatcher
 		client.actPrior = ns.ActPrior
 		client.actUnprior = ns.ActUnprior
 		client.tag = tag
@@ -243,32 +241,32 @@ func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption
 		return nil, 0, dns.ErrEmptyResponse
 	}
 
-	if len(c.expectedIPs) > 0 && !c.actPrior {
-		ips = router.MatchIPs(c.expectedIPs, ips, false)
+	if c.expectedIPs != nil && !c.actPrior {
+		ips, _ = c.expectedIPs.FilterIPs(ips)
 		errors.LogDebug(context.Background(), "domain ", domain, " expectedIPs ", ips, " matched at server ", c.Name())
 		if len(ips) == 0 {
 			return nil, 0, dns.ErrEmptyResponse
 		}
 	}
 
-	if len(c.unexpectedIPs) > 0 && !c.actUnprior {
-		ips = router.MatchIPs(c.unexpectedIPs, ips, true)
+	if c.unexpectedIPs != nil && !c.actUnprior {
+		_, ips = c.unexpectedIPs.FilterIPs(ips)
 		errors.LogDebug(context.Background(), "domain ", domain, " unexpectedIPs ", ips, " matched at server ", c.Name())
 		if len(ips) == 0 {
 			return nil, 0, dns.ErrEmptyResponse
 		}
 	}
 
-	if len(c.expectedIPs) > 0 && c.actPrior {
-		ipsNew := router.MatchIPs(c.expectedIPs, ips, false)
+	if c.expectedIPs != nil && c.actPrior {
+		ipsNew, _ := c.expectedIPs.FilterIPs(ips)
 		if len(ipsNew) > 0 {
 			ips = ipsNew
 			errors.LogDebug(context.Background(), "domain ", domain, " priorIPs ", ips, " matched at server ", c.Name())
 		}
 	}
 
-	if len(c.unexpectedIPs) > 0 && c.actUnprior {
-		ipsNew := router.MatchIPs(c.unexpectedIPs, ips, true)
+	if c.unexpectedIPs != nil && c.actUnprior {
+		_, ipsNew := c.unexpectedIPs.FilterIPs(ips)
 		if len(ipsNew) > 0 {
 			ips = ipsNew
 			errors.LogDebug(context.Background(), "domain ", domain, " unpriorIPs ", ips, " matched at server ", c.Name())

+ 30 - 39
app/router/condition.go

@@ -96,61 +96,53 @@ func (m *DomainMatcher) Apply(ctx routing.Context) bool {
 	return m.ApplyDomain(domain)
 }
 
-type MultiGeoIPMatcher struct {
-	matchers []*GeoIPMatcher
-	asType   string // local, source, target
-}
+type MatcherAsType byte
 
-func NewMultiGeoIPMatcher(geoips []*GeoIP, asType string) (*MultiGeoIPMatcher, error) {
-	var matchers []*GeoIPMatcher
-	for _, geoip := range geoips {
-		matcher, err := GlobalGeoIPContainer.Add(geoip)
-		if err != nil {
-			return nil, err
-		}
-		matchers = append(matchers, matcher)
-	}
+const (
+	MatcherAsType_Local MatcherAsType = iota
+	MatcherAsType_Source
+	MatcherAsType_Target
+	MatcherAsType_VlessRoute // for port
+)
 
-	matcher := &MultiGeoIPMatcher{
-		matchers: matchers,
-		asType:   asType,
-	}
+type IPMatcher struct {
+	matcher GeoIPMatcher
+	asType  MatcherAsType
+}
 
-	return matcher, nil
+func NewIPMatcher(geoips []*GeoIP, asType MatcherAsType) (*IPMatcher, error) {
+	matcher, err := BuildOptimizedGeoIPMatcher(geoips...)
+	if err != nil {
+		return nil, err
+	}
+	return &IPMatcher{matcher: matcher, asType: asType}, nil
 }
 
 // Apply implements Condition.
-func (m *MultiGeoIPMatcher) Apply(ctx routing.Context) bool {
+func (m *IPMatcher) Apply(ctx routing.Context) bool {
 	var ips []net.IP
 
 	switch m.asType {
-	case "local":
+	case MatcherAsType_Local:
 		ips = ctx.GetLocalIPs()
-	case "source":
+	case MatcherAsType_Source:
 		ips = ctx.GetSourceIPs()
-	case "target":
+	case MatcherAsType_Target:
 		ips = ctx.GetTargetIPs()
 	default:
-		panic("unreachable, asType should be local or source or target")
+		panic("unk asType")
 	}
 
-	for _, ip := range ips {
-		for _, matcher := range m.matchers {
-			if matcher.Match(ip) {
-				return true
-			}
-		}
-	}
-	return false
+	return m.matcher.AnyMatch(ips)
 }
 
 type PortMatcher struct {
 	port   net.MemoryPortList
-	asType string // local, source, target
+	asType MatcherAsType
 }
 
 // NewPortMatcher create a new port matcher that can match source or local or destination port
-func NewPortMatcher(list *net.PortList, asType string) *PortMatcher {
+func NewPortMatcher(list *net.PortList, asType MatcherAsType) *PortMatcher {
 	return &PortMatcher{
 		port:   net.PortListFromProto(list),
 		asType: asType,
@@ -160,18 +152,17 @@ func NewPortMatcher(list *net.PortList, asType string) *PortMatcher {
 // Apply implements Condition.
 func (v *PortMatcher) Apply(ctx routing.Context) bool {
 	switch v.asType {
-	case "local":
+	case MatcherAsType_Local:
 		return v.port.Contains(ctx.GetLocalPort())
-	case "source":
+	case MatcherAsType_Source:
 		return v.port.Contains(ctx.GetSourcePort())
-	case "target":
+	case MatcherAsType_Target:
 		return v.port.Contains(ctx.GetTargetPort())
-	case "vlessRoute":
+	case MatcherAsType_VlessRoute:
 		return v.port.Contains(ctx.GetVlessRoute())
 	default:
-		panic("unreachable, asType should be local or source or target")
+		panic("unk asType")
 	}
-
 }
 
 type NetworkMatcher struct {

+ 900 - 83
app/router/condition_geoip.go

@@ -1,144 +1,961 @@
 package router
 
 import (
+	"context"
 	"net/netip"
-	"strconv"
+	"sort"
+	"strings"
+	"sync"
 
+	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
+
 	"go4.org/netipx"
 )
 
-type GeoIPMatcher struct {
-	countryCode  string
-	reverseMatch bool
-	ip4          *netipx.IPSet
-	ip6          *netipx.IPSet
+type GeoIPMatcher interface {
+	// TODO: (PERF) all net.IP -> netipx.Addr
+
+	// Invalid IP always return false.
+	Match(ip net.IP) bool
+
+	// Returns true if *any* IP is valid and match.
+	AnyMatch(ips []net.IP) bool
+
+	// Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false.
+	Matches(ips []net.IP) bool
+
+	// Filters IPs. Invalid IPs are silently dropped and not included in either result.
+	FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP)
+
+	ToggleReverse()
+
+	SetReverse(reverse bool)
+}
+
+type GeoIPSet struct {
+	ipv4, ipv6 *netipx.IPSet
+	max4, max6 uint8
 }
 
-func (m *GeoIPMatcher) Init(cidrs []*CIDR) error {
-	var builder4, builder6 netipx.IPSetBuilder
+type HeuristicGeoIPMatcher struct {
+	ipset   *GeoIPSet
+	reverse bool
+}
+
+type ipBucket struct {
+	rep netip.Addr
+	ips []net.IP
+}
 
-	for _, cidr := range cidrs {
-		ip := net.IP(cidr.GetIp())
-		ipPrefixString := ip.String() + "/" + strconv.Itoa(int(cidr.GetPrefix()))
-		ipPrefix, err := netip.ParsePrefix(ipPrefixString)
-		if err != nil {
-			return err
+// Match implements GeoIPMatcher.
+func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool {
+	ipx, ok := netipx.FromStdIP(ip)
+	if !ok {
+		return false
+	}
+	return m.matchAddr(ipx)
+}
+
+func (m *HeuristicGeoIPMatcher) matchAddr(ipx netip.Addr) bool {
+	if ipx.Is4() {
+		return m.ipset.ipv4.Contains(ipx) != m.reverse
+	}
+	if ipx.Is6() {
+		return m.ipset.ipv6.Contains(ipx) != m.reverse
+	}
+	return false
+}
+
+// AnyMatch implements GeoIPMatcher.
+func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool {
+	n := len(ips)
+	if n == 0 {
+		return false
+	}
+
+	if n == 1 {
+		return m.Match(ips[0])
+	}
+
+	heur4 := m.ipset.max4 <= 24
+	heur6 := m.ipset.max6 <= 64
+	if !heur4 && !heur6 {
+		for _, ip := range ips {
+			if ipx, ok := netipx.FromStdIP(ip); ok {
+				if m.matchAddr(ipx) {
+					return true
+				}
+			}
 		}
+		return false
+	}
 
-		switch len(ip) {
-		case net.IPv4len:
-			builder4.AddPrefix(ipPrefix)
-		case net.IPv6len:
-			builder6.AddPrefix(ipPrefix)
+	buckets := make(map[[9]byte]struct{}, n)
+	for _, ip := range ips {
+		key, ok := prefixKeyFromIP(ip)
+		if !ok {
+			continue
+		}
+		heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6)
+		if heur {
+			if _, exists := buckets[key]; exists {
+				continue
+			}
+		}
+		ipx, ok := netipx.FromStdIP(ip)
+		if !ok {
+			continue
+		}
+		if m.matchAddr(ipx) {
+			return true
+		}
+		if heur {
+			buckets[key] = struct{}{}
 		}
 	}
+	return false
+}
 
-	if ip4, err := builder4.IPSet(); err != nil {
-		return err
-	} else {
-		m.ip4 = ip4
+// Matches implements GeoIPMatcher.
+func (m *HeuristicGeoIPMatcher) Matches(ips []net.IP) bool {
+	n := len(ips)
+	if n == 0 {
+		return false
 	}
 
-	if ip6, err := builder6.IPSet(); err != nil {
-		return err
-	} else {
-		m.ip6 = ip6
+	if n == 1 {
+		return m.Match(ips[0])
+	}
+
+	heur4 := m.ipset.max4 <= 24
+	heur6 := m.ipset.max6 <= 64
+	if !heur4 && !heur6 {
+		for _, ip := range ips {
+			ipx, ok := netipx.FromStdIP(ip)
+			if !ok {
+				return false
+			}
+			if !m.matchAddr(ipx) {
+				return false
+			}
+		}
+		return true
+	}
+
+	buckets := make(map[[9]byte]netip.Addr, n)
+	precise := make([]netip.Addr, 0, n)
+
+	for _, ip := range ips {
+		key, ok := prefixKeyFromIP(ip)
+		if !ok {
+			return false
+		}
+
+		if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) {
+			if _, exists := buckets[key]; !exists {
+				ipx, ok := netipx.FromStdIP(ip)
+				if !ok {
+					return false
+				}
+				buckets[key] = ipx
+			}
+		} else {
+			ipx, ok := netipx.FromStdIP(ip)
+			if !ok {
+				return false
+			}
+			precise = append(precise, ipx)
+		}
 	}
 
-	return nil
+	for _, ipx := range buckets {
+		if !m.matchAddr(ipx) {
+			return false
+		}
+	}
+	for _, ipx := range precise {
+		if !m.matchAddr(ipx) {
+			return false
+		}
+	}
+	return true
 }
 
-func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) {
-	m.reverseMatch = isReverseMatch
+func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) {
+	if ip4 := ip.To4(); ip4 != nil {
+		key[0] = 4
+		key[1] = ip4[0]
+		key[2] = ip4[1]
+		key[3] = ip4[2] // /24
+		return key, true
+	}
+	if ip16 := ip.To16(); ip16 != nil {
+		key[0] = 6
+		key[1] = ip16[0]
+		key[2] = ip16[1]
+		key[3] = ip16[2]
+		key[4] = ip16[3]
+		key[5] = ip16[4]
+		key[6] = ip16[5]
+		key[7] = ip16[6]
+		key[8] = ip16[7] // /64
+		return key, true
+	}
+	return key, false // illegal
 }
 
-func (m *GeoIPMatcher) match4(ip net.IP) bool {
-	nip, ok := netipx.FromStdIP(ip)
+// FilterIPs implements GeoIPMatcher.
+func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
+	n := len(ips)
+	if n == 0 {
+		return []net.IP{}, []net.IP{}
+	}
+
+	if n == 1 {
+		ipx, ok := netipx.FromStdIP(ips[0])
+		if !ok {
+			return []net.IP{}, []net.IP{}
+		}
+		if m.matchAddr(ipx) {
+			return ips, []net.IP{}
+		}
+		return []net.IP{}, ips
+	}
+
+	heur4 := m.ipset.max4 <= 24
+	heur6 := m.ipset.max6 <= 64
+	if !heur4 && !heur6 {
+		matched = make([]net.IP, 0, n)
+		unmatched = make([]net.IP, 0, n)
+		for _, ip := range ips {
+			ipx, ok := netipx.FromStdIP(ip)
+			if !ok {
+				continue // illegal ip, ignore
+			}
+			if m.matchAddr(ipx) {
+				matched = append(matched, ip)
+			} else {
+				unmatched = append(unmatched, ip)
+			}
+		}
+		return
+	}
+
+	buckets := make(map[[9]byte]*ipBucket, n)
+	precise := make([]net.IP, 0, n)
+
+	for _, ip := range ips {
+		key, ok := prefixKeyFromIP(ip)
+		if !ok {
+			continue // illegal ip, ignore
+		}
+
+		if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) {
+			precise = append(precise, ip)
+			continue
+		}
+
+		b, exists := buckets[key]
+		if !exists {
+			// build bucket
+			ipx, ok := netipx.FromStdIP(ip)
+			if !ok {
+				continue // illegal ip, ignore
+			}
+			b = &ipBucket{
+				rep: ipx,
+				ips: make([]net.IP, 0, 4), // for dns answer
+			}
+			buckets[key] = b
+		}
+		b.ips = append(b.ips, ip)
+	}
+
+	matched = make([]net.IP, 0, n)
+	unmatched = make([]net.IP, 0, n)
+	for _, b := range buckets {
+		if m.matchAddr(b.rep) {
+			matched = append(matched, b.ips...)
+		} else {
+			unmatched = append(unmatched, b.ips...)
+		}
+	}
+	for _, ip := range precise {
+		ipx, ok := netipx.FromStdIP(ip)
+		if !ok {
+			continue // illegal ip, ignore
+		}
+		if m.matchAddr(ipx) {
+			matched = append(matched, ip)
+		} else {
+			unmatched = append(unmatched, ip)
+		}
+	}
+	return
+}
+
+// ToggleReverse implements GeoIPMatcher.
+func (m *HeuristicGeoIPMatcher) ToggleReverse() {
+	m.reverse = !m.reverse
+}
+
+// SetReverse implements GeoIPMatcher.
+func (m *HeuristicGeoIPMatcher) SetReverse(reverse bool) {
+	m.reverse = reverse
+}
+
+type GeneralMultiGeoIPMatcher struct {
+	matchers []GeoIPMatcher
+}
+
+// Match implements GeoIPMatcher.
+func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool {
+	for _, m := range mm.matchers {
+		if m.Match(ip) {
+			return true
+		}
+	}
+	return false
+}
+
+// AnyMatch implements GeoIPMatcher.
+func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
+	for _, m := range mm.matchers {
+		if m.AnyMatch(ips) {
+			return true
+		}
+	}
+	return false
+}
+
+// Matches implements GeoIPMatcher.
+func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool {
+	for _, m := range mm.matchers {
+		if m.Matches(ips) {
+			return true
+		}
+	}
+	return false
+}
+
+// FilterIPs implements GeoIPMatcher.
+func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
+	matched = make([]net.IP, 0, len(ips))
+	unmatched = ips
+	for _, m := range mm.matchers {
+		if len(unmatched) == 0 {
+			break
+		}
+		var mtch []net.IP
+		mtch, unmatched = m.FilterIPs(unmatched)
+		if len(mtch) > 0 {
+			matched = append(matched, mtch...)
+		}
+	}
+	return
+}
+
+// ToggleReverse implements GeoIPMatcher.
+func (mm *GeneralMultiGeoIPMatcher) ToggleReverse() {
+	for _, m := range mm.matchers {
+		m.ToggleReverse()
+	}
+}
+
+// SetReverse implements GeoIPMatcher.
+func (mm *GeneralMultiGeoIPMatcher) SetReverse(reverse bool) {
+	for _, m := range mm.matchers {
+		m.SetReverse(reverse)
+	}
+}
+
+type HeuristicMultiGeoIPMatcher struct {
+	matchers []*HeuristicGeoIPMatcher
+}
+
+// Match implements GeoIPMatcher.
+func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool {
+	ipx, ok := netipx.FromStdIP(ip)
 	if !ok {
 		return false
 	}
 
-	return m.ip4.Contains(nip)
+	for _, m := range mm.matchers {
+		if m.matchAddr(ipx) {
+			return true
+		}
+	}
+	return false
 }
 
-func (m *GeoIPMatcher) match6(ip net.IP) bool {
-	nip, ok := netipx.FromStdIP(ip)
-	if !ok {
+// AnyMatch implements GeoIPMatcher.
+func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
+	n := len(ips)
+	if n == 0 {
 		return false
 	}
 
-	return m.ip6.Contains(nip)
+	if n == 1 {
+		return mm.Match(ips[0])
+	}
+
+	buckets := make(map[[9]byte]struct{}, n)
+	for _, ip := range ips {
+		var ipx netip.Addr
+		state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped
+		for _, m := range mm.matchers {
+			heur4 := m.ipset.max4 <= 24
+			heur6 := m.ipset.max6 <= 64
+
+			if state == 0 && (heur4 || heur6) {
+				key, ok := prefixKeyFromIP(ip)
+				if !ok {
+					break
+				}
+				if _, exists := buckets[key]; exists {
+					state = key[0]
+				} else {
+					buckets[key] = struct{}{}
+					state = 1
+				}
+			}
+			if (heur4 && state == 4) || (heur6 && state == 6) {
+				continue
+			}
+
+			if !ipx.IsValid() {
+				nipx, ok := netipx.FromStdIP(ip)
+				if !ok {
+					break
+				}
+				ipx = nipx
+			}
+			if m.matchAddr(ipx) {
+				return true
+			}
+		}
+	}
+	return false
 }
 
-// Match returns true if the given ip is included by the GeoIP.
-func (m *GeoIPMatcher) Match(ip net.IP) bool {
-	isMatched := false
-	switch len(ip) {
-	case net.IPv4len:
-		isMatched = m.match4(ip)
-	case net.IPv6len:
-		isMatched = m.match6(ip)
+// Matches implements GeoIPMatcher.
+func (mm *HeuristicMultiGeoIPMatcher) Matches(ips []net.IP) bool {
+	n := len(ips)
+	if n == 0 {
+		return false
+	}
+
+	if n == 1 {
+		return mm.Match(ips[0])
 	}
-	if m.reverseMatch {
-		return !isMatched
+
+	var views ipViews
+	for _, m := range mm.matchers {
+		if !views.ensureForMatcher(m, ips) {
+			return false
+		}
+
+		matched := true
+		if m.ipset.max4 <= 24 {
+			for _, ipx := range views.buckets4 {
+				if !m.matchAddr(ipx) {
+					matched = false
+					break
+				}
+			}
+		} else {
+			for _, ipx := range views.precise4 {
+				if !m.matchAddr(ipx) {
+					matched = false
+					break
+				}
+			}
+		}
+		if !matched {
+			continue
+		}
+
+		if m.ipset.max6 <= 64 {
+			for _, ipx := range views.buckets6 {
+				if !m.matchAddr(ipx) {
+					matched = false
+					break
+				}
+			}
+		} else {
+			for _, ipx := range views.precise6 {
+				if !m.matchAddr(ipx) {
+					matched = false
+					break
+				}
+			}
+		}
+		if matched {
+			return true
+		}
 	}
-	return isMatched
+	return false
 }
 
-// GeoIPMatcherContainer is a container for GeoIPMatchers. It keeps unique copies of GeoIPMatcher by country code.
-type GeoIPMatcherContainer struct {
-	matchers []*GeoIPMatcher
+type ipViews struct {
+	buckets4, buckets6 map[[9]byte]netip.Addr
+	precise4, precise6 []netip.Addr
 }
 
-// Add adds a new GeoIP set into the container.
-// If the country code of GeoIP is not empty, GeoIPMatcherContainer will try to find an existing one, instead of adding a new one.
-func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) {
-	if len(geoip.CountryCode) > 0 {
-		for _, m := range c.matchers {
-			if m.countryCode == geoip.CountryCode && m.reverseMatch == geoip.ReverseMatch {
-				return m, nil
+func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool {
+	needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
+	needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
+	needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
+	needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
+
+	if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
+		return true
+	}
+
+	if needHeur4 {
+		v.buckets4 = make(map[[9]byte]netip.Addr, len(ips))
+	}
+	if needHeur6 {
+		v.buckets6 = make(map[[9]byte]netip.Addr, len(ips))
+	}
+	if needPrec4 {
+		v.precise4 = make([]netip.Addr, 0, len(ips))
+	}
+	if needPrec6 {
+		v.precise6 = make([]netip.Addr, 0, len(ips))
+	}
+
+	for _, ip := range ips {
+		key, ok := prefixKeyFromIP(ip)
+		if !ok {
+			return false
+		}
+
+		switch key[0] {
+		case 4:
+			var ipx netip.Addr
+			if needHeur4 {
+				if _, exists := v.buckets4[key]; !exists {
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						return false
+					}
+					v.buckets4[key] = ipx
+				}
+			}
+			if needPrec4 {
+				if !ipx.IsValid() {
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						return false
+					}
+				}
+				v.precise4 = append(v.precise4, ipx)
+			}
+		case 6:
+			var ipx netip.Addr
+			if needHeur6 {
+				if _, exists := v.buckets6[key]; !exists {
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						return false
+					}
+					v.buckets6[key] = ipx
+				}
+			}
+			if needPrec6 {
+				if !ipx.IsValid() {
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						return false
+					}
+				}
+				v.precise6 = append(v.precise6, ipx)
 			}
+		default:
+			return false
 		}
 	}
 
-	m := &GeoIPMatcher{
-		countryCode:  geoip.CountryCode,
-		reverseMatch: geoip.ReverseMatch,
+	return true
+}
+
+// FilterIPs implements GeoIPMatcher.
+func (mm *HeuristicMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
+	n := len(ips)
+	if n == 0 {
+		return []net.IP{}, []net.IP{}
 	}
-	if err := m.Init(geoip.Cidr); err != nil {
-		return nil, err
+
+	if n == 1 {
+		ipx, ok := netipx.FromStdIP(ips[0])
+		if !ok {
+			return []net.IP{}, []net.IP{}
+		}
+		for _, m := range mm.matchers {
+			if m.matchAddr(ipx) {
+				return ips, []net.IP{}
+			}
+		}
+		return []net.IP{}, ips
 	}
-	if len(geoip.CountryCode) > 0 {
-		c.matchers = append(c.matchers, m)
+
+	var views ipBucketViews
+
+	matched = make([]net.IP, 0, n)
+	for _, m := range mm.matchers {
+		views.ensureForMatcher(m, ips)
+
+		if m.ipset.max4 <= 24 {
+			for key, b := range views.buckets4 {
+				if b == nil {
+					continue
+				}
+				if m.matchAddr(b.rep) {
+					views.buckets4[key] = nil
+					matched = append(matched, b.ips...)
+				}
+			}
+		} else {
+			for ipx, ip := range views.precise4 {
+				if ip == nil {
+					continue
+				}
+				if m.matchAddr(ipx) {
+					views.precise4[ipx] = nil
+					matched = append(matched, ip)
+				}
+			}
+		}
+
+		if m.ipset.max6 <= 64 {
+			for key, b := range views.buckets6 {
+				if b == nil {
+					continue
+				}
+				if m.matchAddr(b.rep) {
+					views.buckets6[key] = nil
+					matched = append(matched, b.ips...)
+				}
+			}
+		} else {
+			for ipx, ip := range views.precise6 {
+				if ip == nil {
+					continue
+				}
+				if m.matchAddr(ipx) {
+					views.precise6[ipx] = nil
+					matched = append(matched, ip)
+				}
+			}
+		}
+	}
+
+	unmatched = make([]net.IP, 0, n-len(matched))
+	if views.buckets4 != nil {
+		for _, b := range views.buckets4 {
+			if b == nil {
+				continue
+			}
+			unmatched = append(unmatched, b.ips...)
+		}
+	}
+	if views.precise4 != nil {
+		for _, ip := range views.precise4 {
+			if ip == nil {
+				continue
+			}
+			unmatched = append(unmatched, ip)
+		}
+	}
+	if views.buckets6 != nil {
+		for _, b := range views.buckets6 {
+			if b == nil {
+				continue
+			}
+			unmatched = append(unmatched, b.ips...)
+		}
+	}
+	if views.precise6 != nil {
+		for _, ip := range views.precise6 {
+			if ip == nil {
+				continue
+			}
+			unmatched = append(unmatched, ip)
+		}
 	}
-	return m, nil
+
+	return
+}
+
+type ipBucketViews struct {
+	buckets4, buckets6 map[[9]byte]*ipBucket
+	precise4, precise6 map[netip.Addr]net.IP
 }
 
-var GlobalGeoIPContainer GeoIPMatcherContainer
+func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) {
+	needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
+	needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
+	needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
+	needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
 
-func MatchIPs(matchers []*GeoIPMatcher, ips []net.IP, reverse bool) []net.IP {
-	if len(matchers) == 0 {
-		panic("GeoIP matchers should not be empty to avoid ambiguity")
+	if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
+		return
 	}
-	newIPs := make([]net.IP, 0, len(ips))
-	var isFound bool
+
+	if needHeur4 {
+		v.buckets4 = make(map[[9]byte]*ipBucket, len(ips))
+	}
+	if needHeur6 {
+		v.buckets6 = make(map[[9]byte]*ipBucket, len(ips))
+	}
+	if needPrec4 {
+		v.precise4 = make(map[netip.Addr]net.IP, len(ips))
+	}
+	if needPrec6 {
+		v.precise6 = make(map[netip.Addr]net.IP, len(ips))
+	}
+
 	for _, ip := range ips {
-		isFound = false
-		for _, matcher := range matchers {
-			if matcher.Match(ip) {
-				isFound = true
-				break
+		key, ok := prefixKeyFromIP(ip)
+		if !ok {
+			continue // illegal ip, ignore
+		}
+
+		switch key[0] {
+		case 4:
+			var ipx netip.Addr
+			if needHeur4 {
+				b, exists := v.buckets4[key]
+				if !exists {
+					// build bucket
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						continue // illegal ip, ignore
+					}
+					b = &ipBucket{
+						rep: ipx,
+						ips: make([]net.IP, 0, 4), // for dns answer
+					}
+					v.buckets4[key] = b
+				}
+				b.ips = append(b.ips, ip)
+			}
+			if needPrec4 {
+				if !ipx.IsValid() {
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						continue // illegal ip, ignore
+					}
+				}
+				v.precise4[ipx] = ip
+			}
+		case 6:
+			var ipx netip.Addr
+			if needHeur6 {
+				b, exists := v.buckets6[key]
+				if !exists {
+					// build bucket
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						continue // illegal ip, ignore
+					}
+					b = &ipBucket{
+						rep: ipx,
+						ips: make([]net.IP, 0, 4), // for dns answer
+					}
+					v.buckets6[key] = b
+				}
+				b.ips = append(b.ips, ip)
+			}
+			if needPrec6 {
+				if !ipx.IsValid() {
+					ipx, ok = netipx.FromStdIP(ip)
+					if !ok {
+						continue // illegal ip, ignore
+					}
+				}
+				v.precise6[ipx] = ip
 			}
 		}
-		if isFound && !reverse {
-			newIPs = append(newIPs, ip)
-			continue
+	}
+}
+
+// ToggleReverse implements GeoIPMatcher.
+func (mm *HeuristicMultiGeoIPMatcher) ToggleReverse() {
+	for _, m := range mm.matchers {
+		m.ToggleReverse()
+	}
+}
+
+// SetReverse implements GeoIPMatcher.
+func (mm *HeuristicMultiGeoIPMatcher) SetReverse(reverse bool) {
+	for _, m := range mm.matchers {
+		m.SetReverse(reverse)
+	}
+}
+
+type GeoIPSetFactory struct {
+	sync.Mutex
+	shared map[string]*GeoIPSet // TODO: cleanup
+}
+
+var ipsetFactory = GeoIPSetFactory{shared: make(map[string]*GeoIPSet)}
+
+func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPSet, error) {
+	f.Lock()
+	defer f.Unlock()
+
+	if ipset := f.shared[key]; ipset != nil {
+		return ipset, nil
+	}
+
+	ipset, err := f.Create(cidrGroups...)
+	if err == nil {
+		f.shared[key] = ipset
+	}
+	return ipset, err
+}
+
+func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) {
+	var ipv4Builder, ipv6Builder netipx.IPSetBuilder
+
+	for _, cidrGroup := range cidrGroups {
+		for _, cidrEntry := range cidrGroup {
+			ipBytes := cidrEntry.GetIp()
+			prefixLen := int(cidrEntry.GetPrefix())
+
+			addr, ok := netip.AddrFromSlice(ipBytes)
+			if !ok {
+				errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes)
+				continue
+			}
+
+			prefix := netip.PrefixFrom(addr, prefixLen)
+			if !prefix.IsValid() {
+				errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen)
+				continue
+			}
+
+			if addr.Is4() {
+				ipv4Builder.AddPrefix(prefix)
+			} else if addr.Is6() {
+				ipv6Builder.AddPrefix(prefix)
+			}
+		}
+	}
+
+	ipv4, err := ipv4Builder.IPSet()
+	if err != nil {
+		return nil, errors.New("failed to build IPv4 set").Base(err)
+	}
+	ipv6, err := ipv6Builder.IPSet()
+	if err != nil {
+		return nil, errors.New("failed to build IPv6 set").Base(err)
+	}
+
+	var max4, max6 int
+
+	for _, p := range ipv4.Prefixes() {
+		if b := p.Bits(); b > max4 {
+			max4 = b
+		}
+	}
+	for _, p := range ipv6.Prefixes() {
+		if b := p.Bits(); b > max6 {
+			max6 = b
 		}
-		if !isFound && reverse {
-			newIPs = append(newIPs, ip)
+	}
+
+	if max4 == 0 {
+		max4 = 0xff
+	}
+	if max6 == 0 {
+		max6 = 0xff
+	}
+
+	return &GeoIPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil
+}
+
+func BuildOptimizedGeoIPMatcher(geoips ...*GeoIP) (GeoIPMatcher, error) {
+	n := len(geoips)
+	if n == 0 {
+		return nil, errors.New("no geoip configs provided")
+	}
+
+	var subs []*HeuristicGeoIPMatcher
+	pos := make([]*GeoIP, 0, n)
+	neg := make([]*GeoIP, 0, n/2)
+
+	for _, geoip := range geoips {
+		if geoip == nil {
+			return nil, errors.New("geoip entry is nil")
+		}
+		if geoip.CountryCode == "" {
+			ipset, err := ipsetFactory.Create(geoip.Cidr)
+			if err != nil {
+				return nil, err
+			}
+			subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: geoip.ReverseMatch})
 			continue
 		}
+		if !geoip.ReverseMatch {
+			pos = append(pos, geoip)
+		} else {
+			neg = append(neg, geoip)
+		}
+	}
+
+	buildIPSet := func(mergeables []*GeoIP) (*GeoIPSet, error) {
+		n := len(mergeables)
+		if n == 0 {
+			return nil, nil
+		}
+
+		sort.Slice(mergeables, func(i, j int) bool {
+			gi, gj := mergeables[i], mergeables[j]
+			return gi.CountryCode < gj.CountryCode
+		})
+
+		var sb strings.Builder
+		sb.Grow(n * 3) // xx,
+		cidrGroups := make([][]*CIDR, 0, n)
+		var last *GeoIP
+		for i, geoip := range mergeables {
+			if i == 0 || (geoip.CountryCode != last.CountryCode) {
+				last = geoip
+				sb.WriteString(geoip.CountryCode)
+				sb.WriteString(",")
+				cidrGroups = append(cidrGroups, geoip.Cidr)
+			}
+		}
+
+		return ipsetFactory.GetOrCreate(sb.String(), cidrGroups)
+	}
+
+	ipset, err := buildIPSet(pos)
+	if err != nil {
+		return nil, err
+	}
+	if ipset != nil {
+		subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: false})
+	}
+
+	ipset, err = buildIPSet(neg)
+	if err != nil {
+		return nil, err
+	}
+	if ipset != nil {
+		subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: true})
+	}
+
+	switch len(subs) {
+	case 0:
+		return nil, errors.New("no valid geoip matcher")
+	case 1:
+		return subs[0], nil
+	default:
+		return &HeuristicMultiGeoIPMatcher{matchers: subs}, nil
 	}
-	return newIPs
 }

+ 29 - 42
app/router/condition_geoip_test.go

@@ -35,33 +35,6 @@ func getAssetPath(file string) (string, error) {
 	return path, nil
 }
 
-func TestGeoIPMatcherContainer(t *testing.T) {
-	container := &router.GeoIPMatcherContainer{}
-
-	m1, err := container.Add(&router.GeoIP{
-		CountryCode: "CN",
-	})
-	common.Must(err)
-
-	m2, err := container.Add(&router.GeoIP{
-		CountryCode: "US",
-	})
-	common.Must(err)
-
-	m3, err := container.Add(&router.GeoIP{
-		CountryCode: "CN",
-	})
-	common.Must(err)
-
-	if m1 != m3 {
-		t.Error("expect same matcher for same geoip, but not")
-	}
-
-	if m1 == m2 {
-		t.Error("expect different matcher for different geoip, but actually same")
-	}
-}
-
 func TestGeoIPMatcher(t *testing.T) {
 	cidrList := []*router.CIDR{
 		{Ip: []byte{0, 0, 0, 0}, Prefix: 8},
@@ -80,8 +53,10 @@ func TestGeoIPMatcher(t *testing.T) {
 		{Ip: []byte{91, 108, 4, 0}, Prefix: 16},
 	}
 
-	matcher := &router.GeoIPMatcher{}
-	common.Must(matcher.Init(cidrList))
+	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
+		Cidr: cidrList,
+	})
+	common.Must(err)
 
 	testCases := []struct {
 		Input  string
@@ -140,8 +115,10 @@ func TestGeoIPMatcherRegression(t *testing.T) {
 		{Ip: []byte{98, 108, 20, 0}, Prefix: 23},
 	}
 
-	matcher := &router.GeoIPMatcher{}
-	common.Must(matcher.Init(cidrList))
+	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
+		Cidr: cidrList,
+	})
+	common.Must(err)
 
 	testCases := []struct {
 		Input  string
@@ -171,9 +148,11 @@ func TestGeoIPReverseMatcher(t *testing.T) {
 		{Ip: []byte{8, 8, 8, 8}, Prefix: 32},
 		{Ip: []byte{91, 108, 4, 0}, Prefix: 16},
 	}
-	matcher := &router.GeoIPMatcher{}
-	matcher.SetReverseMatch(true) // Reverse match
-	common.Must(matcher.Init(cidrList))
+	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
+		Cidr: cidrList,
+	})
+	common.Must(err)
+	matcher.SetReverse(true) // Reverse match
 
 	testCases := []struct {
 		Input  string
@@ -206,8 +185,10 @@ func TestGeoIPMatcher4CN(t *testing.T) {
 	ips, err := loadGeoIP("CN")
 	common.Must(err)
 
-	matcher := &router.GeoIPMatcher{}
-	common.Must(matcher.Init(ips))
+	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
+		Cidr: ips,
+	})
+	common.Must(err)
 
 	if matcher.Match([]byte{8, 8, 8, 8}) {
 		t.Error("expect CN geoip doesn't contain 8.8.8.8, but actually does")
@@ -218,8 +199,10 @@ func TestGeoIPMatcher6US(t *testing.T) {
 	ips, err := loadGeoIP("US")
 	common.Must(err)
 
-	matcher := &router.GeoIPMatcher{}
-	common.Must(matcher.Init(ips))
+	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
+		Cidr: ips,
+	})
+	common.Must(err)
 
 	if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
 		t.Error("expect US geoip contain 2001:4860:4860::8888, but actually not")
@@ -254,8 +237,10 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
 	ips, err := loadGeoIP("CN")
 	common.Must(err)
 
-	matcher := &router.GeoIPMatcher{}
-	common.Must(matcher.Init(ips))
+	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
+		Cidr: ips,
+	})
+	common.Must(err)
 
 	b.ResetTimer()
 
@@ -268,8 +253,10 @@ func BenchmarkGeoIPMatcher6US(b *testing.B) {
 	ips, err := loadGeoIP("US")
 	common.Must(err)
 
-	matcher := &router.GeoIPMatcher{}
-	common.Must(matcher.Init(ips))
+	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
+		Cidr: ips,
+	})
+	common.Must(err)
 
 	b.ResetTimer()
 

+ 1 - 1
app/router/condition_test.go

@@ -447,7 +447,7 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
 		})
 	}
 
-	matcher, err := NewMultiGeoIPMatcher(geoips, "target")
+	matcher, err := NewIPMatcher(geoips, MatcherAsType_Target)
 	common.Must(err)
 
 	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.ParseAddress("8.8.8.8"), 80)})

+ 7 - 7
app/router/config.go

@@ -46,7 +46,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	}
 
 	if rr.VlessRouteList != nil {
-		conds.Add(NewPortMatcher(rr.VlessRouteList, "vlessRoute"))
+		conds.Add(NewPortMatcher(rr.VlessRouteList, MatcherAsType_VlessRoute))
 	}
 
 	if len(rr.InboundTag) > 0 {
@@ -54,15 +54,15 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	}
 
 	if rr.PortList != nil {
-		conds.Add(NewPortMatcher(rr.PortList, "target"))
+		conds.Add(NewPortMatcher(rr.PortList, MatcherAsType_Target))
 	}
 
 	if rr.SourcePortList != nil {
-		conds.Add(NewPortMatcher(rr.SourcePortList, "source"))
+		conds.Add(NewPortMatcher(rr.SourcePortList, MatcherAsType_Source))
 	}
 
 	if rr.LocalPortList != nil {
-		conds.Add(NewPortMatcher(rr.LocalPortList, "local"))
+		conds.Add(NewPortMatcher(rr.LocalPortList, MatcherAsType_Local))
 	}
 
 	if len(rr.Networks) > 0 {
@@ -70,7 +70,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	}
 
 	if len(rr.Geoip) > 0 {
-		cond, err := NewMultiGeoIPMatcher(rr.Geoip, "target")
+		cond, err := NewIPMatcher(rr.Geoip, MatcherAsType_Target)
 		if err != nil {
 			return nil, err
 		}
@@ -78,7 +78,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	}
 
 	if len(rr.SourceGeoip) > 0 {
-		cond, err := NewMultiGeoIPMatcher(rr.SourceGeoip, "source")
+		cond, err := NewIPMatcher(rr.SourceGeoip, MatcherAsType_Source)
 		if err != nil {
 			return nil, err
 		}
@@ -86,7 +86,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	}
 
 	if len(rr.LocalGeoip) > 0 {
-		cond, err := NewMultiGeoIPMatcher(rr.LocalGeoip, "local")
+		cond, err := NewIPMatcher(rr.LocalGeoip, MatcherAsType_Local)
 		if err != nil {
 			return nil, err
 		}