flowtrack_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package flowtrack
  4. import (
  5. "encoding/json"
  6. "net/netip"
  7. "testing"
  8. "tailscale.com/tstest"
  9. "tailscale.com/types/ipproto"
  10. )
  11. func TestCache(t *testing.T) {
  12. c := &Cache[int]{MaxEntries: 2}
  13. k1 := MakeTuple(0, netip.MustParseAddrPort("1.1.1.1:1"), netip.MustParseAddrPort("1.1.1.1:1"))
  14. k2 := MakeTuple(0, netip.MustParseAddrPort("1.1.1.1:1"), netip.MustParseAddrPort("2.2.2.2:2"))
  15. k3 := MakeTuple(0, netip.MustParseAddrPort("1.1.1.1:1"), netip.MustParseAddrPort("3.3.3.3:3"))
  16. k4 := MakeTuple(0, netip.MustParseAddrPort("1.1.1.1:1"), netip.MustParseAddrPort("4.4.4.4:4"))
  17. wantLen := func(want int) {
  18. t.Helper()
  19. if got := c.Len(); got != want {
  20. t.Fatalf("Len = %d; want %d", got, want)
  21. }
  22. }
  23. wantVal := func(key Tuple, want int) {
  24. t.Helper()
  25. got, ok := c.Get(key)
  26. if !ok {
  27. t.Fatalf("Get(%q) failed; want value %v", key, want)
  28. }
  29. if *got != want {
  30. t.Fatalf("Get(%q) = %v; want %v", key, got, want)
  31. }
  32. }
  33. wantMissing := func(key Tuple) {
  34. t.Helper()
  35. if got, ok := c.Get(key); ok {
  36. t.Fatalf("Get(%q) = %v; want absent from cache", key, got)
  37. }
  38. }
  39. wantLen(0)
  40. c.RemoveOldest() // shouldn't panic
  41. c.Remove(k4) // shouldn't panic
  42. c.Add(k1, 1)
  43. wantLen(1)
  44. c.Add(k2, 2)
  45. wantLen(2)
  46. c.Add(k3, 3)
  47. wantLen(2) // hit the max
  48. wantMissing(k1)
  49. c.Remove(k1)
  50. wantLen(2) // no change; k1 should've been the deleted one per LRU
  51. wantVal(k3, 3)
  52. wantVal(k2, 2)
  53. c.Remove(k2)
  54. wantLen(1)
  55. wantMissing(k2)
  56. c.Add(k3, 30)
  57. wantVal(k3, 30)
  58. wantLen(1)
  59. err := tstest.MinAllocsPerRun(t, 0, func() {
  60. got, ok := c.Get(k3)
  61. if !ok {
  62. t.Fatal("missing k3")
  63. }
  64. if *got != 30 {
  65. t.Fatalf("got = %d; want 30", got)
  66. }
  67. })
  68. if err != nil {
  69. t.Error(err)
  70. }
  71. }
  72. func BenchmarkMapKeys(b *testing.B) {
  73. b.Run("typed", func(b *testing.B) {
  74. c := &Cache[struct{}]{MaxEntries: 1000}
  75. var t Tuple
  76. for proto := range 20 {
  77. t = Tuple{proto: ipproto.Proto(proto), src: netip.MustParseAddr("1.1.1.1").As16(), srcPort: 1, dst: netip.MustParseAddr("1.1.1.1").As16(), dstPort: 1}
  78. c.Add(t, struct{}{})
  79. }
  80. for i := 0; i < b.N; i++ {
  81. _, ok := c.Get(t)
  82. if !ok {
  83. b.Fatal("missing key")
  84. }
  85. }
  86. })
  87. }
  88. func TestStringJSON(t *testing.T) {
  89. v := MakeTuple(123,
  90. netip.MustParseAddrPort("1.2.3.4:5"),
  91. netip.MustParseAddrPort("6.7.8.9:10"))
  92. if got, want := v.String(), "(IPProto-123 1.2.3.4:5 => 6.7.8.9:10)"; got != want {
  93. t.Errorf("String = %q; want %q", got, want)
  94. }
  95. got, err := json.Marshal(v)
  96. if err != nil {
  97. t.Fatal(err)
  98. }
  99. const want = `{"proto":123,"src":"1.2.3.4:5","dst":"6.7.8.9:10"}`
  100. if string(got) != want {
  101. t.Errorf("Marshal = %q; want %q", got, want)
  102. }
  103. var back Tuple
  104. if err := json.Unmarshal(got, &back); err != nil {
  105. t.Fatal(err)
  106. }
  107. if back != v {
  108. t.Errorf("back = %v; want %v", back, v)
  109. }
  110. }