Просмотр исходного кода

tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field (#5648)

* tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field

This field allows the control server to provide explicit information
about how to connect to it; useful if the client's link status can
change after the initial connection, or if the DNS settings pushed by
the control server break future connections.

Change-Id: I720afe6289ec27d40a41b3dcb310ec45bd7e5f3e
Signed-off-by: Andrew Dunham <[email protected]>
Andrew Dunham 3 лет назад
Родитель
Сommit
e1bdbfe710

+ 1 - 0
cmd/tailscale/depaware.txt

@@ -100,6 +100,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/util/groupmember                               from tailscale.com/cmd/tailscale/cli
         tailscale.com/util/lineread                                  from tailscale.com/net/interfaces+
         tailscale.com/util/mak                                       from tailscale.com/net/netcheck
+        tailscale.com/util/multierr                                  from tailscale.com/control/controlhttp
         tailscale.com/util/singleflight                              from tailscale.com/net/dnscache
    L    tailscale.com/util/strs                                      from tailscale.com/hostinfo
    W 💣 tailscale.com/util/winutil                                   from tailscale.com/hostinfo+

+ 45 - 1
control/controlclient/direct.go

@@ -76,6 +76,8 @@ type Direct struct {
 	popBrowser             func(url string) // or nil
 	c2nHandler             http.Handler     // or nil
 
+	dialPlan ControlDialPlanner // can be nil
+
 	mu             sync.Mutex        // mutex guards the following fields
 	serverKey      key.MachinePublic // original ("legacy") nacl crypto_box-based public key
 	serverNoiseKey key.MachinePublic
@@ -133,6 +135,34 @@ type Options struct {
 	// MapResponse.PingRequest queries from the control plane.
 	// If nil, PingRequest queries are not answered.
 	Pinger Pinger
+
+	// DialPlan contains and stores a previous dial plan that we received
+	// from the control server; if nil, we fall back to using DNS.
+	//
+	// If we receive a new DialPlan from the server, this value will be
+	// updated.
+	DialPlan ControlDialPlanner
+}
+
+// ControlDialPlanner is the interface optionally supplied when creating a
+// control client to control exactly how TCP connections to the control plane
+// are dialed.
+//
+// It is usually implemented by an atomic.Pointer.
+type ControlDialPlanner interface {
+	// Load returns the current plan for how to connect to control.
+	//
+	// The returned plan can be nil. If so, connections should be made by
+	// resolving the control URL using DNS.
+	Load() *tailcfg.ControlDialPlan
+
+	// Store updates the dial plan with new directions from the control
+	// server.
+	//
+	// The dial plan can span multiple connections to the control server.
+	// That is, a dial plan received when connected over Wi-Fi is still
+	// valid for a subsequent connection over LTE after a network switch.
+	Store(*tailcfg.ControlDialPlan)
 }
 
 // Pinger is the LocalBackend.Ping method.
@@ -216,6 +246,7 @@ func NewDirect(opts Options) (*Direct, error) {
 		popBrowser:             opts.PopBrowserURL,
 		c2nHandler:             opts.C2NHandler,
 		dialer:                 opts.Dialer,
+		dialPlan:               opts.DialPlan,
 	}
 	if opts.Hostinfo == nil {
 		c.SetHostinfo(hostinfo.New())
@@ -915,6 +946,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
 		} else {
 			vlogf("netmap: got new map")
 		}
+		if resp.ControlDialPlan != nil {
+			if c.dialPlan != nil {
+				c.logf("netmap: got new dial plan from control")
+				c.dialPlan.Store(resp.ControlDialPlan)
+			} else {
+				c.logf("netmap: [unexpected] new dial plan; nowhere to store it")
+			}
+		}
 
 		select {
 		case timeoutReset <- struct{}{}:
@@ -1365,12 +1404,17 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) {
 	if nc != nil {
 		return nc, nil
 	}
+	var dp func() *tailcfg.ControlDialPlan
+	if c.dialPlan != nil {
+		dp = c.dialPlan.Load
+	}
 	nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*noiseClient, error) {
 		k, err := c.getMachinePrivKey()
 		if err != nil {
 			return nil, err
 		}
-		nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer)
+		c.logf("creating new noise client")
+		nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp)
 		if err != nil {
 			return nil, err
 		}

+ 50 - 6
control/controlclient/noise.go

@@ -53,6 +53,11 @@ type noiseClient struct {
 	httpPort     string // the default port to call
 	httpsPort    string // the fallback Noise-over-https port
 
+	// dialPlan optionally returns a ControlDialPlan previously received
+	// from the control server; either the function or the return value can
+	// be nil.
+	dialPlan func() *tailcfg.ControlDialPlan
+
 	// mu only protects the following variables.
 	mu       sync.Mutex
 	nextID   int
@@ -61,7 +66,9 @@ type noiseClient struct {
 
 // newNoiseClient returns a new noiseClient for the provided server and machine key.
 // serverURL is of the form https://<host>:<port> (no trailing slash).
-func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer) (*noiseClient, error) {
+//
+// dialPlan may be nil
+func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) {
 	u, err := url.Parse(serverURL)
 	if err != nil {
 		return nil, err
@@ -89,6 +96,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s
 		httpPort:     httpPort,
 		httpsPort:    httpsPort,
 		dialer:       dialer,
+		dialPlan:     dialPlan,
 	}
 
 	// Create the HTTP/2 Transport using a net/http.Transport
@@ -155,16 +163,51 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
 	nc.nextID++
 	nc.mu.Unlock()
 
-	// Timeout is a little arbitrary, but plenty long enough for even the
-	// highest latency links.
-	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
-	defer cancel()
-
 	if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
 		// Panic, because a test should have started failing several
 		// thousand version numbers before getting to this point.
 		panic("capability version is too high to fit in the wire protocol")
 	}
+
+	var dialPlan *tailcfg.ControlDialPlan
+	if nc.dialPlan != nil {
+		dialPlan = nc.dialPlan()
+	}
+
+	// If we have a dial plan, then set our timeout as slightly longer than
+	// the maximum amount of time contained therein; we assume that
+	// explicit instructions on timeouts are more useful than a single
+	// hard-coded timeout.
+	//
+	// The default value of 5 is chosen so that, when there's no dial plan,
+	// we retain the previous behaviour of 10 seconds end-to-end timeout.
+	timeoutSec := 5.0
+	if dialPlan != nil {
+		for _, c := range dialPlan.Candidates {
+			if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec {
+				timeoutSec = v
+			}
+		}
+	}
+
+	// After we establish a connection, we need some time to actually
+	// upgrade it into a Noise connection. With a ballpark worst-case RTT
+	// of 1000ms, give ourselves an extra 5 seconds to complete the
+	// handshake.
+	timeoutSec += 5
+
+	// Be extremely defensive and ensure that the timeout is in the range
+	// [5, 60] seconds (e.g. if we accidentally get a negative number).
+	if timeoutSec > 60 {
+		timeoutSec = 60
+	} else if timeoutSec < 5 {
+		timeoutSec = 5
+	}
+
+	timeout := time.Duration(timeoutSec * float64(time.Second))
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+
 	conn, err := (&controlhttp.Dialer{
 		Hostname:        nc.host,
 		HTTPPort:        nc.httpPort,
@@ -173,6 +216,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
 		ControlKey:      nc.serverPubKey,
 		ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion),
 		Dialer:          nc.dialer.SystemDial,
+		DialPlan:        dialPlan,
 	}).Dial(ctx)
 	if err != nil {
 		return nil, err

+ 192 - 9
control/controlhttp/client.go

@@ -28,18 +28,25 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"math"
 	"net"
 	"net/http"
 	"net/http/httptrace"
+	"net/netip"
 	"net/url"
+	"sort"
+	"sync/atomic"
 	"time"
 
 	"tailscale.com/control/controlbase"
+	"tailscale.com/envknob"
 	"tailscale.com/net/dnscache"
 	"tailscale.com/net/dnsfallback"
 	"tailscale.com/net/netutil"
 	"tailscale.com/net/tlsdial"
 	"tailscale.com/net/tshttpproxy"
+	"tailscale.com/tailcfg"
+	"tailscale.com/util/multierr"
 )
 
 var stdDialer net.Dialer
@@ -82,7 +89,170 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
 	return 500 * time.Millisecond
 }
 
