Jelajahi Sumber

net/udprelay: replace map+sync.Mutex with sync.Map for VNI lookup

This commit also introduces a sync.Mutex for guarding mutatable fields
on serverEndpoint, now that it is no longer guarded by the sync.Mutex
in Server.

These changes reduce lock contention and by effect increase aggregate
throughput under high flow count load. A benchmark on Linux with AWS
c8gn instances showed a ~30% increase in aggregate throughput (37Gb/s
vs 28Gb/s) for 12 tailscaled flows.

Updates tailscale/corp#35264

Signed-off-by: Jordan Whited <[email protected]>
Jordan Whited 2 bulan lalu
induk
melakukan
a663639bea
2 mengubah file dengan 93 tambahan dan 82 penghapusan
  1. 81 69
      net/udprelay/server.go
  2. 12 13
      net/udprelay/server_test.go

+ 81 - 69
net/udprelay/server.go

@@ -77,8 +77,8 @@ type Server struct {
 	closeCh             chan struct{}
 	netChecker          *netcheck.Client
 
-	mu                  sync.Mutex           // guards the following fields
-	macSecrets          [][blake2s.Size]byte // [0] is most recent, max 2 elements
+	mu                  sync.Mutex                      // guards the following fields
+	macSecrets          views.Slice[[blake2s.Size]byte] // [0] is most recent, max 2 elements
 	macSecretRotatedAt  mono.Time
 	derpMap             *tailcfg.DERPMap
 	onlyStaticAddrPorts bool                        // no dynamic addr port discovery when set
@@ -87,8 +87,11 @@ type Server struct {
 	closed              bool
 	lamportID           uint64
 	nextVNI             uint32
-	byVNI               map[uint32]*serverEndpoint
-	byDisco             map[key.SortedPairOfDiscoPublic]*serverEndpoint
+	// serverEndpointByVNI is consistent with serverEndpointByDisco while mu is
+	// held, i.e. mu must be held around write ops. Read ops in performance
+	// sensitive paths, e.g. packet forwarding, do not need to acquire mu.
+	serverEndpointByVNI   sync.Map // key is uint32 (Geneve VNI), value is [*serverEndpoint]
+	serverEndpointByDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
 }
 
 const macSecretRotationInterval = time.Minute * 2
@@ -100,23 +103,23 @@ const (
 )
 
 // serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state.
-// serverEndpoint methods are not thread-safe.
 type serverEndpoint struct {
 	// discoPubKeys contains the key.DiscoPublic of the served clients. The
 	// indexing of this array aligns with the following fields, e.g.
 	// discoSharedSecrets[0] is the shared secret to use when sealing
 	// Disco protocol messages for transmission towards discoPubKeys[0].
-	discoPubKeys         key.SortedPairOfDiscoPublic
-	discoSharedSecrets   [2]key.DiscoShared
+	discoPubKeys       key.SortedPairOfDiscoPublic
+	discoSharedSecrets [2]key.DiscoShared
+	lamportID          uint64
+	vni                uint32
+	allocatedAt        mono.Time
+
+	mu                   sync.Mutex        // guards the following fields
 	inProgressGeneration [2]uint32         // or zero if a handshake has never started, or has just completed
 	boundAddrPorts       [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg
 	lastSeen             [2]mono.Time
 	packetsRx            [2]uint64 // num packets received from/sent by each client after they are bound
 	bytesRx              [2]uint64 // num bytes received from/sent by each client after they are bound
-
-	lamportID   uint64
-	vni         uint32
-	allocatedAt mono.Time
 }
 
 func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg disco.BindUDPRelayEndpointCommon) ([blake2s.Size]byte, error) {
@@ -141,7 +144,10 @@ func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg di
 	return out, nil
 }
 
-func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) {
+func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+
 	if senderIndex != 0 && senderIndex != 1 {
 		return nil, netip.AddrPort{}
 	}
@@ -186,7 +192,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
 		}
 		reply = append(reply, disco.Magic...)
 		reply = serverDisco.AppendTo(reply)
