Browse Source

ipn/ipnlocal: make GetExt work earlier, before extension init

Taildrop wasn't working on iOS since #15971 because GetExt didn't work
until after init, but that PR moved Init until after Start.

This makes GetExt work before LocalBackend.Start (ExtensionHost.Init).

Updates #15812

Change-Id: I6e87257cd97a20f86083a746d39df223e5b6791b
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 9 months ago
parent
commit
30a89ad378
3 changed files with 59 additions and 8 deletions
  1. 1 1
      ipn/ipnext/ipnext.go
  2. 21 7
      ipn/ipnlocal/extension_host.go
  3. 37 0
      ipn/ipnlocal/extension_host_test.go

+ 1 - 1
ipn/ipnext/ipnext.go

@@ -114,7 +114,7 @@ func RegisterExtension(name string, newExt NewExtensionFn) {
 		panic(fmt.Sprintf("ipnext: newExt is nil: %q", name))
 	}
 	if extensions.Contains(name) {
-		panic(fmt.Sprintf("ipnext: duplicate extensions: %q", name))
+		panic(fmt.Sprintf("ipnext: duplicate extension name %q", name))
 	}
 	extensions.Set(name, &Definition{name, newExt})
 }

+ 21 - 7
ipn/ipnlocal/extension_host.go

@@ -22,6 +22,7 @@ import (
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/logger"
 	"tailscale.com/util/execqueue"
+	"tailscale.com/util/mak"
 	"tailscale.com/util/testenv"
 )
 
@@ -97,7 +98,8 @@ type ExtensionHost struct {
 	initialized atomic.Bool
 	// activeExtensions is a subset of allExtensions that have been initialized and are ready to use.
 	activeExtensions []ipnext.Extension
-	// extensionsByName are the activeExtensions indexed by their names.
+	// extensionsByName are the extensions indexed by their names.
+	// They are not necessarily initialized (in activeExtensions) yet.
 	extensionsByName map[string]ipnext.Extension
 	// postInitWorkQueue is a queue of functions to be executed
 	// by the workQueue after all extensions have been initialized.
@@ -184,6 +186,24 @@ func newExtensionHost(logf logger.Logf, b Backend, overrideExts ...*ipnext.Defin
 			return nil, fmt.Errorf("failed to create %q extension: %v", d.Name(), err)
 		}
 		host.allExtensions = append(host.allExtensions, ext)
+
+		if d.Name() != ext.Name() {
+			return nil, fmt.Errorf("extension name %q does not match the registered name %q", ext.Name(), d.Name())
+		}
+
+		if _, ok := host.extensionsByName[ext.Name()]; ok {
+			return nil, fmt.Errorf("duplicate extension name %q", ext.Name())
+		} else {
+			mak.Set(&host.extensionsByName, ext.Name(), ext)
+		}
+
+		typ := reflect.TypeOf(ext)
+		if _, ok := host.extByType.Load(typ); ok {
+			if _, ok := ext.(interface{ PermitDoubleRegister() }); !ok {
+				return nil, fmt.Errorf("duplicate extension type %T", ext)
+			}
+		}
+		host.extByType.Store(typ, ext)
 	}
 	return host, nil
 }
@@ -215,10 +235,6 @@ func (h *ExtensionHost) init() {
 	defer h.initDone.Store(true)
 
 	// Initialize the extensions in the order they were registered.
-	h.mu.Lock()
-	h.activeExtensions = make([]ipnext.Extension, 0, len(h.allExtensions))
-	h.extensionsByName = make(map[string]ipnext.Extension, len(h.allExtensions))
-	h.mu.Unlock()
 	for _, ext := range h.allExtensions {
 		// Do not hold the lock while calling [ipnext.Extension.Init].
 		// Extensions call back into the host to register their callbacks,
@@ -240,8 +256,6 @@ func (h *ExtensionHost) init() {
 		// We'd like to make them visible to other extensions that are initialized later.
 		h.mu.Lock()
 		h.activeExtensions = append(h.activeExtensions, ext)
-		h.extensionsByName[ext.Name()] = ext
-		h.extByType.Store(reflect.TypeOf(ext), ext)
 		h.mu.Unlock()
 	}
 

+ 37 - 0
ipn/ipnlocal/extension_host_test.go

@@ -30,6 +30,7 @@ import (
 	"tailscale.com/tstime"
 	"tailscale.com/types/key"
 	"tailscale.com/types/lazy"
+	"tailscale.com/types/logger"
 	"tailscale.com/types/persist"
 	"tailscale.com/util/must"
 )
@@ -1042,6 +1043,38 @@ func TestNilExtensionHostMethodCall(t *testing.T) {
 	}
 }
 
+// extBeforeStartExtension is a test extension used by TestGetExtBeforeStart.
+// It is registered with the [ipnext.RegisterExtension].
+type extBeforeStartExtension struct{}
+
+func init() {
+	ipnext.RegisterExtension("ext-before-start", mkExtBeforeStartExtension)
+}
+
+func mkExtBeforeStartExtension(logger.Logf, ipnext.SafeBackend) (ipnext.Extension, error) {
+	return extBeforeStartExtension{}, nil
+}
+
+func (extBeforeStartExtension) Name() string { return "ext-before-start" }
+func (extBeforeStartExtension) Init(ipnext.Host) error {
+	return nil
+}
+func (extBeforeStartExtension) Shutdown() error {
+	return nil
+}
+
+// TestGetExtBeforeStart verifies that an extension registered via
+// RegisterExtension can be retrieved with GetExt before the host is started
+// (via LocalBackend.Start)
+func TestGetExtBeforeStart(t *testing.T) {
+	lb := newTestBackend(t)
+	// Now call GetExt without calling Start on the LocalBackend.
+	_, ok := GetExt[extBeforeStartExtension](lb)
+	if !ok {
+		t.Fatal("didn't find extension")
+	}
+}
+
 // checkMethodCallWithZeroArgs calls the method m on the receiver r
 // with zero values for all its arguments, except the receiver itself.
 // It returns the result of the method call, or fails the test if the call panics.
@@ -1151,6 +1184,10 @@ type testExtension struct {
 
 var _ ipnext.Extension = (*testExtension)(nil)
 
+// PermitDoubleRegister is a sentinel method whose existence tells the
+// ExtensionHost to permit it to be registered multiple times.
+func (*testExtension) PermitDoubleRegister() {}
+
 func (e *testExtension) setT(t *testing.T) {
 	e.t = t
 }