瀏覽代碼

wgengine/magicsock: re-shape relayManager to use an event loop (#15935)

The event loop removes the need for growing locking complexities and
synchronization. Now we simply use channels. The event loop only runs
while there is active work to do.

relayManager remains no-op inside magicsock for the time being.
endpoints are never 'relayCapable' and therefore endpoint & Conn will
not feed CallMeMaybeVia or allocation events into it.

A number of relayManager events remain unimplemented, e.g.
CallMeMaybeVia reception and relay handshaking.

Updates tailscale/corp#27502

Signed-off-by: Jordan Whited <[email protected]>
Jordan Whited 10 月之前
父節點
當前提交
0f4f808e70

+ 1 - 0
cmd/k8s-operator/depaware.txt

@@ -872,6 +872,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
         tailscale.com/net/tsdial                                     from tailscale.com/control/controlclient+
      💣 tailscale.com/net/tshttpproxy                                from tailscale.com/clientupdate/distsign+
         tailscale.com/net/tstun                                      from tailscale.com/tsd+
+        tailscale.com/net/udprelay/endpoint                          from tailscale.com/wgengine/magicsock
         tailscale.com/omit                                           from tailscale.com/ipn/conffile
         tailscale.com/paths                                          from tailscale.com/client/local+
      💣 tailscale.com/portlist                                       from tailscale.com/ipn/ipnlocal

+ 1 - 0
tsnet/depaware.txt

@@ -303,6 +303,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware)
         tailscale.com/net/tsdial                                     from tailscale.com/control/controlclient+
      💣 tailscale.com/net/tshttpproxy                                from tailscale.com/clientupdate/distsign+
         tailscale.com/net/tstun                                      from tailscale.com/tsd+
+        tailscale.com/net/udprelay/endpoint                          from tailscale.com/wgengine/magicsock
         tailscale.com/omit                                           from tailscale.com/ipn/conffile
         tailscale.com/paths                                          from tailscale.com/client/local+
      💣 tailscale.com/portlist                                       from tailscale.com/ipn/ipnlocal

+ 9 - 0
wgengine/magicsock/endpoint.go

@@ -95,6 +95,7 @@ type endpoint struct {
 
 	expired         bool // whether the node has expired
 	isWireguardOnly bool // whether the endpoint is WireGuard only
+	relayCapable    bool // whether the node is capable of speaking via a [tailscale.com/net/udprelay.Server]
 }
 
 func (de *endpoint) setBestAddrLocked(v addrQuality) {
@@ -1249,6 +1250,13 @@ func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) {
 		// sent so our firewall ports are probably open and now
 		// would be a good time for them to connect.
 		go de.c.enqueueCallMeMaybe(derpAddr, de)
+
+		// Schedule allocation of relay endpoints. We make no considerations for
+		// current relay endpoints or best UDP path state for now, keep it
+		// simple.
+		if de.relayCapable {
+			go de.c.relayManager.allocateAndHandshakeAllServers(de)
+		}
 	}
 }
 
@@ -1863,6 +1871,7 @@ func (de *endpoint) resetLocked() {
 		}
 	}
 	de.probeUDPLifetime.resetCycleEndpointLocked()
+	de.c.relayManager.cancelOutstandingWork(de)
 }
 
 func (de *endpoint) numStopAndReset() int64 {

+ 7 - 0
wgengine/magicsock/magicsock.go

@@ -1939,6 +1939,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
 			c.logf("magicsock: disco: ignoring %s from %v; %v is unknown", msgType, sender.ShortString(), derpNodeSrc.ShortString())
 			return
 		}
