Browse Source

util/slicesx: add package for generic slice functions, use

Now that we're using rand.Shuffle in a few locations, create a generic
shuffle function and use it instead. While we're at it, move the
interleaveSlices function to the same package for use.

Signed-off-by: Andrew Dunham <[email protected]>
Change-Id: I0b00920e5b3eea846b6cedc30bd34d978a049fd3
Andrew Dunham 3 years ago
parent
commit
73fa7dd7af

+ 2 - 2
cmd/derper/bootstrap_dns.go

@@ -8,13 +8,13 @@ import (
 	"encoding/json"
 	"expvar"
 	"log"
-	"math/rand"
 	"net"
 	"net/http"
 	"strings"
 	"time"
 
 	"tailscale.com/syncs"
+	"tailscale.com/util/slicesx"
 )
 
 const refreshTimeout = time.Minute
@@ -57,7 +57,7 @@ func refreshBootstrapDNS() {
 	// to IPv6
 	for k := range dnsEntries {
 		ips := dnsEntries[k]
-		rand.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] })
+		slicesx.Shuffle(ips)
 		dnsEntries[k] = ips
 	}
 	j, err := json.MarshalIndent(dnsEntries, "", "\t")

+ 1 - 0
cmd/derper/depaware.txt

@@ -87,6 +87,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         tailscale.com/util/multierr                                  from tailscale.com/health
         tailscale.com/util/set                                       from tailscale.com/health
         tailscale.com/util/singleflight                              from tailscale.com/net/dnscache
+        tailscale.com/util/slicesx                                   from tailscale.com/cmd/derper+
         tailscale.com/util/vizerror                                  from tailscale.com/tsweb
    W 💣 tailscale.com/util/winutil                                   from tailscale.com/hostinfo+
         tailscale.com/version                                        from tailscale.com/derp+

+ 1 - 0
cmd/tailscale/depaware.txt

@@ -122,6 +122,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/util/quarantine                                from tailscale.com/cmd/tailscale/cli
         tailscale.com/util/set                                       from tailscale.com/health+
         tailscale.com/util/singleflight                              from tailscale.com/net/dnscache
+        tailscale.com/util/slicesx                                   from tailscale.com/net/dnscache+
      💣 tailscale.com/util/winutil                                   from tailscale.com/hostinfo+
         tailscale.com/version                                        from tailscale.com/cmd/tailscale/cli+
         tailscale.com/version/distro                                 from tailscale.com/cmd/tailscale/cli+

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -305,6 +305,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/util/racebuild                                 from tailscale.com/logpolicy
         tailscale.com/util/set                                       from tailscale.com/health+
         tailscale.com/util/singleflight                              from tailscale.com/control/controlclient+
+        tailscale.com/util/slicesx                                   from tailscale.com/net/dnscache+
         tailscale.com/util/systemd                                   from tailscale.com/control/controlclient+
         tailscale.com/util/uniq                                      from tailscale.com/wgengine/magicsock+
         tailscale.com/util/vizerror                                  from tailscale.com/tsweb

+ 2 - 16
net/dnscache/dnscache.go

@@ -24,6 +24,7 @@ import (
 	"tailscale.com/types/logger"
 	"tailscale.com/util/cloudenv"
 	"tailscale.com/util/singleflight"
+	"tailscale.com/util/slicesx"
 )
 
 var zaddr netip.Addr
@@ -577,7 +578,7 @@ func (dc *dialCall) raceDial(ctx context.Context, ips []netip.Addr) (net.Conn, e
 			iv4 = append(iv4, ip)
 		}
 	}
-	ips = interleaveSlices(iv6, iv4)
+	ips = slicesx.Interleave(iv6, iv4)
 
 	go func() {
 		for i, ip := range ips {
@@ -636,21 +637,6 @@ func (dc *dialCall) raceDial(ctx context.Context, ips []netip.Addr) (net.Conn, e
 	}
 }
 
-// interleaveSlices combines two slices of the form [a, b, c] and [x, y, z]
-// into a slice with elements interleaved; i.e. [a, x, b, y, c, z].
-func interleaveSlices[T any](a, b []T) []T {
-	var (
-		i   int
-		ret = make([]T, 0, len(a)+len(b))
-	)
-	for i = 0; i < len(a) && i < len(b); i++ {
-		ret = append(ret, a[i], b[i])
-	}
-	ret = append(ret, a[i:]...)
-	ret = append(ret, b[i:]...)
-	return ret
-}
-
 func v4addrs(aa []netip.Addr) (ret []netip.Addr) {
 	for _, a := range aa {
 		a = a.Unmap()

+ 0 - 24
net/dnscache/dnscache_test.go

@@ -141,30 +141,6 @@ func TestResolverAllHostStaticResult(t *testing.T) {
 	}
 }
 
-func TestInterleaveSlices(t *testing.T) {
-	testCases := []struct {
-		name string
-		a, b []int
-		want []int
-	}{
-		{name: "equal", a: []int{1, 3, 5}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 5, 6}},
-		{name: "short_b", a: []int{1, 3, 5}, b: []int{2, 4}, want: []int{1, 2, 3, 4, 5}},
-		{name: "short_a", a: []int{1, 3}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 6}},
-		{name: "len_1", a: []int{1}, b: []int{2, 4, 6}, want: []int{1, 2, 4, 6}},
-		{name: "nil_a", a: nil, b: []int{2, 4, 6}, want: []int{2, 4, 6}},
-		{name: "nil_all", a: nil, b: nil, want: []int{}},
-	}
-
-	for _, tc := range testCases {
-		t.Run(tc.name, func(t *testing.T) {
-			merged := interleaveSlices(tc.a, tc.b)
-			if !reflect.DeepEqual(merged, tc.want) {
-				t.Errorf("got %v; want %v", merged, tc.want)
-			}
-		})
-	}
-}
-
 func TestShouldTryBootstrap(t *testing.T) {
 	oldDebug := debug
 	t.Cleanup(func() { debug = oldDebug })

+ 4 - 4
net/dnsfallback/dnsfallback.go

@@ -14,7 +14,6 @@ import (
 	"errors"
 	"fmt"
 	"log"
-	"math/rand"
 	"net"
 	"net/http"
 	"net/netip"
@@ -31,6 +30,7 @@ import (
 	"tailscale.com/syncs"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/logger"
+	"tailscale.com/util/slicesx"
 )
 
 func Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
@@ -56,8 +56,8 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
 			}
 		}
 	}
