Browse Source

feature/relayserver: use eventbus.Monitor to simplify lifecycle management (#17234)

Instead of using separate channels to manage the lifecycle of the eventbus
client, use the recently-added eventbus.Monitor, which handles signaling the
processing loop to stop and waiting for it to complete.  This allows us to
simplify some of the setup and cleanup code in the relay server.

Updates #15160

Change-Id: Ia1a47ce2e5a31bc8f546dca4c56c3141a40d67af
Signed-off-by: M. J. Fromberger <[email protected]>
M. J. Fromberger 5 months ago
parent
commit
3c32f87624
2 changed files with 70 additions and 77 deletions
  1. 64 71
      feature/relayserver/relayserver.go
  2. 6 6
      feature/relayserver/relayserver_test.go

+ 64 - 71
feature/relayserver/relayserver.go

@@ -82,11 +82,11 @@ type extension struct {
 	logf logger.Logf
 	logf logger.Logf
 	bus  *eventbus.Bus
 	bus  *eventbus.Bus
 
 
-	mu                            sync.Mutex // guards the following fields
-	shutdown                      bool
+	mu       sync.Mutex // guards the following fields
+	shutdown bool
+
 	port                          *int                             // ipn.Prefs.RelayServerPort, nil if disabled
 	port                          *int                             // ipn.Prefs.RelayServerPort, nil if disabled
-	disconnectFromBusCh           chan struct{}                    // non-nil if consumeEventbusTopics is running, closed to signal it to return
-	busDoneCh                     chan struct{}                    // non-nil if consumeEventbusTopics is running, closed when it returns
+	eventSubs                     *eventbus.Monitor                // nil if not connected to eventbus
 	debugSessionsCh               chan chan []status.ServerSession // non-nil if consumeEventbusTopics is running
 	debugSessionsCh               chan chan []status.ServerSession // non-nil if consumeEventbusTopics is running
 	hasNodeAttrDisableRelayServer bool                             // tailcfg.NodeAttrDisableRelayServer
 	hasNodeAttrDisableRelayServer bool                             // tailcfg.NodeAttrDisableRelayServer
 }
 }
