ipset.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Package ipset provides code for creating efficient IP-in-set lookup functions
  4. // with different implementations depending on the set.
  5. package ipset
  6. import (
  7. "net/netip"
  8. "github.com/gaissmai/bart"
  9. "tailscale.com/types/views"
  10. "tailscale.com/util/set"
  11. )
  12. // FalseContainsIPFunc is shorthand for NewContainsIPFunc(views.Slice[netip.Prefix]{}).
  13. func FalseContainsIPFunc() func(ip netip.Addr) bool {
  14. return emptySet
  15. }
  16. func emptySet(ip netip.Addr) bool { return false }
  17. func bartLookup(t *bart.Table[struct{}]) func(netip.Addr) bool {
  18. return func(ip netip.Addr) bool {
  19. _, ok := t.Lookup(ip)
  20. return ok
  21. }
  22. }
  23. func prefixContainsLoop(addrs []netip.Prefix) func(netip.Addr) bool {
  24. return func(ip netip.Addr) bool {
  25. for _, p := range addrs {
  26. if p.Contains(ip) {
  27. return true
  28. }
  29. }
  30. return false
  31. }
  32. }
  33. func oneIP(ip1 netip.Addr) func(netip.Addr) bool {
  34. return func(ip netip.Addr) bool { return ip == ip1 }
  35. }
  36. func twoIP(ip1, ip2 netip.Addr) func(netip.Addr) bool {
  37. return func(ip netip.Addr) bool { return ip == ip1 || ip == ip2 }
  38. }
  39. func ipInMap(m set.Set[netip.Addr]) func(netip.Addr) bool {
  40. return func(ip netip.Addr) bool {
  41. _, ok := m[ip]
  42. return ok
  43. }
  44. }
  45. // pathForTest is a test hook for NewContainsIPFunc, to test that it took the
  46. // right construction path.
  47. var pathForTest = func(string) {}
  48. // NewContainsIPFunc returns a func that reports whether ip is in addrs.
  49. //
  50. // The returned func is optimized for the length of contents of addrs.
  51. func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool {
  52. // Specialize the three common cases: no address, just IPv4
  53. // (or just IPv6), and both IPv4 and IPv6.
  54. if addrs.Len() == 0 {
  55. pathForTest("empty")
  56. return emptySet
  57. }
  58. // If any addr is a prefix with more than a single IP, then do either a
  59. // linear scan or a bart table, depending on the number of addrs.
  60. if addrs.ContainsFunc(func(p netip.Prefix) bool { return !p.IsSingleIP() }) {
  61. if addrs.Len() == 1 {
  62. pathForTest("one-prefix")
  63. return addrs.At(0).Contains
  64. }
  65. if addrs.Len() <= 6 {
  66. // Small enough to do a linear search.
  67. pathForTest("linear-contains")
  68. return prefixContainsLoop(addrs.AsSlice())
  69. }
  70. pathForTest("bart")
  71. // Built a bart table.
  72. t := &bart.Table[struct{}]{}
  73. for _, p := range addrs.All() {
  74. t.Insert(p, struct{}{})
  75. }
  76. return bartLookup(t)
  77. }
  78. // Fast paths for 1 and 2 IPs:
  79. if addrs.Len() == 1 {
  80. pathForTest("one-ip")
  81. return oneIP(addrs.At(0).Addr())
  82. }
  83. if addrs.Len() == 2 {
  84. pathForTest("two-ip")
  85. return twoIP(addrs.At(0).Addr(), addrs.At(1).Addr())
  86. }
  87. // General case:
  88. pathForTest("ip-map")
  89. m := set.Set[netip.Addr]{}
  90. for _, p := range addrs.All() {
  91. m.Add(p.Addr())
  92. }
  93. return ipInMap(m)
  94. }