2
0
Эх сурвалжийг харах

ipn/ipnext: remove support for unregistering extension

Updates #12614

Change-Id: I893e3ea74831deaa6f88e31bba2d95dc017e0470
Co-authored-by: Nick Khyl <[email protected]>
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 10 сар өмнө
parent
commit
25c4dc5fd7

+ 3 - 11
ipn/auditlog/extension.go

@@ -36,8 +36,6 @@ func init() {
 type extension struct {
 	logf logger.Logf
 
-	// cleanup are functions to call on shutdown.
-	cleanup []func()
 	// store is the log store shared by all loggers.
 	// It is created when the first logger is started.
 	store lazy.SyncValue[LogStore]
@@ -66,11 +64,9 @@ func (e *extension) Name() string {
 // Init implements [ipnext.Extension] by registering callbacks and providers
 // for the duration of the extension's lifetime.
 func (e *extension) Init(h ipnext.Host) error {
-	e.cleanup = []func(){
-		h.RegisterControlClientCallback(e.controlClientChanged),
-		h.Profiles().RegisterProfileStateChangeCallback(e.profileChanged),
-		h.RegisterAuditLogProvider(e.getCurrentLogger),
-	}
+	h.RegisterControlClientCallback(e.controlClientChanged)
+	h.Profiles().RegisterProfileStateChangeCallback(e.profileChanged)
+	h.RegisterAuditLogProvider(e.getCurrentLogger)
 	return nil
 }
 
@@ -190,9 +186,5 @@ func (e *extension) getCurrentLogger() ipnauth.AuditLogFunc {
 
 // Shutdown implements [ipnlocal.Extension].
 func (e *extension) Shutdown() error {
-	for _, f := range e.cleanup {
-		f()
-	}
-	e.cleanup = nil
 	return nil
 }

+ 2 - 3
ipn/desktop/extension.go

@@ -74,13 +74,12 @@ func (e *desktopSessionsExt) Name() string {
 // Init implements [ipnext.Extension].
 func (e *desktopSessionsExt) Init(host ipnext.Host) (err error) {
 	e.host = host
-	unregisterResolver := host.Profiles().RegisterBackgroundProfileResolver(e.getBackgroundProfile)
 	unregisterSessionCb, err := e.sm.RegisterStateCallback(e.updateDesktopSessionState)
 	if err != nil {
-		unregisterResolver()
 		return fmt.Errorf("session callback registration failed: %w", err)
 	}
-	e.cleanup = []func(){unregisterResolver, unregisterSessionCb}
+	host.Profiles().RegisterBackgroundProfileResolver(e.getBackgroundProfile)
+	e.cleanup = []func(){unregisterSessionCb}
 	return nil
 }
 

+ 20 - 12
ipn/ipnext/ipnext.go

@@ -43,6 +43,7 @@ type Extension interface {
 	// provided the extension was initialized. For multiple extensions,
 	// Shutdown is called in the reverse order of Init.
 	// Returned errors are not fatal; they are used for logging.
+	// After a call to Shutdown, the extension will not be called again.
 	Shutdown() error
 }
 
@@ -182,9 +183,11 @@ type Host interface {
 
 	// RegisterAuditLogProvider registers an audit log provider,
 	// which returns a function to be called when an auditable action
-	// is about to be performed. The returned function unregisters the provider.
-	// It is a runtime error to register a nil provider.
-	RegisterAuditLogProvider(AuditLogProvider) (unregister func())
+	// is about to be performed.
+	//
+	// It is a runtime error to register a nil provider or call after the host
+	// has been initialized.
+	RegisterAuditLogProvider(AuditLogProvider)
 
 	// AuditLogger returns a function that calls all currently registered audit loggers.
 	// The function fails if any logger returns an error, indicating that the action
@@ -195,9 +198,11 @@ type Host interface {
 	AuditLogger() ipnauth.AuditLogFunc
 
 	// RegisterControlClientCallback registers a function to be called every time a new
-	// control client is created. The returned function unregisters the callback.
-	// It is a runtime error to register a nil callback.
-	RegisterControlClientCallback(NewControlClientCallback) (unregister func())
+	// control client is created.
+	//
+	// It is a runtime error to register a nil provider or call after the host
+	// has been initialized.
+	RegisterControlClientCallback(NewControlClientCallback)
 }
 
 // ExtensionServices provides access to the [Host]'s extension management services,
@@ -252,23 +257,26 @@ type ProfileServices interface {
 	SwitchToBestProfileAsync(reason string)
 
 	// RegisterBackgroundProfileResolver registers a function to be used when
-	// resolving the background profile. The returned function unregisters the resolver.
-	// It is a runtime error to register a nil resolver.
+	// resolving the background profile.
+	//
+	// It is a runtime error to register a nil provider or call after the host
+	// has been initialized.
 	//
 	// TODO(nickkhyl): allow specifying some kind of priority/altitude for the resolver.
 	// TODO(nickkhyl): make it a "profile resolver" instead of a "background profile resolver".
 	// The concepts of the "current user", "foreground profile" and "background profile"
 	// only exist on Windows, and we're moving away from them anyway.
-	RegisterBackgroundProfileResolver(ProfileResolver) (unregister func())
+	RegisterBackgroundProfileResolver(ProfileResolver)
 
 	// RegisterProfileStateChangeCallback registers a function to be called when the current
-	// [ipn.LoginProfile] or its [ipn.Prefs] change. The returned function unregisters the callback.
+	// [ipn.LoginProfile] or its [ipn.Prefs] change.
 	//
 	// To get the initial profile or prefs, use [ProfileServices.CurrentProfileState]
 	// or [ProfileServices.CurrentPrefs] from the extension's [Extension.Init].
 	//
-	// It is a runtime error to register a nil callback.
-	RegisterProfileStateChangeCallback(ProfileStateChangeCallback) (unregister func())
+	// It is a runtime error to register a nil provider or call after the host
+	// has been initialized.
+	RegisterProfileStateChangeCallback(ProfileStateChangeCallback)
 }
 
 // ProfileStore provides read-only access to available login profiles and their preferences.

+ 68 - 109
ipn/ipnlocal/extension_host.go

@@ -7,7 +7,6 @@ import (
 	"context"
 	"errors"
 	"fmt"
-	"iter"
 	"maps"
 	"reflect"
 	"slices"
@@ -24,8 +23,6 @@ import (
 	"tailscale.com/tsd"
 	"tailscale.com/types/logger"
 	"tailscale.com/util/execqueue"
-	"tailscale.com/util/set"
-	"tailscale.com/util/slicesx"
 	"tailscale.com/util/testenv"
 )
 
@@ -78,6 +75,7 @@ type ExtensionHost struct {
 	// initOnce is used to ensure that the extensions are initialized only once,
 	// even if [extensionHost.Init] is called multiple times.
 	initOnce sync.Once
+	initDone atomic.Bool
 	// shutdownOnce is like initOnce, but for [ExtensionHost.Shutdown].
 	shutdownOnce sync.Once
 
@@ -87,6 +85,24 @@ type ExtensionHost struct {
 	// doEnqueueBackendOperation adds an asynchronous [LocalBackend] operation to the workQueue.
 	doEnqueueBackendOperation func(func(Backend))
 
+	// profileStateChangeCbs are callbacks that are invoked when the current login profile
+	// or its [ipn.Prefs] change, after those changes have been made. The current login profile
+	// may be changed either because of a profile switch, or because the profile information
+	// was updated by [LocalBackend.SetControlClientStatus], including when the profile
+	// is first populated and persisted.
+	profileStateChangeCbs []ipnext.ProfileStateChangeCallback
+	// backgroundProfileResolvers are registered background profile resolvers.
+	// They're used to determine the profile to use when no GUI/CLI client is connected.
+	backgroundProfileResolvers []ipnext.ProfileResolver
+	// auditLoggers are registered [AuditLogProvider]s.
+	// Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action
+	// is about to be performed. If an audit logger returns an error, the action is denied.
+	auditLoggers []ipnext.AuditLogProvider
+	// newControlClientCbs are the functions to be called when a new control client is created.
+	newControlClientCbs []ipnext.NewControlClientCallback
+
+	shuttingDown atomic.Bool
+
 	// mu protects the following fields.
 	// It must not be held when calling [LocalBackend] methods
 	// or when invoking callbacks registered by extensions.
@@ -107,22 +123,6 @@ type ExtensionHost struct {
 	// currentPrefs is a read-only view of the current profile's [ipn.Prefs]
 	// with any private keys stripped. It is always Valid.
 	currentPrefs ipn.PrefsView
-
-	// auditLoggers are registered [AuditLogProvider]s.
-	// Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action
-	// is about to be performed. If an audit logger returns an error, the action is denied.
-	auditLoggers set.HandleSet[ipnext.AuditLogProvider]
-	// backgroundProfileResolvers are registered background profile resolvers.
-	// They're used to determine the profile to use when no GUI/CLI client is connected.
-	backgroundProfileResolvers set.HandleSet[ipnext.ProfileResolver]
-	// newControlClientCbs are the functions to be called when a new control client is created.
-	newControlClientCbs set.HandleSet[ipnext.NewControlClientCallback]
-	// profileStateChangeCbs are callbacks that are invoked when the current login profile
-	// or its [ipn.Prefs] change, after those changes have been made. The current login profile
-	// may be changed either because of a profile switch, or because the profile information
-	// was updated by [LocalBackend.SetControlClientStatus], including when the profile
-	// is first populated and persisted.
-	profileStateChangeCbs set.HandleSet[ipnext.ProfileStateChangeCallback]
 }
 
 // Backend is a subset of [LocalBackend] methods that are used by [ExtensionHost].
@@ -160,13 +160,10 @@ func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts
 		host.workQueue.Add(func() { f(b) })
 	}
 
-	var numExts int
-	var exts iter.Seq2[int, *ipnext.Definition]
-	if overrideExts == nil {
-		// Use registered extensions.
-		exts = ipnext.Extensions().All()
-		numExts = ipnext.Extensions().Len()
-	} else {
+	// Use registered extensions.
+	exts := ipnext.Extensions().All()
+	numExts := ipnext.Extensions().Len()
+	if overrideExts != nil {
 		// Use the provided, potentially empty, overrideExts
 		// instead of the registered ones.
 		exts = slices.All(overrideExts)
@@ -196,6 +193,8 @@ func (h *ExtensionHost) Init() {
 }
 
 func (h *ExtensionHost) init() {
+	defer h.initDone.Store(true)
+
 	// Initialize the extensions in the order they were registered.
 	h.mu.Lock()
 	h.activeExtensions = make([]ipnext.Extension, 0, len(h.allExtensions))
@@ -343,21 +342,21 @@ func (h *ExtensionHost) Backend() Backend {
 	return h.b
 }
 
-// RegisterProfileStateChangeCallback implements [ipnext.ProfileServices].
-func (h *ExtensionHost) RegisterProfileStateChangeCallback(cb ipnext.ProfileStateChangeCallback) (unregister func()) {
-	if h == nil {
-		return func() {}
+// addFuncHook appends non-nil fn to hooks.
+func addFuncHook[F any](h *ExtensionHost, hooks *[]F, fn F) {
+	if h.initDone.Load() {
+		panic("invalid callback register after init")
 	}
-	if cb == nil {
-		panic("nil profile change callback")
+	if reflect.ValueOf(fn).IsZero() {
+		panic("nil function hook")
 	}
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	handle := h.profileStateChangeCbs.Add(cb)
-	return func() {
-		h.mu.Lock()
-		defer h.mu.Unlock()
-		delete(h.profileStateChangeCbs, handle)
+	*hooks = append(*hooks, fn)
+}
+
+// RegisterProfileStateChangeCallback implements [ipnext.ProfileServices].
+func (h *ExtensionHost) RegisterProfileStateChangeCallback(cb ipnext.ProfileStateChangeCallback) {
+	if h != nil {
+		addFuncHook(h, &h.profileStateChangeCbs, cb)
 	}
 }
 
@@ -366,7 +365,7 @@ func (h *ExtensionHost) RegisterProfileStateChangeCallback(cb ipnext.ProfileStat
 // It strips private keys from the [ipn.Prefs] before preserving
 // or passing them to the callbacks.
 func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
-	if h == nil {
+	if !h.active() {
 		return
 	}
 	h.mu.Lock()
@@ -378,10 +377,9 @@ func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs
 	// so we can provide them to the extensions later if they ask.
 	h.currentPrefs = prefs
 	h.currentProfile = profile
-	// Get the callbacks to be invoked.
-	cbs := slicesx.MapValues(h.profileStateChangeCbs)
 	h.mu.Unlock()
-	for _, cb := range cbs {
+
+	for _, cb := range h.profileStateChangeCbs {
 		cb(profile, prefs, sameNode)
 	}
 }
@@ -390,7 +388,7 @@ func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs
 // and updates the current profile and prefs in the host.
 // It strips private keys from the [ipn.Prefs] before preserving or using them.
 func (h *ExtensionHost) NotifyProfilePrefsChanged(profile ipn.LoginProfileView, oldPrefs, newPrefs ipn.PrefsView) {
-	if h == nil {
+	if !h.active() {
 		return
 	}
 	h.mu.Lock()
@@ -403,28 +401,24 @@ func (h *ExtensionHost) NotifyProfilePrefsChanged(profile ipn.LoginProfileView,
 	h.currentPrefs = newPrefs
 	h.currentProfile = profile
 	// Get the callbacks to be invoked.
-	stateCbs := slicesx.MapValues(h.profileStateChangeCbs)
 	h.mu.Unlock()
-	for _, cb := range stateCbs {
+
+	for _, cb := range h.profileStateChangeCbs {
 		cb(profile, newPrefs, true)
 	}
 }
 
 // RegisterBackgroundProfileResolver implements [ipnext.ProfileServices].
-func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.ProfileResolver) (unregister func()) {
-	if h == nil {
-		return func() {}
-	}
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	handle := h.backgroundProfileResolvers.Add(resolver)
-	return func() {
-		h.mu.Lock()
-		defer h.mu.Unlock()
-		delete(h.backgroundProfileResolvers, handle)
+func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.ProfileResolver) {
+	if h != nil {
+		addFuncHook(h, &h.backgroundProfileResolvers, resolver)
 	}
 }
 
+func (h *ExtensionHost) active() bool {
+	return h != nil && !h.shuttingDown.Load()
+}
+
 // DetermineBackgroundProfile returns a read-only view of the profile
 // used when no GUI/CLI client is connected, using background profile
 // resolvers registered by extensions.
@@ -434,7 +428,7 @@ func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.Profil
 //
 // As of 2025-02-07, this is only used on Windows.
 func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView {
-	if h == nil {
+	if !h.active() {
 		return ipn.LoginProfileView{}
 	}
 	// TODO(nickkhyl): check if the returned profile is allowed on the device,
@@ -443,10 +437,7 @@ func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore)
 
 	// Attempt to resolve the background profile using the registered
 	// background profile resolvers (e.g., [ipn/desktop.desktopSessionsExt] on Windows).
-	h.mu.Lock()
-	resolvers := slicesx.MapValues(h.backgroundProfileResolvers)
-	h.mu.Unlock()
-	for _, resolver := range resolvers {
+	for _, resolver := range h.backgroundProfileResolvers {
 		if profile := resolver(profiles); profile.Valid() {
 			return profile
 		}
@@ -458,35 +449,21 @@ func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore)
 }
 
 // RegisterControlClientCallback implements [ipnext.Host].
-func (h *ExtensionHost) RegisterControlClientCallback(cb ipnext.NewControlClientCallback) (unregister func()) {
-	if h == nil {
-		return func() {}
-	}
-	if cb == nil {
-		panic("nil control client callback")
-	}
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	handle := h.newControlClientCbs.Add(cb)
-	return func() {
-		h.mu.Lock()
-		defer h.mu.Unlock()
-		delete(h.newControlClientCbs, handle)
+func (h *ExtensionHost) RegisterControlClientCallback(cb ipnext.NewControlClientCallback) {
+	if h != nil {
+		addFuncHook(h, &h.newControlClientCbs, cb)
 	}
 }
 
 // NotifyNewControlClient invokes all registered control client callbacks.
 // It returns callbacks to be executed when the control client shuts down.
 func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile ipn.LoginProfileView) (ccShutdownCbs []func()) {
-	if h == nil {
+	if !h.active() {
 		return nil
 	}
-	h.mu.Lock()
-	cbs := slicesx.MapValues(h.newControlClientCbs)
-	h.mu.Unlock()
-	if len(cbs) > 0 {
-		ccShutdownCbs = make([]func(), 0, len(cbs))
-		for _, cb := range cbs {
+	if len(h.newControlClientCbs) > 0 {
+		ccShutdownCbs = make([]func(), 0, len(h.newControlClientCbs))
+		for _, cb := range h.newControlClientCbs {
 			if shutdown := cb(cc, profile); shutdown != nil {
 				ccShutdownCbs = append(ccShutdownCbs, shutdown)
 			}
@@ -496,20 +473,9 @@ func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile
 }
 
 // RegisterAuditLogProvider implements [ipnext.Host].
-func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvider) (unregister func()) {
-	if h == nil {
-		return func() {}
-	}
-	if provider == nil {
-		panic("nil audit log provider")
-	}
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	handle := h.auditLoggers.Add(provider)
-	return func() {
-		h.mu.Lock()
-		defer h.mu.Unlock()
-		delete(h.auditLoggers, handle)
+func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvider) {
+	if h != nil {
+		addFuncHook(h, &h.auditLoggers, provider)
 	}
 }
 
@@ -523,20 +489,12 @@ func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvide
 // which typically includes the current profile and the audit loggers registered by extensions.
 // It must not be persisted outside of the auditable action context.
 func (h *ExtensionHost) AuditLogger() ipnauth.AuditLogFunc {
-	if h == nil {
+	if !h.active() {
 		return func(tailcfg.ClientAuditAction, string) error { return nil }
 	}
-
-	h.mu.Lock()
-	providers := slicesx.MapValues(h.auditLoggers)
-	h.mu.Unlock()
-
-	var loggers []ipnauth.AuditLogFunc
-	if len(providers) > 0 {
-		loggers = make([]ipnauth.AuditLogFunc, len(providers))
-		for i, provider := range providers {
-			loggers[i] = provider()
-		}
+	loggers := make([]ipnauth.AuditLogFunc, 0, len(h.auditLoggers))
+	for _, provider := range h.auditLoggers {
+		loggers = append(loggers, provider())
 	}
 	return func(action tailcfg.ClientAuditAction, details string) error {
 		// Log auditable actions to the host's log regardless of whether
@@ -567,6 +525,7 @@ func (h *ExtensionHost) Shutdown() {
 }
 
 func (h *ExtensionHost) shutdown() {
+	h.shuttingDown.Store(true)
 	// Prevent any queued but not yet started operations from running,
 	// block new operations from being enqueued, and wait for the
 	// currently executing operation (if any) to finish.

+ 5 - 41
ipn/ipnlocal/extension_host_test.go

@@ -576,30 +576,6 @@ func TestExtensionHostProfileStateChangeCallback(t *testing.T) {
 				{Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
 			},
 		},
-		{
-			// Override the default InitHook used in the test to unregister the callback
-			// after the first call.
-			name: "Register/Once",
-			ext: &testExtension{
-				InitHook: func(e *testExtension) error {
-					var unregister func()
-					handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
-						makeStateChangeAppender(e)(profile, prefs, sameNode)
-						unregister()
-					}
-					unregister = e.host.Profiles().RegisterProfileStateChangeCallback(handler)
-					return nil
-				},
-			},
-			stateCalls: []stateChange{
-				{Profile: &ipn.LoginProfile{ID: "profile-1"}},
-				{Profile: &ipn.LoginProfile{ID: "profile-2"}},
-				{Profile: &ipn.LoginProfile{ID: "profile-3"}},
-			},
-			wantChanges: []stateChange{ // only the first call is received by the callback
-				{Profile: &ipn.LoginProfile{ID: "profile-1"}},
-			},
-		},
 		{
 			// Ensure that ipn.Prefs are passed to the callback.
 			name: "CheckPrefs",
@@ -770,7 +746,7 @@ func TestExtensionHostProfileStateChangeCallback(t *testing.T) {
 				tt.ext.InitHook = func(e *testExtension) error {
 					// Create and register the callback on init.
 					handler := makeStateChangeAppender(e)
-					e.Cleanup(e.host.Profiles().RegisterProfileStateChangeCallback(handler))
+					e.host.Profiles().RegisterProfileStateChangeCallback(handler)
 					return nil
 				}
 			}
@@ -891,14 +867,15 @@ func TestBackgroundProfileResolver(t *testing.T) {
 				}
 			}
 
-			h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true)
+			h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, false)
 
 			// Register the resolvers with the host.
 			// This is typically done by the extensions themselves,
 			// but we do it here for testing purposes.
 			for _, r := range tt.resolvers {
-				t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r))
+				h.Profiles().RegisterBackgroundProfileResolver(r)
 			}
+			h.Init()
 
 			// Call the resolver to get the profile.
 			gotProfile := h.DetermineBackgroundProfile(pm)
@@ -989,7 +966,7 @@ func TestAuditLogProviders(t *testing.T) {
 					}
 				}
 				ext.InitHook = func(e *testExtension) error {
-					e.Cleanup(e.host.RegisterAuditLogProvider(provider))
+					e.host.RegisterAuditLogProvider(provider)
 					return nil
 				}
 				exts = append(exts, ext)
@@ -1168,8 +1145,6 @@ type testExtension struct {
 	// It can be accessed by tests using [setTestExtensionState],
 	// [getTestExtensionStateOk] and [getTestExtensionState].
 	state map[string]any
-	// cleanup are functions to be called on shutdown.
-	cleanup []func()
 }
 
 var _ ipnext.Extension = (*testExtension)(nil)
@@ -1212,22 +1187,11 @@ func (e *testExtension) InitCalled() bool {
 	return e.initCnt.Load() != 0
 }
 
-func (e *testExtension) Cleanup(f func()) {
-	e.mu.Lock()
-	e.cleanup = append(e.cleanup, f)
-	e.mu.Unlock()
-}
-
 // Shutdown implements [ipnext.Extension].
 func (e *testExtension) Shutdown() (err error) {
 	e.t.Helper()
 	e.mu.Lock()
-	cleanup := e.cleanup
-	e.cleanup = nil
 	e.mu.Unlock()
-	for _, f := range cleanup {
-		f()
-	}
 	if e.ShutdownHook != nil {
 		err = e.ShutdownHook(e)
 	}