Browse Source

ipn/{ipnext,ipnlocal}: add a SafeBackend interface

Updates #12614

Change-Id: I197e673666e86ea74c19e3935ed71aec269b6c94
Co-authored-by: Nick Khyl <[email protected]>
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 10 months ago
parent
commit
3d8533b5d0

+ 1 - 2
feature/relayserver/relayserver.go

@@ -20,7 +20,6 @@ import (
 	"tailscale.com/ipn/ipnlocal"
 	"tailscale.com/net/udprelay"
 	"tailscale.com/tailcfg"
-	"tailscale.com/tsd"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/ptr"
@@ -40,7 +39,7 @@ func init() {
 // newExtension is an [ipnext.NewExtensionFn] that creates a new relay server
 // extension. It is registered with [ipnext.RegisterExtension] if the package is
 // imported.
-func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) {
+func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) {
 	return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil
 }
 

+ 7 - 9
feature/taildrop/ext.go

@@ -7,7 +7,6 @@ import (
 	"tailscale.com/ipn/ipnext"
 	"tailscale.com/ipn/ipnlocal"
 	"tailscale.com/taildrop"
-	"tailscale.com/tsd"
 	"tailscale.com/types/logger"
 )
 
@@ -15,7 +14,7 @@ func init() {
 	ipnext.RegisterExtension("taildrop", newExtension)
 }
 
-func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) {
+func newExtension(logf logger.Logf, b ipnext.SafeBackend) (ipnext.Extension, error) {
 	return &extension{
 		logf: logger.WithPrefix(logf, "taildrop: "),
 	}, nil
@@ -23,7 +22,7 @@ func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) {
 
 type extension struct {
 	logf logger.Logf
-	lb   *ipnlocal.LocalBackend
+	sb   ipnext.SafeBackend
 	mgr  *taildrop.Manager
 }
 
@@ -32,11 +31,6 @@ func (e *extension) Name() string {
 }
 
 func (e *extension) Init(h ipnext.Host) error {
-	type I interface {
-		Backend() ipnlocal.Backend
-	}
-	e.lb = h.(I).Backend().(*ipnlocal.LocalBackend)
-
 	// TODO(bradfitz): move init of taildrop.Manager from ipnlocal/peerapi.go to
 	// here
 	e.mgr = nil
@@ -45,7 +39,11 @@ func (e *extension) Init(h ipnext.Host) error {
 }
 
 func (e *extension) Shutdown() error {
-	if mgr, err := e.lb.TaildropManager(); err == nil {
+	lb, ok := e.sb.(*ipnlocal.LocalBackend)
+	if !ok {
+		return nil
+	}
+	if mgr, err := lb.TaildropManager(); err == nil {
 		mgr.Shutdown()
 	} else {
 		e.logf("taildrop: failed to shutdown taildrop manager: %v", err)

+ 1 - 2
ipn/auditlog/extension.go

@@ -16,7 +16,6 @@ import (
 	"tailscale.com/ipn/ipnauth"
 	"tailscale.com/ipn/ipnext"
 	"tailscale.com/tailcfg"
-	"tailscale.com/tsd"
 	"tailscale.com/types/lazy"
 	"tailscale.com/types/logger"
 )
@@ -52,7 +51,7 @@ type extension struct {
 
 // newExtension is an [ipnext.NewExtensionFn] that creates a new audit log extension.
 // It is registered with [ipnext.RegisterExtension] if the package is imported.
-func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) {
+func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) {
 	return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil
 }
 

+ 1 - 2
ipn/desktop/extension.go

@@ -17,7 +17,6 @@ import (
 	"tailscale.com/feature"
 	"tailscale.com/ipn"
 	"tailscale.com/ipn/ipnext"
-	"tailscale.com/tsd"
 	"tailscale.com/types/logger"
 	"tailscale.com/util/syspolicy"
 )
@@ -53,7 +52,7 @@ type desktopSessionsExt struct {
 // newDesktopSessionsExt returns a new [desktopSessionsExt],
 // or an error if a [SessionManager] cannot be created.
 // It is registered with [ipnext.RegisterExtension] if the package is imported.
-func newDesktopSessionsExt(logf logger.Logf, sys *tsd.System) (ipnext.Extension, error) {
+func newDesktopSessionsExt(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) {
 	logf = logger.WithPrefix(logf, featureName+": ")
 	sm, err := NewSessionManager(logf)
 	if err != nil {

+ 19 - 5
ipn/ipnext/ipnext.go

@@ -13,6 +13,7 @@ import (
 	"tailscale.com/ipn"
 	"tailscale.com/ipn/ipnauth"
 	"tailscale.com/tsd"
+	"tailscale.com/tstime"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/views"
 	"tailscale.com/util/mak"
@@ -52,7 +53,7 @@ type Extension interface {
 // If the extension should be skipped at runtime, it must return either [SkipExtension]
 // or a wrapped [SkipExtension]. Any other error returned is fatal and will prevent
 // the LocalBackend from starting.
-type NewExtensionFn func(logger.Logf, *tsd.System) (Extension, error)
+type NewExtensionFn func(logger.Logf, SafeBackend) (Extension, error)
 
 // SkipExtension is an error returned by [NewExtensionFn] to indicate that the extension
 // should be skipped rather than prevent the LocalBackend from starting.
@@ -78,8 +79,8 @@ func (d *Definition) Name() string {
 }
 
 // MakeExtension instantiates the extension.
-func (d *Definition) MakeExtension(logf logger.Logf, sys *tsd.System) (Extension, error) {
-	ext, err := d.newFn(logf, sys)
+func (d *Definition) MakeExtension(logf logger.Logf, sb SafeBackend) (Extension, error) {
+	ext, err := d.newFn(logf, sb)
 	if err != nil {
 		return nil, err
 	}
@@ -130,7 +131,7 @@ func Extensions() views.Slice[*Definition] {
 func DefinitionForTest(ext Extension) *Definition {
 	return &Definition{
 		name:  ext.Name(),
-		newFn: func(logger.Logf, *tsd.System) (Extension, error) { return ext, nil },
+		newFn: func(logger.Logf, SafeBackend) (Extension, error) { return ext, nil },
 	}
 }
 
@@ -140,7 +141,7 @@ func DefinitionForTest(ext Extension) *Definition {
 func DefinitionWithErrForTest(name string, err error) *Definition {
 	return &Definition{
 		name:  name,
-		newFn: func(logger.Logf, *tsd.System) (Extension, error) { return nil, err },
+		newFn: func(logger.Logf, SafeBackend) (Extension, error) { return nil, err },
 	}
 }
 
@@ -203,6 +204,19 @@ type Host interface {
 	// It is a runtime error to register a nil provider or call after the host
 	// has been initialized.
 	RegisterControlClientCallback(NewControlClientCallback)
+
+	// SendNotifyAsync sends a notification to the IPN bus,
+	// typically to the GUI client.
+	SendNotifyAsync(ipn.Notify)
+}
+
+// SafeBackend is a subset of the [ipnlocal.LocalBackend] type's methods that
+// are safe to call from extension hooks at any time (even hooks called while
+// LocalBackend's internal mutex is held).
+type SafeBackend interface {
+	Sys() *tsd.System
+	Clock() tstime.Clock
+	TailscaleVarRoot() string
 }
 
 // ExtensionServices provides access to the [Host]'s extension management services,

+ 27 - 9
ipn/ipnlocal/extension_host.go

@@ -20,7 +20,6 @@ import (
 	"tailscale.com/ipn/ipnauth"
 	"tailscale.com/ipn/ipnext"
 	"tailscale.com/tailcfg"
-	"tailscale.com/tsd"
 	"tailscale.com/types/logger"
 	"tailscale.com/util/execqueue"
 	"tailscale.com/util/testenv"
@@ -131,15 +130,32 @@ type Backend interface {
 	// SwitchToBestProfile switches to the best profile for the current state of the system.
 	// The reason indicates why the profile is being switched.
 	SwitchToBestProfile(reason string)
+
+	SendNotify(ipn.Notify)
+	ipnext.SafeBackend
 }
 
 // NewExtensionHost returns a new [ExtensionHost] which manages registered extensions for the given backend.
 // The extensions are instantiated, but are not initialized until [ExtensionHost.Init] is called.
 // It returns an error if instantiating any extension fails.
+func NewExtensionHost(logf logger.Logf, b Backend) (*ExtensionHost, error) {
+	return newExtensionHost(logf, b)
+}
+
+func NewExtensionHostForTest(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (*ExtensionHost, error) {
+	if !testenv.InTest() {
+		panic("use outside of test")
+	}
+	return newExtensionHost(logf, b, overrideExts...)
+}
+
+// newExtensionHost is the shared implementation of [NewExtensionHost] and
+// [NewExtensionHostForTest].
 //
-// If overrideExts is non-nil, the registered extensions are ignored and the provided extensions are used instead.
-// Overriding extensions is primarily used for testing.
-func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) {
+// If overrideExts is non-nil, the registered extensions are ignored and the
+// provided extensions are used instead. Overriding extensions is primarily used
+// for testing.
+func newExtensionHost(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) {
 	host := &ExtensionHost{
 		b:         b,
 		logf:      logger.WithPrefix(logf, "ipnext: "),
@@ -172,7 +188,7 @@ func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts
 
 	host.allExtensions = make([]ipnext.Extension, 0, numExts)
 	for _, d := range exts {
-		ext, err := d.MakeExtension(logf, sys)
+		ext, err := d.MakeExtension(logf, b)
 		if errors.Is(err, ipnext.SkipExtension) {
 			// The extension wants to be skipped.
 			host.logf("%q: %v", d.Name(), err)
@@ -334,12 +350,14 @@ func (h *ExtensionHost) SwitchToBestProfileAsync(reason string) {
 	})
 }
 
-// Backend returns the [Backend] used by the extension host.
-func (h *ExtensionHost) Backend() Backend {
+// SendNotifyAsync implements [ipnext.Host].
+func (h *ExtensionHost) SendNotifyAsync(n ipn.Notify) {
 	if h == nil {
-		return nil
+		return
 	}
-	return h.b
+	h.enqueueBackendOperation(func(b Backend) {
+		b.SendNotify(n)
+	})
 }
 
 // addFuncHook appends non-nil fn to hooks.

+ 12 - 2
ipn/ipnlocal/extension_host_test.go

@@ -27,7 +27,9 @@ import (
 	"tailscale.com/tailcfg"
 	"tailscale.com/tsd"
 	"tailscale.com/tstest"
+	"tailscale.com/tstime"
 	"tailscale.com/types/key"
+	"tailscale.com/types/lazy"
 	"tailscale.com/types/persist"
 	"tailscale.com/util/must"
 )
@@ -284,7 +286,7 @@ func TestNewExtensionHost(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			t.Parallel()
 			logf := tstest.WhileTestRunningLogger(t)
-			h, err := NewExtensionHost(logf, tsd.NewSystem(), &testBackend{}, tt.defs...)
+			h, err := NewExtensionHostForTest(logf, &testBackend{}, tt.defs...)
 			if gotErr := err != nil; gotErr != tt.wantErr {
 				t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr)
 			}
@@ -1095,7 +1097,7 @@ func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initia
 		}
 		defs[i] = ipnext.DefinitionForTest(ext)
 	}
-	h, err := NewExtensionHost(logf, tsd.NewSystem(), b, defs...)
+	h, err := NewExtensionHostForTest(logf, b, defs...)
 	if err != nil {
 		t.Fatalf("NewExtensionHost: %v", err)
 	}
@@ -1320,6 +1322,7 @@ func (q *testExecQueue) Wait(context.Context) error { return nil }
 // testBackend implements [ipnext.Backend] for testing purposes
 // by calling the provided hooks when its methods are called.
 type testBackend struct {
+	lazySys                 lazy.SyncValue[*tsd.System]
 	switchToBestProfileHook func(reason string)
 
 	// mu protects the backend state.
@@ -1328,6 +1331,13 @@ type testBackend struct {
 	mu sync.Mutex
 }
 
+func (b *testBackend) Clock() tstime.Clock { return tstime.StdClock{} }
+func (b *testBackend) Sys() *tsd.System {
+	return b.lazySys.Get(tsd.NewSystem)
+}
+func (b *testBackend) SendNotify(ipn.Notify)    { panic("not implemented") }
+func (b *testBackend) TailscaleVarRoot() string { panic("not implemented") }
+
 func (b *testBackend) SwitchToBestProfile(reason string) {
 	b.mu.Lock()
 	defer b.mu.Unlock()

+ 8 - 1
ipn/ipnlocal/local.go

@@ -525,7 +525,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
 		}
 	}
 
-	if b.extHost, err = NewExtensionHost(logf, sys, b); err != nil {
+	if b.extHost, err = NewExtensionHost(logf, b); err != nil {
 		return nil, fmt.Errorf("failed to create extension host: %w", err)
 	}
 	b.pm.SetExtensionHost(b.extHost)
@@ -589,6 +589,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
 }
 
 func (b *LocalBackend) Clock() tstime.Clock { return b.clock }
+func (b *LocalBackend) Sys() *tsd.System    { return b.sys }
 
 // FindExtensionByName returns an active extension with the given name,
 // or nil if no such extension exists.
@@ -3187,6 +3188,12 @@ func (b *LocalBackend) send(n ipn.Notify) {
 	b.sendTo(n, allClients)
 }
 
+// SendNotify sends a notification to the IPN bus,
+// typically to the GUI client.
+func (b *LocalBackend) SendNotify(n ipn.Notify) {
+	b.send(n)
+}
+
 // notificationTarget describes a notification recipient.
 // A zero value is valid and indicate that the notification
 // should be broadcast to all active [watchSession]s.