+var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
+
 func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
+	// If we don't have a dial plan, just fall back to dialing the single
+	// host we know about.
+	useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
+	if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 {
+		return a.dialHost(ctx, netip.Addr{})
+	}
+	candidates := a.DialPlan.Candidates
+
+	// Otherwise, we try dialing per the plan. Store the highest priority
+	// in the list, so that if we get a connection to one of those
+	// candidates we can return quickly.
+	var highestPriority int = math.MinInt
+	for _, c := range candidates {
+		if c.Priority > highestPriority {
+			highestPriority = c.Priority
+		}
+	}
+
+	// This context allows us to cancel in-flight connections if we get a
+	// highest-priority connection before we're all done.
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+
+	// Now, for each candidate, kick off a dial in parallel.
+	type dialResult struct {
+		conn     *controlbase.Conn
+		err      error
+		addr     netip.Addr
+		priority int
+	}
+	resultsCh := make(chan dialResult, len(candidates))
+
+	var pending atomic.Int32
+	pending.Store(int32(len(candidates)))
+	for _, c := range candidates {
+		go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
+			var (
+				conn *controlbase.Conn
+				err  error
+			)
+
+			// Always send results back to our channel.
+			defer func() {
+				resultsCh <- dialResult{conn, err, c.IP, c.Priority}
+				if pending.Add(-1) == 0 {
+					close(resultsCh)
+				}
+			}()
+
+			// If non-zero, wait the configured start timeout
+			// before we do anything.
+			if c.DialStartDelaySec > 0 {
+				a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
+				tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
+				defer tmr.Stop()
+				select {
+				case <-ctx.Done():
+					err = ctx.Err()
+					return
+				case <-tmr.C:
+				}
+			}
+
+			// Now, create a sub-context with the given timeout and
+			// try dialing the provided host.
+			ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
+			defer cancel()
+
+			// This will dial, and the defer above sends it back to our parent.
+			a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
+			conn, err = a.dialHost(ctx, c.IP)
+		}(ctx, c)
+	}
+
+	var results []dialResult
+	for res := range resultsCh {
+		// If we get a response that has the highest priority, we don't
+		// need to wait for any of the other connections to finish; we
+		// can just return this connection.
+		//
+		// TODO(andrew): we could make this better by keeping track of
+		// the highest remaining priority dynamically, instead of just
+		// checking for the highest total
+		if res.priority == highestPriority && res.conn != nil {
+			a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr)
+
+			// Drain the channel and any existing connections in
+			// the background.
+			go func() {
+				for _, res := range results {
+					if res.conn != nil {
+						res.conn.Close()
+					}
+				}
+				for res := range resultsCh {
+					if res.conn != nil {
+						res.conn.Close()
+					}
+				}
+				if a.drainFinished != nil {
+					close(a.drainFinished)
+				}
+			}()
+			return res.conn, nil
+		}
+
+		// This isn't a highest-priority result, so just store it until
+		// we're done.
+		results = append(results, res)
+	}
+
+	// After we finish this function, close any remaining open connections.
+	defer func() {
+		for _, result := range results {
+			// Note: below, we nil out the returned connection (if
+			// any) in the slice so we don't close it.
+			if result.conn != nil {
+				result.conn.Close()
+			}
+		}
+
+		// We don't drain asynchronously after this point, so notify our
+		// channel when we return.
+		if a.drainFinished != nil {
+			close(a.drainFinished)
+		}
+	}()
+
+	// Sort by priority, then take the first non-error response.
+	sort.Slice(results, func(i, j int) bool {
+		// NOTE: intentionally inverted so that the highest priority
+		// item comes first
+		return results[i].priority > results[j].priority
+	})
+
+	var (
+		conn *controlbase.Conn
+		errs []error
+	)
+	for i, result := range results {
+		if result.err != nil {
+			errs = append(errs, result.err)
+			continue
+		}
+
+		a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr)
+		conn = result.conn
+		results[i].conn = nil // so we don't close it in the defer
+		return conn, nil
+	}
+	merr := multierr.New(errs...)
+
+	// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
+	a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
+	return a.dialHost(ctx, netip.Addr{})
+}
+
+// dialHost connects to the configured Dialer.Hostname and upgrades the
+// connection into a controlbase.Conn. If addr is valid, then no DNS is used
+// and the connection will be made to the provided address.
+func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Conn, error) {
 	// Create one shared context used by both port 80 and port 443 dials.
 	// If port 80 is still in flight when 443 returns, this deferred cancel
 	// will stop the port 80 dial.
@@ -110,7 +280,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
 	}
 	ch := make(chan tryURLRes) // must be unbuffered
 	try := func(u *url.URL) {
-		cbConn, err := a.dialURL(ctx, u)
+		cbConn, err := a.dialURL(ctx, u, addr)
 		select {
 		case ch <- tryURLRes{u, cbConn, err}:
 		case <-ctx.Done():
@@ -161,12 +331,12 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
 }
 
 // dialURL attempts to connect to the given URL.
-func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
+func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*controlbase.Conn, error) {
 	init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
 	if err != nil {
 		return nil, err
 	}
-	netConn, err := a.tryURLUpgrade(ctx, u, init)
+	netConn, err := a.tryURLUpgrade(ctx, u, addr, init)
 	if err != nil {
 		return nil, err
 	}
@@ -178,14 +348,27 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, er
 	return cbConn, nil
 }
 
-// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
+// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
+// is valid, then no DNS is used and the connection will be made to the
+// provided address.
 //
 // Only the provided ctx is used, not a.ctx.
-func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
-	dns := &dnscache.Resolver{
-		Forward:          dnscache.Get().Forward,
-		LookupIPFallback: dnsfallback.Lookup,
-		UseLastGood:      true,
+func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) {
+	var dns *dnscache.Resolver
+
+	// If we were provided an address to dial, then create a resolver that just
+	// returns that value; otherwise, fall back to DNS.
+	if addr.IsValid() {
+		dns = &dnscache.Resolver{
+			SingleHostStaticResult: []netip.Addr{addr},
+			SingleHost:             u.Hostname(),
+		}
+	} else {
+		dns = &dnscache.Resolver{
+			Forward:          dnscache.Get().Forward,
+			LookupIPFallback: dnsfallback.Lookup,
+			UseLastGood:      true,
+		}
 	}
 
 	var dialer dnscache.DialContextFunc

+ 7 - 0
control/controlhttp/constants.go

@@ -10,6 +10,7 @@ import (
 	"time"
 
 	"tailscale.com/net/dnscache"
+	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 )
@@ -70,9 +71,15 @@ type Dialer struct {
 	// dropped.
 	Logf logger.Logf
 
+	// DialPlan, if set, contains instructions from the control server on
+	// how to connect to it. If present, we will try the methods in this
+	// plan before falling back to DNS.
+	DialPlan *tailcfg.ControlDialPlan
+
 	proxyFunc func(*http.Request) (*url.URL, error) // or nil
 
 	// For tests only
+	drainFinished     chan struct{}
 	insecureTLS       bool
 	testFallbackDelay time.Duration
 }

+ 265 - 0
control/controlhttp/http_test.go

@@ -13,16 +13,21 @@ import (
 	"net"
 	"net/http"
 	"net/http/httputil"
+	"net/netip"
 	"net/url"
+	"runtime"
 	"strconv"
 	"sync"
 	"testing"
 	"time"
 
 	"tailscale.com/control/controlbase"
+	"tailscale.com/net/dnscache"
 	"tailscale.com/net/socks5"
 	"tailscale.com/net/tsdial"
+	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
+	"tailscale.com/types/logger"
 )
 
 type httpTestParam struct {
@@ -444,3 +449,263 @@ func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
 	w.(http.Flusher).Flush()
 	<-r.Context().Done()
 }
+
+func TestDialPlan(t *testing.T) {
+	if runtime.GOOS != "linux" {
+		t.Skip("only works on Linux due to multiple localhost addresses")
+	}
+
+	client, server := key.NewMachine(), key.NewMachine()
+
+	const (
+		testProtocolVersion = 1
+
+		// We need consistent ports for each address; these are chosen
+		// randomly and we hope that they won't conflict during this test.
+		httpPort  = "40080"
+		httpsPort = "40443"
+	)
+
+	makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
+		done := make(chan struct{})
+		t.Cleanup(func() {
+			close(done)
+		})
+		var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			conn, err := AcceptHTTP(context.Background(), w, r, server)
+			if err != nil {
+				log.Print(err)
+			} else {
+				defer conn.Close()
+			}
+			w.Header().Set("X-Handler-Name", name)
+			<-done
+		})
+		if wrap != nil {
+			handler = wrap(handler)
+		}
+
+		httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
+		if err != nil {
+			t.Fatalf("HTTP listen: %v", err)
+		}
+		httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
+		if err != nil {
+			t.Fatalf("HTTPS listen: %v", err)
+		}
+
+		httpServer := &http.Server{Handler: handler}
+		go httpServer.Serve(httpLn)
+		t.Cleanup(func() {
+			httpServer.Close()
+		})
+
+		httpsServer := &http.Server{
+			Handler:   handler,
+			TLSConfig: tlsConfig(t),
+			ErrorLog:  logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
+		}
+		go httpsServer.ServeTLS(httpsLn, "", "")
+		t.Cleanup(func() {
+			httpsServer.Close()
+		})
+		return
+	}
+
+	fallbackAddr := netip.MustParseAddr("127.0.0.1")
+	goodAddr := netip.MustParseAddr("127.0.0.2")
+	otherAddr := netip.MustParseAddr("127.0.0.3")
+	other2Addr := netip.MustParseAddr("127.0.0.4")
+	brokenAddr := netip.MustParseAddr("127.0.0.10")
+
+	testCases := []struct {
+		name string
+		plan *tailcfg.ControlDialPlan
+		wrap func(http.Handler) http.Handler
+		want netip.Addr
+
+		allowFallback bool
+	}{
+		{
+			name: "single",
+			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
+				{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
+			}},
+			want: goodAddr,
+		},
+		{
+			name: "broken-then-good",
+			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
+				// Dials the broken one, which fails, and then
+				// eventually dials the good one and succeeds
+				{IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
+				{IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
+			}},
+			want: goodAddr,
+		},
+		{
+			name: "multiple-priority-fast-path",
+			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
+				// Dials some good IPs and our bad one (which
+				// hangs forever), which then hits the fast
+				// path where we bail without waiting.
+				{IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
+				{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
+				{IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
+				{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
+			}},
+			want: otherAddr,
+		},
+		{
+			name: "multiple-priority-slow-path",
+			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
+				// Our broken address is the highest priority,
+				// so we don't hit our fast path.
+				{IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
+				{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
+				{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
+			}},
+			want: otherAddr,
+		},
+		{
+			name: "fallback",
+			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
+				{IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
+			}},
+			want:          fallbackAddr,
+			allowFallback: true,
+		},
+	}
+	for _, tt := range testCases {
+		t.Run(tt.name, func(t *testing.T) {
+			makeHandler(t, "fallback", fallbackAddr, nil)
+			makeHandler(t, "good", goodAddr, nil)
+			makeHandler(t, "other", otherAddr, nil)
+			makeHandler(t, "other2", other2Addr, nil)
+			makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
+				return http.HandlerFunc(brokenMITMHandler)
+			})
+
+			dialer := closeTrackDialer{
+				t:     t,
+				inner: new(tsdial.Dialer).SystemDial,
+				conns: make(map[*closeTrackConn]bool),
+			}
+			defer dialer.Done()
+
+			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+			defer cancel()
+
+			// By default, we intentionally point to something that
+			// we know won't connect, since we want a fallback to
+			// DNS to be an error.
+			host := "example.com"
+			if tt.allowFallback {
+				host = "localhost"
+			}
+
+			drained := make(chan struct{})
+			a := &Dialer{
+				Hostname:          host,
+				HTTPPort:          httpPort,
+				HTTPSPort:         httpsPort,
+				MachineKey:        client,
+				ControlKey:        server.Public(),
+				ProtocolVersion:   testProtocolVersion,
+				Dialer:            dialer.Dial,
+				Logf:              t.Logf,
+				DialPlan:          tt.plan,
+				proxyFunc:         func(*http.Request) (*url.URL, error) { return nil, nil },
+				drainFinished:     drained,
+				insecureTLS:       true,
+				testFallbackDelay: 50 * time.Millisecond,
+			}
+
+			conn, err := a.dial(ctx)
+			if err != nil {
+				t.Fatalf("dialing controlhttp: %v", err)
+			}
+			defer conn.Close()
+
+			raddr := conn.RemoteAddr().(*net.TCPAddr)
+
+			got, ok := netip.AddrFromSlice(raddr.IP)
+			if !ok {
+				t.Errorf("invalid remote IP: %v", raddr.IP)
+			} else if got != tt.want {
+				t.Errorf("got connection from %q; want %q", got, tt.want)
+			} else {
+				t.Logf("successfully connected to %q", raddr.String())
+			}
+
+			// Wait until our dialer drains so we can verify that
+			// all connections are closed.
+			<-drained
+		})
+	}
+}
+
+type closeTrackDialer struct {
+	t     testing.TB
+	inner dnscache.DialContextFunc
+	mu    sync.Mutex
+	conns map[*closeTrackConn]bool
+}
+
+func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
+	c, err := d.inner(ctx, network, addr)
+	if err != nil {
+		return nil, err
+	}
+	ct := &closeTrackConn{Conn: c, d: d}
+
+	d.mu.Lock()
+	d.conns[ct] = true
+	d.mu.Unlock()
+	return ct, nil
+}
+
+func (d *closeTrackDialer) Done() {
+	// Unfortunately, tsdial.Dialer.SystemDial closes connections
+	// asynchronously in a goroutine, so we can't assume that everything is
+	// closed by the time we get here.
+	//
+	// Sleep/wait a few times on the assumption that things will close
+	// "eventually".
+	const iters = 100
+	for i := 0; i < iters; i++ {
+		d.mu.Lock()
+		if len(d.conns) == 0 {
+			d.mu.Unlock()
+			return
+		}
+
+		// Only error on last iteration
+		if i != iters-1 {
+			d.mu.Unlock()
+			time.Sleep(100 * time.Millisecond)
+			continue
+		}
+
+		for conn := range d.conns {
+			d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
+		}
+		d.mu.Unlock()
+	}
+}
+
+func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
+	d.mu.Lock()
+	delete(d.conns, c) // safe if already deleted
+	d.mu.Unlock()
+}
+
+type closeTrackConn struct {
+	net.Conn
+	d *closeTrackDialer
+}
+
+func (c *closeTrackConn) Close() error {
+	c.d.noteClose(c)
+	return c.Conn.Close()
+}