+		ep.mu.Lock()
+		relayCapable := ep.relayCapable
+		ep.mu.Unlock()
+		if isVia && !relayCapable {
+			c.logf("magicsock: disco: ignoring %s from %v; %v is not known to be relay capable", msgType, sender.ShortString(), sender.ShortString())
+			return
+		}
 		epDisco := ep.disco.Load()
 		if epDisco == nil {
 			return

+ 255 - 20
wgengine/magicsock/relaymanager.go

@@ -4,48 +4,283 @@
 package magicsock
 
 import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"io"
+	"net/http"
 	"net/netip"
 	"sync"
+	"time"
 
 	"tailscale.com/disco"
+	udprelay "tailscale.com/net/udprelay/endpoint"
 	"tailscale.com/types/key"
+	"tailscale.com/util/httpm"
+	"tailscale.com/util/set"
 )
 
 // relayManager manages allocation and handshaking of
 // [tailscale.com/net/udprelay.Server] endpoints. The zero value is ready for
 // use.
 type relayManager struct {
-	mu                     sync.Mutex // guards the following fields
+	initOnce sync.Once
+
+	// ===================================================================
+	// The following fields are owned by a single goroutine, runLoop().
+	serversByAddrPort   set.Set[netip.AddrPort]
+	allocWorkByEndpoint map[*endpoint]*relayEndpointAllocWork
+
+	// ===================================================================
+	// The following chan fields serve event inputs to a single goroutine,
+	// runLoop().
+	allocateHandshakeCh chan *endpoint
+	allocateWorkDoneCh  chan relayEndpointAllocWorkDoneEvent
+	cancelWorkCh        chan *endpoint
+	newServerEndpointCh chan newRelayServerEndpointEvent
+	rxChallengeCh       chan relayHandshakeChallengeEvent
+	rxCallMeMaybeViaCh  chan *disco.CallMeMaybeVia
+
+	discoInfoMu            sync.Mutex // guards the following field
 	discoInfoByServerDisco map[key.DiscoPublic]*discoInfo
+
+	// runLoopStoppedCh is written to by runLoop() upon return, enabling event
+	// writers to restart it when they are blocked (see
+	// relayManagerInputEvent()).
+	runLoopStoppedCh chan struct{}
 }
 
