Browse Source

wgengine/magicsock: refactor maybeRebindOnError

Remove the platform specificity, it is unnecessary complexity.
Deduplicate repeated code as a result of reduced complexity.
Split out error identification code.
Update call-sites and tests.

Updates #14551
Updates tailscale/corp#25648

Signed-off-by: James Tucker <[email protected]>
James Tucker 1 year ago
parent
commit
2c07f5dfcd

+ 22 - 5
wgengine/magicsock/magicsock.go

@@ -364,9 +364,9 @@ type Conn struct {
 	// wireguard state by its public key. If nil, it's not used.
 	getPeerByKey func(key.NodePublic) (_ wgint.Peer, ok bool)
 
-	// lastEPERMRebind tracks the last time a rebind was performed
-	// after experiencing a syscall.EPERM.
-	lastEPERMRebind syncs.AtomicValue[time.Time]
+	// lastErrRebind tracks the last time a rebind was performed after
+	// experiencing a write error, and is used to throttle the rate of rebinds.
+	lastErrRebind syncs.AtomicValue[time.Time]
 
 	// staticEndpoints are user set endpoints that this node should
 	// advertise amongst its wireguard endpoints. It is user's
@@ -1258,7 +1258,7 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err
 			c.logf("magicsock: %s", errGSO.Error())
 			err = errGSO.RetryErr
 		} else {
-			_ = c.maybeRebindOnError(runtime.GOOS, err)
+			c.maybeRebindOnError(err)
 		}
 	}
 	return err == nil, err
