|
|
@@ -287,6 +287,25 @@ func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.Round
|
|
|
return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
|
|
|
}
|
|
|
|
|
|
+// contextErr is an error that wraps another error and is used to indicate that
|
|
|
+// the error was because a context expired.
|
|
|
+type contextErr struct {
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
+func (e contextErr) Error() string {
|
|
|
+ return e.err.Error()
|
|
|
+}
|
|
|
+
|
|
|
+func (e contextErr) Unwrap() error {
|
|
|
+ return e.err
|
|
|
+}
|
|
|
+
|
|
|
+// getConn returns a noiseConn that can be used to make requests to the
|
|
|
+// coordination server. It may return a cached connection or create a new one.
|
|
|
+// Dials are singleflighted, so concurrent calls to getConn may only dial once.
|
|
|
+// As such, context values may not be respected as there are no guarantees that
|
|
|
+// the context passed to getConn is the same as the context passed to dial.
|
|
|
func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
|
|
|
nc.mu.Lock()
|
|
|
if last := nc.last; last != nil && last.canTakeNewRequest() {
|
|
|
@@ -295,11 +314,35 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
|
|
|
}
|
|
|
nc.mu.Unlock()
|
|
|
|
|
|
- conn, err, _ := nc.sfDial.Do(struct{}{}, nc.dial)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+ for {
|
|
|
+ // We singeflight the dial to avoid making multiple connections, however
|
|
|
+ // that means that we can't simply cancel the dial if the context is
|
|
|
+ // canceled. Instead, we have to additionally check that the context
|
|
|
+ // which was canceled is our context and retry if our context is still
|
|
|
+ // valid.
|
|
|
+ conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseConn, error) {
|
|
|
+ c, err := nc.dial(ctx)
|
|
|
+ if err != nil {
|
|
|
+ if ctx.Err() != nil {
|
|
|
+ return nil, contextErr{ctx.Err()}
|
|
|
+ }
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return c, nil
|
|
|
+ })
|
|
|
+ var ce contextErr
|
|
|
+ if err == nil || !errors.As(err, &ce) {
|
|
|
+ return conn, err
|
|
|
+ }
|
|
|
+ if ctx.Err() == nil {
|
|
|
+ // The dial failed because of a context error, but our context
|
|
|
+ // is still valid. Retry.
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ // The dial failed because our context was canceled. Return the
|
|
|
+ // underlying error.
|
|
|
+ return nil, ce.Unwrap()
|
|
|
}
|
|
|
- return conn, nil
|
|
|
}
|
|
|
|
|
|
func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
@@ -344,7 +387,7 @@ func (nc *NoiseClient) Close() error {
|
|
|
|
|
|
// dial opens a new connection to tailcontrol, fetching the server noise key
|
|
|
// if not cached.
|
|
|
-func (nc *NoiseClient) dial() (*noiseConn, error) {
|
|
|
+func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
|
|
|
nc.mu.Lock()
|
|
|
connID := nc.nextID
|
|
|
nc.nextID++
|
|
|
@@ -392,7 +435,7 @@ func (nc *NoiseClient) dial() (*noiseConn, error) {
|
|
|
}
|
|
|
|
|
|
timeout := time.Duration(timeoutSec * float64(time.Second))
|
|
|
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
|
+ ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
|
defer cancel()
|
|
|
|
|
|
clientConn, err := (&controlhttp.Dialer{
|