Просмотр исходного кода

tstest: add method to Replace values for tests

We have many function pointers that we replace for the duration of test and
restore it on test completion, add method to do that.

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 лет назад
Родитель
Сommit
b9ebf7cf14

+ 2 - 6
cmd/tailscale/cli/cli_test.go

@@ -1075,16 +1075,12 @@ func TestUpdatePrefs(t *testing.T) {
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 			if tt.sshOverTailscale {
-				old := getSSHClientEnvVar
-				getSSHClientEnvVar = func() string { return "100.100.100.100 1 1" }
-				t.Cleanup(func() { getSSHClientEnvVar = old })
+				tstest.Replace(t, &getSSHClientEnvVar, func() string { return "100.100.100.100 1 1" })
 			} else if isSSHOverTailscale() {
 				// The test is being executed over a "real" tailscale SSH
 				// session, but sshOverTailscale is unset. Make the test appear
 				// as if it's not over tailscale SSH.
-				old := getSSHClientEnvVar
-				getSSHClientEnvVar = func() string { return "" }
-				t.Cleanup(func() { getSSHClientEnvVar = old })
+				tstest.Replace(t, &getSSHClientEnvVar, func() string { return "" })
 			}
 			if tt.env.goos == "" {
 				tt.env.goos = "linux"

+ 3 - 3
net/dnscache/dnscache_test.go

@@ -13,6 +13,8 @@ import (
 	"reflect"
 	"testing"
 	"time"
+
+	"tailscale.com/tstest"
 )
 
 var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
@@ -142,9 +144,7 @@ func TestResolverAllHostStaticResult(t *testing.T) {
 }
 
 func TestShouldTryBootstrap(t *testing.T) {
-	oldDebug := debug
-	t.Cleanup(func() { debug = oldDebug })
-	debug = func() bool { return true }
+	tstest.Replace(t, &debug, func() bool { return true })
 
 	type step struct {
 		ip  netip.Addr // IP we pretended to dial

+ 2 - 3
net/netcheck/netcheck_test.go

@@ -22,6 +22,7 @@ import (
 	"tailscale.com/net/stun"
 	"tailscale.com/net/stun/stuntest"
 	"tailscale.com/tailcfg"
+	"tailscale.com/tstest"
 )
 
 func TestHairpinSTUN(t *testing.T) {
@@ -679,9 +680,7 @@ func TestNoCaptivePortalWhenUDP(t *testing.T) {
 		}
 	})
 
-	oldTransport := noRedirectClient.Transport
-	t.Cleanup(func() { noRedirectClient.Transport = oldTransport })
-	noRedirectClient.Transport = tr
+	tstest.Replace(t, &noRedirectClient.Transport, http.RoundTripper(tr))
 
 	stunAddr, cleanup := stuntest.Serve(t)
 	defer cleanup()

+ 17 - 0
tstest/tstest.go

@@ -6,12 +6,29 @@ package tstest
 
 import (
 	"context"
+	"testing"
 	"time"
 
 	"tailscale.com/logtail/backoff"
 	"tailscale.com/types/logger"
 )
 
+// Replace replaces the value of target with val.
+// The old value is restored when the test ends.
+func Replace[T any](t *testing.T, target *T, val T) {
+	t.Helper()
+	if target == nil {
+		t.Fatalf("Replace: nil pointer")
+	}
+	old := *target
+	t.Cleanup(func() {
+		*target = old
+	})
+
+	*target = val
+	return
+}
+
 // WaitFor retries try for up to maxWait.
 // It returns nil once try returns nil the first time.
 // If maxWait passes without success, it returns try's last error.

+ 24 - 0
tstest/tstest_test.go

@@ -0,0 +1,24 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package tstest
+
+import "testing"
+
+func TestReplace(t *testing.T) {
+	before := "before"
+	done := false
+	t.Run("replace", func(t *testing.T) {
+		Replace(t, &before, "after")
+		if before != "after" {
+			t.Errorf("before = %q; want %q", before, "after")
+		}
+		done = true
+	})
+	if !done {
+		t.Fatal("subtest didn't run")
+	}
+	if before != "before" {
+		t.Errorf("before = %q; want %q", before, "before")
+	}
+}