Просмотр исходного кода

net/ace, control/controlhttp: start adding ACE dialing support

Updates tailscale/corp#32227

Change-Id: I38afc668f99eb1d6f7632e82554b82922f3ebb9f
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 5 месяцев назад
Родитель
Сommit
ecfdd86fc9

+ 1 - 0
cmd/k8s-operator/depaware.txt

@@ -842,6 +842,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
         tailscale.com/logtail/backoff                                from tailscale.com/control/controlclient+
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
+        tailscale.com/net/ace                                        from tailscale.com/control/controlhttp
         tailscale.com/net/bakedroots                                 from tailscale.com/net/tlsdial+
      💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+

+ 21 - 1
cmd/tailscale/cli/debug.go

@@ -35,6 +35,7 @@ import (
 	"tailscale.com/hostinfo"
 	"tailscale.com/internal/noiseconn"
 	"tailscale.com/ipn"
+	"tailscale.com/net/ace"
 	"tailscale.com/net/netmon"
 	"tailscale.com/net/tsaddr"
 	"tailscale.com/net/tshttpproxy"
@@ -287,6 +288,7 @@ func debugCmd() *ffcli.Command {
 					fs.StringVar(&ts2021Args.host, "host", "controlplane.tailscale.com", "hostname of control plane")
 					fs.IntVar(&ts2021Args.version, "version", int(tailcfg.CurrentCapabilityVersion), "protocol version")
 					fs.BoolVar(&ts2021Args.verbose, "verbose", false, "be extra verbose")
+					fs.StringVar(&ts2021Args.aceHost, "ace", "", "if non-empty, use this ACE server IP/hostname as a candidate path")
 					return fs
 				})(),
 			},
