Browse Source

wgengine/magicsock: avoid RebindingUDPConn mutex in common read/write case

Change-Id: I209fac567326f2e926bace2582dbc67a8bc94c78
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 3 years ago
parent
commit
fb82299f5a
1 changed files with 26 additions and 23 deletions
  1. 26 23
      wgengine/magicsock/magicsock.go

+ 26 - 23
wgengine/magicsock/magicsock.go

@@ -2826,7 +2826,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
 
 	if debugAlwaysDERP {
 		c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network)
-		ruc.pconn = newBlockForeverConn()
+		ruc.setConnLocked(newBlockForeverConn())
 		return nil
 	}
 
@@ -2860,7 +2860,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
 			continue
 		}
 		// Success.
-		ruc.pconn = pconn
+		ruc.setConnLocked(pconn)
 		if network == "udp4" {
 			health.SetUDP4Unbound(false)
 		}
@@ -2871,7 +2871,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
 	// Set pconn to a dummy conn whose reads block until closed.
 	// This keeps the receive funcs alive for a future in which
 	// we get a link change and we can try binding again.
-	ruc.pconn = newBlockForeverConn()
+	ruc.setConnLocked(newBlockForeverConn())
 	if network == "udp4" {
 		health.SetUDP4Unbound(true)
 	}
@@ -2974,11 +2974,26 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) {
 // RebindingUDPConn is a UDP socket that can be re-bound.
 // Unix has no notion of re-binding a socket, so we swap it out for a new one.
 type RebindingUDPConn struct {
-	mu    sync.Mutex
+	// pconnAtomic is the same as pconn, but doesn't require acquiring mu. It's
+	// used for reads/writes and only upon failure do the reads/writes then
+	// check pconn (after acquiring mu) to see if there's been a rebind
+	// meanwhile.
+	// pconn isn't really needed, but makes some of the code simpler
+	// to keep it in a type safe form. TODO(bradfitz): really we should make a generic
+	// atomic.Value. Unfortunately Go 1.19's atomic.Pointer[T] is only for pointers,
+	// not interfaces.
+	pconnAtomic atomic.Value // of nettype.PacketConn
+
+	mu    sync.Mutex // held while changing pconn (and pconnAtomic)
 	pconn nettype.PacketConn
 }
 
-// currentConn returns c's current pconn.
+func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn) {
+	c.pconn = p
+	c.pconnAtomic.Store(p)
+}
+
+// currentConn returns c's current pconn, acquiring c.mu in the process.
 func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
 	c.mu.Lock()
 	defer c.mu.Unlock()
@@ -2989,7 +3004,7 @@ func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
 // It returns the number of bytes copied and the source address.
 func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
 	for {
-		pconn := c.currentConn()
+		pconn := c.pconnAtomic.Load().(nettype.PacketConn)
 		n, addr, err := pconn.ReadFrom(b)
 		if err != nil && pconn != c.currentConn() {
 			continue
@@ -3007,7 +3022,7 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
 // when c's underlying connection is a net.UDPConn.
 func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) {
 	for {
-		pconn := c.currentConn()
+		pconn := c.pconnAtomic.Load().(nettype.PacketConn)
 
 		// Optimization: Treat *net.UDPConn specially.
 		// This lets us avoid allocations by calling ReadFromUDPAddrPort.
@@ -3066,17 +3081,11 @@ func (c *RebindingUDPConn) closeLocked() error {
 
 func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
 	for {
-		c.mu.Lock()
-		pconn := c.pconn
-		c.mu.Unlock()
+		pconn := c.pconnAtomic.Load().(nettype.PacketConn)
 
 		n, err := pconn.WriteTo(b, addr)
 		if err != nil {
-			c.mu.Lock()
-			pconn2 := c.pconn
-			c.mu.Unlock()
-
-			if pconn != pconn2 {
+			if pconn != c.currentConn() {
 				continue
 			}
 		}
@@ -3086,17 +3095,11 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
 
 func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
 	for {
-		c.mu.Lock()
-		pconn := c.pconn
-		c.mu.Unlock()
+		pconn := c.pconnAtomic.Load().(nettype.PacketConn)
 
 		n, err := pconn.WriteToUDPAddrPort(b, addr)
 		if err != nil {
-			c.mu.Lock()
-			pconn2 := c.pconn
-			c.mu.Unlock()
-
-			if pconn != pconn2 {
+			if pconn != c.currentConn() {
 				continue
 			}
 		}