Преглед изворни кода

ipn/ipnlocal, util/goroutines: track goroutines for tests, shutdown

Updates #14520
Updates #14517 (in that I pulled this out of there)

Change-Id: Ibc28162816e083fcadf550586c06805c76e378fc
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick пре 1 година
родитељ
комит
07aae18bca
3 измењених фајлова са 118 додато и 14 уклоњено
  1. 51 13
      ipn/ipnlocal/local.go
  2. 1 1
      util/goroutines/goroutines.go
  3. 66 0
      util/goroutines/tracker.go

+ 51 - 13
ipn/ipnlocal/local.go

@@ -96,6 +96,7 @@ import (
 	"tailscale.com/types/views"
 	"tailscale.com/util/deephash"
 	"tailscale.com/util/dnsname"
+	"tailscale.com/util/goroutines"
 	"tailscale.com/util/httpm"
 	"tailscale.com/util/mak"
 	"tailscale.com/util/multierr"
@@ -178,7 +179,7 @@ type watchSession struct {
 // state machine generates events back out to zero or more components.
 type LocalBackend struct {
 	// Elements that are thread-safe or constant after construction.
-	ctx                      context.Context    // canceled by Close
+	ctx                      context.Context    // canceled by [LocalBackend.Shutdown]
 	ctxCancel                context.CancelFunc // cancels ctx
 	logf                     logger.Logf        // general logging
 	keyLogf                  logger.Logf        // for printing list of peers on change
@@ -231,6 +232,10 @@ type LocalBackend struct {
 	shouldInterceptTCPPortAtomic syncs.AtomicValue[func(uint16) bool]
 	numClientStatusCalls         atomic.Uint32
 
+	// goTracker accounts for all goroutines started by LocalBacked, primarily
+	// for testing and graceful shutdown purposes.
+	goTracker goroutines.Tracker
+
 	// The mutex protects the following elements.
 	mu             sync.Mutex
 	conf           *conffile.Config // latest parsed config, or nil if not in declarative mode
@@ -866,7 +871,7 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) {
 			// TODO(raggi,tailscale/corp#22574): authReconfig should be refactored such that we can call the
 			// necessary operations here and avoid the need for asynchronous behavior that is racy and hard
 			// to test here, and do less extra work in these conditions.
-			go b.authReconfig()
+			b.goTracker.Go(b.authReconfig)
 		}
 	}
 
@@ -879,7 +884,7 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) {
 		want := b.netMap.GetAddresses().Len()
 		if len(b.peerAPIListeners) < want {
 			b.logf("linkChange: peerAPIListeners too low; trying again")
-			go b.initPeerAPIListener()
+			b.goTracker.Go(b.initPeerAPIListener)
 		}
 	}
 }
@@ -1004,6 +1009,33 @@ func (b *LocalBackend) Shutdown() {
 	b.ctxCancel()
 	b.e.Close()
 	<-b.e.Done()
+	b.awaitNoGoroutinesInTest()
+}
+
+func (b *LocalBackend) awaitNoGoroutinesInTest() {
+	if !testenv.InTest() {
+		return
+	}
+	ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
+	defer cancel()
+
+	ch := make(chan bool, 1)
+	defer b.goTracker.AddDoneCallback(func() { ch <- true })()
+
+	for {
+		n := b.goTracker.RunningGoroutines()
+		if n == 0 {
+			return
+		}
+		select {
+		case <-ctx.Done():
+			// TODO(bradfitz): pass down some TB-like failer interface from
+			// tests, without depending on testing from here?
+			// But this is fine in tests too:
+			panic(fmt.Sprintf("timeout waiting for %d goroutines to stop", n))
+		case <-ch:
+		}
+	}
 }
 
 func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView {
@@ -2152,7 +2184,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
 
 	if b.portpoll != nil {
 		b.portpollOnce.Do(func() {
-			go b.readPoller()
+			b.goTracker.Go(b.readPoller)
 		})
 	}
 
@@ -2366,7 +2398,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P
 	b.e.SetJailedFilter(filter.NewShieldsUpFilter(localNets, logNets, oldJailedFilter, b.logf))
 
 	if b.sshServer != nil {
-		go b.sshServer.OnPolicyChange()
+		b.goTracker.Go(b.sshServer.OnPolicyChange)
 	}
 }
 
@@ -2843,7 +2875,7 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A
 	// request every 2 seconds.
 	// TODO(bradfitz): plumb this further and only send a Notify on change.
 	if mask&ipn.NotifyWatchEngineUpdates != 0 {
-		go b.pollRequestEngineStatus(ctx)
+		b.goTracker.Go(func() { b.pollRequestEngineStatus(ctx) })
 	}
 
 	// TODO(marwan-at-work): streaming background logs?
@@ -3850,7 +3882,7 @@ func (b *LocalBackend) editPrefsLockedOnEntry(mp *ipn.MaskedPrefs, unlock unlock
 	if mp.EggSet {
 		mp.EggSet = false
 		b.egg = true
-		go b.doSetHostinfoFilterServices()
+		b.goTracker.Go(b.doSetHostinfoFilterServices)
 	}
 	p0 := b.pm.CurrentPrefs()
 	p1 := b.pm.CurrentPrefs().AsStruct()
@@ -3943,7 +3975,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce)
 
 	if oldp.ShouldSSHBeRunning() && !newp.ShouldSSHBeRunning() {
 		if b.sshServer != nil {
-			go b.sshServer.Shutdown()
+			b.goTracker.Go(b.sshServer.Shutdown)
 			b.sshServer = nil
 		}
 	}
