set_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package set
  4. import (
  5. "encoding/json"
  6. "slices"
  7. "testing"
  8. )
  9. func TestSet(t *testing.T) {
  10. s := Set[int]{}
  11. s.Add(1)
  12. s.Add(2)
  13. if !s.Contains(1) {
  14. t.Error("missing 1")
  15. }
  16. if !s.Contains(2) {
  17. t.Error("missing 2")
  18. }
  19. if s.Contains(3) {
  20. t.Error("shouldn't have 3")
  21. }
  22. if s.Len() != 2 {
  23. t.Errorf("wrong len %d; want 2", s.Len())
  24. }
  25. more := []int{3, 4}
  26. s.AddSlice(more)
  27. if !s.Contains(3) {
  28. t.Error("missing 3")
  29. }
  30. if !s.Contains(4) {
  31. t.Error("missing 4")
  32. }
  33. if s.Contains(5) {
  34. t.Error("shouldn't have 5")
  35. }
  36. if s.Len() != 4 {
  37. t.Errorf("wrong len %d; want 4", s.Len())
  38. }
  39. es := s.Slice()
  40. if len(es) != 4 {
  41. t.Errorf("slice has wrong len %d; want 4", len(es))
  42. }
  43. for _, e := range []int{1, 2, 3, 4} {
  44. if !slices.Contains(es, e) {
  45. t.Errorf("slice missing %d (%#v)", e, es)
  46. }
  47. }
  48. }
  49. func TestSetOf(t *testing.T) {
  50. s := Of(1, 2, 3, 4, 4, 1)
  51. if s.Len() != 4 {
  52. t.Errorf("wrong len %d; want 4", s.Len())
  53. }
  54. for _, n := range []int{1, 2, 3, 4} {
  55. if !s.Contains(n) {
  56. t.Errorf("should contain %d", n)
  57. }
  58. }
  59. }
  60. func TestEqual(t *testing.T) {
  61. type test struct {
  62. name string
  63. a Set[int]
  64. b Set[int]
  65. expected bool
  66. }
  67. tests := []test{
  68. {
  69. "equal",
  70. Of(1, 2, 3, 4),
  71. Of(1, 2, 3, 4),
  72. true,
  73. },
  74. {
  75. "not equal",
  76. Of(1, 2, 3, 4),
  77. Of(1, 2, 3, 5),
  78. false,
  79. },
  80. {
  81. "different lengths",
  82. Of(1, 2, 3, 4, 5),
  83. Of(1, 2, 3, 5),
  84. false,
  85. },
  86. }
  87. for _, tt := range tests {
  88. if tt.a.Equal(tt.b) != tt.expected {
  89. t.Errorf("%s: failed", tt.name)
  90. }
  91. }
  92. }
  93. func TestClone(t *testing.T) {
  94. s := Of(1, 2, 3, 4, 4, 1)
  95. if s.Len() != 4 {
  96. t.Errorf("wrong len %d; want 4", s.Len())
  97. }
  98. s2 := s.Clone()
  99. if !s.Equal(s2) {
  100. t.Error("clone not equal to original")
  101. }
  102. s.Add(100)
  103. if s.Equal(s2) {
  104. t.Error("clone is not distinct from original")
  105. }
  106. }
  107. func TestSetJSONRoundTrip(t *testing.T) {
  108. tests := []struct {
  109. desc string
  110. strings Set[string]
  111. ints Set[int]
  112. }{
  113. {"empty", make(Set[string]), make(Set[int])},
  114. {"nil", nil, nil},
  115. {"one-item", Of("one"), Of(1)},
  116. {"multiple-items", Of("one", "two", "three"), Of(1, 2, 3)},
  117. }
  118. for _, tt := range tests {
  119. t.Run(tt.desc, func(t *testing.T) {
  120. t.Run("strings", func(t *testing.T) {
  121. buf, err := json.Marshal(tt.strings)
  122. if err != nil {
  123. t.Fatalf("json.Marshal: %v", err)
  124. }
  125. t.Logf("marshaled: %s", buf)
  126. var s Set[string]
  127. if err := json.Unmarshal(buf, &s); err != nil {
  128. t.Fatalf("json.Unmarshal: %v", err)
  129. }
  130. if !s.Equal(tt.strings) {
  131. t.Errorf("set changed after JSON marshal/unmarshal, before: %v, after: %v", tt.strings, s)
  132. }
  133. })
  134. t.Run("ints", func(t *testing.T) {
  135. buf, err := json.Marshal(tt.ints)
  136. if err != nil {
  137. t.Fatalf("json.Marshal: %v", err)
  138. }
  139. t.Logf("marshaled: %s", buf)
  140. var s Set[int]
  141. if err := json.Unmarshal(buf, &s); err != nil {
  142. t.Fatalf("json.Unmarshal: %v", err)
  143. }
  144. if !s.Equal(tt.ints) {
  145. t.Errorf("set changed after JSON marshal/unmarshal, before: %v, after: %v", tt.ints, s)
  146. }
  147. })
  148. })
  149. }
  150. }
  151. func TestMake(t *testing.T) {
  152. var s Set[int]
  153. s.Make()
  154. s.Add(1)
  155. if !s.Contains(1) {
  156. t.Error("missing 1")
  157. }
  158. }