Browse Source

control/controlhttp: don't assume port 80 upgrade response will work

Just because we get an HTTP upgrade response over port 80, don't
assume we'll be able to do bi-di Noise over it. There might be a MITM
corp proxy or anti-virus/firewall interfering. Do a bit more work to
validate the connection before proceeding to give up on the TLS port
443 dial.

Updates #4557 (probably fixes)

Change-Id: I0e1bcc195af21ad3d360ffe79daead730dfd86f1
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 3 years ago
parent
commit
1237000efe
2 changed files with 106 additions and 48 deletions
  1. 43 28
      control/controlhttp/client.go
  2. 63 20
      control/controlhttp/http_test.go

+ 43 - 28
control/controlhttp/client.go

@@ -70,7 +70,6 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
 		return nil, err
 	}
 	a := &dialParams{
-		ctx:        ctx,
 		host:       host,
 		httpPort:   port,
 		httpsPort:  "443",
@@ -80,11 +79,10 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
 		proxyFunc:  tshttpproxy.ProxyFromEnvironment,
 		dialer:     dialer,
 	}
-	return a.dial()
+	return a.dial(ctx)
 }
 
 type dialParams struct {
-	ctx        context.Context
 	host       string
 	httpPort   string
 	httpsPort  string
@@ -95,14 +93,24 @@ type dialParams struct {
 	dialer     dnscache.DialContextFunc
 
 	// For tests only
-	insecureTLS bool
+	insecureTLS       bool
+	testFallbackDelay time.Duration
 }
 
-func (a *dialParams) dial() (*controlbase.Conn, error) {
+// httpsFallbackDelay is how long we'll wait for a.httpPort to work before
+// starting to try a.httpsPort.
+func (a *dialParams) httpsFallbackDelay() time.Duration {
+	if v := a.testFallbackDelay; v != 0 {
+		return v
+	}
+	return 500 * time.Millisecond
+}
+
+func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
 	// Create one shared context used by both port 80 and port 443 dials.
 	// If port 80 is still in flight when 443 returns, this deferred cancel
 	// will stop the port 80 dial.
-	ctx, cancel := context.WithCancel(a.ctx)
+	ctx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
 	// u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS,
@@ -118,26 +126,20 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
 		Host:   net.JoinHostPort(a.host, a.httpsPort),
 		Path:   serverUpgradePath,
 	}
+
 	type tryURLRes struct {
-		u    *url.URL
-		conn net.Conn
-		cont controlbase.HandshakeContinuation
+		u    *url.URL          // input (the URL conn+err are for/from)
+		conn *controlbase.Conn // result (mutually exclusive with err)
 		err  error
 	}
 	ch := make(chan tryURLRes) // must be unbuffered
-
 	try := func(u *url.URL) {
-		res := tryURLRes{u: u}
-		var init []byte
-		init, res.cont, res.err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
-		if res.err == nil {
-			res.conn, res.err = a.tryURL(ctx, u, init)
-		}
+		cbConn, err := a.dialURL(ctx, u)
 		select {
-		case ch <- res:
+		case ch <- tryURLRes{u, cbConn, err}:
 		case <-ctx.Done():
-			if res.conn != nil {
-				res.conn.Close()
+			if cbConn != nil {
+				cbConn.Close()
 			}
 		}
 	}
@@ -147,7 +149,7 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
 
 	// In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
 	// to dial port 443 if port 80 doesn't either succeed or fail quickly.
-	try443Timer := time.AfterFunc(500*time.Millisecond, func() { try(u443) })
+	try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) })
 	defer try443Timer.Stop()
 
 	var err80, err443 error
@@ -157,12 +159,7 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
 			return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err())
 		case res := <-ch:
 			if res.err == nil {
-				ret, err := res.cont(ctx, res.conn)
-				if err != nil {
-					res.conn.Close()
-					return nil, err
-				}
-				return ret, nil
+				return res.conn, nil
 			}
 			switch res.u {
 			case u80:
@@ -187,10 +184,28 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
 	}
 }
 
-// tryURL connects to u, and tries to upgrade it to a net.Conn.
+// dialURL attempts to connect to the given URL.
+func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
+	init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
+	if err != nil {
+		return nil, err
+	}
+	netConn, err := a.tryURLUpgrade(ctx, u, init)
+	if err != nil {
+		return nil, err
+	}
+	cbConn, err := cont(ctx, netConn)
+	if err != nil {
+		netConn.Close()
+		return nil, err
+	}
+	return cbConn, nil
+}
+
+// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
 //
 // Only the provided ctx is used, not a.ctx.