+ 8 - 0
ipn/ipnlocal/local.go

@@ -189,6 +189,10 @@ type LocalBackend struct {
 	// statusChanged.Broadcast().
 	statusLock    sync.Mutex
 	statusChanged *sync.Cond
+
+	// dialPlan is any dial plan that we've received from the control
+	// server during a previous connection; it is cleared on logout.
+	dialPlan atomic.Pointer[tailcfg.ControlDialPlan]
 }
 
 // clientGen is a func that creates a control plane client.
@@ -1087,6 +1091,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
 		Dialer:               b.Dialer(),
 		Status:               b.setClientStatus,
 		C2NHandler:           http.HandlerFunc(b.handleC2N),
+		DialPlan:             &b.dialPlan, // pointer because it can't be copied
 
 		// Don't warn about broken Linux IP forwarding when
 		// netstack is being used.
@@ -3112,6 +3117,9 @@ func (b *LocalBackend) logout(ctx context.Context, sync bool) error {
 		Prefs:          ipn.Prefs{WantRunning: false, LoggedOut: true},
 	})
 
+	// Clear any previous dial plan(s), if set.
+	b.dialPlan.Store(nil)
+
 	if cc == nil {
 		// Double Logout can happen via repeated IPN
 		// connections to ipnserver making it repeatedly

+ 36 - 1
tailcfg/tailcfg.go

@@ -80,7 +80,8 @@ type CapabilityVersion int
 //   - 41: 2022-08-30: uses 100.100.100.100 for route-less ExtraRecords if global nameservers is set
 //   - 42: 2022-09-06: NextDNS DoH support; see https://github.com/tailscale/tailscale/pull/5556
 //   - 43: 2022-09-21: clients can return usernames for SSH
-const CurrentCapabilityVersion CapabilityVersion = 43
+//   - 44: 2022-09-22: MapResponse.ControlDialPlan
+const CurrentCapabilityVersion CapabilityVersion = 44
 
 type StableID string
 
@@ -1383,6 +1384,40 @@ type MapResponse struct {
 	// Debug is normally nil, except for when the control server
 	// is setting debug settings on a node.
 	Debug *Debug `json:",omitempty"`
+
+	// ControlDialPlan tells the client how to connect to the control
+	// server. An initial nil is equivalent to new(ControlDialPlan).
+	// A subsequent streamed nil means no change.
+	ControlDialPlan *ControlDialPlan `json:",omitempty"`
+}
+
+// ControlDialPlan is instructions from the control server to the client on how
+// to connect to the control server; this is useful for maintaining connection
+// if the client's network state changes after the initial connection, or due
+// to the configuration that the control server pushes.
+type ControlDialPlan struct {
+	// An empty list means the default: use DNS (unspecified which DNS).
+	Candidates []ControlIPCandidate
+}
+
+// ControlIPCandidate represents a single candidate address to use when
+// connecting to the control server.
+type ControlIPCandidate struct {
+	// IP is the address to attempt connecting to.
+	IP netip.Addr
+
+	// DialStartSec is the number of seconds after the beginning of the
+	// connection process to wait before trying this candidate.
+	DialStartDelaySec float64 `json:",omitempty"`
+
+	// DialTimeoutSec is the timeout for a connection to this candidate,
+	// starting after DialStartDelaySec.
+	DialTimeoutSec float64 `json:",omitempty"`
+
+	// Priority is the relative priority of this candidate; candidates with
+	// a higher priority are preferred over candidates with a lower
+	// priority.
+	Priority int `json:",omitempty"`
 }
 
 // Debug are instructions from the control server to the client