@@ -119,15 +119,13 @@ func (e *extension) handleBusLifetimeLocked() {
 	if !busShouldBeRunning {
 	if !busShouldBeRunning {
 		e.disconnectFromBusLocked()
 		e.disconnectFromBusLocked()
 		return
 		return
-	}
-	if e.busDoneCh != nil {
+	} else if e.eventSubs != nil {
 		return // already running
 		return // already running
 	}
 	}
-	port := *e.port
-	e.disconnectFromBusCh = make(chan struct{})
-	e.busDoneCh = make(chan struct{})
+
+	ec := e.bus.Client("relayserver.extension")
 	e.debugSessionsCh = make(chan chan []status.ServerSession)
 	e.debugSessionsCh = make(chan chan []status.ServerSession)
-	go e.consumeEventbusTopics(port)
+	e.eventSubs = ptr.To(ec.Monitor(e.consumeEventbusTopics(ec, *e.port)))
 }
 }
 
 
 func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) {
 func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) {
@@ -175,77 +173,72 @@ var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) {
 
 
 // consumeEventbusTopics serves endpoint allocation requests over the eventbus.
 // consumeEventbusTopics serves endpoint allocation requests over the eventbus.
 // It also serves [relayServer] debug information on a channel.
 // It also serves [relayServer] debug information on a channel.
-// consumeEventbusTopics must never acquire [extension.mu], which can be held by
-// other goroutines while waiting to receive on [extension.busDoneCh] or the
+// consumeEventbusTopics must never acquire [extension.mu], which can be held
+// by other goroutines while waiting to receive on [extension.eventSubs] or the
 // inner [extension.debugSessionsCh] channel.
 // inner [extension.debugSessionsCh] channel.
-func (e *extension) consumeEventbusTopics(port int) {
-	defer close(e.busDoneCh)
+func (e *extension) consumeEventbusTopics(ec *eventbus.Client, port int) func(*eventbus.Client) {
+	reqSub := eventbus.Subscribe[magicsock.UDPRelayAllocReq](ec)
+	respPub := eventbus.Publish[magicsock.UDPRelayAllocResp](ec)
+	debugSessionsCh := e.debugSessionsCh
 
 
-	eventClient := e.bus.Client("relayserver.extension")
-	reqSub := eventbus.Subscribe[magicsock.UDPRelayAllocReq](eventClient)
-	respPub := eventbus.Publish[magicsock.UDPRelayAllocResp](eventClient)
-	defer eventClient.Close()
-
-	var rs relayServer // lazily initialized
-	defer func() {
-		if rs != nil {
-			rs.Close()
-		}
-	}()
-	for {
-		select {
-		case <-e.disconnectFromBusCh:
-			return
-		case <-eventClient.Done():
-			return
-		case respCh := <-e.debugSessionsCh:
-			if rs == nil {
-				// Don't initialize the server simply for a debug request.
-				respCh <- nil
-				continue
+	return func(ec *eventbus.Client) {
+		var rs relayServer // lazily initialized
+		defer func() {
+			if rs != nil {
+				rs.Close()
 			}
 			}
-			sessions := rs.GetSessions()
-			respCh <- sessions
-		case req := <-reqSub.Events():
-			if rs == nil {
-				var err error
-				rs, err = udprelay.NewServer(e.logf, port, overrideAddrs())
+		}()
+		for {
+			select {
+			case <-ec.Done():
+				return
+			case respCh := <-debugSessionsCh:
+				if rs == nil {
+					// Don't initialize the server simply for a debug request.
+					respCh <- nil
+					continue
+				}
+				sessions := rs.GetSessions()
+				respCh <- sessions
+			case req := <-reqSub.Events():
+				if rs == nil {
+					var err error
+					rs, err = udprelay.NewServer(e.logf, port, overrideAddrs())
+					if err != nil {
+						e.logf("error initializing server: %v", err)
+						continue
+					}
+				}
+				se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1])
 				if err != nil {
 				if err != nil {
-					e.logf("error initializing server: %v", err)
+					e.logf("error allocating endpoint: %v", err)
 					continue
 					continue
 				}
 				}
-			}
-			se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1])
-			if err != nil {
-				e.logf("error allocating endpoint: %v", err)
-				continue
-			}
-			respPub.Publish(magicsock.UDPRelayAllocResp{
-				ReqRxFromNodeKey:  req.RxFromNodeKey,
-				ReqRxFromDiscoKey: req.RxFromDiscoKey,
-				Message: &disco.AllocateUDPRelayEndpointResponse{
-					Generation: req.Message.Generation,
-					UDPRelayEndpoint: disco.UDPRelayEndpoint{
-						ServerDisco:         se.ServerDisco,
-						ClientDisco:         se.ClientDisco,
-						LamportID:           se.LamportID,
-						VNI:                 se.VNI,
-						BindLifetime:        se.BindLifetime.Duration,
-						SteadyStateLifetime: se.SteadyStateLifetime.Duration,
-						AddrPorts:           se.AddrPorts,
+				respPub.Publish(magicsock.UDPRelayAllocResp{
+					ReqRxFromNodeKey:  req.RxFromNodeKey,
+					ReqRxFromDiscoKey: req.RxFromDiscoKey,
+					Message: &disco.AllocateUDPRelayEndpointResponse{
+						Generation: req.Message.Generation,
+						UDPRelayEndpoint: disco.UDPRelayEndpoint{
+							ServerDisco:         se.ServerDisco,
+							ClientDisco:         se.ClientDisco,
+							LamportID:           se.LamportID,
+							VNI:                 se.VNI,
+							BindLifetime:        se.BindLifetime.Duration,
+							SteadyStateLifetime: se.SteadyStateLifetime.Duration,
+							AddrPorts:           se.AddrPorts,
+						},
 					},
 					},
-				},
-			})
+				})
+			}
 		}
 		}
 	}
 	}
 }
 }
 
 
 func (e *extension) disconnectFromBusLocked() {
 func (e *extension) disconnectFromBusLocked() {
-	if e.busDoneCh != nil {
-		close(e.disconnectFromBusCh)
-		<-e.busDoneCh
-		e.busDoneCh = nil
-		e.disconnectFromBusCh = nil
+	if e.eventSubs != nil {
+		e.eventSubs.Close()
+		e.eventSubs = nil
 		e.debugSessionsCh = nil
 		e.debugSessionsCh = nil
 	}
 	}
 }
 }
