Browse Source

util/set: add SmallSet.SoleElement, fix bug, add more tests

This adds SmallSet.SoleElement, which I need in another repo for
efficiency. I added tests, but those tests failed because Add(1) +
Add(1) was promoting the first Add's sole element to a map of one
item. So fix that, and add more tests.

Updates tailscale/corp#29093

Change-Id: Iadd5ad08afe39721ee5449343095e389214d8389
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 9 months ago
parent
commit
ef49e75b10
2 changed files with 54 additions and 5 deletions
  1. 19 5
      util/set/smallset.go
  2. 35 0
      util/set/smallset_test.go

+ 19 - 5
util/set/smallset.go

@@ -50,6 +50,15 @@ func (s SmallSet[T]) Contains(e T) bool {
 	return e != zero && s.one == e
 }
 
+// SoleElement returns the single value in the set, if the set has exactly one
+// element.
+//
+// If the set is empty or has more than one element, ok will be false and e will
+// be the zero value of T.
+func (s SmallSet[T]) SoleElement() (e T, ok bool) {
+	return s.one, s.Len() == 1
+}
+
 // Add adds e to the set.
 //
 // When storing a SmallSet in a map as a value type, it is important to
@@ -61,10 +70,15 @@ func (s *SmallSet[T]) Add(e T) {
 		s.m.Add(e)
 		return
 	}
-	// Size zero to one non-zero element.
-	if s.one == zero && e != zero {
-		s.one = e
-		return
+	// Non-zero elements can go into s.one.
+	if e != zero {
+		if s.one == zero {
+			s.one = e // Len 0 to Len 1
+			return
+		}
+		if s.one == e {
+			return // dup
+		}
 	}
 	// Need to make a multi map, either
 	// because we now have two items, or
@@ -73,7 +87,7 @@ func (s *SmallSet[T]) Add(e T) {
 	if s.one != zero {
 		s.m.Add(s.one) // move single item to multi
 	}
-	s.m.Add(e) // add new item
+	s.m.Add(e) // add new item, possibly zero
 	s.one = zero
 }
 

+ 35 - 0
util/set/smallset_test.go

@@ -84,8 +84,43 @@ func TestSmallSet(t *testing.T) {
 							t.Errorf("contains(%v) mismatch after ops %s: normal=%v, small=%v", e, name(), normal.Contains(e), small.Contains(e))
 						}
 					}
+
+					if err := small.checkInvariants(); err != nil {
+						t.Errorf("checkInvariants failed after ops %s: %v", name(), err)
+					}
+
+					if !t.Failed() {
+						sole, ok := small.SoleElement()
+						if ok != (small.Len() == 1) {
+							t.Errorf("SoleElement ok mismatch after ops %s: SoleElement ok=%v, want=%v", name(), ok, !ok)
+						}
+						if ok && sole != smallEle[0] {
+							t.Errorf("SoleElement value mismatch after ops %s: SoleElement=%v, want=%v", name(), sole, smallEle[0])
+							t.Errorf("Internals: %+v", small)
+						}
+					}
+				}
+			}
+		}
+	}
+}
+
+func (s *SmallSet[T]) checkInvariants() error {
+	var zero T
+	if s.m != nil && s.one != zero {
+		return fmt.Errorf("both m and one are non-zero")
+	}
+	if s.m != nil {
+		switch len(s.m) {
+		case 0:
+			return fmt.Errorf("m is non-nil but empty")
+		case 1:
+			for k := range s.m {
+				if k != zero {
+					return fmt.Errorf("m contains exactly 1 non-zero element, %v", k)
 				}
 			}
 		}
 	}
+	return nil
 }