|
|
@@ -1,144 +1,961 @@
|
|
|
package router
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"net/netip"
|
|
|
- "strconv"
|
|
|
+ "sort"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
|
|
|
+ "github.com/xtls/xray-core/common/errors"
|
|
|
"github.com/xtls/xray-core/common/net"
|
|
|
+
|
|
|
"go4.org/netipx"
|
|
|
)
|
|
|
|
|
|
-type GeoIPMatcher struct {
|
|
|
- countryCode string
|
|
|
- reverseMatch bool
|
|
|
- ip4 *netipx.IPSet
|
|
|
- ip6 *netipx.IPSet
|
|
|
+type GeoIPMatcher interface {
|
|
|
+ // TODO: (PERF) all net.IP -> netipx.Addr
|
|
|
+
|
|
|
+ // Invalid IP always return false.
|
|
|
+ Match(ip net.IP) bool
|
|
|
+
|
|
|
+ // Returns true if *any* IP is valid and match.
|
|
|
+ AnyMatch(ips []net.IP) bool
|
|
|
+
|
|
|
+ // Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false.
|
|
|
+ Matches(ips []net.IP) bool
|
|
|
+
|
|
|
+ // Filters IPs. Invalid IPs are silently dropped and not included in either result.
|
|
|
+ FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP)
|
|
|
+
|
|
|
+ ToggleReverse()
|
|
|
+
|
|
|
+ SetReverse(reverse bool)
|
|
|
+}
|
|
|
+
|
|
|
+type GeoIPSet struct {
|
|
|
+ ipv4, ipv6 *netipx.IPSet
|
|
|
+ max4, max6 uint8
|
|
|
}
|
|
|
|
|
|
-func (m *GeoIPMatcher) Init(cidrs []*CIDR) error {
|
|
|
- var builder4, builder6 netipx.IPSetBuilder
|
|
|
+type HeuristicGeoIPMatcher struct {
|
|
|
+ ipset *GeoIPSet
|
|
|
+ reverse bool
|
|
|
+}
|
|
|
+
|
|
|
+type ipBucket struct {
|
|
|
+ rep netip.Addr
|
|
|
+ ips []net.IP
|
|
|
+}
|
|
|
|
|
|
- for _, cidr := range cidrs {
|
|
|
- ip := net.IP(cidr.GetIp())
|
|
|
- ipPrefixString := ip.String() + "/" + strconv.Itoa(int(cidr.GetPrefix()))
|
|
|
- ipPrefix, err := netip.ParsePrefix(ipPrefixString)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+// Match implements GeoIPMatcher.
|
|
|
+func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool {
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ return m.matchAddr(ipx)
|
|
|
+}
|
|
|
+
|
|
|
+func (m *HeuristicGeoIPMatcher) matchAddr(ipx netip.Addr) bool {
|
|
|
+ if ipx.Is4() {
|
|
|
+ return m.ipset.ipv4.Contains(ipx) != m.reverse
|
|
|
+ }
|
|
|
+ if ipx.Is6() {
|
|
|
+ return m.ipset.ipv6.Contains(ipx) != m.reverse
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+// AnyMatch implements GeoIPMatcher.
|
|
|
+func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool {
|
|
|
+ n := len(ips)
|
|
|
+ if n == 0 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ if n == 1 {
|
|
|
+ return m.Match(ips[0])
|
|
|
+ }
|
|
|
+
|
|
|
+ heur4 := m.ipset.max4 <= 24
|
|
|
+ heur6 := m.ipset.max6 <= 64
|
|
|
+ if !heur4 && !heur6 {
|
|
|
+ for _, ip := range ips {
|
|
|
+ if ipx, ok := netipx.FromStdIP(ip); ok {
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
+ return false
|
|
|
+ }
|
|
|
|
|
|
- switch len(ip) {
|
|
|
- case net.IPv4len:
|
|
|
- builder4.AddPrefix(ipPrefix)
|
|
|
- case net.IPv6len:
|
|
|
- builder6.AddPrefix(ipPrefix)
|
|
|
+ buckets := make(map[[9]byte]struct{}, n)
|
|
|
+ for _, ip := range ips {
|
|
|
+ key, ok := prefixKeyFromIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6)
|
|
|
+ if heur {
|
|
|
+ if _, exists := buckets[key]; exists {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if heur {
|
|
|
+ buckets[key] = struct{}{}
|
|
|
}
|
|
|
}
|
|
|
+ return false
|
|
|
+}
|
|
|
|
|
|
- if ip4, err := builder4.IPSet(); err != nil {
|
|
|
- return err
|
|
|
- } else {
|
|
|
- m.ip4 = ip4
|
|
|
+// Matches implements GeoIPMatcher.
|
|
|
+func (m *HeuristicGeoIPMatcher) Matches(ips []net.IP) bool {
|
|
|
+ n := len(ips)
|
|
|
+ if n == 0 {
|
|
|
+ return false
|
|
|
}
|
|
|
|
|
|
- if ip6, err := builder6.IPSet(); err != nil {
|
|
|
- return err
|
|
|
- } else {
|
|
|
- m.ip6 = ip6
|
|
|
+ if n == 1 {
|
|
|
+ return m.Match(ips[0])
|
|
|
+ }
|
|
|
+
|
|
|
+ heur4 := m.ipset.max4 <= 24
|
|
|
+ heur6 := m.ipset.max6 <= 64
|
|
|
+ if !heur4 && !heur6 {
|
|
|
+ for _, ip := range ips {
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if !m.matchAddr(ipx) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ buckets := make(map[[9]byte]netip.Addr, n)
|
|
|
+ precise := make([]netip.Addr, 0, n)
|
|
|
+
|
|
|
+ for _, ip := range ips {
|
|
|
+ key, ok := prefixKeyFromIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) {
|
|
|
+ if _, exists := buckets[key]; !exists {
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ buckets[key] = ipx
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ precise = append(precise, ipx)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- return nil
|
|
|
+ for _, ipx := range buckets {
|
|
|
+ if !m.matchAddr(ipx) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for _, ipx := range precise {
|
|
|
+ if !m.matchAddr(ipx) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true
|
|
|
}
|
|
|
|
|
|
-func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) {
|
|
|
- m.reverseMatch = isReverseMatch
|
|
|
+func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) {
|
|
|
+ if ip4 := ip.To4(); ip4 != nil {
|
|
|
+ key[0] = 4
|
|
|
+ key[1] = ip4[0]
|
|
|
+ key[2] = ip4[1]
|
|
|
+ key[3] = ip4[2] // /24
|
|
|
+ return key, true
|
|
|
+ }
|
|
|
+ if ip16 := ip.To16(); ip16 != nil {
|
|
|
+ key[0] = 6
|
|
|
+ key[1] = ip16[0]
|
|
|
+ key[2] = ip16[1]
|
|
|
+ key[3] = ip16[2]
|
|
|
+ key[4] = ip16[3]
|
|
|
+ key[5] = ip16[4]
|
|
|
+ key[6] = ip16[5]
|
|
|
+ key[7] = ip16[6]
|
|
|
+ key[8] = ip16[7] // /64
|
|
|
+ return key, true
|
|
|
+ }
|
|
|
+ return key, false // illegal
|
|
|
}
|
|
|
|
|
|
-func (m *GeoIPMatcher) match4(ip net.IP) bool {
|
|
|
- nip, ok := netipx.FromStdIP(ip)
|
|
|
+// FilterIPs implements GeoIPMatcher.
|
|
|
+func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
|
|
+ n := len(ips)
|
|
|
+ if n == 0 {
|
|
|
+ return []net.IP{}, []net.IP{}
|
|
|
+ }
|
|
|
+
|
|
|
+ if n == 1 {
|
|
|
+ ipx, ok := netipx.FromStdIP(ips[0])
|
|
|
+ if !ok {
|
|
|
+ return []net.IP{}, []net.IP{}
|
|
|
+ }
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ return ips, []net.IP{}
|
|
|
+ }
|
|
|
+ return []net.IP{}, ips
|
|
|
+ }
|
|
|
+
|
|
|
+ heur4 := m.ipset.max4 <= 24
|
|
|
+ heur6 := m.ipset.max6 <= 64
|
|
|
+ if !heur4 && !heur6 {
|
|
|
+ matched = make([]net.IP, 0, n)
|
|
|
+ unmatched = make([]net.IP, 0, n)
|
|
|
+ for _, ip := range ips {
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ matched = append(matched, ip)
|
|
|
+ } else {
|
|
|
+ unmatched = append(unmatched, ip)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ buckets := make(map[[9]byte]*ipBucket, n)
|
|
|
+ precise := make([]net.IP, 0, n)
|
|
|
+
|
|
|
+ for _, ip := range ips {
|
|
|
+ key, ok := prefixKeyFromIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+
|
|
|
+ if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) {
|
|
|
+ precise = append(precise, ip)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ b, exists := buckets[key]
|
|
|
+ if !exists {
|
|
|
+ // build bucket
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+ b = &ipBucket{
|
|
|
+ rep: ipx,
|
|
|
+ ips: make([]net.IP, 0, 4), // for dns answer
|
|
|
+ }
|
|
|
+ buckets[key] = b
|
|
|
+ }
|
|
|
+ b.ips = append(b.ips, ip)
|
|
|
+ }
|
|
|
+
|
|
|
+ matched = make([]net.IP, 0, n)
|
|
|
+ unmatched = make([]net.IP, 0, n)
|
|
|
+ for _, b := range buckets {
|
|
|
+ if m.matchAddr(b.rep) {
|
|
|
+ matched = append(matched, b.ips...)
|
|
|
+ } else {
|
|
|
+ unmatched = append(unmatched, b.ips...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for _, ip := range precise {
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ matched = append(matched, ip)
|
|
|
+ } else {
|
|
|
+ unmatched = append(unmatched, ip)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+// ToggleReverse implements GeoIPMatcher.
|
|
|
+func (m *HeuristicGeoIPMatcher) ToggleReverse() {
|
|
|
+ m.reverse = !m.reverse
|
|
|
+}
|
|
|
+
|
|
|
+// SetReverse implements GeoIPMatcher.
|
|
|
+func (m *HeuristicGeoIPMatcher) SetReverse(reverse bool) {
|
|
|
+ m.reverse = reverse
|
|
|
+}
|
|
|
+
|
|
|
+type GeneralMultiGeoIPMatcher struct {
|
|
|
+ matchers []GeoIPMatcher
|
|
|
+}
|
|
|
+
|
|
|
+// Match implements GeoIPMatcher.
|
|
|
+func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool {
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ if m.Match(ip) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+// AnyMatch implements GeoIPMatcher.
|
|
|
+func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ if m.AnyMatch(ips) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+// Matches implements GeoIPMatcher.
|
|
|
+func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool {
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ if m.Matches(ips) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+// FilterIPs implements GeoIPMatcher.
|
|
|
+func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
|
|
+ matched = make([]net.IP, 0, len(ips))
|
|
|
+ unmatched = ips
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ if len(unmatched) == 0 {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ var mtch []net.IP
|
|
|
+ mtch, unmatched = m.FilterIPs(unmatched)
|
|
|
+ if len(mtch) > 0 {
|
|
|
+ matched = append(matched, mtch...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+// ToggleReverse implements GeoIPMatcher.
|
|
|
+func (mm *GeneralMultiGeoIPMatcher) ToggleReverse() {
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ m.ToggleReverse()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// SetReverse implements GeoIPMatcher.
|
|
|
+func (mm *GeneralMultiGeoIPMatcher) SetReverse(reverse bool) {
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ m.SetReverse(reverse)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type HeuristicMultiGeoIPMatcher struct {
|
|
|
+ matchers []*HeuristicGeoIPMatcher
|
|
|
+}
|
|
|
+
|
|
|
+// Match implements GeoIPMatcher.
|
|
|
+func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool {
|
|
|
+ ipx, ok := netipx.FromStdIP(ip)
|
|
|
if !ok {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
- return m.ip4.Contains(nip)
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
}
|
|
|
|
|
|
-func (m *GeoIPMatcher) match6(ip net.IP) bool {
|
|
|
- nip, ok := netipx.FromStdIP(ip)
|
|
|
- if !ok {
|
|
|
+// AnyMatch implements GeoIPMatcher.
|
|
|
+func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
|
|
|
+ n := len(ips)
|
|
|
+ if n == 0 {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
- return m.ip6.Contains(nip)
|
|
|
+ if n == 1 {
|
|
|
+ return mm.Match(ips[0])
|
|
|
+ }
|
|
|
+
|
|
|
+ buckets := make(map[[9]byte]struct{}, n)
|
|
|
+ for _, ip := range ips {
|
|
|
+ var ipx netip.Addr
|
|
|
+ state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ heur4 := m.ipset.max4 <= 24
|
|
|
+ heur6 := m.ipset.max6 <= 64
|
|
|
+
|
|
|
+ if state == 0 && (heur4 || heur6) {
|
|
|
+ key, ok := prefixKeyFromIP(ip)
|
|
|
+ if !ok {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if _, exists := buckets[key]; exists {
|
|
|
+ state = key[0]
|
|
|
+ } else {
|
|
|
+ buckets[key] = struct{}{}
|
|
|
+ state = 1
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (heur4 && state == 4) || (heur6 && state == 6) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if !ipx.IsValid() {
|
|
|
+ nipx, ok := netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ ipx = nipx
|
|
|
+ }
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
}
|
|
|
|
|
|
-// Match returns true if the given ip is included by the GeoIP.
|
|
|
-func (m *GeoIPMatcher) Match(ip net.IP) bool {
|
|
|
- isMatched := false
|
|
|
- switch len(ip) {
|
|
|
- case net.IPv4len:
|
|
|
- isMatched = m.match4(ip)
|
|
|
- case net.IPv6len:
|
|
|
- isMatched = m.match6(ip)
|
|
|
+// Matches implements GeoIPMatcher.
|
|
|
+func (mm *HeuristicMultiGeoIPMatcher) Matches(ips []net.IP) bool {
|
|
|
+ n := len(ips)
|
|
|
+ if n == 0 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ if n == 1 {
|
|
|
+ return mm.Match(ips[0])
|
|
|
}
|
|
|
- if m.reverseMatch {
|
|
|
- return !isMatched
|
|
|
+
|
|
|
+ var views ipViews
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ if !views.ensureForMatcher(m, ips) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ matched := true
|
|
|
+ if m.ipset.max4 <= 24 {
|
|
|
+ for _, ipx := range views.buckets4 {
|
|
|
+ if !m.matchAddr(ipx) {
|
|
|
+ matched = false
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for _, ipx := range views.precise4 {
|
|
|
+ if !m.matchAddr(ipx) {
|
|
|
+ matched = false
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if !matched {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if m.ipset.max6 <= 64 {
|
|
|
+ for _, ipx := range views.buckets6 {
|
|
|
+ if !m.matchAddr(ipx) {
|
|
|
+ matched = false
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for _, ipx := range views.precise6 {
|
|
|
+ if !m.matchAddr(ipx) {
|
|
|
+ matched = false
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if matched {
|
|
|
+ return true
|
|
|
+ }
|
|
|
}
|
|
|
- return isMatched
|
|
|
+ return false
|
|
|
}
|
|
|
|
|
|
-// GeoIPMatcherContainer is a container for GeoIPMatchers. It keeps unique copies of GeoIPMatcher by country code.
|
|
|
-type GeoIPMatcherContainer struct {
|
|
|
- matchers []*GeoIPMatcher
|
|
|
+type ipViews struct {
|
|
|
+ buckets4, buckets6 map[[9]byte]netip.Addr
|
|
|
+ precise4, precise6 []netip.Addr
|
|
|
}
|
|
|
|
|
|
-// Add adds a new GeoIP set into the container.
|
|
|
-// If the country code of GeoIP is not empty, GeoIPMatcherContainer will try to find an existing one, instead of adding a new one.
|
|
|
-func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) {
|
|
|
- if len(geoip.CountryCode) > 0 {
|
|
|
- for _, m := range c.matchers {
|
|
|
- if m.countryCode == geoip.CountryCode && m.reverseMatch == geoip.ReverseMatch {
|
|
|
- return m, nil
|
|
|
+func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool {
|
|
|
+ needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
|
|
|
+ needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
|
|
|
+ needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
|
|
|
+ needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
|
|
|
+
|
|
|
+ if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ if needHeur4 {
|
|
|
+ v.buckets4 = make(map[[9]byte]netip.Addr, len(ips))
|
|
|
+ }
|
|
|
+ if needHeur6 {
|
|
|
+ v.buckets6 = make(map[[9]byte]netip.Addr, len(ips))
|
|
|
+ }
|
|
|
+ if needPrec4 {
|
|
|
+ v.precise4 = make([]netip.Addr, 0, len(ips))
|
|
|
+ }
|
|
|
+ if needPrec6 {
|
|
|
+ v.precise6 = make([]netip.Addr, 0, len(ips))
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, ip := range ips {
|
|
|
+ key, ok := prefixKeyFromIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ switch key[0] {
|
|
|
+ case 4:
|
|
|
+ var ipx netip.Addr
|
|
|
+ if needHeur4 {
|
|
|
+ if _, exists := v.buckets4[key]; !exists {
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ v.buckets4[key] = ipx
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if needPrec4 {
|
|
|
+ if !ipx.IsValid() {
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ v.precise4 = append(v.precise4, ipx)
|
|
|
+ }
|
|
|
+ case 6:
|
|
|
+ var ipx netip.Addr
|
|
|
+ if needHeur6 {
|
|
|
+ if _, exists := v.buckets6[key]; !exists {
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ v.buckets6[key] = ipx
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if needPrec6 {
|
|
|
+ if !ipx.IsValid() {
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ v.precise6 = append(v.precise6, ipx)
|
|
|
}
|
|
|
+ default:
|
|
|
+ return false
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- m := &GeoIPMatcher{
|
|
|
- countryCode: geoip.CountryCode,
|
|
|
- reverseMatch: geoip.ReverseMatch,
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+// FilterIPs implements GeoIPMatcher.
|
|
|
+func (mm *HeuristicMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
|
|
|
+ n := len(ips)
|
|
|
+ if n == 0 {
|
|
|
+ return []net.IP{}, []net.IP{}
|
|
|
}
|
|
|
- if err := m.Init(geoip.Cidr); err != nil {
|
|
|
- return nil, err
|
|
|
+
|
|
|
+ if n == 1 {
|
|
|
+ ipx, ok := netipx.FromStdIP(ips[0])
|
|
|
+ if !ok {
|
|
|
+ return []net.IP{}, []net.IP{}
|
|
|
+ }
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ return ips, []net.IP{}
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return []net.IP{}, ips
|
|
|
}
|
|
|
- if len(geoip.CountryCode) > 0 {
|
|
|
- c.matchers = append(c.matchers, m)
|
|
|
+
|
|
|
+ var views ipBucketViews
|
|
|
+
|
|
|
+ matched = make([]net.IP, 0, n)
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ views.ensureForMatcher(m, ips)
|
|
|
+
|
|
|
+ if m.ipset.max4 <= 24 {
|
|
|
+ for key, b := range views.buckets4 {
|
|
|
+ if b == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if m.matchAddr(b.rep) {
|
|
|
+ views.buckets4[key] = nil
|
|
|
+ matched = append(matched, b.ips...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for ipx, ip := range views.precise4 {
|
|
|
+ if ip == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ views.precise4[ipx] = nil
|
|
|
+ matched = append(matched, ip)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if m.ipset.max6 <= 64 {
|
|
|
+ for key, b := range views.buckets6 {
|
|
|
+ if b == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if m.matchAddr(b.rep) {
|
|
|
+ views.buckets6[key] = nil
|
|
|
+ matched = append(matched, b.ips...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for ipx, ip := range views.precise6 {
|
|
|
+ if ip == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if m.matchAddr(ipx) {
|
|
|
+ views.precise6[ipx] = nil
|
|
|
+ matched = append(matched, ip)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ unmatched = make([]net.IP, 0, n-len(matched))
|
|
|
+ if views.buckets4 != nil {
|
|
|
+ for _, b := range views.buckets4 {
|
|
|
+ if b == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ unmatched = append(unmatched, b.ips...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if views.precise4 != nil {
|
|
|
+ for _, ip := range views.precise4 {
|
|
|
+ if ip == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ unmatched = append(unmatched, ip)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if views.buckets6 != nil {
|
|
|
+ for _, b := range views.buckets6 {
|
|
|
+ if b == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ unmatched = append(unmatched, b.ips...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if views.precise6 != nil {
|
|
|
+ for _, ip := range views.precise6 {
|
|
|
+ if ip == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ unmatched = append(unmatched, ip)
|
|
|
+ }
|
|
|
}
|
|
|
- return m, nil
|
|
|
+
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+type ipBucketViews struct {
|
|
|
+ buckets4, buckets6 map[[9]byte]*ipBucket
|
|
|
+ precise4, precise6 map[netip.Addr]net.IP
|
|
|
}
|
|
|
|
|
|
-var GlobalGeoIPContainer GeoIPMatcherContainer
|
|
|
+func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) {
|
|
|
+ needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
|
|
|
+ needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
|
|
|
+ needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
|
|
|
+ needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
|
|
|
|
|
|
-func MatchIPs(matchers []*GeoIPMatcher, ips []net.IP, reverse bool) []net.IP {
|
|
|
- if len(matchers) == 0 {
|
|
|
- panic("GeoIP matchers should not be empty to avoid ambiguity")
|
|
|
+ if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
|
|
|
+ return
|
|
|
}
|
|
|
- newIPs := make([]net.IP, 0, len(ips))
|
|
|
- var isFound bool
|
|
|
+
|
|
|
+ if needHeur4 {
|
|
|
+ v.buckets4 = make(map[[9]byte]*ipBucket, len(ips))
|
|
|
+ }
|
|
|
+ if needHeur6 {
|
|
|
+ v.buckets6 = make(map[[9]byte]*ipBucket, len(ips))
|
|
|
+ }
|
|
|
+ if needPrec4 {
|
|
|
+ v.precise4 = make(map[netip.Addr]net.IP, len(ips))
|
|
|
+ }
|
|
|
+ if needPrec6 {
|
|
|
+ v.precise6 = make(map[netip.Addr]net.IP, len(ips))
|
|
|
+ }
|
|
|
+
|
|
|
for _, ip := range ips {
|
|
|
- isFound = false
|
|
|
- for _, matcher := range matchers {
|
|
|
- if matcher.Match(ip) {
|
|
|
- isFound = true
|
|
|
- break
|
|
|
+ key, ok := prefixKeyFromIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+
|
|
|
+ switch key[0] {
|
|
|
+ case 4:
|
|
|
+ var ipx netip.Addr
|
|
|
+ if needHeur4 {
|
|
|
+ b, exists := v.buckets4[key]
|
|
|
+ if !exists {
|
|
|
+ // build bucket
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+ b = &ipBucket{
|
|
|
+ rep: ipx,
|
|
|
+ ips: make([]net.IP, 0, 4), // for dns answer
|
|
|
+ }
|
|
|
+ v.buckets4[key] = b
|
|
|
+ }
|
|
|
+ b.ips = append(b.ips, ip)
|
|
|
+ }
|
|
|
+ if needPrec4 {
|
|
|
+ if !ipx.IsValid() {
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+ }
|
|
|
+ v.precise4[ipx] = ip
|
|
|
+ }
|
|
|
+ case 6:
|
|
|
+ var ipx netip.Addr
|
|
|
+ if needHeur6 {
|
|
|
+ b, exists := v.buckets6[key]
|
|
|
+ if !exists {
|
|
|
+ // build bucket
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+ b = &ipBucket{
|
|
|
+ rep: ipx,
|
|
|
+ ips: make([]net.IP, 0, 4), // for dns answer
|
|
|
+ }
|
|
|
+ v.buckets6[key] = b
|
|
|
+ }
|
|
|
+ b.ips = append(b.ips, ip)
|
|
|
+ }
|
|
|
+ if needPrec6 {
|
|
|
+ if !ipx.IsValid() {
|
|
|
+ ipx, ok = netipx.FromStdIP(ip)
|
|
|
+ if !ok {
|
|
|
+ continue // illegal ip, ignore
|
|
|
+ }
|
|
|
+ }
|
|
|
+ v.precise6[ipx] = ip
|
|
|
}
|
|
|
}
|
|
|
- if isFound && !reverse {
|
|
|
- newIPs = append(newIPs, ip)
|
|
|
- continue
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ToggleReverse implements GeoIPMatcher.
|
|
|
+func (mm *HeuristicMultiGeoIPMatcher) ToggleReverse() {
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ m.ToggleReverse()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// SetReverse implements GeoIPMatcher.
|
|
|
+func (mm *HeuristicMultiGeoIPMatcher) SetReverse(reverse bool) {
|
|
|
+ for _, m := range mm.matchers {
|
|
|
+ m.SetReverse(reverse)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type GeoIPSetFactory struct {
|
|
|
+ sync.Mutex
|
|
|
+ shared map[string]*GeoIPSet // TODO: cleanup
|
|
|
+}
|
|
|
+
|
|
|
+var ipsetFactory = GeoIPSetFactory{shared: make(map[string]*GeoIPSet)}
|
|
|
+
|
|
|
+func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPSet, error) {
|
|
|
+ f.Lock()
|
|
|
+ defer f.Unlock()
|
|
|
+
|
|
|
+ if ipset := f.shared[key]; ipset != nil {
|
|
|
+ return ipset, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ ipset, err := f.Create(cidrGroups...)
|
|
|
+ if err == nil {
|
|
|
+ f.shared[key] = ipset
|
|
|
+ }
|
|
|
+ return ipset, err
|
|
|
+}
|
|
|
+
|
|
|
+func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) {
|
|
|
+ var ipv4Builder, ipv6Builder netipx.IPSetBuilder
|
|
|
+
|
|
|
+ for _, cidrGroup := range cidrGroups {
|
|
|
+ for _, cidrEntry := range cidrGroup {
|
|
|
+ ipBytes := cidrEntry.GetIp()
|
|
|
+ prefixLen := int(cidrEntry.GetPrefix())
|
|
|
+
|
|
|
+ addr, ok := netip.AddrFromSlice(ipBytes)
|
|
|
+ if !ok {
|
|
|
+ errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ prefix := netip.PrefixFrom(addr, prefixLen)
|
|
|
+ if !prefix.IsValid() {
|
|
|
+ errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if addr.Is4() {
|
|
|
+ ipv4Builder.AddPrefix(prefix)
|
|
|
+ } else if addr.Is6() {
|
|
|
+ ipv6Builder.AddPrefix(prefix)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ ipv4, err := ipv4Builder.IPSet()
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.New("failed to build IPv4 set").Base(err)
|
|
|
+ }
|
|
|
+ ipv6, err := ipv6Builder.IPSet()
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.New("failed to build IPv6 set").Base(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var max4, max6 int
|
|
|
+
|
|
|
+ for _, p := range ipv4.Prefixes() {
|
|
|
+ if b := p.Bits(); b > max4 {
|
|
|
+ max4 = b
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for _, p := range ipv6.Prefixes() {
|
|
|
+ if b := p.Bits(); b > max6 {
|
|
|
+ max6 = b
|
|
|
}
|
|
|
- if !isFound && reverse {
|
|
|
- newIPs = append(newIPs, ip)
|
|
|
+ }
|
|
|
+
|
|
|
+ if max4 == 0 {
|
|
|
+ max4 = 0xff
|
|
|
+ }
|
|
|
+ if max6 == 0 {
|
|
|
+ max6 = 0xff
|
|
|
+ }
|
|
|
+
|
|
|
+ return &GeoIPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil
|
|
|
+}
|
|
|
+
|
|
|
+func BuildOptimizedGeoIPMatcher(geoips ...*GeoIP) (GeoIPMatcher, error) {
|
|
|
+ n := len(geoips)
|
|
|
+ if n == 0 {
|
|
|
+ return nil, errors.New("no geoip configs provided")
|
|
|
+ }
|
|
|
+
|
|
|
+ var subs []*HeuristicGeoIPMatcher
|
|
|
+ pos := make([]*GeoIP, 0, n)
|
|
|
+ neg := make([]*GeoIP, 0, n/2)
|
|
|
+
|
|
|
+ for _, geoip := range geoips {
|
|
|
+ if geoip == nil {
|
|
|
+ return nil, errors.New("geoip entry is nil")
|
|
|
+ }
|
|
|
+ if geoip.CountryCode == "" {
|
|
|
+ ipset, err := ipsetFactory.Create(geoip.Cidr)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: geoip.ReverseMatch})
|
|
|
continue
|
|
|
}
|
|
|
+ if !geoip.ReverseMatch {
|
|
|
+ pos = append(pos, geoip)
|
|
|
+ } else {
|
|
|
+ neg = append(neg, geoip)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ buildIPSet := func(mergeables []*GeoIP) (*GeoIPSet, error) {
|
|
|
+ n := len(mergeables)
|
|
|
+ if n == 0 {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ sort.Slice(mergeables, func(i, j int) bool {
|
|
|
+ gi, gj := mergeables[i], mergeables[j]
|
|
|
+ return gi.CountryCode < gj.CountryCode
|
|
|
+ })
|
|
|
+
|
|
|
+ var sb strings.Builder
|
|
|
+ sb.Grow(n * 3) // xx,
|
|
|
+ cidrGroups := make([][]*CIDR, 0, n)
|
|
|
+ var last *GeoIP
|
|
|
+ for i, geoip := range mergeables {
|
|
|
+ if i == 0 || (geoip.CountryCode != last.CountryCode) {
|
|
|
+ last = geoip
|
|
|
+ sb.WriteString(geoip.CountryCode)
|
|
|
+ sb.WriteString(",")
|
|
|
+ cidrGroups = append(cidrGroups, geoip.Cidr)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return ipsetFactory.GetOrCreate(sb.String(), cidrGroups)
|
|
|
+ }
|
|
|
+
|
|
|
+ ipset, err := buildIPSet(pos)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if ipset != nil {
|
|
|
+ subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: false})
|
|
|
+ }
|
|
|
+
|
|
|
+ ipset, err = buildIPSet(neg)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if ipset != nil {
|
|
|
+ subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: true})
|
|
|
+ }
|
|
|
+
|
|
|
+ switch len(subs) {
|
|
|
+ case 0:
|
|
|
+ return nil, errors.New("no valid geoip matcher")
|
|
|
+ case 1:
|
|
|
+ return subs[0], nil
|
|
|
+ default:
|
|
|
+ return &HeuristicMultiGeoIPMatcher{matchers: subs}, nil
|
|
|
}
|
|
|
- return newIPs
|
|
|
}
|