|
|
@@ -28,18 +28,25 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "math"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/http/httptrace"
|
|
|
+ "net/netip"
|
|
|
"net/url"
|
|
|
+ "sort"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"tailscale.com/control/controlbase"
|
|
|
+ "tailscale.com/envknob"
|
|
|
"tailscale.com/net/dnscache"
|
|
|
"tailscale.com/net/dnsfallback"
|
|
|
"tailscale.com/net/netutil"
|
|
|
"tailscale.com/net/tlsdial"
|
|
|
"tailscale.com/net/tshttpproxy"
|
|
|
+ "tailscale.com/tailcfg"
|
|
|
+ "tailscale.com/util/multierr"
|
|
|
)
|
|
|
|
|
|
var stdDialer net.Dialer
|
|
|
@@ -82,7 +89,170 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
|
|
|
return 500 * time.Millisecond
|
|
|
}
|
|
|
|
|
|
+var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
|
|
|
+
|
|
|
func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
|
|
+ // If we don't have a dial plan, just fall back to dialing the single
|
|
|
+ // host we know about.
|
|
|
+ useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
|
|
|
+ if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 {
|
|
|
+ return a.dialHost(ctx, netip.Addr{})
|
|
|
+ }
|
|
|
+ candidates := a.DialPlan.Candidates
|
|
|
+
|
|
|
+ // Otherwise, we try dialing per the plan. Store the highest priority
|
|
|
+ // in the list, so that if we get a connection to one of those
|
|
|
+ // candidates we can return quickly.
|
|
|
+ var highestPriority int = math.MinInt
|
|
|
+ for _, c := range candidates {
|
|
|
+ if c.Priority > highestPriority {
|
|
|
+ highestPriority = c.Priority
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // This context allows us to cancel in-flight connections if we get a
|
|
|
+ // highest-priority connection before we're all done.
|
|
|
+ ctx, cancel := context.WithCancel(ctx)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ // Now, for each candidate, kick off a dial in parallel.
|
|
|
+ type dialResult struct {
|
|
|
+ conn *controlbase.Conn
|
|
|
+ err error
|
|
|
+ addr netip.Addr
|
|
|
+ priority int
|
|
|
+ }
|
|
|
+ resultsCh := make(chan dialResult, len(candidates))
|
|
|
+
|
|
|
+ var pending atomic.Int32
|
|
|
+ pending.Store(int32(len(candidates)))
|
|
|
+ for _, c := range candidates {
|
|
|
+ go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
|
|
|
+ var (
|
|
|
+ conn *controlbase.Conn
|
|
|
+ err error
|
|
|
+ )
|
|
|
+
|
|
|
+ // Always send results back to our channel.
|
|
|
+ defer func() {
|
|
|
+ resultsCh <- dialResult{conn, err, c.IP, c.Priority}
|
|
|
+ if pending.Add(-1) == 0 {
|
|
|
+ close(resultsCh)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // If non-zero, wait the configured start timeout
|
|
|
+ // before we do anything.
|
|
|
+ if c.DialStartDelaySec > 0 {
|
|
|
+ a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
|
|
|
+ tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
|
|
|
+ defer tmr.Stop()
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ err = ctx.Err()
|
|
|
+ return
|
|
|
+ case <-tmr.C:
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now, create a sub-context with the given timeout and
|
|
|
+ // try dialing the provided host.
|
|
|
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ // This will dial, and the defer above sends it back to our parent.
|
|
|
+ a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
|
|
|
+ conn, err = a.dialHost(ctx, c.IP)
|
|
|
+ }(ctx, c)
|
|
|
+ }
|
|
|
+
|
|
|
+ var results []dialResult
|
|
|
+ for res := range resultsCh {
|
|
|
+ // If we get a response that has the highest priority, we don't
|
|
|
+ // need to wait for any of the other connections to finish; we
|
|
|
+ // can just return this connection.
|
|
|
+ //
|
|
|
+ // TODO(andrew): we could make this better by keeping track of
|
|
|
+ // the highest remaining priority dynamically, instead of just
|
|
|
+ // checking for the highest total
|
|
|
+ if res.priority == highestPriority && res.conn != nil {
|
|
|
+ a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr)
|
|
|
+
|
|
|
+ // Drain the channel and any existing connections in
|
|
|
+ // the background.
|
|
|
+ go func() {
|
|
|
+ for _, res := range results {
|
|
|
+ if res.conn != nil {
|
|
|
+ res.conn.Close()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for res := range resultsCh {
|
|
|
+ if res.conn != nil {
|
|
|
+ res.conn.Close()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if a.drainFinished != nil {
|
|
|
+ close(a.drainFinished)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ return res.conn, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // This isn't a highest-priority result, so just store it until
|
|
|
+ // we're done.
|
|
|
+ results = append(results, res)
|
|
|
+ }
|
|
|
+
|
|
|
+ // After we finish this function, close any remaining open connections.
|
|
|
+ defer func() {
|
|
|
+ for _, result := range results {
|
|
|
+ // Note: below, we nil out the returned connection (if
|
|
|
+ // any) in the slice so we don't close it.
|
|
|
+ if result.conn != nil {
|
|
|
+ result.conn.Close()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // We don't drain asynchronously after this point, so notify our
|
|
|
+ // channel when we return.
|
|
|
+ if a.drainFinished != nil {
|
|
|
+ close(a.drainFinished)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Sort by priority, then take the first non-error response.
|
|
|
+ sort.Slice(results, func(i, j int) bool {
|
|
|
+ // NOTE: intentionally inverted so that the highest priority
|
|
|
+ // item comes first
|
|
|
+ return results[i].priority > results[j].priority
|
|
|
+ })
|
|
|
+
|
|
|
+ var (
|
|
|
+ conn *controlbase.Conn
|
|
|
+ errs []error
|
|
|
+ )
|
|
|
+ for i, result := range results {
|
|
|
+ if result.err != nil {
|
|
|
+ errs = append(errs, result.err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr)
|
|
|
+ conn = result.conn
|
|
|
+ results[i].conn = nil // so we don't close it in the defer
|
|
|
+ return conn, nil
|
|
|
+ }
|
|
|
+ merr := multierr.New(errs...)
|
|
|
+
|
|
|
+ // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
|
|
|
+ a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
|
|
|
+ return a.dialHost(ctx, netip.Addr{})
|
|
|
+}
|
|
|
+
|
|
|
+// dialHost connects to the configured Dialer.Hostname and upgrades the
|
|
|
+// connection into a controlbase.Conn. If addr is valid, then no DNS is used
|
|
|
+// and the connection will be made to the provided address.
|
|
|
+func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Conn, error) {
|
|
|
// Create one shared context used by both port 80 and port 443 dials.
|
|
|
// If port 80 is still in flight when 443 returns, this deferred cancel
|
|
|
// will stop the port 80 dial.
|
|
|
@@ -110,7 +280,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
|
|
}
|
|
|
ch := make(chan tryURLRes) // must be unbuffered
|
|
|
try := func(u *url.URL) {
|
|
|
- cbConn, err := a.dialURL(ctx, u)
|
|
|
+ cbConn, err := a.dialURL(ctx, u, addr)
|
|
|
select {
|
|
|
case ch <- tryURLRes{u, cbConn, err}:
|
|
|
case <-ctx.Done():
|
|
|
@@ -161,12 +331,12 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
|
|
}
|
|
|
|
|
|
// dialURL attempts to connect to the given URL.
|
|
|
-func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
|
|
|
+func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*controlbase.Conn, error) {
|
|
|
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- netConn, err := a.tryURLUpgrade(ctx, u, init)
|
|
|
+ netConn, err := a.tryURLUpgrade(ctx, u, addr, init)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -178,14 +348,27 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, er
|
|
|
return cbConn, nil
|
|
|
}
|
|
|
|
|
|
-// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
|
|
|
+// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
|
|
|
+// is valid, then no DNS is used and the connection will be made to the
|
|
|
+// provided address.
|
|
|
//
|
|
|
// Only the provided ctx is used, not a.ctx.
|
|
|
-func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
|
|
|
- dns := &dnscache.Resolver{
|
|
|
- Forward: dnscache.Get().Forward,
|
|
|
- LookupIPFallback: dnsfallback.Lookup,
|
|
|
- UseLastGood: true,
|
|
|
+func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) {
|
|
|
+ var dns *dnscache.Resolver
|
|
|
+
|
|
|
+ // If we were provided an address to dial, then create a resolver that just
|
|
|
+ // returns that value; otherwise, fall back to DNS.
|
|
|
+ if addr.IsValid() {
|
|
|
+ dns = &dnscache.Resolver{
|
|
|
+ SingleHostStaticResult: []netip.Addr{addr},
|
|
|
+ SingleHost: u.Hostname(),
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ dns = &dnscache.Resolver{
|
|
|
+ Forward: dnscache.Get().Forward,
|
|
|
+ LookupIPFallback: dnsfallback.Lookup,
|
|
|
+ UseLastGood: true,
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
var dialer dnscache.DialContextFunc
|