stride_table_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package art
  4. import (
  5. "bytes"
  6. "fmt"
  7. "math/rand"
  8. "net/netip"
  9. "runtime"
  10. "sort"
  11. "strings"
  12. "testing"
  13. "github.com/google/go-cmp/cmp"
  14. )
  15. func TestInversePrefix(t *testing.T) {
  16. t.Parallel()
  17. for i := range 256 {
  18. for len := 0; len < 9; len++ {
  19. addr := i & (0xFF << (8 - len))
  20. idx := prefixIndex(uint8(addr), len)
  21. addr2, len2 := inversePrefixIndex(idx)
  22. if addr2 != uint8(addr) || len2 != len {
  23. t.Errorf("inverse(index(%d/%d)) != %d/%d", addr, len, addr2, len2)
  24. }
  25. }
  26. }
  27. }
  28. func TestHostIndex(t *testing.T) {
  29. t.Parallel()
  30. for i := range 256 {
  31. got := hostIndex(uint8(i))
  32. want := prefixIndex(uint8(i), 8)
  33. if got != want {
  34. t.Errorf("hostIndex(%d) = %d, want %d", i, got, want)
  35. }
  36. }
  37. }
  38. func TestStrideTableInsert(t *testing.T) {
  39. t.Parallel()
  40. // Verify that strideTable's lookup results after a bunch of inserts exactly
  41. // match those of a naive implementation that just scans all prefixes on
  42. // every lookup. The naive implementation is very slow, but its behavior is
  43. // easy to verify by inspection.
  44. pfxs := shufflePrefixes(allPrefixes())[:100]
  45. slow := slowTable[int]{pfxs}
  46. fast := strideTable[int]{}
  47. if debugStrideInsert {
  48. t.Logf("slow table:\n%s", slow.String())
  49. }
  50. for _, pfx := range pfxs {
  51. fast.insert(pfx.addr, pfx.len, pfx.val)
  52. if debugStrideInsert {
  53. t.Logf("after insert %d/%d:\n%s", pfx.addr, pfx.len, fast.tableDebugString())
  54. }
  55. }
  56. for i := range 256 {
  57. addr := uint8(i)
  58. slowVal, slowOK := slow.get(addr)
  59. fastVal, fastOK := fast.get(addr)
  60. if !getsEqual(fastVal, fastOK, slowVal, slowOK) {
  61. t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK)
  62. }
  63. }
  64. }
  65. func TestStrideTableInsertShuffled(t *testing.T) {
  66. t.Parallel()
  67. // The order in which routes are inserted into a route table does not
  68. // influence the final shape of the table, as long as the same set of
  69. // prefixes is being inserted. This test verifies that strideTable behaves
  70. // this way.
  71. //
  72. // In addition to the basic shuffle test, we also check that this behavior
  73. // is maintained if all inserted routes have the same value pointer. This
  74. // shouldn't matter (the strideTable still needs to correctly account for
  75. // each inserted route, regardless of associated value), but during initial
  76. // development a subtle bug made the table corrupt itself in that setup, so
  77. // this test includes a regression test for that.
  78. routes := shufflePrefixes(allPrefixes())[:100]
  79. zero := 0
  80. rt := strideTable[int]{}
  81. // strideTable has a value interface, but internally has to keep
  82. // track of distinct routes even if they all have the same
  83. // value. rtZero uses the same value for all routes, and expects
  84. // correct behavior.
  85. rtZero := strideTable[int]{}
  86. for _, route := range routes {
  87. rt.insert(route.addr, route.len, route.val)
  88. rtZero.insert(route.addr, route.len, zero)
  89. }
  90. // Order of insertion should not affect the final shape of the stride table.
  91. routes2 := append([]slowEntry[int](nil), routes...) // dup so we can print both slices on fail
  92. for range 100 {
  93. rand.Shuffle(len(routes2), func(i, j int) { routes2[i], routes2[j] = routes2[j], routes2[i] })
  94. rt2 := strideTable[int]{}
  95. for _, route := range routes2 {
  96. rt2.insert(route.addr, route.len, route.val)
  97. }
  98. if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString()); diff != "" {
  99. t.Errorf("tables ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2))
  100. }
  101. rtZero2 := strideTable[int]{}
  102. for _, route := range routes2 {
  103. rtZero2.insert(route.addr, route.len, zero)
  104. }
  105. if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" {
  106. t.Errorf("tables with identical vals ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2))
  107. }
  108. }
  109. }
  110. func TestStrideTableDelete(t *testing.T) {
  111. t.Parallel()
  112. // Compare route deletion to our reference slowTable.
  113. pfxs := shufflePrefixes(allPrefixes())[:100]
  114. slow := slowTable[int]{pfxs}
  115. fast := strideTable[int]{}
  116. if debugStrideDelete {
  117. t.Logf("slow table:\n%s", slow.String())
  118. }
  119. for _, pfx := range pfxs {
  120. fast.insert(pfx.addr, pfx.len, pfx.val)
  121. if debugStrideDelete {
  122. t.Logf("after insert %d/%d:\n%s", pfx.addr, pfx.len, fast.tableDebugString())
  123. }
  124. }
  125. toDelete := pfxs[:50]
  126. for _, pfx := range toDelete {
  127. slow.delete(pfx.addr, pfx.len)
  128. fast.delete(pfx.addr, pfx.len)
  129. }
  130. // Sanity check that slowTable seems to have done the right thing.
  131. if cnt := len(slow.prefixes); cnt != 50 {
  132. t.Fatalf("slowTable has %d entries after deletes, want 50", cnt)
  133. }
  134. for i := range 256 {
  135. addr := uint8(i)
  136. slowVal, slowOK := slow.get(addr)
  137. fastVal, fastOK := fast.get(addr)
  138. if !getsEqual(fastVal, fastOK, slowVal, slowOK) {
  139. t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK)
  140. }
  141. }
  142. }
  143. func TestStrideTableDeleteShuffle(t *testing.T) {
  144. t.Parallel()
  145. // Same as TestStrideTableInsertShuffle, the order in which prefixes are
  146. // deleted should not impact the final shape of the route table.
  147. routes := shufflePrefixes(allPrefixes())[:100]
  148. toDelete := routes[:50]
  149. zero := 0
  150. rt := strideTable[int]{}
  151. // strideTable has a value interface, but internally has to keep
  152. // track of distinct routes even if they all have the same
  153. // value. rtZero uses the same value for all routes, and expects
  154. // correct behavior.
  155. rtZero := strideTable[int]{}
  156. for _, route := range routes {
  157. rt.insert(route.addr, route.len, route.val)
  158. rtZero.insert(route.addr, route.len, zero)
  159. }
  160. for _, route := range toDelete {
  161. rt.delete(route.addr, route.len)
  162. rtZero.delete(route.addr, route.len)
  163. }
  164. // Order of deletion should not affect the final shape of the stride table.
  165. toDelete2 := append([]slowEntry[int](nil), toDelete...) // dup so we can print both slices on fail
  166. for range 100 {
  167. rand.Shuffle(len(toDelete2), func(i, j int) { toDelete2[i], toDelete2[j] = toDelete2[j], toDelete2[i] })
  168. rt2 := strideTable[int]{}
  169. for _, route := range routes {
  170. rt2.insert(route.addr, route.len, route.val)
  171. }
  172. for _, route := range toDelete2 {
  173. rt2.delete(route.addr, route.len)
  174. }
  175. if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString(), cmpDiffOpts...); diff != "" {
  176. t.Errorf("tables ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2))
  177. }
  178. rtZero2 := strideTable[int]{}
  179. for _, route := range routes {
  180. rtZero2.insert(route.addr, route.len, zero)
  181. }
  182. for _, route := range toDelete2 {
  183. rtZero2.delete(route.addr, route.len)
  184. }
  185. if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" {
  186. t.Errorf("tables with identical vals ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2))
  187. }
  188. }
  189. }
  190. var strideRouteCount = []int{10, 50, 100, 200}
  191. // forCountAndOrdering runs the benchmark fn with different sets of routes.
  192. //
  193. // fn is called once for each combination of {num_routes, order}, where
  194. // num_routes is the values in strideRouteCount, and order is the order of the
  195. // routes in the list: random, largest prefix first (/0 to /8), and smallest
  196. // prefix first (/8 to /0).
  197. func forStrideCountAndOrdering(b *testing.B, fn func(b *testing.B, routes []slowEntry[int])) {
  198. routes := shufflePrefixes(allPrefixes())
  199. for _, nroutes := range strideRouteCount {
  200. b.Run(fmt.Sprint(nroutes), func(b *testing.B) {
  201. runAndRecord := func(b *testing.B) {
  202. b.ReportAllocs()
  203. var startMem, endMem runtime.MemStats
  204. runtime.ReadMemStats(&startMem)
  205. fn(b, routes)
  206. runtime.ReadMemStats(&endMem)
  207. ops := float64(b.N) * float64(len(routes))
  208. allocs := float64(endMem.Mallocs - startMem.Mallocs)
  209. bytes := float64(endMem.TotalAlloc - startMem.TotalAlloc)
  210. b.ReportMetric(roundFloat64(allocs/ops), "allocs/op")
  211. b.ReportMetric(roundFloat64(bytes/ops), "B/op")
  212. }
  213. routes := append([]slowEntry[int](nil), routes[:nroutes]...)
  214. b.Run("random_order", runAndRecord)
  215. sort.Slice(routes, func(i, j int) bool {
  216. if routes[i].len < routes[j].len {
  217. return true
  218. }
  219. return routes[i].addr < routes[j].addr
  220. })
  221. b.Run("largest_first", runAndRecord)
  222. sort.Slice(routes, func(i, j int) bool {
  223. if routes[j].len < routes[i].len {
  224. return true
  225. }
  226. return routes[j].addr < routes[i].addr
  227. })
  228. b.Run("smallest_first", runAndRecord)
  229. })
  230. }
  231. }
  232. func BenchmarkStrideTableInsertion(b *testing.B) {
  233. forStrideCountAndOrdering(b, func(b *testing.B, routes []slowEntry[int]) {
  234. val := 0
  235. for range b.N {
  236. var rt strideTable[int]
  237. for _, route := range routes {
  238. rt.insert(route.addr, route.len, val)
  239. }
  240. }
  241. inserts := float64(b.N) * float64(len(routes))
  242. elapsed := float64(b.Elapsed().Nanoseconds())
  243. elapsedSec := b.Elapsed().Seconds()
  244. b.ReportMetric(elapsed/inserts, "ns/op")
  245. b.ReportMetric(inserts/elapsedSec, "routes/s")
  246. })
  247. }
  248. func BenchmarkStrideTableDeletion(b *testing.B) {
  249. forStrideCountAndOrdering(b, func(b *testing.B, routes []slowEntry[int]) {
  250. val := 0
  251. var rt strideTable[int]
  252. for _, route := range routes {
  253. rt.insert(route.addr, route.len, val)
  254. }
  255. b.ResetTimer()
  256. for range b.N {
  257. rt2 := rt
  258. for _, route := range routes {
  259. rt2.delete(route.addr, route.len)
  260. }
  261. }
  262. deletes := float64(b.N) * float64(len(routes))
  263. elapsed := float64(b.Elapsed().Nanoseconds())
  264. elapsedSec := b.Elapsed().Seconds()
  265. b.ReportMetric(elapsed/deletes, "ns/op")
  266. b.ReportMetric(deletes/elapsedSec, "routes/s")
  267. })
  268. }
  269. var writeSink int
  270. func BenchmarkStrideTableGet(b *testing.B) {
  271. // No need to forCountAndOrdering here, route lookup time is independent of
  272. // the route count.
  273. routes := shufflePrefixes(allPrefixes())[:100]
  274. var rt strideTable[int]
  275. for _, route := range routes {
  276. rt.insert(route.addr, route.len, route.val)
  277. }
  278. b.ResetTimer()
  279. for i := range b.N {
  280. writeSink, _ = rt.get(uint8(i))
  281. }
  282. gets := float64(b.N)
  283. elapsedSec := b.Elapsed().Seconds()
  284. b.ReportMetric(gets/elapsedSec, "routes/s")
  285. }
  286. // slowTable is an 8-bit routing table implemented as a set of prefixes that are
  287. // explicitly scanned in full for every route lookup. It is very slow, but also
  288. // reasonably easy to verify by inspection, and so a good comparison target for
  289. // strideTable.
  290. type slowTable[T any] struct {
  291. prefixes []slowEntry[T]
  292. }
  293. type slowEntry[T any] struct {
  294. addr uint8
  295. len int
  296. val T
  297. }
  298. func (t *slowTable[T]) String() string {
  299. pfxs := append([]slowEntry[T](nil), t.prefixes...)
  300. sort.Slice(pfxs, func(i, j int) bool {
  301. if pfxs[i].len != pfxs[j].len {
  302. return pfxs[i].len < pfxs[j].len
  303. }
  304. return pfxs[i].addr < pfxs[j].addr
  305. })
  306. var ret bytes.Buffer
  307. for _, pfx := range pfxs {
  308. fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), pfx.val)
  309. }
  310. return ret.String()
  311. }
  312. func (t *slowTable[T]) delete(addr uint8, prefixLen int) {
  313. pfx := make([]slowEntry[T], 0, len(t.prefixes))
  314. for _, e := range t.prefixes {
  315. if e.addr == addr && e.len == prefixLen {
  316. continue
  317. }
  318. pfx = append(pfx, e)
  319. }
  320. t.prefixes = pfx
  321. }
  322. func (t *slowTable[T]) get(addr uint8) (ret T, ok bool) {
  323. var curLen = -1
  324. for _, e := range t.prefixes {
  325. if addr&pfxMask(e.len) == e.addr && e.len >= curLen {
  326. ret = e.val
  327. curLen = e.len
  328. }
  329. }
  330. return ret, curLen != -1
  331. }
  332. func pfxMask(pfxLen int) uint8 {
  333. return 0xFF << (8 - pfxLen)
  334. }
  335. func allPrefixes() []slowEntry[int] {
  336. ret := make([]slowEntry[int], 0, lastHostIndex)
  337. for i := 1; i < lastHostIndex+1; i++ {
  338. a, ln := inversePrefixIndex(i)
  339. ret = append(ret, slowEntry[int]{a, ln, i})
  340. }
  341. return ret
  342. }
  343. func shufflePrefixes(pfxs []slowEntry[int]) []slowEntry[int] {
  344. rand.Shuffle(len(pfxs), func(i, j int) { pfxs[i], pfxs[j] = pfxs[j], pfxs[i] })
  345. return pfxs
  346. }
  347. func formatSlowEntriesShort[T any](ents []slowEntry[T]) string {
  348. var ret []string
  349. for _, ent := range ents {
  350. ret = append(ret, fmt.Sprintf("%d/%d", ent.addr, ent.len))
  351. }
  352. return "[" + strings.Join(ret, " ") + "]"
  353. }
  354. var cmpDiffOpts = []cmp.Option{
  355. cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }),
  356. }
  357. func getsEqual[T comparable](a T, aOK bool, b T, bOK bool) bool {
  358. if !aOK && !bOK {
  359. return true
  360. }
  361. if aOK != bOK {
  362. return false
  363. }
  364. return a == b
  365. }