-		mac, err := blakeMACFromBindMsg(macSecrets[0], from, m.BindUDPRelayEndpointCommon)
+		mac, err := blakeMACFromBindMsg(macSecrets.At(0), from, m.BindUDPRelayEndpointCommon)
 		if err != nil {
 			return nil, netip.AddrPort{}
 		}
@@ -206,7 +212,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
 			// silently drop
 			return nil, netip.AddrPort{}
 		}
-		for _, macSecret := range macSecrets {
+		for _, macSecret := range macSecrets.All() {
 			mac, err := blakeMACFromBindMsg(macSecret, from, discoMsg.BindUDPRelayEndpointCommon)
 			if err != nil {
 				// silently drop
@@ -230,7 +236,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
 	}
 }
 
-func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) {
+func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) {
 	senderRaw, isDiscoMsg := disco.Source(b)
 	if !isDiscoMsg {
 		// Not a Disco message
@@ -265,7 +271,9 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by
 }
 
 func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mono.Time) (write []byte, to netip.AddrPort) {
-	if !e.isBound() {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	if !e.isBoundLocked() {
 		// not a control packet, but serverEndpoint isn't bound
 		return nil, netip.AddrPort{}
 	}
@@ -287,7 +295,9 @@ func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mon
 }
 
 func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
-	if !e.isBound() {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	if !e.isBoundLocked() {
 		if now.Sub(e.allocatedAt) > bindLifetime {
 			return true
 		}
@@ -299,9 +309,9 @@ func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifet
 	return false
 }
 
-// isBound returns true if both clients have completed a 3-way handshake,
+// isBoundLocked returns true if both clients have completed a 3-way handshake,
 // otherwise false.
-func (e *serverEndpoint) isBound() bool {
+func (e *serverEndpoint) isBoundLocked() bool {
 	return e.boundAddrPorts[0].IsValid() &&
 		e.boundAddrPorts[1].IsValid()
 }
@@ -313,15 +323,14 @@ func (e *serverEndpoint) isBound() bool {
 // used.
 func NewServer(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (s *Server, err error) {
 	s = &Server{
-		logf:                logf,
-		disco:               key.NewDisco(),
-		bindLifetime:        defaultBindLifetime,
-		steadyStateLifetime: defaultSteadyStateLifetime,
-		closeCh:             make(chan struct{}),
-		onlyStaticAddrPorts: onlyStaticAddrPorts,
-		byDisco:             make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
-		nextVNI:             minVNI,
-		byVNI:               make(map[uint32]*serverEndpoint),
+		logf:                  logf,
+		disco:                 key.NewDisco(),
+		bindLifetime:          defaultBindLifetime,
+		steadyStateLifetime:   defaultSteadyStateLifetime,
+		closeCh:               make(chan struct{}),
+		onlyStaticAddrPorts:   onlyStaticAddrPorts,
+		serverEndpointByDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
+		nextVNI:               minVNI,
 	}
 	s.discoPublic = s.disco.Public()
 
@@ -640,8 +649,8 @@ func (s *Server) Close() error {
 		// acquire s.mu.
 		s.mu.Lock()
 		defer s.mu.Unlock()
-		clear(s.byVNI)
-		clear(s.byDisco)
+		s.serverEndpointByVNI.Clear()
+		clear(s.serverEndpointByDisco)
 		s.closed = true
 		s.bus.Close()
 	})
@@ -659,10 +668,10 @@ func (s *Server) endpointGCLoop() {
 		// holding s.mu for the duration. Keep it simple (and slow) for now.
 		s.mu.Lock()
 		defer s.mu.Unlock()
-		for k, v := range s.byDisco {
+		for k, v := range s.serverEndpointByDisco {
 			if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
-				delete(s.byDisco, k)
-				delete(s.byVNI, v.vni)
+				delete(s.serverEndpointByDisco, k)
+				s.serverEndpointByVNI.Delete(v.vni)
 			}
 		}
 	}
@@ -690,12 +699,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n
 	if err != nil {
 		return nil, netip.AddrPort{}
 	}
-	// TODO: consider performance implications of holding s.mu for the remainder
-	// of this method, which does a bunch of disco/crypto work depending. Keep
-	// it simple (and slow) for now.
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	e, ok := s.byVNI[gh.VNI.Get()]
+	e, ok := s.serverEndpointByVNI.Load(gh.VNI.Get())
 	if !ok {
 		// unknown VNI
 		return nil, netip.AddrPort{}
@@ -708,27 +712,36 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n
 			return nil, netip.AddrPort{}
 		}
 		msg := b[packet.GeneveFixedHeaderLength:]
-		s.maybeRotateMACSecretLocked(now)
-		return e.handleSealedDiscoControlMsg(from, msg, s.discoPublic, s.macSecrets, now)
+		secrets := s.getMACSecrets(now)
+		return e.(*serverEndpoint).handleSealedDiscoControlMsg(from, msg, s.discoPublic, secrets, now)
 	}
-	return e.handleDataPacket(from, b, now)
+	return e.(*serverEndpoint).handleDataPacket(from, b, now)
+}
+
+func (s *Server) getMACSecrets(now mono.Time) views.Slice[[blake2s.Size]byte] {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	s.maybeRotateMACSecretLocked(now)
+	return s.macSecrets
 }
 
 func (s *Server) maybeRotateMACSecretLocked(now mono.Time) {
 	if !s.macSecretRotatedAt.IsZero() && now.Sub(s.macSecretRotatedAt) < macSecretRotationInterval {
 		return
 	}
-	switch len(s.macSecrets) {
+	secrets := s.macSecrets.AsSlice()
+	switch len(secrets) {
 	case 0:
-		s.macSecrets = make([][blake2s.Size]byte, 1, 2)
+		secrets = make([][blake2s.Size]byte, 1, 2)
 	case 1:
-		s.macSecrets = append(s.macSecrets, [blake2s.Size]byte{})
+		secrets = append(secrets, [blake2s.Size]byte{})
 		fallthrough
 	case 2:
-		s.macSecrets[1] = s.macSecrets[0]
+		secrets[1] = secrets[0]
 	}
-	rand.Read(s.macSecrets[0][:])
+	rand.Read(secrets[0][:])
 	s.macSecretRotatedAt = now
+	s.macSecrets = views.SliceOf(secrets)
 	return
 }
 
@@ -838,7 +851,7 @@ func (s *Server) getNextVNILocked() (uint32, error) {
 		} else {
 			s.nextVNI++
 		}
-		_, ok := s.byVNI[vni]
+		_, ok := s.serverEndpointByVNI.Load(vni)
 		if !ok {
 			return vni, nil
 		}
@@ -877,7 +890,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
 	}
 
 	pair := key.NewSortedPairOfDiscoPublic(discoA, discoB)
-	e, ok := s.byDisco[pair]
+	e, ok := s.serverEndpointByDisco[pair]
 	if ok {
 		// Return the existing allocation. Clients can resolve duplicate
 		// [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID].
@@ -915,8 +928,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
 	e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0])
 	e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1])
 
