client.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build !js
  5. // Package controlhttp implements the Tailscale 2021 control protocol
  6. // base transport over HTTP.
  7. //
  8. // This tunnels the protocol in control/controlbase over HTTP with a
  9. // variety of compatibility fallbacks for handling picky or deep
  10. // inspecting proxies.
  11. //
  12. // In the happy path, a client makes a single cleartext HTTP request
  13. // to the server, the server responds with 101 Switching Protocols,
  14. // and the control base protocol takes place over plain TCP.
  15. //
  16. // In the compatibility path, the client does the above over HTTPS,
  17. // resulting in double encryption (once for the control transport, and
  18. // once for the outer TLS layer).
  19. package controlhttp
  20. import (
  21. "context"
  22. "crypto/tls"
  23. "encoding/base64"
  24. "errors"
  25. "fmt"
  26. "io"
  27. "math"
  28. "net"
  29. "net/http"
  30. "net/http/httptrace"
  31. "net/netip"
  32. "net/url"
  33. "sort"
  34. "sync/atomic"
  35. "time"
  36. "tailscale.com/control/controlbase"
  37. "tailscale.com/envknob"
  38. "tailscale.com/net/dnscache"
  39. "tailscale.com/net/dnsfallback"
  40. "tailscale.com/net/netutil"
  41. "tailscale.com/net/tlsdial"
  42. "tailscale.com/net/tshttpproxy"
  43. "tailscale.com/tailcfg"
  44. "tailscale.com/util/multierr"
  45. )
  46. var stdDialer net.Dialer
  47. // Dial connects to the HTTP server at this Dialer's Host:HTTPPort, requests to
  48. // switch to the Tailscale control protocol, and returns an established control
  49. // protocol connection.
  50. //
  51. // If Dial fails to connect using HTTP, it also tries to tunnel over TLS to the
  52. // Dialer's Host:HTTPSPort as a compatibility fallback.
  53. //
  54. // The provided ctx is only used for the initial connection, until
  55. // Dial returns. It does not affect the connection once established.
  56. func (a *Dialer) Dial(ctx context.Context) (*ClientConn, error) {
  57. if a.Hostname == "" {
  58. return nil, errors.New("required Dialer.Hostname empty")
  59. }
  60. return a.dial(ctx)
  61. }
  62. func (a *Dialer) logf(format string, args ...any) {
  63. if a.Logf != nil {
  64. a.Logf(format, args...)
  65. }
  66. }
  67. func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) {
  68. if a.proxyFunc != nil {
  69. return a.proxyFunc
  70. }
  71. return tshttpproxy.ProxyFromEnvironment
  72. }
  73. // httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before
  74. // starting to try a.HTTPSPort.
  75. func (a *Dialer) httpsFallbackDelay() time.Duration {
  76. if v := a.testFallbackDelay; v != 0 {
  77. return v
  78. }
  79. return 500 * time.Millisecond
  80. }
  81. var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
  82. func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
  83. // If we don't have a dial plan, just fall back to dialing the single
  84. // host we know about.
  85. useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
  86. if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 {
  87. return a.dialHost(ctx, netip.Addr{})
  88. }
  89. candidates := a.DialPlan.Candidates
  90. // Otherwise, we try dialing per the plan. Store the highest priority
  91. // in the list, so that if we get a connection to one of those
  92. // candidates we can return quickly.
  93. var highestPriority int = math.MinInt
  94. for _, c := range candidates {
  95. if c.Priority > highestPriority {
  96. highestPriority = c.Priority
  97. }
  98. }
  99. // This context allows us to cancel in-flight connections if we get a
  100. // highest-priority connection before we're all done.
  101. ctx, cancel := context.WithCancel(ctx)
  102. defer cancel()
  103. // Now, for each candidate, kick off a dial in parallel.
  104. type dialResult struct {
  105. conn *ClientConn
  106. err error
  107. addr netip.Addr
  108. priority int
  109. }
  110. resultsCh := make(chan dialResult, len(candidates))
  111. var pending atomic.Int32
  112. pending.Store(int32(len(candidates)))
  113. for _, c := range candidates {
  114. go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
  115. var (
  116. conn *ClientConn
  117. err error
  118. )
  119. // Always send results back to our channel.
  120. defer func() {
  121. resultsCh <- dialResult{conn, err, c.IP, c.Priority}
  122. if pending.Add(-1) == 0 {
  123. close(resultsCh)
  124. }
  125. }()
  126. // If non-zero, wait the configured start timeout
  127. // before we do anything.
  128. if c.DialStartDelaySec > 0 {
  129. a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
  130. tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
  131. defer tmr.Stop()
  132. select {
  133. case <-ctx.Done():
  134. err = ctx.Err()
  135. return
  136. case <-tmr.C:
  137. }
  138. }
  139. // Now, create a sub-context with the given timeout and
  140. // try dialing the provided host.
  141. ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
  142. defer cancel()
  143. // This will dial, and the defer above sends it back to our parent.
  144. a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
  145. conn, err = a.dialHost(ctx, c.IP)
  146. }(ctx, c)
  147. }
  148. var results []dialResult
  149. for res := range resultsCh {
  150. // If we get a response that has the highest priority, we don't
  151. // need to wait for any of the other connections to finish; we
  152. // can just return this connection.
  153. //
  154. // TODO(andrew): we could make this better by keeping track of
  155. // the highest remaining priority dynamically, instead of just
  156. // checking for the highest total
  157. if res.priority == highestPriority && res.conn != nil {
  158. a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr)
  159. // Drain the channel and any existing connections in
  160. // the background.
  161. go func() {
  162. for _, res := range results {
  163. if res.conn != nil {
  164. res.conn.Close()
  165. }
  166. }
  167. for res := range resultsCh {
  168. if res.conn != nil {
  169. res.conn.Close()
  170. }
  171. }
  172. if a.drainFinished != nil {
  173. close(a.drainFinished)
  174. }
  175. }()
  176. return res.conn, nil
  177. }
  178. // This isn't a highest-priority result, so just store it until
  179. // we're done.
  180. results = append(results, res)
  181. }
  182. // After we finish this function, close any remaining open connections.
  183. defer func() {
  184. for _, result := range results {
  185. // Note: below, we nil out the returned connection (if
  186. // any) in the slice so we don't close it.
  187. if result.conn != nil {
  188. result.conn.Close()
  189. }
  190. }
  191. // We don't drain asynchronously after this point, so notify our
  192. // channel when we return.
  193. if a.drainFinished != nil {
  194. close(a.drainFinished)
  195. }
  196. }()
  197. // Sort by priority, then take the first non-error response.
  198. sort.Slice(results, func(i, j int) bool {
  199. // NOTE: intentionally inverted so that the highest priority
  200. // item comes first
  201. return results[i].priority > results[j].priority
  202. })
  203. var (
  204. conn *ClientConn
  205. errs []error
  206. )
  207. for i, result := range results {
  208. if result.err != nil {
  209. errs = append(errs, result.err)
  210. continue
  211. }
  212. a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr)
  213. conn = result.conn
  214. results[i].conn = nil // so we don't close it in the defer
  215. return conn, nil
  216. }
  217. merr := multierr.New(errs...)
  218. // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
  219. a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
  220. return a.dialHost(ctx, netip.Addr{})
  221. }
  222. // dialHost connects to the configured Dialer.Hostname and upgrades the
  223. // connection into a controlbase.Conn. If addr is valid, then no DNS is used
  224. // and the connection will be made to the provided address.
  225. func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, error) {
  226. // Create one shared context used by both port 80 and port 443 dials.
  227. // If port 80 is still in flight when 443 returns, this deferred cancel
  228. // will stop the port 80 dial.
  229. ctx, cancel := context.WithCancel(ctx)
  230. defer cancel()
  231. // u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS,
  232. // respectively, in order to do the HTTP upgrade to a net.Conn over which
  233. // we'll speak Noise.
  234. u80 := &url.URL{
  235. Scheme: "http",
  236. Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPPort, "80")),
  237. Path: serverUpgradePath,
  238. }
  239. u443 := &url.URL{
  240. Scheme: "https",
  241. Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")),
  242. Path: serverUpgradePath,
  243. }
  244. type tryURLRes struct {
  245. u *url.URL // input (the URL conn+err are for/from)
  246. conn *ClientConn // result (mutually exclusive with err)
  247. err error
  248. }
  249. ch := make(chan tryURLRes) // must be unbuffered
  250. try := func(u *url.URL) {
  251. cbConn, err := a.dialURL(ctx, u, addr)
  252. select {
  253. case ch <- tryURLRes{u, cbConn, err}:
  254. case <-ctx.Done():
  255. if cbConn != nil {
  256. cbConn.Close()
  257. }
  258. }
  259. }
  260. // Start the plaintext HTTP attempt first.
  261. go try(u80)
  262. // In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
  263. // to dial port 443 if port 80 doesn't either succeed or fail quickly.
  264. try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) })
  265. defer try443Timer.Stop()
  266. var err80, err443 error
  267. for {
  268. select {
  269. case <-ctx.Done():
  270. return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err())
  271. case res := <-ch:
  272. if res.err == nil {
  273. return res.conn, nil
  274. }
  275. switch res.u {
  276. case u80:
  277. // Connecting over plain HTTP failed; assume it's an HTTP proxy
  278. // being difficult and see if we can get through over HTTPS.
  279. err80 = res.err
  280. // Stop the fallback timer and run it immediately. We don't use
  281. // Timer.Reset(0) here because on AfterFuncs, that can run it
  282. // again.
  283. if try443Timer.Stop() {
  284. go try(u443)
  285. } // else we lost the race and it started already which is what we want
  286. case u443:
  287. err443 = res.err
  288. default:
  289. panic("invalid")
  290. }
  291. if err80 != nil && err443 != nil {
  292. return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", err80, err443)
  293. }
  294. }
  295. }
  296. }
  297. // dialURL attempts to connect to the given URL.
  298. func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*ClientConn, error) {
  299. init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
  300. if err != nil {
  301. return nil, err
  302. }
  303. netConn, err := a.tryURLUpgrade(ctx, u, addr, init)
  304. if err != nil {
  305. return nil, err
  306. }
  307. cbConn, err := cont(ctx, netConn)
  308. if err != nil {
  309. netConn.Close()
  310. return nil, err
  311. }
  312. return &ClientConn{
  313. Conn: cbConn,
  314. }, nil
  315. }
  316. // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
  317. // is valid, then no DNS is used and the connection will be made to the
  318. // provided address.
  319. //
  320. // Only the provided ctx is used, not a.ctx.
  321. func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) {
  322. var dns *dnscache.Resolver
  323. // If we were provided an address to dial, then create a resolver that just
  324. // returns that value; otherwise, fall back to DNS.
  325. if addr.IsValid() {
  326. dns = &dnscache.Resolver{
  327. SingleHostStaticResult: []netip.Addr{addr},
  328. SingleHost: u.Hostname(),
  329. }
  330. } else {
  331. dns = &dnscache.Resolver{
  332. Forward: dnscache.Get().Forward,
  333. LookupIPFallback: dnsfallback.Lookup,
  334. UseLastGood: true,
  335. }
  336. }
  337. var dialer dnscache.DialContextFunc
  338. if a.Dialer != nil {
  339. dialer = a.Dialer
  340. } else {
  341. dialer = stdDialer.DialContext
  342. }
  343. tr := http.DefaultTransport.(*http.Transport).Clone()
  344. defer tr.CloseIdleConnections()
  345. tr.Proxy = a.getProxyFunc()
  346. tshttpproxy.SetTransportGetProxyConnectHeader(tr)
  347. tr.DialContext = dnscache.Dialer(dialer, dns)
  348. // Disable HTTP2, since h2 can't do protocol switching.
  349. tr.TLSClientConfig.NextProtos = []string{}
  350. tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
  351. tr.TLSClientConfig = tlsdial.Config(a.Hostname, tr.TLSClientConfig)
  352. if a.insecureTLS {
  353. tr.TLSClientConfig.InsecureSkipVerify = true
  354. tr.TLSClientConfig.VerifyConnection = nil
  355. }
  356. tr.DialTLSContext = dnscache.TLSDialer(dialer, dns, tr.TLSClientConfig)
  357. tr.DisableCompression = true
  358. // (mis)use httptrace to extract the underlying net.Conn from the
  359. // transport. We make exactly 1 request using this transport, so
  360. // there will be exactly 1 GotConn call. Additionally, the
  361. // transport handles 101 Switching Protocols correctly, such that
  362. // the Conn will not be reused or kept alive by the transport once
  363. // the response has been handed back from RoundTrip.
  364. //
  365. // In theory, the machinery of net/http should make it such that
  366. // the trace callback happens-before we get the response, but
  367. // there's no promise of that. So, to make sure, we use a buffered
  368. // channel as a synchronization step to avoid data races.
  369. //
  370. // Note that even though we're able to extract a net.Conn via this
  371. // mechanism, we must still keep using the eventual resp.Body to
  372. // read from, because it includes a buffer we can't get rid of. If
  373. // the server never sends any data after sending the HTTP
  374. // response, we could get away with it, but violating this
  375. // assumption leads to very mysterious transport errors (lockups,
  376. // unexpected EOFs...), and we're bound to forget someday and
  377. // introduce a protocol optimization at a higher level that starts
  378. // eagerly transmitting from the server.
  379. connCh := make(chan net.Conn, 1)
  380. trace := httptrace.ClientTrace{
  381. GotConn: func(info httptrace.GotConnInfo) {
  382. connCh <- info.Conn
  383. },
  384. }
  385. ctx = httptrace.WithClientTrace(ctx, &trace)
  386. req := &http.Request{
  387. Method: "POST",
  388. URL: u,
  389. Header: http.Header{
  390. "Upgrade": []string{upgradeHeaderValue},
  391. "Connection": []string{"upgrade"},
  392. handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)},
  393. },
  394. }
  395. req = req.WithContext(ctx)
  396. resp, err := tr.RoundTrip(req)
  397. if err != nil {
  398. return nil, err
  399. }
  400. if resp.StatusCode != http.StatusSwitchingProtocols {
  401. return nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status)
  402. }
  403. // From here on, the underlying net.Conn is ours to use, but there
  404. // is still a read buffer attached to it within resp.Body. So, we
  405. // must direct I/O through resp.Body, but we can still use the
  406. // underlying net.Conn for stuff like deadlines.
  407. var switchedConn net.Conn
  408. select {
  409. case switchedConn = <-connCh:
  410. default:
  411. }
  412. if switchedConn == nil {
  413. resp.Body.Close()
  414. return nil, fmt.Errorf("httptrace didn't provide a connection")
  415. }
  416. if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue {
  417. resp.Body.Close()
  418. return nil, fmt.Errorf("server switched to unexpected protocol %q", next)
  419. }
  420. rwc, ok := resp.Body.(io.ReadWriteCloser)
  421. if !ok {
  422. resp.Body.Close()
  423. return nil, errors.New("http Transport did not provide a writable body")
  424. }
  425. return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), nil
  426. }