Преглед изворни кода

types/views: fix SliceEqualAnyOrderFunc short optimization

This was flagged by @tkhattra on the merge commit; thanks!

Updates tailscale/corp#25479

Signed-off-by: Andrew Dunham <[email protected]>
Change-Id: Ia8045640f02bd4dcc0fe7433249fd72ac6b9cf52
Andrew Dunham пре 1 година
родитељ
комит
eb299302ba
2 измењених фајлова са 64 додато и 6 уклоњено
  1. 32 6
      types/views/views.go
  2. 32 0
      types/views/views_test.go

+ 32 - 6
types/views/views.go

@@ -386,14 +386,32 @@ func SliceEqualAnyOrderFunc[T any, V comparable](a, b Slice[T], cmp func(T) V) b
 	// do the quadratic thing. We can also only check the items between
 	// diffStart and the end.
 	nRemain := a.Len() - diffStart
-	if nRemain <= 5 {
-		maxLen := a.Len() // same as b.Len()
-		for i := diffStart; i < maxLen; i++ {
-			av := cmp(a.At(i))
+	const shortOptLen = 5
+	if nRemain <= shortOptLen {
+		// These track which elements in a and b have been matched, so
+		// that we don't treat arrays with differing number of
+		// duplicate elements as equal (e.g. [1, 1, 2] and [1, 2, 2]).
+		var aMatched, bMatched [shortOptLen]bool
+
+		// Compare each element in a to each element in b
+		for i := range nRemain {
+			av := cmp(a.At(i + diffStart))
 			found := false
-			for j := diffStart; j < maxLen; j++ {
-				bv := cmp(b.At(j))
+			for j := range nRemain {
+				// Skip elements in b that have already been
+				// used to match an item in a.
+				if bMatched[j] {
+					continue
+				}
+
+				bv := cmp(b.At(j + diffStart))
 				if av == bv {
+					// Mark these elements as already
+					// matched, so that a future loop
+					// iteration (of a duplicate element)
+					// doesn't match it again.
+					aMatched[i] = true
+					bMatched[j] = true
 					found = true
 					break
 				}
@@ -402,6 +420,14 @@ func SliceEqualAnyOrderFunc[T any, V comparable](a, b Slice[T], cmp func(T) V) b
 				return false
 			}
 		}
+
+		// Verify all elements were matched exactly once.
+		for i := range nRemain {
+			if !aMatched[i] || !bMatched[i] {
+				return false
+			}
+		}
+
 		return true
 	}
 

+ 32 - 0
types/views/views_test.go

@@ -197,6 +197,38 @@ func TestSliceEqualAnyOrderFunc(t *testing.T) {
 	// Long difference; past the quadratic limit
 	longDiff := ncFrom("b", "a", "c", "d", "e", "f", "g", "h", "i", "k") // differs at end
 	c.Check(SliceEqualAnyOrderFunc(longSlice, longDiff, cmp), qt.Equals, false)
+
+	// The short slice optimization had a bug where it wouldn't handle
+	// duplicate elements; test various cases here driven by code coverage.
+	shortTestCases := []struct {
+		name   string
+		s1, s2 Slice[nc]
+		want   bool
+	}{
+		{
+			name: "duplicates_same_length",
+			s1:   ncFrom("a", "a", "b"),
+			s2:   ncFrom("a", "b", "b"),
+			want: false,
+		},
+		{
+			name: "duplicates_different_matched",
+			s1:   ncFrom("x", "y", "a", "a", "b"),
+			s2:   ncFrom("x", "y", "b", "a", "a"),
+			want: true,
+		},
+		{
+			name: "item_in_a_not_b",
+			s1:   ncFrom("x", "y", "a", "b", "c"),
+			s2:   ncFrom("x", "y", "b", "c", "q"),
+			want: false,
+		},
+	}
+	for _, tc := range shortTestCases {
+		t.Run("short_"+tc.name, func(t *testing.T) {
+			c.Check(SliceEqualAnyOrderFunc(tc.s1, tc.s2, cmp), qt.Equals, tc.want)
+		})
+	}
 }
 
 func TestSliceEqual(t *testing.T) {