@@ -4285,8 +4317,14 @@ func (b *LocalBackend) authReconfig() {
 	dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, b.logf, version.OS())
 	// If the current node is an app connector, ensure the app connector machine is started
 	b.reconfigAppConnectorLocked(nm, prefs)
+	closing := b.shutdownCalled
 	b.mu.Unlock()
 
+	if closing {
+		b.logf("[v1] authReconfig: skipping because in shutdown")
+		return
+	}
+
 	if blocked {
 		b.logf("[v1] authReconfig: blocked, skipping.")
 		return
@@ -4751,7 +4789,7 @@ func (b *LocalBackend) initPeerAPIListener() {
 		b.peerAPIListeners = append(b.peerAPIListeners, pln)
 	}
 
-	go b.doSetHostinfoFilterServices()
+	b.goTracker.Go(b.doSetHostinfoFilterServices)
 }
 
 // magicDNSRootDomains returns the subset of nm.DNS.Domains that are the search domains for MagicDNS.
@@ -5020,7 +5058,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock
 		// can be shut down if we transition away from Running.
 		if b.captiveCancel == nil {
 			b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx)
-			go b.checkCaptivePortalLoop(b.captiveCtx)
+			b.goTracker.Go(func() { b.checkCaptivePortalLoop(b.captiveCtx) })
 		}
 	} else if oldState == ipn.Running {
 		// Transitioning away from running.
@@ -5272,7 +5310,7 @@ func (b *LocalBackend) requestEngineStatusAndWait() {
 	b.statusLock.Lock()
 	defer b.statusLock.Unlock()
 
-	go b.e.RequestStatus()
+	b.goTracker.Go(b.e.RequestStatus)
 	b.logf("requestEngineStatusAndWait: waiting...")
 	b.statusChanged.Wait() // temporarily releases lock while waiting
 	b.logf("requestEngineStatusAndWait: got status update.")
@@ -5383,7 +5421,7 @@ func (b *LocalBackend) setWebClientAtomicBoolLocked(nm *netmap.NetworkMap) {
 	shouldRun := !nm.HasCap(tailcfg.NodeAttrDisableWebClient)
 	wasRunning := b.webClientAtomicBool.Swap(shouldRun)
 	if wasRunning && !shouldRun {
-		go b.webClientShutdown() // stop web client
+		b.goTracker.Go(b.webClientShutdown) // stop web client
 	}
 }
 
@@ -5900,7 +5938,7 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn.
 	if wire := b.wantIngressLocked(); b.hostinfo != nil && b.hostinfo.WireIngress != wire {
 		b.logf("Hostinfo.WireIngress changed to %v", wire)
 		b.hostinfo.WireIngress = wire
-		go b.doSetHostinfoFilterServices()
+		b.goTracker.Go(b.doSetHostinfoFilterServices)
 	}
 
 	b.setTCPPortsIntercepted(handlePorts)

+ 1 - 1
util/goroutines/goroutines.go

@@ -1,7 +1,7 @@
 // Copyright (c) Tailscale Inc & AUTHORS
 // SPDX-License-Identifier: BSD-3-Clause
 
-// The goroutines package contains utilities for getting active goroutines.
+// The goroutines package contains utilities for tracking and getting active goroutines.
 package goroutines
 
 import (

+ 66 - 0
util/goroutines/tracker.go

@@ -0,0 +1,66 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package goroutines
+
+import (
+	"sync"
+	"sync/atomic"
+
+	"tailscale.com/util/set"
+)
+
+// Tracker tracks a set of goroutines.
+type Tracker struct {
+	started atomic.Int64 // counter
+	running atomic.Int64 // gauge
+
+	mu     sync.Mutex
+	onDone set.HandleSet[func()]
+}
+
+func (t *Tracker) Go(f func()) {
+	t.started.Add(1)
+	t.running.Add(1)
+	go t.goAndDecr(f)
+}
+
+func (t *Tracker) goAndDecr(f func()) {
+	defer t.decr()
+	f()
+}
+
+func (t *Tracker) decr() {
+	t.running.Add(-1)
+
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	for _, f := range t.onDone {
+		go f()
+	}
+}
+
+// AddDoneCallback adds a callback to be called in a new goroutine
+// whenever a goroutine managed by t (excluding ones from this method)
+// finishes. It returns a function to remove the callback.
+func (t *Tracker) AddDoneCallback(f func()) (remove func()) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.onDone == nil {
+		t.onDone = set.HandleSet[func()]{}
+	}
+	h := t.onDone.Add(f)
+	return func() {
+		t.mu.Lock()
+		defer t.mu.Unlock()
+		delete(t.onDone, h)
+	}
+}
+
+func (t *Tracker) RunningGoroutines() int64 {
+	return t.running.Load()
+}
+
+func (t *Tracker) StartedGoroutines() int64 {
+	return t.started.Load()
+}