-func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
+func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
 	dns := &dnscache.Resolver{
 		Forward:          dnscache.Get().Forward,
 		LookupIPFallback: dnsfallback.Lookup,

+ 63 - 20
control/controlhttp/http_test.go

@@ -17,6 +17,7 @@ import (
 	"strconv"
 	"sync"
 	"testing"
+	"time"
 
 	"tailscale.com/control/controlbase"
 	"tailscale.com/net/socks5"
@@ -24,16 +25,28 @@ import (
 	"tailscale.com/types/key"
 )
 
+type httpTestParam struct {
+	name  string
+	proxy proxy
+
+	// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
+	// 101 switching protocols.
+	makeHTTPHangAfterUpgrade bool
+}
+
 func TestControlHTTP(t *testing.T) {
-	tests := []struct {
-		name  string
-		proxy proxy
-	}{
+	tests := []httpTestParam{
 		// direct connection
 		{
 			name:  "no_proxy",
 			proxy: nil,
 		},
+		// direct connection but port 80 is MITM'ed and broken
+		{
+			name:                     "port80_broken_mitm",
+			proxy:                    nil,
+			makeHTTPHangAfterUpgrade: true,
+		},
 		// SOCKS5
 		{
 			name:  "socks5",
@@ -97,12 +110,13 @@ func TestControlHTTP(t *testing.T) {
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			testControlHTTP(t, test.proxy)
+			testControlHTTP(t, test)
 		})
 	}
 }
 
-func testControlHTTP(t *testing.T, proxy proxy) {
+func testControlHTTP(t *testing.T, param httpTestParam) {
+	proxy := param.proxy
 	client, server := key.NewMachine(), key.NewMachine()
 
 	const testProtocolVersion = 1
@@ -133,7 +147,11 @@ func testControlHTTP(t *testing.T, proxy proxy) {
 		t.Fatalf("HTTPS listen: %v", err)
 	}
 
-	httpServer := &http.Server{Handler: handler}
+	var httpHandler http.Handler = handler
+	if param.makeHTTPHangAfterUpgrade {
+		httpHandler = http.HandlerFunc(brokenMITMHandler)
+	}
+	httpServer := &http.Server{Handler: httpHandler}
 	go httpServer.Serve(httpLn)
 	defer httpServer.Close()
 
@@ -144,19 +162,24 @@ func testControlHTTP(t *testing.T, proxy proxy) {
 	go httpsServer.ServeTLS(httpsLn, "", "")
 	defer httpsServer.Close()
 
-	//ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
-	//defer cancel()
+	ctx := context.Background()
+	const debugTimeout = false
+	if debugTimeout {
+		var cancel context.CancelFunc
+		ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
+		defer cancel()
+	}
 
 	a := dialParams{
-		ctx:         context.Background(), //ctx,
-		host:        "localhost",
-		httpPort:    strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
-		httpsPort:   strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
-		machineKey:  client,
-		controlKey:  server.Public(),
-		version:     testProtocolVersion,
-		insecureTLS: true,
-		dialer:      new(tsdial.Dialer).SystemDial,
+		host:              "localhost",
+		httpPort:          strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
+		httpsPort:         strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
+		machineKey:        client,
+		controlKey:        server.Public(),
+		version:           testProtocolVersion,
+		insecureTLS:       true,
+		dialer:            new(tsdial.Dialer).SystemDial,
+		testFallbackDelay: 50 * time.Millisecond,
 	}
 
 	if proxy != nil {
@@ -175,7 +198,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
 		}
 	}
 
-	conn, err := a.dial()
+	conn, err := a.dial(ctx)
 	if err != nil {
 		t.Fatalf("dialing controlhttp: %v", err)
 	}
@@ -217,6 +240,7 @@ type proxy interface {
 
 type socksProxy struct {
 	sync.Mutex
+	closed          bool
 	proxy           socks5.Server
 	ln              net.Listener
 	clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
@@ -232,7 +256,14 @@ func (s *socksProxy) Start(t *testing.T) (url string) {
 	}
 	s.ln = ln
 	s.clientConnAddrs = map[string]bool{}
-	s.proxy.Logf = t.Logf
+	s.proxy.Logf = func(format string, a ...any) {
+		s.Lock()
+		defer s.Unlock()
+		if s.closed {
+			return
+		}
+		t.Logf(format, a...)
+	}
 	s.proxy.Dialer = s.dialAndRecord
 	go s.proxy.Serve(ln)
 	return fmt.Sprintf("socks5://%s", ln.Addr().String())
@@ -241,6 +272,10 @@ func (s *socksProxy) Start(t *testing.T) (url string) {
 func (s *socksProxy) Close() {
 	s.Lock()
 	defer s.Unlock()
+	if s.closed {
+		return
+	}
+	s.closed = true
 	s.ln.Close()
 }
 
@@ -400,3 +435,11 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
 		Certificates: []tls.Certificate{cert},
 	}
 }
+
+func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
+	w.Header().Set("Upgrade", upgradeHeaderValue)
+	w.Header().Set("Connection", "upgrade")
+	w.WriteHeader(http.StatusSwitchingProtocols)
+	w.(http.Flusher).Flush()
+	<-r.Context().Done()
+}