@@ -1273,7 +1273,7 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte, isDisco bool) (sent bool, e
 	sent, err = c.sendUDPStd(ipp, b)
 	if err != nil {
 		metricSendUDPError.Add(1)
-		_ = c.maybeRebindOnError(runtime.GOOS, err)
+		c.maybeRebindOnError(err)
 	} else {
 		if sent && !isDisco {
 			switch {
@@ -1289,6 +1289,23 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte, isDisco bool) (sent bool, e
 	return
 }
 
+// maybeRebindOnError performs a rebind and restun if the error is one that is
+// known to be healed by a rebind, and the rebind is not throttled.
+func (c *Conn) maybeRebindOnError(err error) {
+	ok, reason := shouldRebind(err)
+	if !ok {
+		return
+	}
+
+	if c.lastErrRebind.Load().Before(time.Now().Add(-5 * time.Second)) {
+		c.logf("magicsock: performing rebind due to %q", reason)
+		c.Rebind()
+		go c.ReSTUN(reason)
+	} else {
+		c.logf("magicsock: not performing %q rebind due to throttle", reason)
+	}
+}
+
 // sendUDPNetcheck sends b via UDP to addr. It is used exclusively by netcheck.
 // It returns the number of bytes sent along with any error encountered. It
 // returns errors.ErrUnsupported if the client is explicitly configured to only

+ 15 - 33
wgengine/magicsock/magicsock_notplan9.go

@@ -8,42 +8,24 @@ package magicsock
 import (
 	"errors"
 	"syscall"
-	"time"
 )
 
-// maybeRebindOnError performs a rebind and restun if the error is defined and
-// any conditionals are met.
-func (c *Conn) maybeRebindOnError(os string, err error) bool {
+// shouldRebind returns if the error is one that is known to be healed by a
+// rebind, and if so also returns a resason string for the rebind.
+func shouldRebind(err error) (ok bool, reason string) {
 	switch {
-	case errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ENOTCONN):
-		// EPIPE/ENOTCONN are common errors when a send fails due to a closed
-		// socket. There is some platform and version inconsistency in which
-		// error is returned, but the meaning is the same.
-		why := "broken-pipe-rebind"
-		c.logf("magicsock: performing %q", why)
-		c.Rebind()
-		go c.ReSTUN(why)
-		return true
+	// EPIPE/ENOTCONN are common errors when a send fails due to a closed
+	// socket. There is some platform and version inconsistency in which
+	// error is returned, but the meaning is the same.
+	case errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ENOTCONN):
+		return true, "broken-pipe"
+
+	// EPERM is typically caused by EDR software, and has been observed to be
+	// transient, it seems that some versions of some EDR lose track of sockets
+	// at times, and return EPERM, but reconnects will establish appropriate
+	// rights associated with a new socket.
 	case errors.Is(err, syscall.EPERM):
-		why := "operation-not-permitted-rebind"
-		switch os {
-		// We currently will only rebind and restun on a syscall.EPERM if it is experienced
-		// on a client running darwin.
-		// TODO(charlotte, raggi): expand os options if required.
-		case "darwin":
-			// TODO(charlotte): implement a backoff, so we don't end up in a rebind loop for persistent
-			// EPERMs.
-			if c.lastEPERMRebind.Load().Before(time.Now().Add(-5 * time.Second)) {
-				c.logf("magicsock: performing %q", why)
-				c.lastEPERMRebind.Store(time.Now())
-				c.Rebind()
-				go c.ReSTUN(why)
-				return true
-			}
-		default:
-			c.logf("magicsock: not performing %q", why)
-			return false
-		}
+		return true, "operation-not-permitted"
 	}
-	return false
+	return false, ""
 }

+ 4 - 4
wgengine/magicsock/magicsock_plan9.go

@@ -5,8 +5,8 @@
 
 package magicsock
 
-// maybeRebindOnError performs a rebind and restun if the error is defined and
-// any conditionals are met.
-func (c *Conn) maybeRebindOnError(os string, err error) bool {
-	return false
+// shouldRebind returns if the error is one that is known to be healed by a
+// rebind, and if so also returns a resason string for the rebind.
+func shouldRebind(err error) (ok bool, reason string) {
+	return false, ""
 }

+ 56 - 25
wgengine/magicsock/magicsock_test.go

@@ -3050,37 +3050,68 @@ func TestMaybeSetNearestDERP(t *testing.T) {
 	}
 }
 
+func TestShouldRebind(t *testing.T) {
+	tests := []struct {
+		err    error
+		ok     bool
+		reason string
+	}{
+		{nil, false, ""},
+		{io.EOF, false, ""},
+		{io.ErrUnexpectedEOF, false, ""},
+		{io.ErrShortBuffer, false, ""},
+		{&net.OpError{Err: syscall.EPERM}, true, "operation-not-permitted"},
+		{&net.OpError{Err: syscall.EPIPE}, true, "broken-pipe"},
+		{&net.OpError{Err: syscall.ENOTCONN}, true, "broken-pipe"},
+	}
+	for _, tt := range tests {
+		t.Run(fmt.Sprintf("%s-%v", tt.err, tt.ok), func(t *testing.T) {
+			if got, reason := shouldRebind(tt.err); got != tt.ok || reason != tt.reason {
+				t.Errorf("errShouldRebind(%v) = %v, %q; want %v, %q", tt.err, got, reason, tt.ok, tt.reason)
+			}
+		})
+	}
+}
+
 func TestMaybeRebindOnError(t *testing.T) {
 	tstest.PanicOnLog()
 	tstest.ResourceCheck(t)
 
-	err := fmt.Errorf("outer err: %w", syscall.EPERM)
-
-	t.Run("darwin-rebind", func(t *testing.T) {
-		conn := newTestConn(t)
-		defer conn.Close()
-		rebound := conn.maybeRebindOnError("darwin", err)
-		if !rebound {
-			t.Errorf("darwin should rebind on syscall.EPERM")
-		}
-	})
-
-	t.Run("linux-not-rebind", func(t *testing.T) {
-		conn := newTestConn(t)
-		defer conn.Close()
-		rebound := conn.maybeRebindOnError("linux", err)
-		if rebound {
-			t.Errorf("linux should not rebind on syscall.EPERM")
-		}
-	})
+	var rebindErrs []error
+	if runtime.GOOS != "plan9" {
+		rebindErrs = append(rebindErrs,
+			&net.OpError{Err: syscall.EPERM},
+			&net.OpError{Err: syscall.EPIPE},
+			&net.OpError{Err: syscall.ENOTCONN},
+		)
+	}
+
+	for _, rebindErr := range rebindErrs {
+		t.Run(fmt.Sprintf("rebind-%s", rebindErr), func(t *testing.T) {
+			conn := newTestConn(t)
+			defer conn.Close()
+
+			before := metricRebindCalls.Value()
+			conn.maybeRebindOnError(rebindErr)
+			after := metricRebindCalls.Value()
+			if before+1 != after {
+				t.Errorf("should rebind on %#v", rebindErr)
+			}
+		})
+	}
 
 	t.Run("no-frequent-rebind", func(t *testing.T) {
-		conn := newTestConn(t)
-		defer conn.Close()
-		conn.lastEPERMRebind.Store(time.Now().Add(-1 * time.Second))
-		rebound := conn.maybeRebindOnError("darwin", err)
-		if rebound {
-			t.Errorf("darwin should not rebind on syscall.EPERM within 5 seconds of last")
+		if runtime.GOOS != "plan9" {
+			err := fmt.Errorf("outer err: %w", syscall.EPERM)
+			conn := newTestConn(t)
+			defer conn.Close()
+			conn.lastErrRebind.Store(time.Now().Add(-1 * time.Second))
+			before := metricRebindCalls.Value()
+			conn.maybeRebindOnError(err)
+			after := metricRebindCalls.Value()
+			if before != after {
+				t.Errorf("should not rebind within 5 seconds of last")
+			}
 		}
 	})
 }