Browse Source

control/controlhttp: extract the last network connection

The same context we use for the HTTP request here might be re-used by
the dialer, which could result in `GotConn` being called multiple times.
We only care about the last one.

Fixes #13009

Signed-off-by: Anton Tolchanov <[email protected]>
Anton Tolchanov 1 year ago
parent
commit
7bac5dffcb
2 changed files with 60 additions and 12 deletions
  1. 13 12
      control/controlhttp/client.go
  2. 47 0
      control/controlhttp/http_test.go

+ 13 - 12
control/controlhttp/client.go

@@ -46,6 +46,7 @@ import (
 	"tailscale.com/net/sockstats"
 	"tailscale.com/net/sockstats"
 	"tailscale.com/net/tlsdial"
 	"tailscale.com/net/tlsdial"
 	"tailscale.com/net/tshttpproxy"
 	"tailscale.com/net/tshttpproxy"
+	"tailscale.com/syncs"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tstime"
 	"tailscale.com/tstime"
 	"tailscale.com/util/multierr"
 	"tailscale.com/util/multierr"
@@ -497,11 +498,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
 	tr.DisableCompression = true
 	tr.DisableCompression = true
 
 
 	// (mis)use httptrace to extract the underlying net.Conn from the
 	// (mis)use httptrace to extract the underlying net.Conn from the
-	// transport. We make exactly 1 request using this transport, so
-	// there will be exactly 1 GotConn call. Additionally, the
-	// transport handles 101 Switching Protocols correctly, such that
-	// the Conn will not be reused or kept alive by the transport once
-	// the response has been handed back from RoundTrip.
+	// transport. The transport handles 101 Switching Protocols correctly,
+	// such that the Conn will not be reused or kept alive by the transport
+	// once the response has been handed back from RoundTrip.
 	//
 	//
 	// In theory, the machinery of net/http should make it such that
 	// In theory, the machinery of net/http should make it such that
 	// the trace callback happens-before we get the response, but
 	// the trace callback happens-before we get the response, but
@@ -517,10 +516,16 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
 	// unexpected EOFs...), and we're bound to forget someday and
 	// unexpected EOFs...), and we're bound to forget someday and
 	// introduce a protocol optimization at a higher level that starts
 	// introduce a protocol optimization at a higher level that starts
 	// eagerly transmitting from the server.
 	// eagerly transmitting from the server.
-	connCh := make(chan net.Conn, 1)
+	var lastConn syncs.AtomicValue[net.Conn]
 	trace := httptrace.ClientTrace{
 	trace := httptrace.ClientTrace{
+		// Even though we only make a single HTTP request which should
+		// require a single connection, the context (with the attached
+		// trace configuration) might be used by our custom dialer to
+		// make other HTTP requests (e.g. BootstrapDNS). We only care
+		// about the last connection made, which should be the one to
+		// the control server.
 		GotConn: func(info httptrace.GotConnInfo) {
 		GotConn: func(info httptrace.GotConnInfo) {
-			connCh <- info.Conn
+			lastConn.Store(info.Conn)
 		},
 		},
 	}
 	}
 	ctx = httptrace.WithClientTrace(ctx, &trace)
 	ctx = httptrace.WithClientTrace(ctx, &trace)
@@ -548,11 +553,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
 	// is still a read buffer attached to it within resp.Body. So, we
 	// is still a read buffer attached to it within resp.Body. So, we
 	// must direct I/O through resp.Body, but we can still use the
 	// must direct I/O through resp.Body, but we can still use the
 	// underlying net.Conn for stuff like deadlines.
 	// underlying net.Conn for stuff like deadlines.
-	var switchedConn net.Conn
-	select {
-	case switchedConn = <-connCh:
-	default:
-	}
+	switchedConn := lastConn.Load()
 	if switchedConn == nil {
 	if switchedConn == nil {
 		resp.Body.Close()
 		resp.Body.Close()
 		return nil, fmt.Errorf("httptrace didn't provide a connection")
 		return nil, fmt.Errorf("httptrace didn't provide a connection")

+ 47 - 0
control/controlhttp/http_test.go

@@ -11,10 +11,12 @@ import (
 	"log"
 	"log"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"net/http/httptest"
 	"net/http/httputil"
 	"net/http/httputil"
 	"net/netip"
 	"net/netip"
 	"net/url"
 	"net/url"
 	"runtime"
 	"runtime"
+	"slices"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
@@ -41,6 +43,8 @@ type httpTestParam struct {
 	makeHTTPHangAfterUpgrade bool
 	makeHTTPHangAfterUpgrade bool
 
 
 	doEarlyWrite bool
 	doEarlyWrite bool
+
+	httpInDial bool
 }
 }
 
 
 func TestControlHTTP(t *testing.T) {
 func TestControlHTTP(t *testing.T) {
@@ -120,6 +124,12 @@ func TestControlHTTP(t *testing.T) {
 			name:         "early_write",
 			name:         "early_write",
 			doEarlyWrite: true,
 			doEarlyWrite: true,
 		},
 		},
+		// Dialer needed to make another HTTP request along the way (e.g. to
+		// resolve the hostname via BootstrapDNS).
+		{
+			name:       "http_request_in_dial",
+			httpInDial: true,
+		},
 	}
 	}
 
 
 	for _, test := range tests {
 	for _, test := range tests {
@@ -217,6 +227,29 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
 		Clock:                clock,
 		Clock:                clock,
 	}
 	}
 
 
+	if param.httpInDial {
+		// Spin up a separate server to get a different port on localhost.
+		secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return }))
+		defer secondServer.Close()
+
+		prev := a.Dialer
+		a.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
+			ctx, cancel := context.WithTimeout(ctx, time.Second)
+			defer cancel()
+			req, err := http.NewRequestWithContext(ctx, "GET", secondServer.URL, nil)
+			if err != nil {
+				t.Errorf("http.NewRequest: %v", err)
+			}
+			r, err := http.DefaultClient.Do(req)
+			if err != nil {
+				t.Errorf("http.Get: %v", err)
+			}
+			r.Body.Close()
+
+			return prev(ctx, network, addr)
+		}
+	}
+
 	if proxy != nil {
 	if proxy != nil {
 		proxyEnv := proxy.Start(t)
 		proxyEnv := proxy.Start(t)
 		defer proxy.Close()
 		defer proxy.Close()
@@ -238,6 +271,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
 		t.Fatalf("dialing controlhttp: %v", err)
 		t.Fatalf("dialing controlhttp: %v", err)
 	}
 	}
 	defer conn.Close()
 	defer conn.Close()
+
 	si := <-sch
 	si := <-sch
 	if si.conn != nil {
 	if si.conn != nil {
 		defer si.conn.Close()
 		defer si.conn.Close()
@@ -266,6 +300,19 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
 			t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
 			t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
 		}
 		}
 	}
 	}
+
+	// When no proxy is used, the RemoteAddr of the returned connection should match
+	// one of the listeners of the test server.
+	if proxy == nil {
+		var expectedAddrs []string
+		for _, ln := range []net.Listener{httpLn, httpsLn} {
+			expectedAddrs = append(expectedAddrs, fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port))
+			expectedAddrs = append(expectedAddrs, fmt.Sprintf("[::1]:%d", ln.Addr().(*net.TCPAddr).Port))
+		}
+		if !slices.Contains(expectedAddrs, conn.RemoteAddr().String()) {
+			t.Errorf("unexpected remote addr: %s, want %s", conn.RemoteAddr(), expectedAddrs)
+		}
+	}
 }
 }
 
 
 type serverResult struct {
 type serverResult struct {