@@ -964,6 +966,7 @@ var ts2021Args struct {
 	host    string // "controlplane.tailscale.com"
 	version int    // 27 or whatever
 	verbose bool
+	aceHost string // if non-empty, FQDN of https ACE server to use ("ace.example.com")
 }
 
 func runTS2021(ctx context.Context, args []string) error {
@@ -972,6 +975,13 @@ func runTS2021(ctx context.Context, args []string) error {
 
 	keysURL := "https://" + ts2021Args.host + "/key?v=" + strconv.Itoa(ts2021Args.version)
 
+	keyTransport := http.DefaultTransport.(*http.Transport).Clone()
+	if ts2021Args.aceHost != "" {
+		log.Printf("using ACE server %q", ts2021Args.aceHost)
+		keyTransport.Proxy = nil
+		keyTransport.DialContext = (&ace.Dialer{ACEHost: ts2021Args.aceHost}).Dial
+	}
+
 	if ts2021Args.verbose {
 		u, err := url.Parse(keysURL)
 		if err != nil {
@@ -997,7 +1007,7 @@ func runTS2021(ctx context.Context, args []string) error {
 	if err != nil {
 		return err
 	}
-	res, err := http.DefaultClient.Do(req)
+	res, err := keyTransport.RoundTrip(req)
 	if err != nil {
 		log.Printf("Do: %v", err)
 		return err
@@ -1052,6 +1062,16 @@ func runTS2021(ctx context.Context, args []string) error {
 		Logf:            logf,
 		NetMon:          netMon,
 	}
+	if ts2021Args.aceHost != "" {
+		noiseDialer.DialPlan = &tailcfg.ControlDialPlan{
+			Candidates: []tailcfg.ControlIPCandidate{
+				{
+					ACEHost:        ts2021Args.aceHost,
+					DialTimeoutSec: 10,
+				},
+			},
+		}
+	}
 	const tries = 2
 	for i := range tries {
 		err := tryConnect(ctx, keys.PublicKey, noiseDialer)

+ 1 - 0
cmd/tailscale/depaware.txt

@@ -120,6 +120,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/kube/kubetypes                                 from tailscale.com/envknob
         tailscale.com/licenses                                       from tailscale.com/client/web+
         tailscale.com/metrics                                        from tailscale.com/derp+
+        tailscale.com/net/ace                                        from tailscale.com/cmd/tailscale/cli+
         tailscale.com/net/bakedroots                                 from tailscale.com/net/tlsdial
         tailscale.com/net/captivedetection                           from tailscale.com/net/netcheck
         tailscale.com/net/dnscache                                   from tailscale.com/control/controlhttp+

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -314,6 +314,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/logtail/backoff                                from tailscale.com/cmd/tailscaled+
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
+        tailscale.com/net/ace                                        from tailscale.com/control/controlhttp
         tailscale.com/net/bakedroots                                 from tailscale.com/net/tlsdial+
      💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock+
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+

+ 1 - 0
cmd/tsidp/depaware.txt

@@ -273,6 +273,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar
         tailscale.com/logtail/backoff                                from tailscale.com/control/controlclient+
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
+        tailscale.com/net/ace                                        from tailscale.com/control/controlhttp
         tailscale.com/net/bakedroots                                 from tailscale.com/ipn/ipnlocal+
      💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+

+ 54 - 24
control/controlhttp/client.go

@@ -20,6 +20,7 @@
 package controlhttp
 
 import (
+	"cmp"
 	"context"
 	"crypto/tls"
 	"encoding/base64"
@@ -41,6 +42,7 @@ import (
 	"tailscale.com/control/controlhttp/controlhttpcommon"
 	"tailscale.com/envknob"
 	"tailscale.com/health"
+	"tailscale.com/net/ace"
 	"tailscale.com/net/dnscache"
 	"tailscale.com/net/dnsfallback"
 	"tailscale.com/net/netutil"
@@ -104,7 +106,7 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 	// host we know about.
 	useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
 	if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 {
-		return a.dialHost(ctx, netip.Addr{})
+		return a.dialHost(ctx)
 	}
 	candidates := a.DialPlan.Candidates
 
@@ -125,10 +127,9 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 
 	// Now, for each candidate, kick off a dial in parallel.
 	type dialResult struct {
-		conn     *ClientConn
-		err      error
-		addr     netip.Addr
-		priority int
+		conn *ClientConn
+		err  error
+		cand tailcfg.ControlIPCandidate
 	}
 	resultsCh := make(chan dialResult, len(candidates))
 
@@ -143,7 +144,7 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 
 			// Always send results back to our channel.
 			defer func() {
-				resultsCh <- dialResult{conn, err, c.IP, c.Priority}
+				resultsCh <- dialResult{conn, err, c}
 				if pending.Add(-1) == 0 {
 					close(resultsCh)
 				}
@@ -168,9 +169,13 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 			ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
 			defer cancel()
 
+			if c.IP.IsValid() {
+				a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
+			} else if c.ACEHost != "" {
+				a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, c.ACEHost)
+			}
 			// This will dial, and the defer above sends it back to our parent.
-			a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
-			conn, err = a.dialHost(ctx, c.IP)
+			conn, err = a.dialHostOpt(ctx, c.IP, c.ACEHost)
 		}(ctx, c)
 	}
 
@@ -183,8 +188,8 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 		// TODO(andrew): we could make this better by keeping track of
 		// the highest remaining priority dynamically, instead of just
 		// checking for the highest total
-		if res.priority == highestPriority && res.conn != nil {
-			a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr)
+		if res.cand.Priority == highestPriority && res.conn != nil {
+			a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, cmp.Or(res.cand.ACEHost, res.cand.IP.String()))
 
 			// Drain the channel and any existing connections in
 			// the background.
@@ -232,7 +237,7 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 	sort.Slice(results, func(i, j int) bool {
 		// NOTE: intentionally inverted so that the highest priority
 		// item comes first
-		return results[i].priority > results[j].priority
+		return results[i].cand.Priority > results[j].cand.Priority
 	})
 
 	var (
@@ -245,7 +250,7 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 			continue
 		}
 
-		a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr)
+		a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(result.cand.ACEHost, result.cand.IP.String()))
 		conn = result.conn
 		results[i].conn = nil // so we don't close it in the defer
 		return conn, nil
@@ -259,7 +264,7 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
 
 	// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
 	a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
-	return a.dialHost(ctx, netip.Addr{})
+	return a.dialHost(ctx)
 }
 
 // The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to
@@ -316,10 +321,19 @@ var debugNoiseDial = envknob.RegisterBool("TS_DEBUG_NOISE_DIAL")
 
 // dialHost connects to the configured Dialer.Hostname and upgrades the
 // connection into a controlbase.Conn.
+func (a *Dialer) dialHost(ctx context.Context) (*ClientConn, error) {
+	return a.dialHostOpt(ctx,
+		netip.Addr{}, // no pre-resolved IP
+		"",           // don't use ACE
+	)
+}
+
+// dialHostOpt connects to the configured Dialer.Hostname and upgrades the
+// connection into a controlbase.Conn.
 //
 // If optAddr is valid, then no DNS is used and the connection will be made to the
 // provided address.
-func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn, error) {
+func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost string) (*ClientConn, error) {
 	// Create one shared context used by both port 80 and port 443 dials.
 	// If port 80 is still in flight when 443 returns, this deferred cancel
 	// will stop the port 80 dial.
@@ -341,7 +355,7 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn,
 		Host:   net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")),
 		Path:   serverUpgradePath,
 	}
-	if a.HTTPSPort == NoPort {
+	if a.HTTPSPort == NoPort || optACEHost != "" {
 		u443 = nil
 	}
 
@@ -353,11 +367,11 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn,
 	ch := make(chan tryURLRes) // must be unbuffered
 	try := func(u *url.URL) {
 		if debugNoiseDial() {
-			a.logf("trying noise dial (%v, %v) ...", u, optAddr)
+			a.logf("trying noise dial (%v, %v) ...", u, cmp.Or(optACEHost, optAddr.String()))
 		}
-		cbConn, err := a.dialURL(ctx, u, optAddr)
+		cbConn, err := a.dialURL(ctx, u, optAddr, optACEHost)
 		if debugNoiseDial() {
-			a.logf("noise dial (%v, %v) = (%v, %v)", u, optAddr, cbConn, err)
+			a.logf("noise dial (%v, %v) = (%v, %v)", u, cmp.Or(optACEHost, optAddr.String()), cbConn, err)
 		}
 		select {
 		case ch <- tryURLRes{u, cbConn, err}:
@@ -423,12 +437,12 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn,
 //
 // If optAddr is valid, then no DNS is used and the connection will be made to the
 // provided address.
-func (a *Dialer) dialURL(ctx context.Context, u *url.URL, optAddr netip.Addr) (*ClientConn, error) {
+func (a *Dialer) dialURL(ctx context.Context, u *url.URL, optAddr netip.Addr, optACEHost string) (*ClientConn, error) {
 	init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
 	if err != nil {
 		return nil, err
 	}
-	netConn, err := a.tryURLUpgrade(ctx, u, optAddr, init)
+	netConn, err := a.tryURLUpgrade(ctx, u, optAddr, optACEHost, init)
 	if err != nil {
 		return nil, err
 	}
@@ -480,7 +494,7 @@ var macOSScreenTime = health.Register(&health.Warnable{
 // the provided address.
 //
 // Only the provided ctx is used, not a.ctx.
-func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Addr, init []byte) (_ net.Conn, retErr error) {
+func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Addr, optACEHost string, init []byte) (_ net.Conn, retErr error) {
 	var dns *dnscache.Resolver
 
 	// If we were provided an address to dial, then create a resolver that just
@@ -502,6 +516,14 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad
 		dialer = stdDialer.DialContext
 	}
 
+	if optACEHost != "" {
+		dialer = (&ace.Dialer{
+			ACEHost:   optACEHost,
+			ACEHostIP: optAddr, // may be zero
+			NetDialer: dialer,
+		}).Dial
+	}
+
 	// On macOS, see if Screen Time is blocking things.
 	if runtime.GOOS == "darwin" {
 		var proxydIntercepted atomic.Bool // intercepted by macOS webfilterproxyd
@@ -528,9 +550,17 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad
 
 	tr := http.DefaultTransport.(*http.Transport).Clone()
 	defer tr.CloseIdleConnections()
-	tr.Proxy = a.getProxyFunc()
-	tshttpproxy.SetTransportGetProxyConnectHeader(tr)
-	tr.DialContext = dnscache.Dialer(dialer, dns)
+	if optACEHost != "" {
+		// If using ACE, we don't want to use any HTTP proxy.
+		// ACE is already a tunnel+proxy.
+		// TODO(tailscale/corp#32483): use system proxy too?
+		tr.Proxy = nil
+		tr.DialContext = dialer
+	} else {
+		tr.Proxy = a.getProxyFunc()
+		tshttpproxy.SetTransportGetProxyConnectHeader(tr)
+		tr.DialContext = dnscache.Dialer(dialer, dns)
+	}
 	// Disable HTTP2, since h2 can't do protocol switching.
 	tr.TLSClientConfig.NextProtos = []string{}
 	tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}

+ 123 - 0
net/ace/ace.go

@@ -0,0 +1,123 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package ace implements a Dialer that dials via a Tailscale ACE (CONNECT)
+// proxy.
+//
+// TODO: document this more, when it's more done. As of 2025-09-17, it's in
+// development.
+package ace
+
+import (
+	"bufio"
+	"cmp"
+	"context"
+	"crypto/tls"
+	"errors"
+	"fmt"
+	"net"
+	"net/http"
+	"net/netip"
+	"sync/atomic"
+)
+
+// Dialer is an HTTP CONNECT proxy dialer to dial the control plane via an ACE
+// proxy.
+type Dialer struct {
+	ACEHost   string
+	ACEHostIP netip.Addr // optional; if non-zero, use this IP instead of DNS
+	ACEPort   int        // zero means 443
+
+	NetDialer func(ctx context.Context, network, address string) (net.Conn, error)
+}
+
+func (d *Dialer) netDialer() func(ctx context.Context, network, address string) (net.Conn, error) {
+	if d.NetDialer != nil {
+		return d.NetDialer
+	}
+	var std net.Dialer
+	return std.DialContext
+}
+
+func (d *Dialer) acePort() int { return cmp.Or(d.ACEPort, 443) }
+
+func (d *Dialer) Dial(ctx context.Context, network, address string) (_ net.Conn, err error) {
+	if network != "tcp" {
+		return nil, errors.New("only TCP is supported")
+	}
+
+	var targetHost string
+	if d.ACEHostIP.IsValid() {
+		targetHost = d.ACEHostIP.String()
+	} else {
+		targetHost = d.ACEHost
+	}
+
+	cc, err := d.netDialer()(ctx, "tcp", net.JoinHostPort(targetHost, fmt.Sprint(d.acePort())))
+	if err != nil {
+		return nil, err
+	}
+
+	// Now that we've dialed, we're about to do three potentially blocking
+	// operations: the TLS handshake, the CONNECT write, and the HTTP response
+	// read. To make our context work over all that, we use a context.AfterFunc
+	// to start a goroutine that'll tear down the underlying connection if the
+	// context expires.
+	//
+	// To prevent races, we use an atomic.Bool to guard access to the underlying
+	// connection being either good or bad. Only one goroutine (the success path
+	// in this goroutine after the ReadResponse or the AfterFunc's failure
+	// goroutine) will compare-and-swap it from false to true.
+	var done atomic.Bool
+	stop := context.AfterFunc(ctx, func() {
+		if done.CompareAndSwap(false, true) {
+			cc.Close()
+		}
+	})
+	defer func() {
+		if err != nil {
+			if ctx.Err() != nil {
+				// Prefer the context error. The other error is likely a side
+				// effect of the context expiring and our tearing down of the
+				// underlying connection, and is thus probably something like
+				// "use of closed network connection", which isn't useful (and
+				// actually misleading) for the caller.
+				err = ctx.Err()
+			}
+			stop()
+			cc.Close()
+		}
+	}()
+
+	tc := tls.Client(cc, &tls.Config{ServerName: d.ACEHost})
+	if err := tc.Handshake(); err != nil {
+		return nil, err
+	}
+
+	// TODO(tailscale/corp#32484): send proxy-auth header
+	if _, err := fmt.Fprintf(tc, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", address, d.ACEHost); err != nil {
+		return nil, err
+	}
+
+	br := bufio.NewReader(tc)
+	connRes, err := http.ReadResponse(br, &http.Request{Method: "CONNECT"})
+	if err != nil {
+		return nil, fmt.Errorf("reading CONNECT response: %w", err)
+	}
+
+	// Now that we're done with blocking operations, mark the connection
+	// as good, to prevent the context's AfterFunc from closing it.
+	if !stop() || !done.CompareAndSwap(false, true) {
+		// We lost a race and the context expired.
+		return nil, ctx.Err()
+	}
+
+	if connRes.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("ACE CONNECT response: %s", connRes.Status)
+	}
+
+	if br.Buffered() > 0 {
+		return nil, fmt.Errorf("unexpected %d bytes of buffered data after ACE CONNECT", br.Buffered())
+	}
+	return tc, nil
+}

+ 8 - 1
tailcfg/tailcfg.go

@@ -2264,7 +2264,14 @@ type ControlDialPlan struct {
 // connecting to the control server.
 type ControlIPCandidate struct {
 	// IP is the address to attempt connecting to.
-	IP netip.Addr
+	IP netip.Addr `json:",omitzero"`
+
+	// ACEHost, if non-empty, means that the client should connect to the
+	// control plane using an HTTPS CONNECT request to the provided hostname. If
+	// the IP field is also set, then the IP is the IP address of the ACEHost
+	// (and not the control plane) and DNS should not be used. The target (the
+	// argument to CONNECT) is always the control plane's hostname, not an IP.
+	ACEHost string `json:",omitempty"`
 
 	// DialStartSec is the number of seconds after the beginning of the
 	// connection process to wait before trying this candidate.

+ 1 - 0
tsnet/depaware.txt

@@ -269,6 +269,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware)
         tailscale.com/logtail/backoff                                from tailscale.com/control/controlclient+
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
+        tailscale.com/net/ace                                        from tailscale.com/control/controlhttp
         tailscale.com/net/bakedroots                                 from tailscale.com/ipn/ipnlocal+
      💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+