-	rand.Shuffle(len(cands4), func(i, j int) { cands4[i], cands4[j] = cands4[j], cands4[i] })
-	rand.Shuffle(len(cands6), func(i, j int) { cands6[i], cands6[j] = cands6[j], cands6[i] })
+	slicesx.Shuffle(cands4)
+	slicesx.Shuffle(cands6)
 
 	const maxCands = 6
 	var cands []nameIP // up to maxCands alternating v4/v6 as long as we have both
@@ -87,7 +87,7 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
 			continue
 		}
 		if ips := dm[host]; len(ips) > 0 {
-			rand.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] })
+			slicesx.Shuffle(ips)
 			logf("bootstrapDNS(%q, %q) for %q = %v", cand.dnsName, cand.ip, host, ips)
 			return ips, nil
 		}

+ 44 - 0
util/slicesx/slicesx.go

@@ -0,0 +1,44 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package slicesx contains some helpful generic slice functions.
+package slicesx
+
+import "math/rand"
+
+// Interleave combines two slices of the form [a, b, c] and [x, y, z] into a
+// slice with elements interleaved; i.e. [a, x, b, y, c, z].
+func Interleave[S ~[]T, T any](a, b S) S {
+	// Avoid allocating an empty slice.
+	if a == nil && b == nil {
+		return nil
+	}
+
+	var (
+		i   int
+		ret = make([]T, 0, len(a)+len(b))
+	)
+	for i = 0; i < len(a) && i < len(b); i++ {
+		ret = append(ret, a[i], b[i])
+	}
+	ret = append(ret, a[i:]...)
+	ret = append(ret, b[i:]...)
+	return ret
+}
+
+// Shuffle randomly shuffles a slice in-place, similar to rand.Shuffle.
+func Shuffle[S ~[]T, T any](s S) {
+	// TODO(andrew): use a pooled Rand?
+
+	// This is the same Fisher-Yates shuffle implementation as rand.Shuffle
+	n := len(s)
+	i := n - 1
+	for ; i > 1<<31-1-1; i-- {
+		j := int(rand.Int63n(int64(i + 1)))
+		s[i], s[j] = s[j], s[i]
+	}
+	for ; i > 0; i-- {
+		j := int(rand.Int31n(int32(i + 1)))
+		s[i], s[j] = s[j], s[i]
+	}
+}

+ 66 - 0
util/slicesx/slicesx_test.go

@@ -0,0 +1,66 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package slicesx
+
+import (
+	"reflect"
+	"testing"
+
+	"golang.org/x/exp/slices"
+)
+
+func TestInterleave(t *testing.T) {
+	testCases := []struct {
+		name string
+		a, b []int
+		want []int
+	}{
+		{name: "equal", a: []int{1, 3, 5}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 5, 6}},
+		{name: "short_b", a: []int{1, 3, 5}, b: []int{2, 4}, want: []int{1, 2, 3, 4, 5}},
+		{name: "short_a", a: []int{1, 3}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 6}},
+		{name: "len_1", a: []int{1}, b: []int{2, 4, 6}, want: []int{1, 2, 4, 6}},
+		{name: "nil_a", a: nil, b: []int{2, 4, 6}, want: []int{2, 4, 6}},
+		{name: "nil_all", a: nil, b: nil, want: nil},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			merged := Interleave(tc.a, tc.b)
+			if !reflect.DeepEqual(merged, tc.want) {
+				t.Errorf("got %v; want %v", merged, tc.want)
+			}
+		})
+	}
+}
+
+func BenchmarkInterleave(b *testing.B) {
+	b.ReportAllocs()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		Interleave(
+			[]int{1, 2, 3},
+			[]int{9, 8, 7},
+		)
+	}
+}
+func TestShuffle(t *testing.T) {
+	var sl []int
+	for i := 0; i < 100; i++ {
+		sl = append(sl, i)
+	}
+
+	var wasShuffled bool
+	for try := 0; try < 10; try++ {
+		shuffled := slices.Clone(sl)
+		Shuffle(shuffled)
+		if !reflect.DeepEqual(shuffled, sl) {
+			wasShuffled = true
+			break
+		}
+	}
+
+	if !wasShuffled {
+		t.Errorf("expected shuffle after 10 tries")
+	}
+}