@@ -270,7 +263,7 @@ func (e *extension) serverStatus() status.ServerStatus {
 		UDPPort:  nil,
 		UDPPort:  nil,
 		Sessions: nil,
 		Sessions: nil,
 	}
 	}
-	if e.port == nil || e.busDoneCh == nil {
+	if e.port == nil || e.eventSubs == nil {
 		return st
 		return st
 	}
 	}
 	st.UDPPort = ptr.To(*e.port)
 	st.UDPPort = ptr.To(*e.port)
@@ -281,7 +274,7 @@ func (e *extension) serverStatus() status.ServerStatus {
 		resp := <-ch
 		resp := <-ch
 		st.Sessions = resp
 		st.Sessions = resp
 		return st
 		return st
-	case <-e.busDoneCh:
+	case <-e.eventSubs.Done():
 		return st
 		return st
 	}
 	}
 }
 }

+ 6 - 6
feature/relayserver/relayserver_test.go

@@ -101,8 +101,8 @@ func Test_extension_profileStateChanged(t *testing.T) {
 			}
 			}
 			defer e.disconnectFromBusLocked()
 			defer e.disconnectFromBusLocked()
 			e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode)
 			e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode)
-			if tt.wantBusRunning != (e.busDoneCh != nil) {
-				t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil)
+			if tt.wantBusRunning != (e.eventSubs != nil) {
+				t.Errorf("wantBusRunning: %v != (e.eventSubs != nil): %v", tt.wantBusRunning, e.eventSubs != nil)
 			}
 			}
 			if (tt.wantPort == nil) != (e.port == nil) {
 			if (tt.wantPort == nil) != (e.port == nil) {
 				t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil)
 				t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil)
@@ -118,7 +118,7 @@ func Test_extension_handleBusLifetimeLocked(t *testing.T) {
 		name                          string
 		name                          string
 		shutdown                      bool
 		shutdown                      bool
 		port                          *int
 		port                          *int
-		busDoneCh                     chan struct{}
+		eventSubs                     *eventbus.Monitor
 		hasNodeAttrDisableRelayServer bool
 		hasNodeAttrDisableRelayServer bool
 		wantBusRunning                bool
 		wantBusRunning                bool
 	}{
 	}{
@@ -157,13 +157,13 @@ func Test_extension_handleBusLifetimeLocked(t *testing.T) {
 				bus:                           eventbus.New(),
 				bus:                           eventbus.New(),
 				shutdown:                      tt.shutdown,
 				shutdown:                      tt.shutdown,
 				port:                          tt.port,
 				port:                          tt.port,
-				busDoneCh:                     tt.busDoneCh,
+				eventSubs:                     tt.eventSubs,
 				hasNodeAttrDisableRelayServer: tt.hasNodeAttrDisableRelayServer,
 				hasNodeAttrDisableRelayServer: tt.hasNodeAttrDisableRelayServer,
 			}
 			}
 			e.handleBusLifetimeLocked()
 			e.handleBusLifetimeLocked()
 			defer e.disconnectFromBusLocked()
 			defer e.disconnectFromBusLocked()
-			if tt.wantBusRunning != (e.busDoneCh != nil) {
-				t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil)
+			if tt.wantBusRunning != (e.eventSubs != nil) {
+				t.Errorf("wantBusRunning: %v != (e.eventSubs != nil): %v", tt.wantBusRunning, e.eventSubs != nil)
 			}
 			}
 		})
 		})
 	}
 	}