Преглед изворни кода

derp/derphttp: add error notify for RunWatchConnectionLoop (#16261)

The caller of client.RunWatchConnectionLoop may need to be
aware of errors that occur within loop. Add a channel
that notifies of errors to the caller to allow for
decisions to be make as to the state of the client.

Updates tailscale/corp#25756

Signed-off-by: Mike O'Driscoll <[email protected]>
Mike O'Driscoll пре 8 месеци
родитељ
комит
e7f5c9a015
3 измењених фајлова са 84 додато и 7 уклоњено
  1. 2 1
      cmd/derper/mesh.go
  2. 69 4
      derp/derphttp/derphttp_test.go
  3. 13 2
      derp/derphttp/mesh_client.go

+ 2 - 1
cmd/derper/mesh.go

@@ -72,6 +72,7 @@ func startMeshWithHost(s *derp.Server, hostTuple string) error {
 
 	add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key, c) }
 	remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer, c) }
-	go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove)
+	notifyError := func(err error) {}
+	go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove, notifyError)
 	return nil
 }

+ 69 - 4
derp/derphttp/derphttp_test.go

@@ -11,12 +11,14 @@ import (
 	"net"
 	"net/http"
 	"net/http/httptest"
+	"strings"
 	"sync"
 	"testing"
 	"time"
 
 	"tailscale.com/derp"
 	"tailscale.com/net/netmon"
+	"tailscale.com/net/netx"
 	"tailscale.com/types/key"
 )
 
@@ -298,6 +300,7 @@ func TestBreakWatcherConnRecv(t *testing.T) {
 	defer cancel()
 
 	watcherChan := make(chan int, 1)
+	errChan := make(chan error, 1)
 
 	// Start the watcher thread (which connects to the watched server)
 	wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343
@@ -311,8 +314,11 @@ func TestBreakWatcherConnRecv(t *testing.T) {
 			watcherChan <- peers
 		}
 		remove := func(m derp.PeerGoneMessage) { t.Logf("remove: %v", m.Peer.ShortString()); peers-- }
+		notifyErr := func(err error) {
+			errChan <- err
+		}
 
-		watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove)
+		watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove, notifyErr)
 	}()
 
 	timer := time.NewTimer(5 * time.Second)
@@ -326,6 +332,10 @@ func TestBreakWatcherConnRecv(t *testing.T) {
 			if peers != 1 {
 				t.Fatal("wrong number of peers added during watcher connection")
 			}
+		case err := <-errChan:
+			if !strings.Contains(err.Error(), "use of closed network connection") {
+				t.Fatalf("expected notifyError connection error to contain 'use of closed network connection', got %v", err)
+			}
 		case <-timer.C:
 			t.Fatalf("watcher did not process the peer update")
 		}
@@ -369,6 +379,7 @@ func TestBreakWatcherConn(t *testing.T) {
 
 	watcherChan := make(chan int, 1)
 	breakerChan := make(chan bool, 1)
+	errorChan := make(chan error, 1)
 
 	// Start the watcher thread (which connects to the watched server)
 	wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343
@@ -384,8 +395,11 @@ func TestBreakWatcherConn(t *testing.T) {
 			<-breakerChan
 		}
 		remove := func(m derp.PeerGoneMessage) { t.Logf("remove: %v", m.Peer.ShortString()); peers-- }
+		notifyError := func(err error) {
+			errorChan <- err
+		}
 
-		watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove)
+		watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove, notifyError)
 	}()
 
 	timer := time.NewTimer(5 * time.Second)
@@ -399,6 +413,10 @@ func TestBreakWatcherConn(t *testing.T) {
 			if peers != 1 {
 				t.Fatal("wrong number of peers added during watcher connection")
 			}
+		case err := <-errorChan:
+			if !strings.Contains(err.Error(), "use of closed network connection") {
+				t.Fatalf("expected notifyError connection error to contain 'use of closed network connection', got %v", err)
+			}
 		case <-timer.C:
 			t.Fatalf("watcher did not process the peer update")
 		}
@@ -414,6 +432,7 @@ func TestBreakWatcherConn(t *testing.T) {
 
 func noopAdd(derp.PeerPresentMessage) {}
 func noopRemove(derp.PeerGoneMessage) {}
+func noopNotifyError(error)           {}
 
 func TestRunWatchConnectionLoopServeConnect(t *testing.T) {
 	defer func() { testHookWatchLookConnectResult = nil }()
@@ -441,7 +460,7 @@ func TestRunWatchConnectionLoopServeConnect(t *testing.T) {
 		}
 		return false
 	}
-	watcher.RunWatchConnectionLoop(ctx, pub, t.Logf, noopAdd, noopRemove)
+	watcher.RunWatchConnectionLoop(ctx, pub, t.Logf, noopAdd, noopRemove, noopNotifyError)
 
 	// Test connecting to the server with a zero value for ignoreServerKey,
 	// so we should always connect.
@@ -455,7 +474,7 @@ func TestRunWatchConnectionLoopServeConnect(t *testing.T) {
 		}
 		return false
 	}
