smallset_test.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package set
  4. import (
  5. "fmt"
  6. "iter"
  7. "maps"
  8. "reflect"
  9. "slices"
  10. "testing"
  11. )
  12. func TestSmallSet(t *testing.T) {
  13. t.Parallel()
  14. wantSize := reflect.TypeFor[int64]().Size() + reflect.TypeFor[map[int]struct{}]().Size()
  15. if wantSize > 16 {
  16. t.Errorf("wantSize should be no more than 16") // it might be smaller on 32-bit systems
  17. }
  18. if size := reflect.TypeFor[SmallSet[int64]]().Size(); size != wantSize {
  19. t.Errorf("SmallSet[int64] size is %d, want %v", size, wantSize)
  20. }
  21. type op struct {
  22. add bool
  23. v int
  24. }
  25. ops := iter.Seq[op](func(yield func(op) bool) {
  26. for _, add := range []bool{false, true} {
  27. for v := range 4 {
  28. if !yield(op{add: add, v: v}) {
  29. return
  30. }
  31. }
  32. }
  33. })
  34. type setLike interface {
  35. Add(int)
  36. Delete(int)
  37. }
  38. apply := func(s setLike, o op) {
  39. if o.add {
  40. s.Add(o.v)
  41. } else {
  42. s.Delete(o.v)
  43. }
  44. }
  45. // For all combinations of 4 operations,
  46. // apply them to both a regular map and SmallSet
  47. // and make sure all the invariants hold.
  48. for op1 := range ops {
  49. for op2 := range ops {
  50. for op3 := range ops {
  51. for op4 := range ops {
  52. normal := Set[int]{}
  53. small := &SmallSet[int]{}
  54. for _, op := range []op{op1, op2, op3, op4} {
  55. apply(normal, op)
  56. apply(small, op)
  57. }
  58. name := func() string {
  59. return fmt.Sprintf("op1=%v, op2=%v, op3=%v, op4=%v", op1, op2, op3, op4)
  60. }
  61. if normal.Len() != small.Len() {
  62. t.Errorf("len mismatch after ops %s: normal=%d, small=%d", name(), normal.Len(), small.Len())
  63. }
  64. if got := small.Clone().Len(); normal.Len() != got {
  65. t.Errorf("len mismatch after ops %s: normal=%d, clone=%d", name(), normal.Len(), got)
  66. }
  67. normalEle := slices.Sorted(maps.Keys(normal))
  68. smallEle := slices.Sorted(small.Values())
  69. if !slices.Equal(normalEle, smallEle) {
  70. t.Errorf("elements mismatch after ops %s: normal=%v, small=%v", name(), normalEle, smallEle)
  71. }
  72. for e := range 5 {
  73. if normal.Contains(e) != small.Contains(e) {
  74. t.Errorf("contains(%v) mismatch after ops %s: normal=%v, small=%v", e, name(), normal.Contains(e), small.Contains(e))
  75. }
  76. }
  77. if err := small.checkInvariants(); err != nil {
  78. t.Errorf("checkInvariants failed after ops %s: %v", name(), err)
  79. }
  80. if !t.Failed() {
  81. sole, ok := small.SoleElement()
  82. if ok != (small.Len() == 1) {
  83. t.Errorf("SoleElement ok mismatch after ops %s: SoleElement ok=%v, want=%v", name(), ok, !ok)
  84. }
  85. if ok && sole != smallEle[0] {
  86. t.Errorf("SoleElement value mismatch after ops %s: SoleElement=%v, want=%v", name(), sole, smallEle[0])
  87. t.Errorf("Internals: %+v", small)
  88. }
  89. }
  90. }
  91. }
  92. }
  93. }
  94. }
  95. func (s *SmallSet[T]) checkInvariants() error {
  96. var zero T
  97. if s.m != nil && s.one != zero {
  98. return fmt.Errorf("both m and one are non-zero")
  99. }
  100. if s.m != nil {
  101. switch len(s.m) {
  102. case 0:
  103. return fmt.Errorf("m is non-nil but empty")
  104. case 1:
  105. for k := range s.m {
  106. if k != zero {
  107. return fmt.Errorf("m contains exactly 1 non-zero element, %v", k)
  108. }
  109. }
  110. }
  111. }
  112. return nil
  113. }