-func (h *relayManager) initLocked() {
-	if h.discoInfoByServerDisco != nil {
-		return
+type newRelayServerEndpointEvent struct {
+	ep *endpoint
+	se udprelay.ServerEndpoint
+}
+
+type relayEndpointAllocWorkDoneEvent struct {
+	ep   *endpoint
+	work *relayEndpointAllocWork
+}
+
+// activeWork returns true if there is outstanding allocation or handshaking
+// work, otherwise it returns false.
+func (r *relayManager) activeWork() bool {
+	return len(r.allocWorkByEndpoint) > 0
+	// TODO(jwhited): consider handshaking work
+}
+
+// runLoop is a form of event loop. It ensures exclusive access to most of
+// [relayManager] state.
+func (r *relayManager) runLoop() {
+	defer func() {
+		r.runLoopStoppedCh <- struct{}{}
+	}()
+
+	for {
+		select {
+		case ep := <-r.allocateHandshakeCh:
+			r.cancelAndClearWork(ep)
+			r.allocateAllServersForEndpoint(ep)
+			if !r.activeWork() {
+				return
+			}
+		case msg := <-r.allocateWorkDoneCh:
+			work, ok := r.allocWorkByEndpoint[msg.ep]
+			if ok && work == msg.work {
+				// Verify the work in the map is the same as the one that we're
+				// cleaning up. New events on r.allocateHandshakeCh can
+				// overwrite pre-existing keys.
+				delete(r.allocWorkByEndpoint, msg.ep)
+			}
+			if !r.activeWork() {
+				return
+			}
+		case ep := <-r.cancelWorkCh:
+			r.cancelAndClearWork(ep)
+			if !r.activeWork() {
+				return
+			}
+		case newEndpoint := <-r.newServerEndpointCh:
+			_ = newEndpoint
+			// TODO(jwhited): implement
+			if !r.activeWork() {
+				return
+			}
+		case challenge := <-r.rxChallengeCh:
+			_ = challenge
+			// TODO(jwhited): implement
+			if !r.activeWork() {
+				return
+			}
+		case via := <-r.rxCallMeMaybeViaCh:
+			_ = via
+			// TODO(jwhited): implement
+			if !r.activeWork() {
+				return
+			}
+		}
 	}
-	h.discoInfoByServerDisco = make(map[key.DiscoPublic]*discoInfo)
+}
+
+type relayHandshakeChallengeEvent struct {
+	challenge [32]byte
+	disco     key.DiscoPublic
+	from      netip.AddrPort
+	vni       uint32
+	at        time.Time
+}
+
+// relayEndpointAllocWork serves to track in-progress relay endpoint allocation
+// for an [*endpoint]. This structure is immutable once initialized.
+type relayEndpointAllocWork struct {
+	// ep is the [*endpoint] associated with the work
+	ep *endpoint
+	// cancel() will signal all associated goroutines to return
+	cancel context.CancelFunc
+	// wg.Wait() will return once all associated goroutines have returned
+	wg *sync.WaitGroup
+}
+
+// init initializes [relayManager] if it is not already initialized.
+func (r *relayManager) init() {
+	r.initOnce.Do(func() {
+		r.discoInfoByServerDisco = make(map[key.DiscoPublic]*discoInfo)
+		r.allocWorkByEndpoint = make(map[*endpoint]*relayEndpointAllocWork)
+		r.allocateHandshakeCh = make(chan *endpoint)
+		r.allocateWorkDoneCh = make(chan relayEndpointAllocWorkDoneEvent)
+		r.cancelWorkCh = make(chan *endpoint)
+		r.newServerEndpointCh = make(chan newRelayServerEndpointEvent)
+		r.rxChallengeCh = make(chan relayHandshakeChallengeEvent)
+		r.rxCallMeMaybeViaCh = make(chan *disco.CallMeMaybeVia)
+		r.runLoopStoppedCh = make(chan struct{}, 1)
+		go r.runLoop()
+	})
 }
 
 // discoInfo returns a [*discoInfo] for 'serverDisco' if there is an
 // active/ongoing handshake with it, otherwise it returns nil, false.
-func (h *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) {
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	h.initLocked()
-	di, ok := h.discoInfoByServerDisco[serverDisco]
+func (r *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) {
+	r.discoInfoMu.Lock()
+	defer r.discoInfoMu.Unlock()
+	di, ok := r.discoInfoByServerDisco[serverDisco]
 	return di, ok
 }
 
-func (h *relayManager) handleCallMeMaybeVia(dm *disco.CallMeMaybeVia) {
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	h.initLocked()
-	// TODO(jwhited): implement
+func (r *relayManager) handleCallMeMaybeVia(dm *disco.CallMeMaybeVia) {
+	relayManagerInputEvent(r, nil, &r.rxCallMeMaybeViaCh, dm)
+}
+
+func (r *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) {
+	relayManagerInputEvent(r, nil, &r.rxChallengeCh, relayHandshakeChallengeEvent{challenge: dm.Challenge, disco: di.discoKey, from: src, vni: vni, at: time.Now()})
+}
+
+// relayManagerInputEvent initializes [relayManager] if necessary, starts
+// relayManager.runLoop() if it is not running, and writes 'event' on 'eventCh'.
+//
+// [relayManager] initialization will make `*eventCh`, so it must be passed as
+// a pointer to a channel.
+//
+// 'ctx' can be used for returning when runLoop is waiting for the caller to
+// return, i.e. the calling goroutine was birthed by runLoop and is cancelable
+// via 'ctx'. 'ctx' may be nil.
+func relayManagerInputEvent[T any](r *relayManager, ctx context.Context, eventCh *chan T, event T) {
+	r.init()
+	var ctxDoneCh <-chan struct{}
+	if ctx != nil {
+		ctxDoneCh = ctx.Done()
+	}
+	for {
+		select {
+		case <-ctxDoneCh:
+			return
+		case *eventCh <- event:
+			return
+		case <-r.runLoopStoppedCh:
+			go r.runLoop()
+		}
+	}
+}
+
+// allocateAndHandshakeAllServers kicks off allocation and handshaking of relay
+// endpoints for 'ep' on all known relay servers, canceling any existing
+// in-progress work.
+func (r *relayManager) allocateAndHandshakeAllServers(ep *endpoint) {
+	relayManagerInputEvent(r, nil, &r.allocateHandshakeCh, ep)
+}
+
+// cancelOutstandingWork cancels all outstanding allocation & handshaking work
+// for 'ep'.
+func (r *relayManager) cancelOutstandingWork(ep *endpoint) {
+	relayManagerInputEvent(r, nil, &r.cancelWorkCh, ep)
 }
 
-func (h *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) {
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	h.initLocked()
-	// TODO(jwhited): implement
+// cancelAndClearWork cancels & clears any outstanding work for 'ep'.
+func (r *relayManager) cancelAndClearWork(ep *endpoint) {
+	allocWork, ok := r.allocWorkByEndpoint[ep]
+	if ok {
+		allocWork.cancel()
+		allocWork.wg.Wait()
+		delete(r.allocWorkByEndpoint, ep)
+	}
+	// TODO(jwhited): cancel & clear handshake work
+}
+
+func (r *relayManager) allocateAllServersForEndpoint(ep *endpoint) {
+	if len(r.serversByAddrPort) == 0 {
+		return
+	}
+	ctx, cancel := context.WithCancel(context.Background())
+	started := &relayEndpointAllocWork{ep: ep, cancel: cancel, wg: &sync.WaitGroup{}}
+	for k := range r.serversByAddrPort {
+		started.wg.Add(1)
+		go r.allocateEndpoint(ctx, started.wg, k, ep)
+	}
+	r.allocWorkByEndpoint[ep] = started
+	go func() {
+		started.wg.Wait()
+		started.cancel()
+		relayManagerInputEvent(r, ctx, &r.allocateWorkDoneCh, relayEndpointAllocWorkDoneEvent{ep: ep, work: started})
+	}()
+}
+
+func (r *relayManager) allocateEndpoint(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, ep *endpoint) {
+	// TODO(jwhited): introduce client metrics counters for notable failures
+	defer wg.Done()
+	var b bytes.Buffer
+	remoteDisco := ep.disco.Load()
+	if remoteDisco == nil {
+		return
+	}
+	type allocateRelayEndpointReq struct {
+		DiscoKeys []key.DiscoPublic
+	}
+	a := &allocateRelayEndpointReq{
+		DiscoKeys: []key.DiscoPublic{ep.c.discoPublic, remoteDisco.key},
+	}
+	err := json.NewEncoder(&b).Encode(a)
+	if err != nil {
+		return
+	}
+	const reqTimeout = time.Second * 10
+	reqCtx, cancel := context.WithTimeout(ctx, reqTimeout)
+	defer cancel()
+	req, err := http.NewRequestWithContext(reqCtx, httpm.POST, "http://"+server.String()+"/relay/endpoint", &b)
+	if err != nil {
+		return
+	}
+	resp, err := http.DefaultClient.Do(req)
+	if err != nil {
+		return
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != http.StatusOK {
+		return
+	}
+	var se udprelay.ServerEndpoint
+	err = json.NewDecoder(io.LimitReader(resp.Body, 4096)).Decode(&se)
+	if err != nil {
+		return
+	}
+	relayManagerInputEvent(r, ctx, &r.newServerEndpointCh, newRelayServerEndpointEvent{
+		ep: ep,
+		se: se,
+	})
 }

+ 29 - 0
wgengine/magicsock/relaymanager_test.go

@@ -0,0 +1,29 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+	"net/netip"
+	"testing"
+
+	"tailscale.com/disco"
+)
+
+func TestRelayManagerInitAndIdle(t *testing.T) {
+	rm := relayManager{}
+	rm.allocateAndHandshakeAllServers(&endpoint{})
+	<-rm.runLoopStoppedCh
+
+	rm = relayManager{}
+	rm.cancelOutstandingWork(&endpoint{})
+	<-rm.runLoopStoppedCh
+
+	rm = relayManager{}
+	rm.handleCallMeMaybeVia(&disco.CallMeMaybeVia{})
+	<-rm.runLoopStoppedCh
+
+	rm = relayManager{}
+	rm.handleBindUDPRelayEndpointChallenge(&disco.BindUDPRelayEndpointChallenge{}, &discoInfo{}, netip.AddrPort{}, 0)
+	<-rm.runLoopStoppedCh
+}