|
|
@@ -0,0 +1,1139 @@
|
|
|
+// Copyright (c) Tailscale Inc & AUTHORS
|
|
|
+// SPDX-License-Identifier: BSD-3-Clause
|
|
|
+
|
|
|
+package ipnlocal
|
|
|
+
|
|
|
+import (
|
|
|
+ "cmp"
|
|
|
+ "context"
|
|
|
+ "errors"
|
|
|
+ "net/netip"
|
|
|
+ "reflect"
|
|
|
+ "slices"
|
|
|
+ "strconv"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
+ "testing"
|
|
|
+
|
|
|
+ deepcmp "github.com/google/go-cmp/cmp"
|
|
|
+ "github.com/google/go-cmp/cmp/cmpopts"
|
|
|
+
|
|
|
+ "tailscale.com/health"
|
|
|
+ "tailscale.com/ipn"
|
|
|
+ "tailscale.com/ipn/ipnauth"
|
|
|
+ "tailscale.com/ipn/ipnext"
|
|
|
+ "tailscale.com/ipn/store/mem"
|
|
|
+ "tailscale.com/tailcfg"
|
|
|
+ "tailscale.com/tsd"
|
|
|
+ "tailscale.com/tstest"
|
|
|
+ "tailscale.com/types/key"
|
|
|
+ "tailscale.com/types/persist"
|
|
|
+ "tailscale.com/util/must"
|
|
|
+)
|
|
|
+
|
|
|
+// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes
|
|
|
+// and shuts down extensions.
|
|
|
+func TestExtensionInitShutdown(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ // As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors
|
|
|
+ // as extension initialization and shutdown errors are not fatal.
|
|
|
+ // If these methods are updated to return errors, this test should also be updated.
|
|
|
+ // The conversions below will fail to compile if their signatures change, reminding us to update the test.
|
|
|
+ _ = (func(*ExtensionHost))((*ExtensionHost).Init)
|
|
|
+ _ = (func(*ExtensionHost))((*ExtensionHost).Shutdown)
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ nilHost bool
|
|
|
+ exts []*testExtension
|
|
|
+ wantInit []string
|
|
|
+ wantShutdown []string
|
|
|
+ skipInit bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "nil-host",
|
|
|
+ nilHost: true,
|
|
|
+ exts: []*testExtension{},
|
|
|
+ wantInit: []string{},
|
|
|
+ wantShutdown: []string{},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "empty-extensions",
|
|
|
+ exts: []*testExtension{},
|
|
|
+ wantInit: []string{},
|
|
|
+ wantShutdown: []string{},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "single-extension",
|
|
|
+ exts: []*testExtension{{name: "A"}},
|
|
|
+ wantInit: []string{"A"},
|
|
|
+ wantShutdown: []string{"A"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "multiple-extensions/all-ok",
|
|
|
+ exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
|
|
|
+ wantInit: []string{"A", "B", "C"},
|
|
|
+ wantShutdown: []string{"C", "B", "A"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "multiple-extensions/no-init-no-shutdown",
|
|
|
+ exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
|
|
|
+ wantInit: []string{},
|
|
|
+ wantShutdown: []string{},
|
|
|
+ skipInit: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "multiple-extensions/init-failed/first",
|
|
|
+ exts: []*testExtension{{
|
|
|
+ name: "A",
|
|
|
+ InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|
|
+ }, {
|
|
|
+ name: "B",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }, {
|
|
|
+ name: "C",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }},
|
|
|
+ wantInit: []string{"A", "B", "C"},
|
|
|
+ wantShutdown: []string{"C", "B"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "multiple-extensions/init-failed/second",
|
|
|
+ exts: []*testExtension{{
|
|
|
+ name: "A",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }, {
|
|
|
+ name: "B",
|
|
|
+ InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|
|
+ }, {
|
|
|
+ name: "C",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }},
|
|
|
+ wantInit: []string{"A", "B", "C"},
|
|
|
+ wantShutdown: []string{"C", "A"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "multiple-extensions/init-failed/third",
|
|
|
+ exts: []*testExtension{{
|
|
|
+ name: "A",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }, {
|
|
|
+ name: "B",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }, {
|
|
|
+ name: "C",
|
|
|
+ InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|
|
+ }},
|
|
|
+ wantInit: []string{"A", "B", "C"},
|
|
|
+ wantShutdown: []string{"B", "A"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "multiple-extensions/init-failed/all",
|
|
|
+ exts: []*testExtension{{
|
|
|
+ name: "A",
|
|
|
+ InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|
|
+ }, {
|
|
|
+ name: "B",
|
|
|
+ InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|
|
+ }, {
|
|
|
+ name: "C",
|
|
|
+ InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|
|
+ }},
|
|
|
+ wantInit: []string{"A", "B", "C"},
|
|
|
+ wantShutdown: []string{},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "multiple-extensions/init-skipped",
|
|
|
+ exts: []*testExtension{{
|
|
|
+ name: "A",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }, {
|
|
|
+ name: "B",
|
|
|
+ InitHook: func(*testExtension) error { return ipnext.SkipExtension },
|
|
|
+ }, {
|
|
|
+ name: "C",
|
|
|
+ InitHook: func(*testExtension) error { return nil },
|
|
|
+ }},
|
|
|
+ wantInit: []string{"A", "B", "C"},
|
|
|
+ wantShutdown: []string{"C", "A"},
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ // Configure all extensions to append their names
|
|
|
+ // to the gotInit and gotShutdown slices
|
|
|
+ // during initialization and shutdown,
|
|
|
+ // so we can check that they are called in the right order
|
|
|
+ // and that shutdown is not unless init succeeded.
|
|
|
+ var gotInit, gotShutdown []string
|
|
|
+ for _, ext := range tt.exts {
|
|
|
+ oldInitHook := ext.InitHook
|
|
|
+ ext.InitHook = func(e *testExtension) error {
|
|
|
+ gotInit = append(gotInit, e.name)
|
|
|
+ if oldInitHook == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return oldInitHook(e)
|
|
|
+ }
|
|
|
+ ext.ShutdownHook = func(e *testExtension) error {
|
|
|
+ gotShutdown = append(gotShutdown, e.name)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var h *ExtensionHost
|
|
|
+ if !tt.nilHost {
|
|
|
+ h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !tt.skipInit {
|
|
|
+ h.Init()
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check that the extensions were initialized in the right order.
|
|
|
+ if !slices.Equal(gotInit, tt.wantInit) {
|
|
|
+ t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Calling Init again on the host should be a no-op.
|
|
|
+ // The [testExtension.Init] method fails the test if called more than once,
|
|
|
+ // regardless of which test is running, so we don't need to check it here.
|
|
|
+ // Similarly, calling Shutdown again on the host should be a no-op as well.
|
|
|
+ // It is verified by the [testExtension.Shutdown] method itself.
|
|
|
+ if !tt.skipInit {
|
|
|
+ h.Init()
|
|
|
+ }
|
|
|
+
|
|
|
+ // Extensions should not be shut down before the host is shut down,
|
|
|
+ // even if they are not initialized successfully.
|
|
|
+ for _, ext := range tt.exts {
|
|
|
+ if gotShutdown := ext.ShutdownCalled(); gotShutdown {
|
|
|
+ t.Errorf("%q: Extension shutdown called before host shutdown", ext.name)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ h.Shutdown()
|
|
|
+ // Check that the extensions were shut down in the right order,
|
|
|
+ // and that they were not shut down if they were not initialized successfully.
|
|
|
+ if !slices.Equal(gotShutdown, tt.wantShutdown) {
|
|
|
+ t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown)
|
|
|
+ }
|
|
|
+
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// TestNewExtensionHost tests that [NewExtensionHost] correctly creates
|
|
|
+// an [ExtensionHost], instantiates the extensions and handles errors
|
|
|
+// if an extension cannot be created.
|
|
|
+func TestNewExtensionHost(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ defs []*ipnext.Definition
|
|
|
+ wantErr bool
|
|
|
+ wantExts []string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "no-exts",
|
|
|
+ defs: []*ipnext.Definition{},
|
|
|
+ wantErr: false,
|
|
|
+ wantExts: []string{},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "exts-ok",
|
|
|
+ defs: []*ipnext.Definition{
|
|
|
+ ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
|
|
+ ipnext.DefinitionForTest(&testExtension{name: "B"}),
|
|
|
+ ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
|
|
+ },
|
|
|
+ wantErr: false,
|
|
|
+ wantExts: []string{"A", "B", "C"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "exts-skipped",
|
|
|
+ defs: []*ipnext.Definition{
|
|
|
+ ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
|
|
+ ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension),
|
|
|
+ ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
|
|
+ },
|
|
|
+ wantErr: false, // extension B is skipped, that's ok
|
|
|
+ wantExts: []string{"A", "C"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "exts-fail",
|
|
|
+ defs: []*ipnext.Definition{
|
|
|
+ ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
|
|
+ ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")),
|
|
|
+ ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
|
|
+ },
|
|
|
+ wantErr: true, // extension B failed to create, that's not ok
|
|
|
+ wantExts: []string{},
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ logf := tstest.WhileTestRunningLogger(t)
|
|
|
+ h, err := NewExtensionHost(logf, &tsd.System{}, &testBackend{}, tt.defs...)
|
|
|
+ if gotErr := err != nil; gotErr != tt.wantErr {
|
|
|
+ t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var gotExts []string
|
|
|
+ for _, ext := range h.allExtensions {
|
|
|
+ gotExts = append(gotExts, ext.Name())
|
|
|
+ }
|
|
|
+
|
|
|
+ if !slices.Equal(gotExts, tt.wantExts) {
|
|
|
+ t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues
|
|
|
+// backend operations and executes them asynchronously in the order they were received.
|
|
|
+// It also checks that operations requested before the host and all extensions are initialized
|
|
|
+// are not executed immediately but rather after the host and extensions are initialized.
|
|
|
+func TestExtensionHostEnqueueBackendOperation(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ preInitCalls []string // before host init
|
|
|
+ extInitCalls []string // from [Extension.Init]; "" means no call
|
|
|
+ wantInitCalls []string // what we expect to be called after host init
|
|
|
+ postInitCalls []string // after host init
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "no-calls",
|
|
|
+ preInitCalls: []string{},
|
|
|
+ extInitCalls: []string{},
|
|
|
+ wantInitCalls: []string{},
|
|
|
+ postInitCalls: []string{},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "pre-init-calls",
|
|
|
+ preInitCalls: []string{"pre-init-1", "pre-init-2"},
|
|
|
+ extInitCalls: []string{},
|
|
|
+ wantInitCalls: []string{"pre-init-1", "pre-init-2"},
|
|
|
+ postInitCalls: []string{},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "init-calls",
|
|
|
+ preInitCalls: []string{},
|
|
|
+ extInitCalls: []string{"init-1", "init-2"},
|
|
|
+ wantInitCalls: []string{"init-1", "init-2"},
|
|
|
+ postInitCalls: []string{},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "post-init-calls",
|
|
|
+ preInitCalls: []string{},
|
|
|
+ extInitCalls: []string{},
|
|
|
+ wantInitCalls: []string{},
|
|
|
+ postInitCalls: []string{"post-init-1", "post-init-2"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "mixed-calls",
|
|
|
+ preInitCalls: []string{"pre-init-1", "pre-init-2"},
|
|
|
+ extInitCalls: []string{"init-1", "", "init-2"},
|
|
|
+ wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"},
|
|
|
+ postInitCalls: []string{"post-init-1", "post-init-2"},
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ var gotCalls []string
|
|
|
+ var h *ExtensionHost
|
|
|
+ b := &testBackend{
|
|
|
+ switchToBestProfileHook: func(reason string) {
|
|
|
+ gotCalls = append(gotCalls, reason)
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ exts := make([]*testExtension, len(tt.extInitCalls))
|
|
|
+ for i, reason := range tt.extInitCalls {
|
|
|
+ exts[i] = &testExtension{}
|
|
|
+ if reason != "" {
|
|
|
+ exts[i].InitHook = func(e *testExtension) error {
|
|
|
+ e.host.Profiles().SwitchToBestProfileAsync(reason)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ h = newExtensionHostForTest(t, b, false, exts...)
|
|
|
+ wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue].
|
|
|
+
|
|
|
+ // Issue some pre-init calls. They should be deferred and not
|
|
|
+ // added to the queue until the host is initialized.
|
|
|
+ for _, call := range tt.preInitCalls {
|
|
|
+ h.Profiles().SwitchToBestProfileAsync(call)
|
|
|
+ }
|
|
|
+
|
|
|
+ // The queue should be empty before the host is initialized.
|
|
|
+ wq.Drain()
|
|
|
+ if len(gotCalls) != 0 {
|
|
|
+ t.Errorf("Pre-init calls: got %v; want (none)", gotCalls)
|
|
|
+ }
|
|
|
+ gotCalls = nil
|
|
|
+
|
|
|
+ // Initialize the host and all extensions.
|
|
|
+ // The extensions will make their calls during initialization.
|
|
|
+ h.Init()
|
|
|
+
|
|
|
+ // Calls made before or during initialization should now be enqueued and running.
|
|
|
+ wq.Drain()
|
|
|
+ if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
|
|
|
+ t.Errorf("Init calls: (+got -want): %v", diff)
|
|
|
+ }
|
|
|
+ gotCalls = nil
|
|
|
+
|
|
|
+ // Let's make some more calls, as if extensions were making them in a response
|
|
|
+ // to external events.
|
|
|
+ for _, call := range tt.postInitCalls {
|
|
|
+ h.Profiles().SwitchToBestProfileAsync(call)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Any calls made after initialization should be enqueued and running.
|
|
|
+ wq.Drain()
|
|
|
+ if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
|
|
|
+ t.Errorf("Init calls: (+got -want): %v", diff)
|
|
|
+ }
|
|
|
+ gotCalls = nil
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// TestExtensionHostProfileChangeCallback verifies that [ExtensionHost] correctly handles the registration,
|
|
|
+// invocation, and unregistration of profile change callbacks. It also checks that the callbacks are called
|
|
|
+// with the correct arguments and that any private keys are stripped from [ipn.Prefs] before being passed to the callback.
|
|
|
+func TestExtensionHostProfileChangeCallback(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ type profileChange struct {
|
|
|
+ Profile *ipn.LoginProfile
|
|
|
+ Prefs *ipn.Prefs
|
|
|
+ SameNode bool
|
|
|
+ }
|
|
|
+ // newProfileChange creates a new profile change with deep copies of the profile and prefs.
|
|
|
+ newProfileChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) profileChange {
|
|
|
+ return profileChange{
|
|
|
+ Profile: profile.AsStruct(),
|
|
|
+ Prefs: prefs.AsStruct(),
|
|
|
+ SameNode: sameNode,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // makeProfileChangeAppender returns a callback that appends profile changes to the extension's state.
|
|
|
+ makeProfileChangeAppender := func(e *testExtension) ipnext.ProfileChangeCallback {
|
|
|
+ return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
|
|
|
+ UpdateExtState(e, "changes", func(changes []profileChange) []profileChange {
|
|
|
+ return append(changes, newProfileChange(profile, prefs, sameNode))
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // getProfileChanges returns the profile changes stored in the extension's state.
|
|
|
+ getProfileChanges := func(e *testExtension) []profileChange {
|
|
|
+ changes, _ := GetExtStateOk[[]profileChange](e, "changes")
|
|
|
+ return changes
|
|
|
+ }
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ ext *testExtension
|
|
|
+ calls []profileChange
|
|
|
+ wantCalls []profileChange
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ // Register the callback for the lifetime of the extension.
|
|
|
+ name: "Register/Lifetime",
|
|
|
+ ext: &testExtension{},
|
|
|
+ calls: []profileChange{
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
|
|
|
+ },
|
|
|
+ wantCalls: []profileChange{ // all calls are received by the callback
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ // Override the default InitHook used in the test to unregister the callback
|
|
|
+ // after the first call.
|
|
|
+ name: "Register/Once",
|
|
|
+ ext: &testExtension{
|
|
|
+ InitHook: func(e *testExtension) error {
|
|
|
+ var unregister func()
|
|
|
+ handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
|
|
|
+ makeProfileChangeAppender(e)(profile, prefs, sameNode)
|
|
|
+ unregister()
|
|
|
+ }
|
|
|
+ unregister = e.host.Profiles().RegisterProfileChangeCallback(handler)
|
|
|
+ return nil
|
|
|
+ },
|
|
|
+ },
|
|
|
+ calls: []profileChange{
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
|
|
+ },
|
|
|
+ wantCalls: []profileChange{ // only the first call is received by the callback
|
|
|
+ {Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ // Ensure that ipn.Prefs are passed to the callback.
|
|
|
+ name: "CheckPrefs",
|
|
|
+ ext: &testExtension{},
|
|
|
+ calls: []profileChange{{
|
|
|
+ Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|
|
+ Prefs: &ipn.Prefs{
|
|
|
+ WantRunning: true,
|
|
|
+ LoggedOut: false,
|
|
|
+ AdvertiseRoutes: []netip.Prefix{
|
|
|
+ netip.MustParsePrefix("192.168.1.0/24"),
|
|
|
+ netip.MustParsePrefix("192.168.2.0/24"),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }},
|
|
|
+ wantCalls: []profileChange{{
|
|
|
+ Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|
|
+ Prefs: &ipn.Prefs{
|
|
|
+ WantRunning: true,
|
|
|
+ LoggedOut: false,
|
|
|
+ AdvertiseRoutes: []netip.Prefix{
|
|
|
+ netip.MustParsePrefix("192.168.1.0/24"),
|
|
|
+ netip.MustParsePrefix("192.168.2.0/24"),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ // Ensure that private keys are stripped from persist.Persist shared with extensions.
|
|
|
+ name: "StripPrivateKeys",
|
|
|
+ ext: &testExtension{},
|
|
|
+ calls: []profileChange{{
|
|
|
+ Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|
|
+ Prefs: &ipn.Prefs{
|
|
|
+ Persist: &persist.Persist{
|
|
|
+ NodeID: "12345",
|
|
|
+ PrivateNodeKey: key.NewNode(),
|
|
|
+ OldPrivateNodeKey: key.NewNode(),
|
|
|
+ NetworkLockKey: key.NewNLPrivate(),
|
|
|
+ UserProfile: tailcfg.UserProfile{
|
|
|
+ ID: 12345,
|
|
|
+ LoginName: "[email protected]",
|
|
|
+ DisplayName: "Test User",
|
|
|
+ ProfilePicURL: "https://example.com/profile.png",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }},
|
|
|
+ wantCalls: []profileChange{{
|
|
|
+ Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|
|
+ Prefs: &ipn.Prefs{
|
|
|
+ Persist: &persist.Persist{
|
|
|
+ NodeID: "12345",
|
|
|
+ PrivateNodeKey: key.NodePrivate{}, // stripped
|
|
|
+ OldPrivateNodeKey: key.NodePrivate{}, // stripped
|
|
|
+ NetworkLockKey: key.NLPrivate{}, // stripped
|
|
|
+ UserProfile: tailcfg.UserProfile{
|
|
|
+ ID: 12345,
|
|
|
+ LoginName: "[email protected]",
|
|
|
+ DisplayName: "Test User",
|
|
|
+ ProfilePicURL: "https://example.com/profile.png",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }},
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ // Use the default InitHook if not provided by the test.
|
|
|
+ if tt.ext.InitHook == nil {
|
|
|
+ tt.ext.InitHook = func(e *testExtension) error {
|
|
|
+ // Create and register the callback on init.
|
|
|
+ handler := makeProfileChangeAppender(e)
|
|
|
+ e.Cleanup(e.host.Profiles().RegisterProfileChangeCallback(handler))
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext)
|
|
|
+ for _, call := range tt.calls {
|
|
|
+ h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode)
|
|
|
+ }
|
|
|
+ opts := []deepcmp.Option{
|
|
|
+ cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}),
|
|
|
+ }
|
|
|
+ if diff := deepcmp.Diff(tt.wantCalls, getProfileChanges(tt.ext), opts...); diff != "" {
|
|
|
+ t.Errorf("ProfileChange callbacks: (-want +got): %v", diff)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// TestBackgroundProfileResolver tests that the background profile resolvers
|
|
|
+// are correctly registered, unregistered and invoked by the [ExtensionHost].
|
|
|
+func TestBackgroundProfileResolver(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ profiles []ipn.LoginProfile // the first one is the current profile
|
|
|
+ resolvers []ipnext.ProfileResolver
|
|
|
+ wantProfile *ipn.LoginProfile
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "No-Profiles/No-Resolvers",
|
|
|
+ profiles: nil,
|
|
|
+ resolvers: nil,
|
|
|
+ wantProfile: nil,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ // TODO(nickkhyl): update this test as we change "background profile resolvers"
|
|
|
+ // to just "profile resolvers". The wantProfile should be the current profile by default.
|
|
|
+ name: "Has-Profiles/No-Resolvers",
|
|
|
+ profiles: []ipn.LoginProfile{{ID: "profile-1"}},
|
|
|
+ resolvers: nil,
|
|
|
+ wantProfile: nil,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Has-Profiles/Single-Resolver",
|
|
|
+ profiles: []ipn.LoginProfile{{ID: "profile-1"}},
|
|
|
+ resolvers: []ipnext.ProfileResolver{
|
|
|
+ func(ps ipnext.ProfileStore) ipn.LoginProfileView {
|
|
|
+ return ps.CurrentProfile()
|
|
|
+ },
|
|
|
+ },
|
|
|
+ wantProfile: &ipn.LoginProfile{ID: "profile-1"},
|
|
|
+ },
|
|
|
+ // TODO(nickkhyl): add more tests for multiple resolvers and different profiles
|
|
|
+ // once we change "background profile resolvers" to just "profile resolvers"
|
|
|
+ // and add proper conflict resolution logic.
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ // Create a new profile manager and add the profiles to it.
|
|
|
+ // We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface.
|
|
|
+ pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker)))
|
|
|
+ for i, p := range tt.profiles {
|
|
|
+ // Generate a unique ID and key for each profile,
|
|
|
+ // unless the profile already has them set
|
|
|
+ // or is an empty, unnamed profile.
|
|
|
+ if p.Name != "" {
|
|
|
+ if p.ID == "" {
|
|
|
+ p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i))
|
|
|
+ }
|
|
|
+ if p.Key == "" {
|
|
|
+ p.Key = "key-" + ipn.StateKey(p.ID)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ pv := p.View()
|
|
|
+ pm.knownProfiles[p.ID] = pv
|
|
|
+ if i == 0 {
|
|
|
+ // Set the first profile as the current one.
|
|
|
+ // A profileManager starts with an empty profile,
|
|
|
+ // so it's okay if the list of profiles is empty.
|
|
|
+ pm.SwitchToProfile(pv)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true)
|
|
|
+
|
|
|
+ // Register the resolvers with the host.
|
|
|
+ // This is typically done by the extensions themselves,
|
|
|
+ // but we do it here for testing purposes.
|
|
|
+ for _, r := range tt.resolvers {
|
|
|
+ t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r))
|
|
|
+ }
|
|
|
+
|
|
|
+ // Call the resolver to get the profile.
|
|
|
+ gotProfile := h.DetermineBackgroundProfile(pm)
|
|
|
+ if !gotProfile.Equals(tt.wantProfile.View()) {
|
|
|
+ t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// TestAuditLogProviders tests that the [ExtensionHost] correctly handles
|
|
|
+// the registration and invocation of audit log providers. It verifies that
|
|
|
+// the audit loggers are called with the correct actions and details,
|
|
|
+// and that any errors returned by the providers are properly propagated.
|
|
|
+func TestAuditLogProviders(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ auditLoggers []ipnauth.AuditLogFunc // each represents an extension
|
|
|
+ actions []tailcfg.ClientAuditAction
|
|
|
+ wantErr bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "No-Providers",
|
|
|
+ auditLoggers: nil,
|
|
|
+ actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Single-Provider/Ok",
|
|
|
+ auditLoggers: []ipnauth.AuditLogFunc{
|
|
|
+ func(tailcfg.ClientAuditAction, string) error { return nil },
|
|
|
+ },
|
|
|
+ actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Single-Provider/Err",
|
|
|
+ auditLoggers: []ipnauth.AuditLogFunc{
|
|
|
+ func(tailcfg.ClientAuditAction, string) error {
|
|
|
+ return errors.New("failed to log")
|
|
|
+ },
|
|
|
+ },
|
|
|
+ actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|
|
+ wantErr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Many-Providers/Ok",
|
|
|
+ auditLoggers: []ipnauth.AuditLogFunc{
|
|
|
+ func(tailcfg.ClientAuditAction, string) error { return nil },
|
|
|
+ func(tailcfg.ClientAuditAction, string) error { return nil },
|
|
|
+ },
|
|
|
+ actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Many-Providers/Err",
|
|
|
+ auditLoggers: []ipnauth.AuditLogFunc{
|
|
|
+ func(tailcfg.ClientAuditAction, string) error {
|
|
|
+ return errors.New("failed to log")
|
|
|
+ },
|
|
|
+ func(tailcfg.ClientAuditAction, string) error {
|
|
|
+ return nil // all good
|
|
|
+ },
|
|
|
+ func(tailcfg.ClientAuditAction, string) error {
|
|
|
+ return errors.New("also failed to log")
|
|
|
+ },
|
|
|
+ },
|
|
|
+ actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|
|
+ wantErr: true, // some providers failed to log, so that's an error
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ // Create extensions that register the audit log providers.
|
|
|
+ // Each extension/provider will append auditable actions to its state,
|
|
|
+ // then call the test's auditLogger function.
|
|
|
+ var exts []*testExtension
|
|
|
+ for _, auditLogger := range tt.auditLoggers {
|
|
|
+ ext := &testExtension{}
|
|
|
+ provider := func() ipnauth.AuditLogFunc {
|
|
|
+ return func(action tailcfg.ClientAuditAction, details string) error {
|
|
|
+ UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction {
|
|
|
+ return append(actions, action)
|
|
|
+ })
|
|
|
+ return auditLogger(action, details)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ext.InitHook = func(e *testExtension) error {
|
|
|
+ e.Cleanup(e.host.RegisterAuditLogProvider(provider))
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ exts = append(exts, ext)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Initialize the host and the extensions.
|
|
|
+ h := newExtensionHostForTest(t, &testBackend{}, true, exts...)
|
|
|
+
|
|
|
+ // Use [ExtensionHost.AuditLogger] to log actions.
|
|
|
+ for _, action := range tt.actions {
|
|
|
+ err := h.AuditLogger()(action, "Test details")
|
|
|
+ if gotErr := err != nil; gotErr != tt.wantErr {
|
|
|
+ t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check that the actions were logged correctly by each provider.
|
|
|
+ for _, ext := range exts {
|
|
|
+ gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions")
|
|
|
+ if !slices.Equal(gotActions, tt.actions) {
|
|
|
+ t.Errorf("Actions: got %v; want %v", gotActions, tt.actions)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// TestNilExtensionHostMethodCall tests that calling exported methods
|
|
|
+// on a nil [ExtensionHost] does not panic. We should treat it as a valid
|
|
|
+// value since it's used in various tests that instantiate [LocalBackend]
|
|
|
+// manually without calling [NewLocalBackend]. It also verifies that if
|
|
|
+// a method returns a single func value (e.g., a cleanup function),
|
|
|
+// it should not be nil. This is a basic sanity check to ensure that
|
|
|
+// typical method calls on a nil receiver work as expected.
|
|
|
+// It does not replace the need for more thorough testing of specific methods.
|
|
|
+func TestNilExtensionHostMethodCall(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ var h *ExtensionHost
|
|
|
+ typ := reflect.TypeOf(h)
|
|
|
+ for i := range typ.NumMethod() {
|
|
|
+ m := typ.Method(i)
|
|
|
+ if strings.HasSuffix(m.Name, "ForTest") {
|
|
|
+ // Skip methods that are only for testing.
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ t.Run(m.Name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+ // Calling the method on the nil receiver should not panic.
|
|
|
+ ret := checkMethodCallWithZeroArgs(t, m, h)
|
|
|
+ if len(ret) == 1 && ret[0].Kind() == reflect.Func {
|
|
|
+ // If the method returns a single func, such as a cleanup function,
|
|
|
+ // it should not be nil.
|
|
|
+ fn := ret[0]
|
|
|
+ if fn.IsNil() {
|
|
|
+ t.Fatalf("(%T).%s returned a nil func", h, m.Name)
|
|
|
+ }
|
|
|
+ // We expect it to be a no-op and calling it should not panic.
|
|
|
+ args := makeZeroArgsFor(fn)
|
|
|
+ func() {
|
|
|
+ defer func() {
|
|
|
+ if e := recover(); e != nil {
|
|
|
+ t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ fn.Call(args)
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// checkMethodCallWithZeroArgs calls the method m on the receiver r
|
|
|
+// with zero values for all its arguments, except the receiver itself.
|
|
|
+// It returns the result of the method call, or fails the test if the call panics.
|
|
|
+func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value {
|
|
|
+ t.Helper()
|
|
|
+ args := makeZeroArgsFor(m.Func)
|
|
|
+ // The first arg is the receiver.
|
|
|
+ args[0] = reflect.ValueOf(r)
|
|
|
+ // Calling the method should not panic.
|
|
|
+ defer func() {
|
|
|
+ if e := recover(); e != nil {
|
|
|
+ t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ return m.Func.Call(args)
|
|
|
+}
|
|
|
+
|
|
|
+func makeZeroArgsFor(fn reflect.Value) []reflect.Value {
|
|
|
+ args := make([]reflect.Value, fn.Type().NumIn())
|
|
|
+ for i := range args {
|
|
|
+ args[i] = reflect.Zero(fn.Type().In(i))
|
|
|
+ }
|
|
|
+ return args
|
|
|
+}
|
|
|
+
|
|
|
+// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions.
|
|
|
+// It associates each extension that either is or embeds a [testExtension] with the test
|
|
|
+// and assigns a name if one isn’t already set.
|
|
|
+//
|
|
|
+// If the host cannot be created, it fails the test.
|
|
|
+//
|
|
|
+// The host is initialized if the initialize parameter is true.
|
|
|
+// It is shut down automatically when the test ends.
|
|
|
+func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost {
|
|
|
+ t.Helper()
|
|
|
+
|
|
|
+ // testExtensionIface is a subset of the methods implemented by [testExtension] that are used here.
|
|
|
+ // We use testExtensionIface in type assertions instead of using the [testExtension] type directly,
|
|
|
+ // which supports scenarios where an extension type embeds a [testExtension].
|
|
|
+ type testExtensionIface interface {
|
|
|
+ Name() string
|
|
|
+ setName(string)
|
|
|
+ setT(*testing.T)
|
|
|
+ checkShutdown()
|
|
|
+ }
|
|
|
+
|
|
|
+ logf := tstest.WhileTestRunningLogger(t)
|
|
|
+ defs := make([]*ipnext.Definition, len(exts))
|
|
|
+ for i, ext := range exts {
|
|
|
+ if ext, ok := any(ext).(testExtensionIface); ok {
|
|
|
+ ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i)))
|
|
|
+ ext.setT(t)
|
|
|
+ }
|
|
|
+ defs[i] = ipnext.DefinitionForTest(ext)
|
|
|
+ }
|
|
|
+ h, err := NewExtensionHost(logf, &tsd.System{}, b, defs...)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("NewExtensionHost: %v", err)
|
|
|
+ }
|
|
|
+ // Replace doEnqueueBackendOperation with the one that's marked as a helper,
|
|
|
+ // so that we'll have better output if [testExecQueue.Add] fails a test.
|
|
|
+ h.doEnqueueBackendOperation = func(f func(Backend)) {
|
|
|
+ t.Helper()
|
|
|
+ h.workQueue.Add(func() { f(b) })
|
|
|
+ }
|
|
|
+ for _, ext := range exts {
|
|
|
+ if ext, ok := any(ext).(testExtensionIface); ok {
|
|
|
+ t.Cleanup(ext.checkShutdown)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ t.Cleanup(h.Shutdown)
|
|
|
+ if initialize {
|
|
|
+ h.Init()
|
|
|
+ }
|
|
|
+ return h
|
|
|
+}
|
|
|
+
|
|
|
+// testExtension is an [ipnext.Extension] that:
|
|
|
+// - Calls the provided init and shutdown callbacks
|
|
|
+// when [Init] and [Shutdown] are called.
|
|
|
+// - Ensures that [Init] and [Shutdown] are called at most once,
|
|
|
+// that [Shutdown] is called after [Init], but is not called if [Init] fails
|
|
|
+// and is called before the test ends if [Init] succeeds.
|
|
|
+//
|
|
|
+// Typically, [testExtension]s are created and passed to [newExtensionHostForTest]
|
|
|
+// when creating an [ExtensionHost] for testing.
|
|
|
+type testExtension struct {
|
|
|
+ t *testing.T // test that created the extension
|
|
|
+ name string // name of the extension, used for logging
|
|
|
+
|
|
|
+ host ipnext.Host // or nil if not initialized
|
|
|
+
|
|
|
+ // InitHook and ShutdownHook are optional hooks that can be set by tests.
|
|
|
+ InitHook, ShutdownHook func(*testExtension) error
|
|
|
+
|
|
|
+ // initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown
|
|
|
+ // are called at most once and in the correct order.
|
|
|
+ initCnt, initOkCnt, shutdownCnt atomic.Int32
|
|
|
+
|
|
|
+ // mu protects the following fields.
|
|
|
+ mu sync.Mutex
|
|
|
+ // state is the optional state used by tests.
|
|
|
+ // It can be accessed by tests using [setTestExtensionState],
|
|
|
+ // [getTestExtensionStateOk] and [getTestExtensionState].
|
|
|
+ state map[string]any
|
|
|
+ // cleanup are functions to be called on shutdown.
|
|
|
+ cleanup []func()
|
|
|
+}
|
|
|
+
|
|
|
+var _ ipnext.Extension = (*testExtension)(nil)
|
|
|
+
|
|
|
+func (e *testExtension) setT(t *testing.T) {
|
|
|
+ e.t = t
|
|
|
+}
|
|
|
+
|
|
|
+func (e *testExtension) setName(name string) {
|
|
|
+ e.name = name
|
|
|
+}
|
|
|
+
|
|
|
+// Name implements [ipnext.Extension].
|
|
|
+func (e *testExtension) Name() string {
|
|
|
+ return e.name
|
|
|
+}
|
|
|
+
|
|
|
+// Init implements [ipnext.Extension].
|
|
|
+func (e *testExtension) Init(host ipnext.Host) (err error) {
|
|
|
+ e.t.Helper()
|
|
|
+ e.host = host
|
|
|
+ if e.initCnt.Add(1) == 1 {
|
|
|
+ e.mu.Lock()
|
|
|
+ e.state = make(map[string]any)
|
|
|
+ e.mu.Unlock()
|
|
|
+ } else {
|
|
|
+ e.t.Errorf("%q: Init called more than once", e.name)
|
|
|
+ }
|
|
|
+ if e.InitHook != nil {
|
|
|
+ err = e.InitHook(e)
|
|
|
+ }
|
|
|
+ if err == nil {
|
|
|
+ e.initOkCnt.Add(1)
|
|
|
+ }
|
|
|
+ return err // may be nil or non-nil
|
|
|
+}
|
|
|
+
|
|
|
+// InitCalled reports whether the Init method was called on the receiver.
|
|
|
+func (e *testExtension) InitCalled() bool {
|
|
|
+ return e.initCnt.Load() != 0
|
|
|
+}
|
|
|
+
|
|
|
+func (e *testExtension) Cleanup(f func()) {
|
|
|
+ e.mu.Lock()
|
|
|
+ e.cleanup = append(e.cleanup, f)
|
|
|
+ e.mu.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
+// Shutdown implements [ipnext.Extension].
|
|
|
+func (e *testExtension) Shutdown() (err error) {
|
|
|
+ e.t.Helper()
|
|
|
+ e.mu.Lock()
|
|
|
+ cleanup := e.cleanup
|
|
|
+ e.cleanup = nil
|
|
|
+ e.mu.Unlock()
|
|
|
+ for _, f := range cleanup {
|
|
|
+ f()
|
|
|
+ }
|
|
|
+ if e.ShutdownHook != nil {
|
|
|
+ err = e.ShutdownHook(e)
|
|
|
+ }
|
|
|
+ if e.shutdownCnt.Add(1) != 1 {
|
|
|
+ e.t.Errorf("%q: Shutdown called more than once", e.name)
|
|
|
+ }
|
|
|
+ if e.initCnt.Load() == 0 {
|
|
|
+ e.t.Errorf("%q: Shutdown called without Init", e.name)
|
|
|
+ } else if e.initOkCnt.Load() == 0 {
|
|
|
+ e.t.Errorf("%q: Shutdown called despite failed Init", e.name)
|
|
|
+ }
|
|
|
+ e.host = nil
|
|
|
+ return err // may be nil or non-nil
|
|
|
+}
|
|
|
+
|
|
|
+func (e *testExtension) checkShutdown() {
|
|
|
+ e.t.Helper()
|
|
|
+ if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 {
|
|
|
+ e.t.Errorf("%q: Shutdown has not been called before test end", e.name)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ShutdownCalled reports whether the Shutdown method was called on the receiver.
|
|
|
+func (e *testExtension) ShutdownCalled() bool {
|
|
|
+ return e.shutdownCnt.Load() != 0
|
|
|
+}
|
|
|
+
|
|
|
+// SetExtState sets a keyed state on [testExtension] to the given value.
|
|
|
+// Tests use it to propagate test-specific state throughout the extension lifecycle
|
|
|
+// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks)
|
|
|
+func SetExtState[T any](e *testExtension, key string, value T) {
|
|
|
+ e.mu.Lock()
|
|
|
+ defer e.mu.Unlock()
|
|
|
+ e.state[key] = value
|
|
|
+}
|
|
|
+
|
|
|
+// UpdateExtState updates a keyed state of the extension using the provided update function.
|
|
|
+func UpdateExtState[T any](e *testExtension, key string, update func(T) T) {
|
|
|
+ e.mu.Lock()
|
|
|
+ defer e.mu.Unlock()
|
|
|
+ old, _ := e.state[key].(T)
|
|
|
+ new := update(old)
|
|
|
+ e.state[key] = new
|
|
|
+}
|
|
|
+
|
|
|
+// GetExtState returns the value of the keyed state of the extension.
|
|
|
+// It returns a zero value of T if the state is not set or is of a different type.
|
|
|
+func GetExtState[T any](e *testExtension, key string) T {
|
|
|
+ v, _ := GetExtStateOk[T](e, key)
|
|
|
+ return v
|
|
|
+}
|
|
|
+
|
|
|
+// GetExtStateOk is like [getExtState], but also reports whether the state
|
|
|
+// with the given key exists and is of the expected type.
|
|
|
+func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) {
|
|
|
+ e.mu.Lock()
|
|
|
+ defer e.mu.Unlock()
|
|
|
+ v, ok := e.state[key].(T)
|
|
|
+ return v, ok
|
|
|
+}
|
|
|
+
|
|
|
+// testExecQueue is a test implementation of [execQueue]
|
|
|
+// that defers execution of the enqueued funcs until
|
|
|
+// [testExecQueue.Drain] is called, and fails the test if
|
|
|
+// if [execQueue.Add] is called before the host is initialized.
|
|
|
+//
|
|
|
+// It is typically used by calling [ExtensionHost.SetWorkQueueForTest].
|
|
|
+type testExecQueue struct {
|
|
|
+ t *testing.T // test that created the queue
|
|
|
+ h *ExtensionHost // host to own the queue
|
|
|
+
|
|
|
+ mu sync.Mutex
|
|
|
+ queue []func()
|
|
|
+}
|
|
|
+
|
|
|
+var _ execQueue = (*testExecQueue)(nil)
|
|
|
+
|
|
|
+// SetWorkQueueForTest is a helper function that creates a new [testExecQueue]
|
|
|
+// and sets it as the work queue for the specified [ExtensionHost],
|
|
|
+// returning the new queue.
|
|
|
+//
|
|
|
+// It fails the test if the host is already initialized.
|
|
|
+func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue {
|
|
|
+ t.Helper()
|
|
|
+ if h.initialized.Load() {
|
|
|
+ t.Fatalf("UseTestWorkQueue: host is already initialized")
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ q := &testExecQueue{t: t, h: h}
|
|
|
+ h.workQueue = q
|
|
|
+ return q
|
|
|
+}
|
|
|
+
|
|
|
+// Add implements [execQueue].
|
|
|
+func (q *testExecQueue) Add(f func()) {
|
|
|
+ q.t.Helper()
|
|
|
+
|
|
|
+ if !q.h.initialized.Load() {
|
|
|
+ q.t.Fatal("ExecQueue.Add must not be called until the host is initialized")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ q.mu.Lock()
|
|
|
+ q.queue = append(q.queue, f)
|
|
|
+ q.mu.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
+// Drain executes all queued functions in the order they were added.
|
|
|
+func (q *testExecQueue) Drain() {
|
|
|
+ q.mu.Lock()
|
|
|
+ queue := q.queue
|
|
|
+ q.queue = nil
|
|
|
+ q.mu.Unlock()
|
|
|
+
|
|
|
+ for _, f := range queue {
|
|
|
+ f()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Shutdown implements [execQueue].
|
|
|
+func (q *testExecQueue) Shutdown() {}
|
|
|
+
|
|
|
+// Wait implements [execQueue].
|
|
|
+func (q *testExecQueue) Wait(context.Context) error { return nil }
|
|
|
+
|
|
|
+// testBackend implements [ipnext.Backend] for testing purposes
|
|
|
+// by calling the provided hooks when its methods are called.
|
|
|
+type testBackend struct {
|
|
|
+ switchToBestProfileHook func(reason string)
|
|
|
+
|
|
|
+ // mu protects the backend state.
|
|
|
+ // It is acquired on entry to the exported methods of the backend
|
|
|
+ // and released on exit, mimicking the behavior of the [LocalBackend].
|
|
|
+ mu sync.Mutex
|
|
|
+}
|
|
|
+
|
|
|
+func (b *testBackend) SwitchToBestProfile(reason string) {
|
|
|
+ b.mu.Lock()
|
|
|
+ defer b.mu.Unlock()
|
|
|
+ if b.switchToBestProfileHook != nil {
|
|
|
+ b.switchToBestProfileHook(reason)
|
|
|
+ }
|
|
|
+}
|