Przeglądaj źródła

control/controlhttp: start port 443 fallback sooner if 80's stuck

Fixes #4544

Change-Id: I39877e71915ad48c6668351c45cd8e33e2f5dbae
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 3 lat temu
rodzic
commit
e38d3dfc76
1 zmienionych plików z 82 dodań i 31 usunięć
  1. 82 31
      control/controlhttp/client.go

+ 82 - 31
control/controlhttp/client.go

@@ -30,6 +30,7 @@ import (
 	"net/http"
 	"net/http/httptrace"
 	"net/url"
+	"time"
 
 	"tailscale.com/control/controlbase"
 	"tailscale.com/net/dnscache"
@@ -98,48 +99,98 @@ type dialParams struct {
 }
 
 func (a *dialParams) dial() (*controlbase.Conn, error) {
-	init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
-	if err != nil {
-		return nil, err
-	}
+	// Create one shared context used by both port 80 and port 443 dials.
+	// If port 80 is still in flight when 443 returns, this deferred cancel
+	// will stop the port 80 dial.
+	ctx, cancel := context.WithCancel(a.ctx)
+	defer cancel()
 
-	u := &url.URL{
+	// u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS,
+	// respectively, in order to do the HTTP upgrade to a net.Conn over which
+	// we'll speak Noise.
+	u80 := &url.URL{
 		Scheme: "http",
 		Host:   net.JoinHostPort(a.host, a.httpPort),
 		Path:   serverUpgradePath,
 	}
-	conn, httpErr := a.tryURL(u, init)
-	if httpErr == nil {
-		ret, err := cont(a.ctx, conn)
-		if err != nil {
-			conn.Close()
-			return nil, err
-		}
-		return ret, nil
+	u443 := &url.URL{
+		Scheme: "https",
+		Host:   net.JoinHostPort(a.host, a.httpsPort),
+		Path:   serverUpgradePath,
 	}
-
-	// Connecting over plain HTTP failed, assume it's an HTTP proxy
-	// being difficult and see if we can get through over HTTPS.
-	u.Scheme = "https"
-	u.Host = net.JoinHostPort(a.host, a.httpsPort)
-	init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
-	if err != nil {
-		return nil, err
+	type tryURLRes struct {
+		u    *url.URL
+		conn net.Conn
+		cont controlbase.HandshakeContinuation
+		err  error
 	}
-	conn, tlsErr := a.tryURL(u, init)
-	if tlsErr == nil {
-		ret, err := cont(a.ctx, conn)
-		if err != nil {
-			conn.Close()
-			return nil, err
+	ch := make(chan tryURLRes) // must be unbuffered
+
+	try := func(u *url.URL) {
+		res := tryURLRes{u: u}
+		var init []byte
+		init, res.cont, res.err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
+		if res.err == nil {
+			res.conn, res.err = a.tryURL(ctx, u, init)
+		}
+		select {
+		case ch <- res:
+		case <-ctx.Done():
+			if res.conn != nil {
+				res.conn.Close()
+			}
 		}
-		return ret, nil
 	}
 
-	return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", httpErr, tlsErr)
+	// Start the plaintext HTTP attempt first.
+	go try(u80)
+
+	// In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
+	// to dial port 443 if port 80 doesn't either succeed or fail quickly.
+	try443Timer := time.AfterFunc(500*time.Millisecond, func() { try(u443) })
+	defer try443Timer.Stop()
+
+	var err80, err443 error
+	for {
+		select {
+		case <-ctx.Done():
+			return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err())
+		case res := <-ch:
+			if res.err == nil {
+				ret, err := res.cont(ctx, res.conn)
+				if err != nil {
+					res.conn.Close()
+					return nil, err
+				}
+				return ret, nil
+			}
+			switch res.u {
+			case u80:
+				// Connecting over plain HTTP failed; assume it's an HTTP proxy
+				// being difficult and see if we can get through over HTTPS.
+				err80 = res.err
+				// Stop the fallback timer and run it immediately. We don't use
+				// Timer.Reset(0) here because on AfterFuncs, that can run it
+				// again.
+				if try443Timer.Stop() {
+					go try(u443)
+				} // else we lost the race and it started already which is what we want
+			case u443:
+				err443 = res.err
+			default:
+				panic("invalid")
+			}
+			if err80 != nil && err443 != nil {
+				return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", err80, err443)
+			}
+		}
+	}
 }
 
-func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) {
+// tryURL connects to u, and tries to upgrade it to a net.Conn.
+//
+// Only the provided ctx is used, not a.ctx.
+func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
 	dns := &dnscache.Resolver{
 		Forward:          dnscache.Get().Forward,
 		LookupIPFallback: dnsfallback.Lookup,
@@ -189,7 +240,7 @@ func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) {
 			connCh <- info.Conn
 		},
 	}
-	ctx := httptrace.WithClientTrace(a.ctx, &trace)
+	ctx = httptrace.WithClientTrace(ctx, &trace)
 	req := &http.Request{
 		Method: "POST",
 		URL:    u,