-	watcher.RunWatchConnectionLoop(ctx, key.NodePublic{}, t.Logf, noopAdd, noopRemove)
+	watcher.RunWatchConnectionLoop(ctx, key.NodePublic{}, t.Logf, noopAdd, noopRemove, noopNotifyError)
 }
 
 // verify that the LocalAddr method doesn't acquire the mutex.
@@ -491,3 +510,49 @@ func TestProbe(t *testing.T) {
 		}
 	}
 }
+
+func TestNotifyError(t *testing.T) {
+	defer func() { testHookWatchLookConnectResult = nil }()
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	defer cancel()
+
+	priv := key.NewNode()
+	serverURL, s := newTestServer(t, priv)
+	defer s.Close()
+
+	pub := priv.Public()
+
+	// Test early error notification when c.connect fails.
+	watcher := newWatcherClient(t, priv, serverURL)
+	watcher.SetURLDialer(netx.DialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
+		t.Helper()
+		return nil, fmt.Errorf("test error: %s", addr)
+	}))
+	defer watcher.Close()
+
+	testHookWatchLookConnectResult = func(err error, wasSelfConnect bool) bool {
+		t.Helper()
+		if err == nil {
+			t.Fatal("expected error connecting to server, got nil")
+		}
+		if wasSelfConnect {
+			t.Error("wanted normal connect; got self connect")
+		}
+		return false
+	}
+
+	errChan := make(chan error, 1)
+	notifyError := func(err error) {
+		errChan <- err
+	}
+	watcher.RunWatchConnectionLoop(ctx, pub, t.Logf, noopAdd, noopRemove, notifyError)
+
+	select {
+	case err := <-errChan:
+		if !strings.Contains(err.Error(), "test") {
+			t.Errorf("expected test error, got %v", err)
+		}
+	case <-ctx.Done():
+		t.Fatalf("context done before receiving error: %v", ctx.Err())
+	}
+}

+ 13 - 2
derp/derphttp/mesh_client.go

@@ -31,6 +31,9 @@ var testHookWatchLookConnectResult func(connectError error, wasSelfConnect bool)
 // This behavior will likely change. Callers should do their own accounting
 // and dup suppression as needed.
 //
+// If set the notifyError func is called with any error that occurs within the ctx
+// main loop connection setup, or the inner loop receiving messages via RecvDetail.
+//
 // infoLogf, if non-nil, is the logger to write periodic status updates about
 // how many peers are on the server. Error log output is set to the c's logger,
 // regardless of infoLogf's value.
@@ -42,10 +45,11 @@ var testHookWatchLookConnectResult func(connectError error, wasSelfConnect bool)
 // initialized Client.WatchConnectionChanges to true.
 //
 // If the DERP connection breaks and reconnects, remove will be called for all
-// previously seen peers, with Reason type PeerGoneReasonSynthetic. Those
+// previously seen peers, with Reason type PeerGoneReasonMeshConnBroke. Those
 // clients are likely still connected and their add message will appear after
 // reconnect.
-func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add func(derp.PeerPresentMessage), remove func(derp.PeerGoneMessage)) {
+func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf,
+	add func(derp.PeerPresentMessage), remove func(derp.PeerGoneMessage), notifyError func(error)) {
 	if !c.WatchConnectionChanges {
 		if c.isStarted() {
 			panic("invalid use of RunWatchConnectionLoop on already-started Client without setting Client.RunWatchConnectionLoop")
@@ -121,6 +125,10 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
 		// Make sure we're connected before calling s.ServerPublicKey.
 		_, _, err := c.connect(ctx, "RunWatchConnectionLoop")
 		if err != nil {
+			logf("mesh connect: %v", err)
+			if notifyError != nil {
+				notifyError(err)
+			}
 			if f := testHookWatchLookConnectResult; f != nil && !f(err, false) {
 				return
 			}
@@ -141,6 +149,9 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
 			if err != nil {
 				clear()
 				logf("Recv: %v", err)
+				if notifyError != nil {
+					notifyError(err)
+				}
 				sleep(retryInterval)
 				break
 			}