Browse Source

control/controlclient: use ctx passed down to NoiseClient.getConn

Without this, the client would just get stuck dialing even if the
context was canceled.

Updates tailscale/corp#12590

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 2 years ago
parent
commit
9d1a3a995c
1 changed files with 49 additions and 6 deletions
  1. 49 6
      control/controlclient/noise.go

+ 49 - 6
control/controlclient/noise.go

@@ -287,6 +287,25 @@ func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.Round
 	return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
 }
 
+// contextErr is an error that wraps another error and is used to indicate that
+// the error was because a context expired.
+type contextErr struct {
+	err error
+}
+
+func (e contextErr) Error() string {
+	return e.err.Error()
+}
+
+func (e contextErr) Unwrap() error {
+	return e.err
+}
+
+// getConn returns a noiseConn that can be used to make requests to the
+// coordination server. It may return a cached connection or create a new one.
+// Dials are singleflighted, so concurrent calls to getConn may only dial once.
+// As such, context values may not be respected as there are no guarantees that
+// the context passed to getConn is the same as the context passed to dial.
 func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
 	nc.mu.Lock()
 	if last := nc.last; last != nil && last.canTakeNewRequest() {
@@ -295,11 +314,35 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
 	}
 	nc.mu.Unlock()
 
-	conn, err, _ := nc.sfDial.Do(struct{}{}, nc.dial)
-	if err != nil {
-		return nil, err
+	for {
+		// We singeflight the dial to avoid making multiple connections, however
+		// that means that we can't simply cancel the dial if the context is
+		// canceled. Instead, we have to additionally check that the context
+		// which was canceled is our context and retry if our context is still
+		// valid.
+		conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseConn, error) {
+			c, err := nc.dial(ctx)
+			if err != nil {
+				if ctx.Err() != nil {
+					return nil, contextErr{ctx.Err()}
+				}
+				return nil, err
+			}
+			return c, nil
+		})
+		var ce contextErr
+		if err == nil || !errors.As(err, &ce) {
+			return conn, err
+		}
+		if ctx.Err() == nil {
+			// The dial failed because of a context error, but our context
+			// is still valid. Retry.
+			continue
+		}
+		// The dial failed because our context was canceled. Return the
+		// underlying error.
+		return nil, ce.Unwrap()
 	}
-	return conn, nil
 }
 
 func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -344,7 +387,7 @@ func (nc *NoiseClient) Close() error {
 
 // dial opens a new connection to tailcontrol, fetching the server noise key
 // if not cached.
-func (nc *NoiseClient) dial() (*noiseConn, error) {
+func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
 	nc.mu.Lock()
 	connID := nc.nextID
 	nc.nextID++
@@ -392,7 +435,7 @@ func (nc *NoiseClient) dial() (*noiseConn, error) {
 	}
 
 	timeout := time.Duration(timeoutSec * float64(time.Second))
-	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	ctx, cancel := context.WithTimeout(ctx, timeout)
 	defer cancel()
 
 	clientConn, err := (&controlhttp.Dialer{