| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961 |
- package router
- import (
- "context"
- "net/netip"
- "sort"
- "strings"
- "sync"
- "github.com/xtls/xray-core/common/errors"
- "github.com/xtls/xray-core/common/net"
- "go4.org/netipx"
- )
- type GeoIPMatcher interface {
- // TODO: (PERF) all net.IP -> netipx.Addr
- // Invalid IP always return false.
- Match(ip net.IP) bool
- // Returns true if *any* IP is valid and match.
- AnyMatch(ips []net.IP) bool
- // Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false.
- Matches(ips []net.IP) bool
- // Filters IPs. Invalid IPs are silently dropped and not included in either result.
- FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP)
- ToggleReverse()
- SetReverse(reverse bool)
- }
- type GeoIPSet struct {
- ipv4, ipv6 *netipx.IPSet
- max4, max6 uint8
- }
- type HeuristicGeoIPMatcher struct {
- ipset *GeoIPSet
- reverse bool
- }
- type ipBucket struct {
- rep netip.Addr
- ips []net.IP
- }
- // Match implements GeoIPMatcher.
- func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool {
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- return m.matchAddr(ipx)
- }
- func (m *HeuristicGeoIPMatcher) matchAddr(ipx netip.Addr) bool {
- if ipx.Is4() {
- return m.ipset.ipv4.Contains(ipx) != m.reverse
- }
- if ipx.Is6() {
- return m.ipset.ipv6.Contains(ipx) != m.reverse
- }
- return false
- }
- // AnyMatch implements GeoIPMatcher.
- func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool {
- n := len(ips)
- if n == 0 {
- return false
- }
- if n == 1 {
- return m.Match(ips[0])
- }
- heur4 := m.ipset.max4 <= 24
- heur6 := m.ipset.max6 <= 64
- if !heur4 && !heur6 {
- for _, ip := range ips {
- if ipx, ok := netipx.FromStdIP(ip); ok {
- if m.matchAddr(ipx) {
- return true
- }
- }
- }
- return false
- }
- buckets := make(map[[9]byte]struct{}, n)
- for _, ip := range ips {
- key, ok := prefixKeyFromIP(ip)
- if !ok {
- continue
- }
- heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6)
- if heur {
- if _, exists := buckets[key]; exists {
- continue
- }
- }
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- continue
- }
- if m.matchAddr(ipx) {
- return true
- }
- if heur {
- buckets[key] = struct{}{}
- }
- }
- return false
- }
- // Matches implements GeoIPMatcher.
- func (m *HeuristicGeoIPMatcher) Matches(ips []net.IP) bool {
- n := len(ips)
- if n == 0 {
- return false
- }
- if n == 1 {
- return m.Match(ips[0])
- }
- heur4 := m.ipset.max4 <= 24
- heur6 := m.ipset.max6 <= 64
- if !heur4 && !heur6 {
- for _, ip := range ips {
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- if !m.matchAddr(ipx) {
- return false
- }
- }
- return true
- }
- buckets := make(map[[9]byte]netip.Addr, n)
- precise := make([]netip.Addr, 0, n)
- for _, ip := range ips {
- key, ok := prefixKeyFromIP(ip)
- if !ok {
- return false
- }
- if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) {
- if _, exists := buckets[key]; !exists {
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- buckets[key] = ipx
- }
- } else {
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- precise = append(precise, ipx)
- }
- }
- for _, ipx := range buckets {
- if !m.matchAddr(ipx) {
- return false
- }
- }
- for _, ipx := range precise {
- if !m.matchAddr(ipx) {
- return false
- }
- }
- return true
- }
- func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) {
- if ip4 := ip.To4(); ip4 != nil {
- key[0] = 4
- key[1] = ip4[0]
- key[2] = ip4[1]
- key[3] = ip4[2] // /24
- return key, true
- }
- if ip16 := ip.To16(); ip16 != nil {
- key[0] = 6
- key[1] = ip16[0]
- key[2] = ip16[1]
- key[3] = ip16[2]
- key[4] = ip16[3]
- key[5] = ip16[4]
- key[6] = ip16[5]
- key[7] = ip16[6]
- key[8] = ip16[7] // /64
- return key, true
- }
- return key, false // illegal
- }
- // FilterIPs implements GeoIPMatcher.
- func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
- n := len(ips)
- if n == 0 {
- return []net.IP{}, []net.IP{}
- }
- if n == 1 {
- ipx, ok := netipx.FromStdIP(ips[0])
- if !ok {
- return []net.IP{}, []net.IP{}
- }
- if m.matchAddr(ipx) {
- return ips, []net.IP{}
- }
- return []net.IP{}, ips
- }
- heur4 := m.ipset.max4 <= 24
- heur6 := m.ipset.max6 <= 64
- if !heur4 && !heur6 {
- matched = make([]net.IP, 0, n)
- unmatched = make([]net.IP, 0, n)
- for _, ip := range ips {
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- if m.matchAddr(ipx) {
- matched = append(matched, ip)
- } else {
- unmatched = append(unmatched, ip)
- }
- }
- return
- }
- buckets := make(map[[9]byte]*ipBucket, n)
- precise := make([]net.IP, 0, n)
- for _, ip := range ips {
- key, ok := prefixKeyFromIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) {
- precise = append(precise, ip)
- continue
- }
- b, exists := buckets[key]
- if !exists {
- // build bucket
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- b = &ipBucket{
- rep: ipx,
- ips: make([]net.IP, 0, 4), // for dns answer
- }
- buckets[key] = b
- }
- b.ips = append(b.ips, ip)
- }
- matched = make([]net.IP, 0, n)
- unmatched = make([]net.IP, 0, n)
- for _, b := range buckets {
- if m.matchAddr(b.rep) {
- matched = append(matched, b.ips...)
- } else {
- unmatched = append(unmatched, b.ips...)
- }
- }
- for _, ip := range precise {
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- if m.matchAddr(ipx) {
- matched = append(matched, ip)
- } else {
- unmatched = append(unmatched, ip)
- }
- }
- return
- }
- // ToggleReverse implements GeoIPMatcher.
- func (m *HeuristicGeoIPMatcher) ToggleReverse() {
- m.reverse = !m.reverse
- }
- // SetReverse implements GeoIPMatcher.
- func (m *HeuristicGeoIPMatcher) SetReverse(reverse bool) {
- m.reverse = reverse
- }
- type GeneralMultiGeoIPMatcher struct {
- matchers []GeoIPMatcher
- }
- // Match implements GeoIPMatcher.
- func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool {
- for _, m := range mm.matchers {
- if m.Match(ip) {
- return true
- }
- }
- return false
- }
- // AnyMatch implements GeoIPMatcher.
- func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
- for _, m := range mm.matchers {
- if m.AnyMatch(ips) {
- return true
- }
- }
- return false
- }
- // Matches implements GeoIPMatcher.
- func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool {
- for _, m := range mm.matchers {
- if m.Matches(ips) {
- return true
- }
- }
- return false
- }
- // FilterIPs implements GeoIPMatcher.
- func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
- matched = make([]net.IP, 0, len(ips))
- unmatched = ips
- for _, m := range mm.matchers {
- if len(unmatched) == 0 {
- break
- }
- var mtch []net.IP
- mtch, unmatched = m.FilterIPs(unmatched)
- if len(mtch) > 0 {
- matched = append(matched, mtch...)
- }
- }
- return
- }
- // ToggleReverse implements GeoIPMatcher.
- func (mm *GeneralMultiGeoIPMatcher) ToggleReverse() {
- for _, m := range mm.matchers {
- m.ToggleReverse()
- }
- }
- // SetReverse implements GeoIPMatcher.
- func (mm *GeneralMultiGeoIPMatcher) SetReverse(reverse bool) {
- for _, m := range mm.matchers {
- m.SetReverse(reverse)
- }
- }
- type HeuristicMultiGeoIPMatcher struct {
- matchers []*HeuristicGeoIPMatcher
- }
- // Match implements GeoIPMatcher.
- func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool {
- ipx, ok := netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- for _, m := range mm.matchers {
- if m.matchAddr(ipx) {
- return true
- }
- }
- return false
- }
- // AnyMatch implements GeoIPMatcher.
- func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
- n := len(ips)
- if n == 0 {
- return false
- }
- if n == 1 {
- return mm.Match(ips[0])
- }
- buckets := make(map[[9]byte]struct{}, n)
- for _, ip := range ips {
- var ipx netip.Addr
- state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped
- for _, m := range mm.matchers {
- heur4 := m.ipset.max4 <= 24
- heur6 := m.ipset.max6 <= 64
- if state == 0 && (heur4 || heur6) {
- key, ok := prefixKeyFromIP(ip)
- if !ok {
- break
- }
- if _, exists := buckets[key]; exists {
- state = key[0]
- } else {
- buckets[key] = struct{}{}
- state = 1
- }
- }
- if (heur4 && state == 4) || (heur6 && state == 6) {
- continue
- }
- if !ipx.IsValid() {
- nipx, ok := netipx.FromStdIP(ip)
- if !ok {
- break
- }
- ipx = nipx
- }
- if m.matchAddr(ipx) {
- return true
- }
- }
- }
- return false
- }
- // Matches implements GeoIPMatcher.
- func (mm *HeuristicMultiGeoIPMatcher) Matches(ips []net.IP) bool {
- n := len(ips)
- if n == 0 {
- return false
- }
- if n == 1 {
- return mm.Match(ips[0])
- }
- var views ipViews
- for _, m := range mm.matchers {
- if !views.ensureForMatcher(m, ips) {
- return false
- }
- matched := true
- if m.ipset.max4 <= 24 {
- for _, ipx := range views.buckets4 {
- if !m.matchAddr(ipx) {
- matched = false
- break
- }
- }
- } else {
- for _, ipx := range views.precise4 {
- if !m.matchAddr(ipx) {
- matched = false
- break
- }
- }
- }
- if !matched {
- continue
- }
- if m.ipset.max6 <= 64 {
- for _, ipx := range views.buckets6 {
- if !m.matchAddr(ipx) {
- matched = false
- break
- }
- }
- } else {
- for _, ipx := range views.precise6 {
- if !m.matchAddr(ipx) {
- matched = false
- break
- }
- }
- }
- if matched {
- return true
- }
- }
- return false
- }
- type ipViews struct {
- buckets4, buckets6 map[[9]byte]netip.Addr
- precise4, precise6 []netip.Addr
- }
- func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool {
- needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
- needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
- needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
- needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
- if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
- return true
- }
- if needHeur4 {
- v.buckets4 = make(map[[9]byte]netip.Addr, len(ips))
- }
- if needHeur6 {
- v.buckets6 = make(map[[9]byte]netip.Addr, len(ips))
- }
- if needPrec4 {
- v.precise4 = make([]netip.Addr, 0, len(ips))
- }
- if needPrec6 {
- v.precise6 = make([]netip.Addr, 0, len(ips))
- }
- for _, ip := range ips {
- key, ok := prefixKeyFromIP(ip)
- if !ok {
- return false
- }
- switch key[0] {
- case 4:
- var ipx netip.Addr
- if needHeur4 {
- if _, exists := v.buckets4[key]; !exists {
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- v.buckets4[key] = ipx
- }
- }
- if needPrec4 {
- if !ipx.IsValid() {
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- }
- v.precise4 = append(v.precise4, ipx)
- }
- case 6:
- var ipx netip.Addr
- if needHeur6 {
- if _, exists := v.buckets6[key]; !exists {
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- v.buckets6[key] = ipx
- }
- }
- if needPrec6 {
- if !ipx.IsValid() {
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- return false
- }
- }
- v.precise6 = append(v.precise6, ipx)
- }
- default:
- return false
- }
- }
- return true
- }
- // FilterIPs implements GeoIPMatcher.
- func (mm *HeuristicMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
- n := len(ips)
- if n == 0 {
- return []net.IP{}, []net.IP{}
- }
- if n == 1 {
- ipx, ok := netipx.FromStdIP(ips[0])
- if !ok {
- return []net.IP{}, []net.IP{}
- }
- for _, m := range mm.matchers {
- if m.matchAddr(ipx) {
- return ips, []net.IP{}
- }
- }
- return []net.IP{}, ips
- }
- var views ipBucketViews
- matched = make([]net.IP, 0, n)
- for _, m := range mm.matchers {
- views.ensureForMatcher(m, ips)
- if m.ipset.max4 <= 24 {
- for key, b := range views.buckets4 {
- if b == nil {
- continue
- }
- if m.matchAddr(b.rep) {
- views.buckets4[key] = nil
- matched = append(matched, b.ips...)
- }
- }
- } else {
- for ipx, ip := range views.precise4 {
- if ip == nil {
- continue
- }
- if m.matchAddr(ipx) {
- views.precise4[ipx] = nil
- matched = append(matched, ip)
- }
- }
- }
- if m.ipset.max6 <= 64 {
- for key, b := range views.buckets6 {
- if b == nil {
- continue
- }
- if m.matchAddr(b.rep) {
- views.buckets6[key] = nil
- matched = append(matched, b.ips...)
- }
- }
- } else {
- for ipx, ip := range views.precise6 {
- if ip == nil {
- continue
- }
- if m.matchAddr(ipx) {
- views.precise6[ipx] = nil
- matched = append(matched, ip)
- }
- }
- }
- }
- unmatched = make([]net.IP, 0, n-len(matched))
- if views.buckets4 != nil {
- for _, b := range views.buckets4 {
- if b == nil {
- continue
- }
- unmatched = append(unmatched, b.ips...)
- }
- }
- if views.precise4 != nil {
- for _, ip := range views.precise4 {
- if ip == nil {
- continue
- }
- unmatched = append(unmatched, ip)
- }
- }
- if views.buckets6 != nil {
- for _, b := range views.buckets6 {
- if b == nil {
- continue
- }
- unmatched = append(unmatched, b.ips...)
- }
- }
- if views.precise6 != nil {
- for _, ip := range views.precise6 {
- if ip == nil {
- continue
- }
- unmatched = append(unmatched, ip)
- }
- }
- return
- }
- type ipBucketViews struct {
- buckets4, buckets6 map[[9]byte]*ipBucket
- precise4, precise6 map[netip.Addr]net.IP
- }
- func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) {
- needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
- needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
- needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
- needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
- if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
- return
- }
- if needHeur4 {
- v.buckets4 = make(map[[9]byte]*ipBucket, len(ips))
- }
- if needHeur6 {
- v.buckets6 = make(map[[9]byte]*ipBucket, len(ips))
- }
- if needPrec4 {
- v.precise4 = make(map[netip.Addr]net.IP, len(ips))
- }
- if needPrec6 {
- v.precise6 = make(map[netip.Addr]net.IP, len(ips))
- }
- for _, ip := range ips {
- key, ok := prefixKeyFromIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- switch key[0] {
- case 4:
- var ipx netip.Addr
- if needHeur4 {
- b, exists := v.buckets4[key]
- if !exists {
- // build bucket
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- b = &ipBucket{
- rep: ipx,
- ips: make([]net.IP, 0, 4), // for dns answer
- }
- v.buckets4[key] = b
- }
- b.ips = append(b.ips, ip)
- }
- if needPrec4 {
- if !ipx.IsValid() {
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- }
- v.precise4[ipx] = ip
- }
- case 6:
- var ipx netip.Addr
- if needHeur6 {
- b, exists := v.buckets6[key]
- if !exists {
- // build bucket
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- b = &ipBucket{
- rep: ipx,
- ips: make([]net.IP, 0, 4), // for dns answer
- }
- v.buckets6[key] = b
- }
- b.ips = append(b.ips, ip)
- }
- if needPrec6 {
- if !ipx.IsValid() {
- ipx, ok = netipx.FromStdIP(ip)
- if !ok {
- continue // illegal ip, ignore
- }
- }
- v.precise6[ipx] = ip
- }
- }
- }
- }
- // ToggleReverse implements GeoIPMatcher.
- func (mm *HeuristicMultiGeoIPMatcher) ToggleReverse() {
- for _, m := range mm.matchers {
- m.ToggleReverse()
- }
- }
- // SetReverse implements GeoIPMatcher.
- func (mm *HeuristicMultiGeoIPMatcher) SetReverse(reverse bool) {
- for _, m := range mm.matchers {
- m.SetReverse(reverse)
- }
- }
- type GeoIPSetFactory struct {
- sync.Mutex
- shared map[string]*GeoIPSet // TODO: cleanup
- }
- var ipsetFactory = GeoIPSetFactory{shared: make(map[string]*GeoIPSet)}
- func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPSet, error) {
- f.Lock()
- defer f.Unlock()
- if ipset := f.shared[key]; ipset != nil {
- return ipset, nil
- }
- ipset, err := f.Create(cidrGroups...)
- if err == nil {
- f.shared[key] = ipset
- }
- return ipset, err
- }
- func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) {
- var ipv4Builder, ipv6Builder netipx.IPSetBuilder
- for _, cidrGroup := range cidrGroups {
- for _, cidrEntry := range cidrGroup {
- ipBytes := cidrEntry.GetIp()
- prefixLen := int(cidrEntry.GetPrefix())
- addr, ok := netip.AddrFromSlice(ipBytes)
- if !ok {
- errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes)
- continue
- }
- prefix := netip.PrefixFrom(addr, prefixLen)
- if !prefix.IsValid() {
- errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen)
- continue
- }
- if addr.Is4() {
- ipv4Builder.AddPrefix(prefix)
- } else if addr.Is6() {
- ipv6Builder.AddPrefix(prefix)
- }
- }
- }
- ipv4, err := ipv4Builder.IPSet()
- if err != nil {
- return nil, errors.New("failed to build IPv4 set").Base(err)
- }
- ipv6, err := ipv6Builder.IPSet()
- if err != nil {
- return nil, errors.New("failed to build IPv6 set").Base(err)
- }
- var max4, max6 int
- for _, p := range ipv4.Prefixes() {
- if b := p.Bits(); b > max4 {
- max4 = b
- }
- }
- for _, p := range ipv6.Prefixes() {
- if b := p.Bits(); b > max6 {
- max6 = b
- }
- }
- if max4 == 0 {
- max4 = 0xff
- }
- if max6 == 0 {
- max6 = 0xff
- }
- return &GeoIPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil
- }
- func BuildOptimizedGeoIPMatcher(geoips ...*GeoIP) (GeoIPMatcher, error) {
- n := len(geoips)
- if n == 0 {
- return nil, errors.New("no geoip configs provided")
- }
- var subs []*HeuristicGeoIPMatcher
- pos := make([]*GeoIP, 0, n)
- neg := make([]*GeoIP, 0, n/2)
- for _, geoip := range geoips {
- if geoip == nil {
- return nil, errors.New("geoip entry is nil")
- }
- if geoip.CountryCode == "" {
- ipset, err := ipsetFactory.Create(geoip.Cidr)
- if err != nil {
- return nil, err
- }
- subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: geoip.ReverseMatch})
- continue
- }
- if !geoip.ReverseMatch {
- pos = append(pos, geoip)
- } else {
- neg = append(neg, geoip)
- }
- }
- buildIPSet := func(mergeables []*GeoIP) (*GeoIPSet, error) {
- n := len(mergeables)
- if n == 0 {
- return nil, nil
- }
- sort.Slice(mergeables, func(i, j int) bool {
- gi, gj := mergeables[i], mergeables[j]
- return gi.CountryCode < gj.CountryCode
- })
- var sb strings.Builder
- sb.Grow(n * 3) // xx,
- cidrGroups := make([][]*CIDR, 0, n)
- var last *GeoIP
- for i, geoip := range mergeables {
- if i == 0 || (geoip.CountryCode != last.CountryCode) {
- last = geoip
- sb.WriteString(geoip.CountryCode)
- sb.WriteString(",")
- cidrGroups = append(cidrGroups, geoip.Cidr)
- }
- }
- return ipsetFactory.GetOrCreate(sb.String(), cidrGroups)
- }
- ipset, err := buildIPSet(pos)
- if err != nil {
- return nil, err
- }
- if ipset != nil {
- subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: false})
- }
- ipset, err = buildIPSet(neg)
- if err != nil {
- return nil, err
- }
- if ipset != nil {
- subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: true})
- }
- switch len(subs) {
- case 0:
- return nil, errors.New("no valid geoip matcher")
- case 1:
- return subs[0], nil
- default:
- return &HeuristicMultiGeoIPMatcher{matchers: subs}, nil
- }
- }
|