-	s.byDisco[pair] = e
-	s.byVNI[e.vni] = e
+	s.serverEndpointByDisco[pair] = e
+	s.serverEndpointByVNI.Store(e.vni, e)
 
 	s.logf("allocated endpoint vni=%d lamportID=%d disco[0]=%v disco[1]=%v", e.vni, e.lamportID, pair.Get()[0].ShortString(), pair.Get()[1].ShortString())
 	return endpoint.ServerEndpoint{
@@ -930,19 +943,19 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
 	}, nil
 }
 
-// extractClientInfo constructs a [status.ClientInfo] for one of the two peer
-// relay clients involved in this session.
-func extractClientInfo(idx int, ep *serverEndpoint) status.ClientInfo {
-	if idx != 0 && idx != 1 {
-		panic(fmt.Sprintf("idx passed to extractClientInfo() must be 0 or 1; got %d", idx))
-	}
-
-	return status.ClientInfo{
-		Endpoint:   ep.boundAddrPorts[idx],
-		ShortDisco: ep.discoPubKeys.Get()[idx].ShortString(),
-		PacketsTx:  ep.packetsRx[idx],
-		BytesTx:    ep.bytesRx[idx],
+// extractClientInfo constructs a [status.ClientInfo] for both relay clients
+// involved in this session.
+func (e *serverEndpoint) extractClientInfo() [2]status.ClientInfo {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	ret := [2]status.ClientInfo{}
+	for i := range e.boundAddrPorts {
+		ret[i].Endpoint = e.boundAddrPorts[i]
+		ret[i].ShortDisco = e.discoPubKeys.Get()[i].ShortString()
+		ret[i].PacketsTx = e.packetsRx[i]
+		ret[i].BytesTx = e.bytesRx[i]
 	}
+	return ret
 }
 
 // GetSessions returns a slice of peer relay session statuses, with each
@@ -955,14 +968,13 @@ func (s *Server) GetSessions() []status.ServerSession {
 	if s.closed {
 		return nil
 	}
-	var sessions = make([]status.ServerSession, 0, len(s.byDisco))
-	for _, se := range s.byDisco {
-		c1 := extractClientInfo(0, se)
-		c2 := extractClientInfo(1, se)
+	var sessions = make([]status.ServerSession, 0, len(s.serverEndpointByDisco))
+	for _, se := range s.serverEndpointByDisco {
+		clientInfos := se.extractClientInfo()
 		sessions = append(sessions, status.ServerSession{
 			VNI:     se.vni,
-			Client1: c1,
-			Client2: c2,
+			Client1: clientInfos[0],
+			Client2: clientInfos[1],
 		})
 	}
 	return sessions

+ 12 - 13
net/udprelay/server_test.go

@@ -339,19 +339,18 @@ func TestServer_getNextVNILocked(t *testing.T) {
 	c := qt.New(t)
 	s := &Server{
 		nextVNI: minVNI,
-		byVNI:   make(map[uint32]*serverEndpoint),
 	}
 	for i := uint64(0); i < uint64(totalPossibleVNI); i++ {
 		vni, err := s.getNextVNILocked()
 		if err != nil { // using quicktest here triples test time
 			t.Fatal(err)
 		}
-		s.byVNI[vni] = nil
+		s.serverEndpointByVNI.Store(vni, nil)
 	}
 	c.Assert(s.nextVNI, qt.Equals, minVNI)
 	_, err := s.getNextVNILocked()
 	c.Assert(err, qt.IsNotNil)
-	delete(s.byVNI, minVNI)
+	s.serverEndpointByVNI.Delete(minVNI)
 	_, err = s.getNextVNILocked()
 	c.Assert(err, qt.IsNil)
 }
@@ -455,17 +454,17 @@ func TestServer_maybeRotateMACSecretLocked(t *testing.T) {
 	s := &Server{}
 	start := mono.Now()
 	s.maybeRotateMACSecretLocked(start)
-	qt.Assert(t, len(s.macSecrets), qt.Equals, 1)
-	macSecret := s.macSecrets[0]
+	qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1)
+	macSecret := s.macSecrets.At(0)
 	s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval - time.Nanosecond))
-	qt.Assert(t, len(s.macSecrets), qt.Equals, 1)
-	qt.Assert(t, s.macSecrets[0], qt.Equals, macSecret)
+	qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1)
+	qt.Assert(t, s.macSecrets.At(0), qt.Equals, macSecret)
 	s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval))
-	qt.Assert(t, len(s.macSecrets), qt.Equals, 2)
-	qt.Assert(t, s.macSecrets[1], qt.Equals, macSecret)
-	qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1])
+	qt.Assert(t, s.macSecrets.Len(), qt.Equals, 2)
+	qt.Assert(t, s.macSecrets.At(1), qt.Equals, macSecret)
+	qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1))
 	s.maybeRotateMACSecretLocked(s.macSecretRotatedAt.Add(macSecretRotationInterval))
-	qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[0])
-	qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[1])
-	qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1])
+	qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(0))
+	qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(1))
+	qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1))
 }