condition_geoip.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961
  1. package router
  2. import (
  3. "context"
  4. "net/netip"
  5. "sort"
  6. "strings"
  7. "sync"
  8. "github.com/xtls/xray-core/common/errors"
  9. "github.com/xtls/xray-core/common/net"
  10. "go4.org/netipx"
  11. )
  12. type GeoIPMatcher interface {
  13. // TODO: (PERF) all net.IP -> netipx.Addr
  14. // Invalid IP always return false.
  15. Match(ip net.IP) bool
  16. // Returns true if *any* IP is valid and match.
  17. AnyMatch(ips []net.IP) bool
  18. // Returns true only if *all* IPs are valid and match. Any invalid IP, or non-matching valid IP, causes false.
  19. Matches(ips []net.IP) bool
  20. // Filters IPs. Invalid IPs are silently dropped and not included in either result.
  21. FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP)
  22. ToggleReverse()
  23. SetReverse(reverse bool)
  24. }
  25. type GeoIPSet struct {
  26. ipv4, ipv6 *netipx.IPSet
  27. max4, max6 uint8
  28. }
  29. type HeuristicGeoIPMatcher struct {
  30. ipset *GeoIPSet
  31. reverse bool
  32. }
  33. type ipBucket struct {
  34. rep netip.Addr
  35. ips []net.IP
  36. }
  37. // Match implements GeoIPMatcher.
  38. func (m *HeuristicGeoIPMatcher) Match(ip net.IP) bool {
  39. ipx, ok := netipx.FromStdIP(ip)
  40. if !ok {
  41. return false
  42. }
  43. return m.matchAddr(ipx)
  44. }
  45. func (m *HeuristicGeoIPMatcher) matchAddr(ipx netip.Addr) bool {
  46. if ipx.Is4() {
  47. return m.ipset.ipv4.Contains(ipx) != m.reverse
  48. }
  49. if ipx.Is6() {
  50. return m.ipset.ipv6.Contains(ipx) != m.reverse
  51. }
  52. return false
  53. }
  54. // AnyMatch implements GeoIPMatcher.
  55. func (m *HeuristicGeoIPMatcher) AnyMatch(ips []net.IP) bool {
  56. n := len(ips)
  57. if n == 0 {
  58. return false
  59. }
  60. if n == 1 {
  61. return m.Match(ips[0])
  62. }
  63. heur4 := m.ipset.max4 <= 24
  64. heur6 := m.ipset.max6 <= 64
  65. if !heur4 && !heur6 {
  66. for _, ip := range ips {
  67. if ipx, ok := netipx.FromStdIP(ip); ok {
  68. if m.matchAddr(ipx) {
  69. return true
  70. }
  71. }
  72. }
  73. return false
  74. }
  75. buckets := make(map[[9]byte]struct{}, n)
  76. for _, ip := range ips {
  77. key, ok := prefixKeyFromIP(ip)
  78. if !ok {
  79. continue
  80. }
  81. heur := (key[0] == 4 && heur4) || (key[0] == 6 && heur6)
  82. if heur {
  83. if _, exists := buckets[key]; exists {
  84. continue
  85. }
  86. }
  87. ipx, ok := netipx.FromStdIP(ip)
  88. if !ok {
  89. continue
  90. }
  91. if m.matchAddr(ipx) {
  92. return true
  93. }
  94. if heur {
  95. buckets[key] = struct{}{}
  96. }
  97. }
  98. return false
  99. }
  100. // Matches implements GeoIPMatcher.
  101. func (m *HeuristicGeoIPMatcher) Matches(ips []net.IP) bool {
  102. n := len(ips)
  103. if n == 0 {
  104. return false
  105. }
  106. if n == 1 {
  107. return m.Match(ips[0])
  108. }
  109. heur4 := m.ipset.max4 <= 24
  110. heur6 := m.ipset.max6 <= 64
  111. if !heur4 && !heur6 {
  112. for _, ip := range ips {
  113. ipx, ok := netipx.FromStdIP(ip)
  114. if !ok {
  115. return false
  116. }
  117. if !m.matchAddr(ipx) {
  118. return false
  119. }
  120. }
  121. return true
  122. }
  123. buckets := make(map[[9]byte]netip.Addr, n)
  124. precise := make([]netip.Addr, 0, n)
  125. for _, ip := range ips {
  126. key, ok := prefixKeyFromIP(ip)
  127. if !ok {
  128. return false
  129. }
  130. if (key[0] == 4 && heur4) || (key[0] == 6 && heur6) {
  131. if _, exists := buckets[key]; !exists {
  132. ipx, ok := netipx.FromStdIP(ip)
  133. if !ok {
  134. return false
  135. }
  136. buckets[key] = ipx
  137. }
  138. } else {
  139. ipx, ok := netipx.FromStdIP(ip)
  140. if !ok {
  141. return false
  142. }
  143. precise = append(precise, ipx)
  144. }
  145. }
  146. for _, ipx := range buckets {
  147. if !m.matchAddr(ipx) {
  148. return false
  149. }
  150. }
  151. for _, ipx := range precise {
  152. if !m.matchAddr(ipx) {
  153. return false
  154. }
  155. }
  156. return true
  157. }
  158. func prefixKeyFromIP(ip net.IP) (key [9]byte, ok bool) {
  159. if ip4 := ip.To4(); ip4 != nil {
  160. key[0] = 4
  161. key[1] = ip4[0]
  162. key[2] = ip4[1]
  163. key[3] = ip4[2] // /24
  164. return key, true
  165. }
  166. if ip16 := ip.To16(); ip16 != nil {
  167. key[0] = 6
  168. key[1] = ip16[0]
  169. key[2] = ip16[1]
  170. key[3] = ip16[2]
  171. key[4] = ip16[3]
  172. key[5] = ip16[4]
  173. key[6] = ip16[5]
  174. key[7] = ip16[6]
  175. key[8] = ip16[7] // /64
  176. return key, true
  177. }
  178. return key, false // illegal
  179. }
  180. // FilterIPs implements GeoIPMatcher.
  181. func (m *HeuristicGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
  182. n := len(ips)
  183. if n == 0 {
  184. return []net.IP{}, []net.IP{}
  185. }
  186. if n == 1 {
  187. ipx, ok := netipx.FromStdIP(ips[0])
  188. if !ok {
  189. return []net.IP{}, []net.IP{}
  190. }
  191. if m.matchAddr(ipx) {
  192. return ips, []net.IP{}
  193. }
  194. return []net.IP{}, ips
  195. }
  196. heur4 := m.ipset.max4 <= 24
  197. heur6 := m.ipset.max6 <= 64
  198. if !heur4 && !heur6 {
  199. matched = make([]net.IP, 0, n)
  200. unmatched = make([]net.IP, 0, n)
  201. for _, ip := range ips {
  202. ipx, ok := netipx.FromStdIP(ip)
  203. if !ok {
  204. continue // illegal ip, ignore
  205. }
  206. if m.matchAddr(ipx) {
  207. matched = append(matched, ip)
  208. } else {
  209. unmatched = append(unmatched, ip)
  210. }
  211. }
  212. return
  213. }
  214. buckets := make(map[[9]byte]*ipBucket, n)
  215. precise := make([]net.IP, 0, n)
  216. for _, ip := range ips {
  217. key, ok := prefixKeyFromIP(ip)
  218. if !ok {
  219. continue // illegal ip, ignore
  220. }
  221. if (key[0] == 4 && !heur4) || (key[0] == 6 && !heur6) {
  222. precise = append(precise, ip)
  223. continue
  224. }
  225. b, exists := buckets[key]
  226. if !exists {
  227. // build bucket
  228. ipx, ok := netipx.FromStdIP(ip)
  229. if !ok {
  230. continue // illegal ip, ignore
  231. }
  232. b = &ipBucket{
  233. rep: ipx,
  234. ips: make([]net.IP, 0, 4), // for dns answer
  235. }
  236. buckets[key] = b
  237. }
  238. b.ips = append(b.ips, ip)
  239. }
  240. matched = make([]net.IP, 0, n)
  241. unmatched = make([]net.IP, 0, n)
  242. for _, b := range buckets {
  243. if m.matchAddr(b.rep) {
  244. matched = append(matched, b.ips...)
  245. } else {
  246. unmatched = append(unmatched, b.ips...)
  247. }
  248. }
  249. for _, ip := range precise {
  250. ipx, ok := netipx.FromStdIP(ip)
  251. if !ok {
  252. continue // illegal ip, ignore
  253. }
  254. if m.matchAddr(ipx) {
  255. matched = append(matched, ip)
  256. } else {
  257. unmatched = append(unmatched, ip)
  258. }
  259. }
  260. return
  261. }
  262. // ToggleReverse implements GeoIPMatcher.
  263. func (m *HeuristicGeoIPMatcher) ToggleReverse() {
  264. m.reverse = !m.reverse
  265. }
  266. // SetReverse implements GeoIPMatcher.
  267. func (m *HeuristicGeoIPMatcher) SetReverse(reverse bool) {
  268. m.reverse = reverse
  269. }
  270. type GeneralMultiGeoIPMatcher struct {
  271. matchers []GeoIPMatcher
  272. }
  273. // Match implements GeoIPMatcher.
  274. func (mm *GeneralMultiGeoIPMatcher) Match(ip net.IP) bool {
  275. for _, m := range mm.matchers {
  276. if m.Match(ip) {
  277. return true
  278. }
  279. }
  280. return false
  281. }
  282. // AnyMatch implements GeoIPMatcher.
  283. func (mm *GeneralMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
  284. for _, m := range mm.matchers {
  285. if m.AnyMatch(ips) {
  286. return true
  287. }
  288. }
  289. return false
  290. }
  291. // Matches implements GeoIPMatcher.
  292. func (mm *GeneralMultiGeoIPMatcher) Matches(ips []net.IP) bool {
  293. for _, m := range mm.matchers {
  294. if m.Matches(ips) {
  295. return true
  296. }
  297. }
  298. return false
  299. }
  300. // FilterIPs implements GeoIPMatcher.
  301. func (mm *GeneralMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
  302. matched = make([]net.IP, 0, len(ips))
  303. unmatched = ips
  304. for _, m := range mm.matchers {
  305. if len(unmatched) == 0 {
  306. break
  307. }
  308. var mtch []net.IP
  309. mtch, unmatched = m.FilterIPs(unmatched)
  310. if len(mtch) > 0 {
  311. matched = append(matched, mtch...)
  312. }
  313. }
  314. return
  315. }
  316. // ToggleReverse implements GeoIPMatcher.
  317. func (mm *GeneralMultiGeoIPMatcher) ToggleReverse() {
  318. for _, m := range mm.matchers {
  319. m.ToggleReverse()
  320. }
  321. }
  322. // SetReverse implements GeoIPMatcher.
  323. func (mm *GeneralMultiGeoIPMatcher) SetReverse(reverse bool) {
  324. for _, m := range mm.matchers {
  325. m.SetReverse(reverse)
  326. }
  327. }
  328. type HeuristicMultiGeoIPMatcher struct {
  329. matchers []*HeuristicGeoIPMatcher
  330. }
  331. // Match implements GeoIPMatcher.
  332. func (mm *HeuristicMultiGeoIPMatcher) Match(ip net.IP) bool {
  333. ipx, ok := netipx.FromStdIP(ip)
  334. if !ok {
  335. return false
  336. }
  337. for _, m := range mm.matchers {
  338. if m.matchAddr(ipx) {
  339. return true
  340. }
  341. }
  342. return false
  343. }
  344. // AnyMatch implements GeoIPMatcher.
  345. func (mm *HeuristicMultiGeoIPMatcher) AnyMatch(ips []net.IP) bool {
  346. n := len(ips)
  347. if n == 0 {
  348. return false
  349. }
  350. if n == 1 {
  351. return mm.Match(ips[0])
  352. }
  353. buckets := make(map[[9]byte]struct{}, n)
  354. for _, ip := range ips {
  355. var ipx netip.Addr
  356. state := uint8(0) // 0 = Not initialized, 1 = Initialized, 4 = IPv4 can be skipped, 6 = IPv6 can be skipped
  357. for _, m := range mm.matchers {
  358. heur4 := m.ipset.max4 <= 24
  359. heur6 := m.ipset.max6 <= 64
  360. if state == 0 && (heur4 || heur6) {
  361. key, ok := prefixKeyFromIP(ip)
  362. if !ok {
  363. break
  364. }
  365. if _, exists := buckets[key]; exists {
  366. state = key[0]
  367. } else {
  368. buckets[key] = struct{}{}
  369. state = 1
  370. }
  371. }
  372. if (heur4 && state == 4) || (heur6 && state == 6) {
  373. continue
  374. }
  375. if !ipx.IsValid() {
  376. nipx, ok := netipx.FromStdIP(ip)
  377. if !ok {
  378. break
  379. }
  380. ipx = nipx
  381. }
  382. if m.matchAddr(ipx) {
  383. return true
  384. }
  385. }
  386. }
  387. return false
  388. }
  389. // Matches implements GeoIPMatcher.
  390. func (mm *HeuristicMultiGeoIPMatcher) Matches(ips []net.IP) bool {
  391. n := len(ips)
  392. if n == 0 {
  393. return false
  394. }
  395. if n == 1 {
  396. return mm.Match(ips[0])
  397. }
  398. var views ipViews
  399. for _, m := range mm.matchers {
  400. if !views.ensureForMatcher(m, ips) {
  401. return false
  402. }
  403. matched := true
  404. if m.ipset.max4 <= 24 {
  405. for _, ipx := range views.buckets4 {
  406. if !m.matchAddr(ipx) {
  407. matched = false
  408. break
  409. }
  410. }
  411. } else {
  412. for _, ipx := range views.precise4 {
  413. if !m.matchAddr(ipx) {
  414. matched = false
  415. break
  416. }
  417. }
  418. }
  419. if !matched {
  420. continue
  421. }
  422. if m.ipset.max6 <= 64 {
  423. for _, ipx := range views.buckets6 {
  424. if !m.matchAddr(ipx) {
  425. matched = false
  426. break
  427. }
  428. }
  429. } else {
  430. for _, ipx := range views.precise6 {
  431. if !m.matchAddr(ipx) {
  432. matched = false
  433. break
  434. }
  435. }
  436. }
  437. if matched {
  438. return true
  439. }
  440. }
  441. return false
  442. }
  443. type ipViews struct {
  444. buckets4, buckets6 map[[9]byte]netip.Addr
  445. precise4, precise6 []netip.Addr
  446. }
  447. func (v *ipViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) bool {
  448. needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
  449. needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
  450. needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
  451. needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
  452. if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
  453. return true
  454. }
  455. if needHeur4 {
  456. v.buckets4 = make(map[[9]byte]netip.Addr, len(ips))
  457. }
  458. if needHeur6 {
  459. v.buckets6 = make(map[[9]byte]netip.Addr, len(ips))
  460. }
  461. if needPrec4 {
  462. v.precise4 = make([]netip.Addr, 0, len(ips))
  463. }
  464. if needPrec6 {
  465. v.precise6 = make([]netip.Addr, 0, len(ips))
  466. }
  467. for _, ip := range ips {
  468. key, ok := prefixKeyFromIP(ip)
  469. if !ok {
  470. return false
  471. }
  472. switch key[0] {
  473. case 4:
  474. var ipx netip.Addr
  475. if needHeur4 {
  476. if _, exists := v.buckets4[key]; !exists {
  477. ipx, ok = netipx.FromStdIP(ip)
  478. if !ok {
  479. return false
  480. }
  481. v.buckets4[key] = ipx
  482. }
  483. }
  484. if needPrec4 {
  485. if !ipx.IsValid() {
  486. ipx, ok = netipx.FromStdIP(ip)
  487. if !ok {
  488. return false
  489. }
  490. }
  491. v.precise4 = append(v.precise4, ipx)
  492. }
  493. case 6:
  494. var ipx netip.Addr
  495. if needHeur6 {
  496. if _, exists := v.buckets6[key]; !exists {
  497. ipx, ok = netipx.FromStdIP(ip)
  498. if !ok {
  499. return false
  500. }
  501. v.buckets6[key] = ipx
  502. }
  503. }
  504. if needPrec6 {
  505. if !ipx.IsValid() {
  506. ipx, ok = netipx.FromStdIP(ip)
  507. if !ok {
  508. return false
  509. }
  510. }
  511. v.precise6 = append(v.precise6, ipx)
  512. }
  513. default:
  514. return false
  515. }
  516. }
  517. return true
  518. }
  519. // FilterIPs implements GeoIPMatcher.
  520. func (mm *HeuristicMultiGeoIPMatcher) FilterIPs(ips []net.IP) (matched []net.IP, unmatched []net.IP) {
  521. n := len(ips)
  522. if n == 0 {
  523. return []net.IP{}, []net.IP{}
  524. }
  525. if n == 1 {
  526. ipx, ok := netipx.FromStdIP(ips[0])
  527. if !ok {
  528. return []net.IP{}, []net.IP{}
  529. }
  530. for _, m := range mm.matchers {
  531. if m.matchAddr(ipx) {
  532. return ips, []net.IP{}
  533. }
  534. }
  535. return []net.IP{}, ips
  536. }
  537. var views ipBucketViews
  538. matched = make([]net.IP, 0, n)
  539. for _, m := range mm.matchers {
  540. views.ensureForMatcher(m, ips)
  541. if m.ipset.max4 <= 24 {
  542. for key, b := range views.buckets4 {
  543. if b == nil {
  544. continue
  545. }
  546. if m.matchAddr(b.rep) {
  547. views.buckets4[key] = nil
  548. matched = append(matched, b.ips...)
  549. }
  550. }
  551. } else {
  552. for ipx, ip := range views.precise4 {
  553. if ip == nil {
  554. continue
  555. }
  556. if m.matchAddr(ipx) {
  557. views.precise4[ipx] = nil
  558. matched = append(matched, ip)
  559. }
  560. }
  561. }
  562. if m.ipset.max6 <= 64 {
  563. for key, b := range views.buckets6 {
  564. if b == nil {
  565. continue
  566. }
  567. if m.matchAddr(b.rep) {
  568. views.buckets6[key] = nil
  569. matched = append(matched, b.ips...)
  570. }
  571. }
  572. } else {
  573. for ipx, ip := range views.precise6 {
  574. if ip == nil {
  575. continue
  576. }
  577. if m.matchAddr(ipx) {
  578. views.precise6[ipx] = nil
  579. matched = append(matched, ip)
  580. }
  581. }
  582. }
  583. }
  584. unmatched = make([]net.IP, 0, n-len(matched))
  585. if views.buckets4 != nil {
  586. for _, b := range views.buckets4 {
  587. if b == nil {
  588. continue
  589. }
  590. unmatched = append(unmatched, b.ips...)
  591. }
  592. }
  593. if views.precise4 != nil {
  594. for _, ip := range views.precise4 {
  595. if ip == nil {
  596. continue
  597. }
  598. unmatched = append(unmatched, ip)
  599. }
  600. }
  601. if views.buckets6 != nil {
  602. for _, b := range views.buckets6 {
  603. if b == nil {
  604. continue
  605. }
  606. unmatched = append(unmatched, b.ips...)
  607. }
  608. }
  609. if views.precise6 != nil {
  610. for _, ip := range views.precise6 {
  611. if ip == nil {
  612. continue
  613. }
  614. unmatched = append(unmatched, ip)
  615. }
  616. }
  617. return
  618. }
  619. type ipBucketViews struct {
  620. buckets4, buckets6 map[[9]byte]*ipBucket
  621. precise4, precise6 map[netip.Addr]net.IP
  622. }
  623. func (v *ipBucketViews) ensureForMatcher(m *HeuristicGeoIPMatcher, ips []net.IP) {
  624. needHeur4 := m.ipset.max4 <= 24 && v.buckets4 == nil
  625. needHeur6 := m.ipset.max6 <= 64 && v.buckets6 == nil
  626. needPrec4 := m.ipset.max4 > 24 && v.precise4 == nil
  627. needPrec6 := m.ipset.max6 > 64 && v.precise6 == nil
  628. if !needHeur4 && !needHeur6 && !needPrec4 && !needPrec6 {
  629. return
  630. }
  631. if needHeur4 {
  632. v.buckets4 = make(map[[9]byte]*ipBucket, len(ips))
  633. }
  634. if needHeur6 {
  635. v.buckets6 = make(map[[9]byte]*ipBucket, len(ips))
  636. }
  637. if needPrec4 {
  638. v.precise4 = make(map[netip.Addr]net.IP, len(ips))
  639. }
  640. if needPrec6 {
  641. v.precise6 = make(map[netip.Addr]net.IP, len(ips))
  642. }
  643. for _, ip := range ips {
  644. key, ok := prefixKeyFromIP(ip)
  645. if !ok {
  646. continue // illegal ip, ignore
  647. }
  648. switch key[0] {
  649. case 4:
  650. var ipx netip.Addr
  651. if needHeur4 {
  652. b, exists := v.buckets4[key]
  653. if !exists {
  654. // build bucket
  655. ipx, ok = netipx.FromStdIP(ip)
  656. if !ok {
  657. continue // illegal ip, ignore
  658. }
  659. b = &ipBucket{
  660. rep: ipx,
  661. ips: make([]net.IP, 0, 4), // for dns answer
  662. }
  663. v.buckets4[key] = b
  664. }
  665. b.ips = append(b.ips, ip)
  666. }
  667. if needPrec4 {
  668. if !ipx.IsValid() {
  669. ipx, ok = netipx.FromStdIP(ip)
  670. if !ok {
  671. continue // illegal ip, ignore
  672. }
  673. }
  674. v.precise4[ipx] = ip
  675. }
  676. case 6:
  677. var ipx netip.Addr
  678. if needHeur6 {
  679. b, exists := v.buckets6[key]
  680. if !exists {
  681. // build bucket
  682. ipx, ok = netipx.FromStdIP(ip)
  683. if !ok {
  684. continue // illegal ip, ignore
  685. }
  686. b = &ipBucket{
  687. rep: ipx,
  688. ips: make([]net.IP, 0, 4), // for dns answer
  689. }
  690. v.buckets6[key] = b
  691. }
  692. b.ips = append(b.ips, ip)
  693. }
  694. if needPrec6 {
  695. if !ipx.IsValid() {
  696. ipx, ok = netipx.FromStdIP(ip)
  697. if !ok {
  698. continue // illegal ip, ignore
  699. }
  700. }
  701. v.precise6[ipx] = ip
  702. }
  703. }
  704. }
  705. }
  706. // ToggleReverse implements GeoIPMatcher.
  707. func (mm *HeuristicMultiGeoIPMatcher) ToggleReverse() {
  708. for _, m := range mm.matchers {
  709. m.ToggleReverse()
  710. }
  711. }
  712. // SetReverse implements GeoIPMatcher.
  713. func (mm *HeuristicMultiGeoIPMatcher) SetReverse(reverse bool) {
  714. for _, m := range mm.matchers {
  715. m.SetReverse(reverse)
  716. }
  717. }
  718. type GeoIPSetFactory struct {
  719. sync.Mutex
  720. shared map[string]*GeoIPSet // TODO: cleanup
  721. }
  722. var ipsetFactory = GeoIPSetFactory{shared: make(map[string]*GeoIPSet)}
  723. func (f *GeoIPSetFactory) GetOrCreate(key string, cidrGroups [][]*CIDR) (*GeoIPSet, error) {
  724. f.Lock()
  725. defer f.Unlock()
  726. if ipset := f.shared[key]; ipset != nil {
  727. return ipset, nil
  728. }
  729. ipset, err := f.Create(cidrGroups...)
  730. if err == nil {
  731. f.shared[key] = ipset
  732. }
  733. return ipset, err
  734. }
  735. func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) {
  736. var ipv4Builder, ipv6Builder netipx.IPSetBuilder
  737. for _, cidrGroup := range cidrGroups {
  738. for _, cidrEntry := range cidrGroup {
  739. ipBytes := cidrEntry.GetIp()
  740. prefixLen := int(cidrEntry.GetPrefix())
  741. addr, ok := netip.AddrFromSlice(ipBytes)
  742. if !ok {
  743. errors.LogError(context.Background(), "ignore invalid IP byte slice: ", ipBytes)
  744. continue
  745. }
  746. prefix := netip.PrefixFrom(addr, prefixLen)
  747. if !prefix.IsValid() {
  748. errors.LogError(context.Background(), "ignore created invalid prefix from addr ", addr, " and length ", prefixLen)
  749. continue
  750. }
  751. if addr.Is4() {
  752. ipv4Builder.AddPrefix(prefix)
  753. } else if addr.Is6() {
  754. ipv6Builder.AddPrefix(prefix)
  755. }
  756. }
  757. }
  758. ipv4, err := ipv4Builder.IPSet()
  759. if err != nil {
  760. return nil, errors.New("failed to build IPv4 set").Base(err)
  761. }
  762. ipv6, err := ipv6Builder.IPSet()
  763. if err != nil {
  764. return nil, errors.New("failed to build IPv6 set").Base(err)
  765. }
  766. var max4, max6 int
  767. for _, p := range ipv4.Prefixes() {
  768. if b := p.Bits(); b > max4 {
  769. max4 = b
  770. }
  771. }
  772. for _, p := range ipv6.Prefixes() {
  773. if b := p.Bits(); b > max6 {
  774. max6 = b
  775. }
  776. }
  777. if max4 == 0 {
  778. max4 = 0xff
  779. }
  780. if max6 == 0 {
  781. max6 = 0xff
  782. }
  783. return &GeoIPSet{ipv4: ipv4, ipv6: ipv6, max4: uint8(max4), max6: uint8(max6)}, nil
  784. }
  785. func BuildOptimizedGeoIPMatcher(geoips ...*GeoIP) (GeoIPMatcher, error) {
  786. n := len(geoips)
  787. if n == 0 {
  788. return nil, errors.New("no geoip configs provided")
  789. }
  790. var subs []*HeuristicGeoIPMatcher
  791. pos := make([]*GeoIP, 0, n)
  792. neg := make([]*GeoIP, 0, n/2)
  793. for _, geoip := range geoips {
  794. if geoip == nil {
  795. return nil, errors.New("geoip entry is nil")
  796. }
  797. if geoip.CountryCode == "" {
  798. ipset, err := ipsetFactory.Create(geoip.Cidr)
  799. if err != nil {
  800. return nil, err
  801. }
  802. subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: geoip.ReverseMatch})
  803. continue
  804. }
  805. if !geoip.ReverseMatch {
  806. pos = append(pos, geoip)
  807. } else {
  808. neg = append(neg, geoip)
  809. }
  810. }
  811. buildIPSet := func(mergeables []*GeoIP) (*GeoIPSet, error) {
  812. n := len(mergeables)
  813. if n == 0 {
  814. return nil, nil
  815. }
  816. sort.Slice(mergeables, func(i, j int) bool {
  817. gi, gj := mergeables[i], mergeables[j]
  818. return gi.CountryCode < gj.CountryCode
  819. })
  820. var sb strings.Builder
  821. sb.Grow(n * 3) // xx,
  822. cidrGroups := make([][]*CIDR, 0, n)
  823. var last *GeoIP
  824. for i, geoip := range mergeables {
  825. if i == 0 || (geoip.CountryCode != last.CountryCode) {
  826. last = geoip
  827. sb.WriteString(geoip.CountryCode)
  828. sb.WriteString(",")
  829. cidrGroups = append(cidrGroups, geoip.Cidr)
  830. }
  831. }
  832. return ipsetFactory.GetOrCreate(sb.String(), cidrGroups)
  833. }
  834. ipset, err := buildIPSet(pos)
  835. if err != nil {
  836. return nil, err
  837. }
  838. if ipset != nil {
  839. subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: false})
  840. }
  841. ipset, err = buildIPSet(neg)
  842. if err != nil {
  843. return nil, err
  844. }
  845. if ipset != nil {
  846. subs = append(subs, &HeuristicGeoIPMatcher{ipset: ipset, reverse: true})
  847. }
  848. switch len(subs) {
  849. case 0:
  850. return nil, errors.New("no valid geoip matcher")
  851. case 1:
  852. return subs[0], nil
  853. default:
  854. return &HeuristicMultiGeoIPMatcher{matchers: subs}, nil
  855. }
  856. }