Browse Source

ipn/ipnlocal: signal nodeBackend readiness and shutdown

We update LocalBackend to shut down the current nodeBackend
when switching to a different node, and to mark the new node's
nodeBackend as ready when the switch completes.

Updates tailscale/corp#28014
Updates tailscale/corp#29543
Updates #12614

Signed-off-by: Nick Khyl <[email protected]>
Nick Khyl 9 months ago
parent
commit
733bfaeffe
3 changed files with 230 additions and 16 deletions
  1. 32 11
      ipn/ipnlocal/local.go
  2. 77 5
      ipn/ipnlocal/node_backend.go
  3. 121 0
      ipn/ipnlocal/node_backend_test.go

+ 32 - 11
ipn/ipnlocal/local.go

@@ -168,6 +168,17 @@ type watchSession struct {
 
 var metricCaptivePortalDetected = clientmetric.NewCounter("captiveportal_detected")
 
+var (
+	// errShutdown indicates that the [LocalBackend.Shutdown] was called.
+	errShutdown = errors.New("shutting down")
+
+	// errNodeContextChanged indicates that [LocalBackend] has switched
+	// to a different [localNodeContext], usually due to a profile change.
+	// It is used as a context cancellation cause for the old context
+	// and can be returned when an operation is performed on it.
+	errNodeContextChanged = errors.New("profile changed")
+)
+
 // LocalBackend is the glue between the major pieces of the Tailscale
 // network software: the cloud control plane (via controlclient), the
 // network data plane (via wgengine), and the user-facing UIs and CLIs
@@ -180,11 +191,11 @@ var metricCaptivePortalDetected = clientmetric.NewCounter("captiveportal_detecte
 // state machine generates events back out to zero or more components.
 type LocalBackend struct {
 	// Elements that are thread-safe or constant after construction.
-	ctx                      context.Context    // canceled by [LocalBackend.Shutdown]
-	ctxCancel                context.CancelFunc // cancels ctx
-	logf                     logger.Logf        // general logging
-	keyLogf                  logger.Logf        // for printing list of peers on change
-	statsLogf                logger.Logf        // for printing peers stats on change
+	ctx                      context.Context         // canceled by [LocalBackend.Shutdown]
+	ctxCancel                context.CancelCauseFunc // cancels ctx
+	logf                     logger.Logf             // general logging
+	keyLogf                  logger.Logf             // for printing list of peers on change
+	statsLogf                logger.Logf             // for printing peers stats on change
 	sys                      *tsd.System
 	health                   *health.Tracker // always non-nil
 	metrics                  metrics
@@ -463,7 +474,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
 
 	envknob.LogCurrent(logf)
 
-	ctx, cancel := context.WithCancel(context.Background())
+	ctx, cancel := context.WithCancelCause(context.Background())
 	clock := tstime.StdClock{}
 
 	// Until we transition to a Running state, use a canceled context for
@@ -503,7 +514,10 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
 		captiveCancel:         nil, // so that we start checkCaptivePortalLoop when Running
 		needsCaptiveDetection: make(chan bool),
 	}
-	b.currentNodeAtomic.Store(newNodeBackend())
+	nb := newNodeBackend(ctx)
+	b.currentNodeAtomic.Store(nb)
+	nb.ready()
+
 	mConn.SetNetInfoCallback(b.setNetInfo)
 
 	if sys.InitialConfig != nil {
@@ -586,8 +600,10 @@ func (b *LocalBackend) currentNode() *nodeBackend {
 		return v
 	}
 	// Auto-init one in tests for LocalBackend created without the NewLocalBackend constructor...
-	v := newNodeBackend()
-	b.currentNodeAtomic.CompareAndSwap(nil, v)
+	v := newNodeBackend(cmp.Or(b.ctx, context.Background()))
+	if b.currentNodeAtomic.CompareAndSwap(nil, v) {
+		v.ready()
+	}
 	return b.currentNodeAtomic.Load()
 }
 
@@ -1089,8 +1105,9 @@ func (b *LocalBackend) Shutdown() {
 	if cc != nil {
 		cc.Shutdown()
 	}
+	b.ctxCancel(errShutdown)
+	b.currentNode().shutdown(errShutdown)
 	extHost.Shutdown()
-	b.ctxCancel()
 	b.e.Close()
 	<-b.e.Done()
 	b.awaitNoGoroutinesInTest()
@@ -6992,7 +7009,11 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err
 		// down, so no need to do any work.
 		return nil
 	}
-	b.currentNodeAtomic.Store(newNodeBackend())
+	newNode := newNodeBackend(b.ctx)
+	if oldNode := b.currentNodeAtomic.Swap(newNode); oldNode != nil {
+		oldNode.shutdown(errNodeContextChanged)
+	}
+	defer newNode.ready()
 	b.setNetMapLocked(nil) // Reset netmap.
 	b.updateFilterLocked(ipn.PrefsView{})
 	// Reset the NetworkMap in the engine

+ 77 - 5
ipn/ipnlocal/node_backend.go

@@ -5,6 +5,7 @@ package ipnlocal
 
 import (
 	"cmp"
+	"context"
 	"net/netip"
 	"slices"
 	"sync"
@@ -39,7 +40,7 @@ import (
 // Two pointers to different [nodeBackend] instances represent different local nodes.
 // However, there's currently a bug where a new [nodeBackend] might not be created
 // during an implicit node switch (see tailscale/corp#28014).
-
+//
 // In the future, we might want to include at least the following in this struct (in addition to the current fields).
 // However, not everything should be exported or otherwise made available to the outside world (e.g. [ipnext] extensions,
 // peer API handlers, etc.).
@@ -61,6 +62,9 @@ import (
 // Even if they're tied to the local node, instead of moving them here, we should extract the entire feature
 // into a separate package and have it install proper hooks.
 type nodeBackend struct {
+	ctx       context.Context         // canceled by [nodeBackend.shutdown]
+	ctxCancel context.CancelCauseFunc // cancels ctx
+
 	// filterAtomic is a stateful packet filter. Immutable once created, but can be
 	// replaced with a new one.
 	filterAtomic atomic.Pointer[filter.Filter]
@@ -68,6 +72,9 @@ type nodeBackend struct {
 	// TODO(nickkhyl): maybe use sync.RWMutex?
 	mu sync.Mutex // protects the following fields
 
+	shutdownOnce sync.Once     // guards calling [nodeBackend.shutdown]
+	readyCh      chan struct{} // closed by [nodeBackend.ready]; nil after shutdown
+
 	// NetMap is the most recently set full netmap from the controlclient.
 	// It can't be mutated in place once set. Because it can't be mutated in place,
 	// delta updates from the control server don't apply to it. Instead, use
@@ -88,12 +95,24 @@ type nodeBackend struct {
 	nodeByAddr map[netip.Addr]tailcfg.NodeID
 }
 
-func newNodeBackend() *nodeBackend {
-	cn := &nodeBackend{}
+func newNodeBackend(ctx context.Context) *nodeBackend {
+	ctx, ctxCancel := context.WithCancelCause(ctx)
+	nb := &nodeBackend{
+		ctx:       ctx,
+		ctxCancel: ctxCancel,
+		readyCh:   make(chan struct{}),
+	}
 	// Default filter blocks everything and logs nothing.
 	noneFilter := filter.NewAllowNone(logger.Discard, &netipx.IPSet{})
-	cn.filterAtomic.Store(noneFilter)
-	return cn
+	nb.filterAtomic.Store(noneFilter)
+	return nb
+}
+
+// Context returns a context that is canceled when the [nodeBackend] shuts down,
+// either because [LocalBackend] is switching to a different [nodeBackend]
+// or is shutting down itself.
+func (nb *nodeBackend) Context() context.Context {
+	return nb.ctx
 }
 
 func (nb *nodeBackend) Self() tailcfg.NodeView {
@@ -475,6 +494,59 @@ func (nb *nodeBackend) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (doh
 	return exitNodeCanProxyDNS(nb.netMap, nb.peers, exitNodeID)
 }
 
+// ready signals that [LocalBackend] has completed the switch to this [nodeBackend]
+// and any pending calls to [nodeBackend.Wait] must be unblocked.
+func (nb *nodeBackend) ready() {
+	nb.mu.Lock()
+	defer nb.mu.Unlock()
+	if nb.readyCh != nil {
+		close(nb.readyCh)
+	}
+}
+
+// Wait blocks until [LocalBackend] completes the switch to this [nodeBackend]
+// and calls [nodeBackend.ready]. It returns an error if the provided context
+// is canceled or if the [nodeBackend] shuts down or is already shut down.
+//
+// It must not be called with the [LocalBackend]'s internal mutex held as [LocalBackend]
+// may need to acquire it to complete the switch.
+//
+// TODO(nickkhyl): Relax this restriction once [LocalBackend]'s state machine
+// runs in its own goroutine, or if we decide that waiting for the state machine
+// restart to finish isn't necessary for [LocalBackend] to consider the switch complete.
+// We mostly need this because of [LocalBackend.Start] acquiring b.mu and the fact that
+// methods like [LocalBackend.SwitchProfile] must report any errors returned by it.
+// Perhaps we could report those errors asynchronously as [health.Warnable]s?
+func (nb *nodeBackend) Wait(ctx context.Context) error {
+	nb.mu.Lock()
+	readyCh := nb.readyCh
+	nb.mu.Unlock()
+
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	case <-nb.ctx.Done():
+		return context.Cause(nb.ctx)
+	case <-readyCh:
+		return nil
+	}
+}
+
+// shutdown shuts down the [nodeBackend] and cancels its context
+// with the provided cause.
+func (nb *nodeBackend) shutdown(cause error) {
+	nb.shutdownOnce.Do(func() {
+		nb.doShutdown(cause)
+	})
+}
+
+func (nb *nodeBackend) doShutdown(cause error) {
+	nb.mu.Lock()
+	defer nb.mu.Unlock()
+	nb.ctxCancel(cause)
+	nb.readyCh = nil
+}
+
 // dnsConfigForNetmap returns a *dns.Config for the given netmap,
 // prefs, client OS version, and cloud hosting environment.
 //

+ 121 - 0
ipn/ipnlocal/node_backend_test.go

@@ -0,0 +1,121 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ipnlocal
+
+import (
+	"context"
+	"errors"
+	"testing"
+	"time"
+)
+
+func TestNodeBackendReadiness(t *testing.T) {
+	nb := newNodeBackend(t.Context())
+
+	// The node backend is not ready until [nodeBackend.ready] is called,
+	// and [nodeBackend.Wait] should fail with [context.DeadlineExceeded].
+	ctx, cancelCtx := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer cancelCtx()
+	if err := nb.Wait(ctx); err != ctx.Err() {
+		t.Fatalf("Wait: got %v; want %v", err, ctx.Err())
+	}
+
+	// Start a goroutine to wait for the node backend to become ready.
+	waitDone := make(chan struct{})
+	go func() {
+		if err := nb.Wait(context.Background()); err != nil {
+			t.Errorf("Wait: got %v; want nil", err)
+		}
+		close(waitDone)
+	}()
+
+	// Call [nodeBackend.ready] to indicate that the node backend is now ready.
+	go nb.ready()
+
+	// Once the backend is called, [nodeBackend.Wait] should return immediately without error.
+	if err := nb.Wait(context.Background()); err != nil {
+		t.Fatalf("Wait: got %v; want nil", err)
+	}
+	// And any pending waiters should also be unblocked.
+	<-waitDone
+}
+
+func TestNodeBackendShutdown(t *testing.T) {
+	nb := newNodeBackend(t.Context())
+
+	shutdownCause := errors.New("test shutdown")
+
+	// Start a goroutine to wait for the node backend to become ready.
+	// This test expects it to block until the node backend shuts down
+	// and then return the specified shutdown cause.
+	waitDone := make(chan struct{})
+	go func() {
+		if err := nb.Wait(context.Background()); err != shutdownCause {
+			t.Errorf("Wait: got %v; want %v", err, shutdownCause)
+		}
+		close(waitDone)
+	}()
+
+	// Call [nodeBackend.shutdown] to indicate that the node backend is shutting down.
+	nb.shutdown(shutdownCause)
+
+	// Calling it again is fine, but should not change the shutdown cause.
+	nb.shutdown(errors.New("test shutdown again"))
+
+	// After shutdown, [nodeBackend.Wait] should return with the specified shutdown cause.
+	if err := nb.Wait(context.Background()); err != shutdownCause {
+		t.Fatalf("Wait: got %v; want %v", err, shutdownCause)
+	}
+	// The context associated with the node backend should also be cancelled
+	// and its cancellation cause should match the shutdown cause.
+	if err := nb.Context().Err(); !errors.Is(err, context.Canceled) {
+		t.Fatalf("Context.Err: got %v; want %v", err, context.Canceled)
+	}
+	if cause := context.Cause(nb.Context()); cause != shutdownCause {
+		t.Fatalf("Cause: got %v; want %v", cause, shutdownCause)
+	}
+	// And any pending waiters should also be unblocked.
+	<-waitDone
+}
+
+func TestNodeBackendReadyAfterShutdown(t *testing.T) {
+	nb := newNodeBackend(t.Context())
+
+	shutdownCause := errors.New("test shutdown")
+	nb.shutdown(shutdownCause)
+	nb.ready() // Calling ready after shutdown is a no-op, but should not panic, etc.
+	if err := nb.Wait(context.Background()); err != shutdownCause {
+		t.Fatalf("Wait: got %v; want %v", err, shutdownCause)
+	}
+}
+
+func TestNodeBackendParentContextCancellation(t *testing.T) {
+	ctx, cancelCtx := context.WithCancel(context.Background())
+	nb := newNodeBackend(ctx)
+
+	cancelCtx()
+
+	// Cancelling the parent context should cause [nodeBackend.Wait]
+	// to return with [context.Canceled].
+	if err := nb.Wait(context.Background()); !errors.Is(err, context.Canceled) {
+		t.Fatalf("Wait: got %v; want %v", err, context.Canceled)
+	}
+
+	// And the node backend's context should also be cancelled.
+	if err := nb.Context().Err(); !errors.Is(err, context.Canceled) {
+		t.Fatalf("Context.Err: got %v; want %v", err, context.Canceled)
+	}
+}
+
+func TestNodeBackendConcurrentReadyAndShutdown(t *testing.T) {
+	nb := newNodeBackend(t.Context())
+
+	// Calling [nodeBackend.ready] and [nodeBackend.shutdown] concurrently
+	// should not cause issues, and [nodeBackend.Wait] should unblock,
+	// but the result of [nodeBackend.Wait] is intentionally undefined.
+	go nb.ready()
+	go nb.shutdown(errors.New("test shutdown"))
+
+	nb.Wait(context.Background())
+}