http_test.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlhttp
  4. import (
  5. "context"
  6. "crypto/tls"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httputil"
  13. "net/netip"
  14. "net/url"
  15. "runtime"
  16. "strconv"
  17. "sync"
  18. "testing"
  19. "time"
  20. "tailscale.com/control/controlbase"
  21. "tailscale.com/net/dnscache"
  22. "tailscale.com/net/netmon"
  23. "tailscale.com/net/socks5"
  24. "tailscale.com/net/tsdial"
  25. "tailscale.com/tailcfg"
  26. "tailscale.com/tstest"
  27. "tailscale.com/tstime"
  28. "tailscale.com/types/key"
  29. "tailscale.com/types/logger"
  30. )
  31. type httpTestParam struct {
  32. name string
  33. proxy proxy
  34. // makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
  35. // 101 switching protocols.
  36. makeHTTPHangAfterUpgrade bool
  37. doEarlyWrite bool
  38. }
  39. func TestControlHTTP(t *testing.T) {
  40. tests := []httpTestParam{
  41. // direct connection
  42. {
  43. name: "no_proxy",
  44. proxy: nil,
  45. },
  46. // direct connection but port 80 is MITM'ed and broken
  47. {
  48. name: "port80_broken_mitm",
  49. proxy: nil,
  50. makeHTTPHangAfterUpgrade: true,
  51. },
  52. // SOCKS5
  53. {
  54. name: "socks5",
  55. proxy: &socksProxy{},
  56. },
  57. // HTTP->HTTP
  58. {
  59. name: "http_to_http",
  60. proxy: &httpProxy{
  61. useTLS: false,
  62. allowConnect: false,
  63. allowHTTP: true,
  64. },
  65. },
  66. // HTTP->HTTPS
  67. {
  68. name: "http_to_https",
  69. proxy: &httpProxy{
  70. useTLS: false,
  71. allowConnect: true,
  72. allowHTTP: false,
  73. },
  74. },
  75. // HTTP->any (will pick HTTP)
  76. {
  77. name: "http_to_any",
  78. proxy: &httpProxy{
  79. useTLS: false,
  80. allowConnect: true,
  81. allowHTTP: true,
  82. },
  83. },
  84. // HTTPS->HTTP
  85. {
  86. name: "https_to_http",
  87. proxy: &httpProxy{
  88. useTLS: true,
  89. allowConnect: false,
  90. allowHTTP: true,
  91. },
  92. },
  93. // HTTPS->HTTPS
  94. {
  95. name: "https_to_https",
  96. proxy: &httpProxy{
  97. useTLS: true,
  98. allowConnect: true,
  99. allowHTTP: false,
  100. },
  101. },
  102. // HTTPS->any (will pick HTTP)
  103. {
  104. name: "https_to_any",
  105. proxy: &httpProxy{
  106. useTLS: true,
  107. allowConnect: true,
  108. allowHTTP: true,
  109. },
  110. },
  111. // Early write
  112. {
  113. name: "early_write",
  114. doEarlyWrite: true,
  115. },
  116. }
  117. for _, test := range tests {
  118. t.Run(test.name, func(t *testing.T) {
  119. testControlHTTP(t, test)
  120. })
  121. }
  122. }
  123. func testControlHTTP(t *testing.T, param httpTestParam) {
  124. proxy := param.proxy
  125. client, server := key.NewMachine(), key.NewMachine()
  126. const testProtocolVersion = 1
  127. const earlyWriteMsg = "Hello, world!"
  128. sch := make(chan serverResult, 1)
  129. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  130. var earlyWriteFn func(protocolVersion int, w io.Writer) error
  131. if param.doEarlyWrite {
  132. earlyWriteFn = func(protocolVersion int, w io.Writer) error {
  133. if protocolVersion != testProtocolVersion {
  134. t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
  135. return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
  136. }
  137. _, err := io.WriteString(w, earlyWriteMsg)
  138. return err
  139. }
  140. }
  141. conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn)
  142. if err != nil {
  143. log.Print(err)
  144. }
  145. res := serverResult{
  146. err: err,
  147. }
  148. if conn != nil {
  149. res.clientAddr = conn.RemoteAddr().String()
  150. res.version = conn.ProtocolVersion()
  151. res.peer = conn.Peer()
  152. res.conn = conn
  153. }
  154. sch <- res
  155. })
  156. httpLn, err := net.Listen("tcp", "127.0.0.1:0")
  157. if err != nil {
  158. t.Fatalf("HTTP listen: %v", err)
  159. }
  160. httpsLn, err := net.Listen("tcp", "127.0.0.1:0")
  161. if err != nil {
  162. t.Fatalf("HTTPS listen: %v", err)
  163. }
  164. var httpHandler http.Handler = handler
  165. const fallbackDelay = 50 * time.Millisecond
  166. clock := tstest.NewClock(tstest.ClockOpts{Step: 2 * fallbackDelay})
  167. // Advance once to init the clock.
  168. clock.Now()
  169. if param.makeHTTPHangAfterUpgrade {
  170. httpHandler = brokenMITMHandler(clock)
  171. }
  172. httpServer := &http.Server{Handler: httpHandler}
  173. go httpServer.Serve(httpLn)
  174. defer httpServer.Close()
  175. httpsServer := &http.Server{
  176. Handler: handler,
  177. TLSConfig: tlsConfig(t),
  178. }
  179. go httpsServer.ServeTLS(httpsLn, "", "")
  180. defer httpsServer.Close()
  181. ctx := context.Background()
  182. const debugTimeout = false
  183. if debugTimeout {
  184. var cancel context.CancelFunc
  185. ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
  186. defer cancel()
  187. }
  188. netMon := netmon.NewStatic()
  189. dialer := tsdial.NewDialer(netMon)
  190. a := &Dialer{
  191. Hostname: "localhost",
  192. HTTPPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
  193. HTTPSPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
  194. MachineKey: client,
  195. ControlKey: server.Public(),
  196. NetMon: netMon,
  197. ProtocolVersion: testProtocolVersion,
  198. Dialer: dialer.SystemDial,
  199. Logf: t.Logf,
  200. omitCertErrorLogging: true,
  201. testFallbackDelay: fallbackDelay,
  202. Clock: clock,
  203. }
  204. if proxy != nil {
  205. proxyEnv := proxy.Start(t)
  206. defer proxy.Close()
  207. proxyURL, err := url.Parse(proxyEnv)
  208. if err != nil {
  209. t.Fatal(err)
  210. }
  211. a.proxyFunc = func(*http.Request) (*url.URL, error) {
  212. return proxyURL, nil
  213. }
  214. } else {
  215. a.proxyFunc = func(*http.Request) (*url.URL, error) {
  216. return nil, nil
  217. }
  218. }
  219. conn, err := a.dial(ctx)
  220. if err != nil {
  221. t.Fatalf("dialing controlhttp: %v", err)
  222. }
  223. defer conn.Close()
  224. si := <-sch
  225. if si.conn != nil {
  226. defer si.conn.Close()
  227. }
  228. if si.err != nil {
  229. t.Fatalf("controlhttp server got error: %v", err)
  230. }
  231. if clientVersion := conn.ProtocolVersion(); si.version != clientVersion {
  232. t.Fatalf("client and server don't agree on protocol version: %d vs %d", clientVersion, si.version)
  233. }
  234. if si.peer != client.Public() {
  235. t.Fatalf("server got peer pubkey %s, want %s", si.peer, client.Public())
  236. }
  237. if spub := conn.Peer(); spub != server.Public() {
  238. t.Fatalf("client got peer pubkey %s, want %s", spub, server.Public())
  239. }
  240. if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
  241. t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
  242. }
  243. if param.doEarlyWrite {
  244. buf := make([]byte, len(earlyWriteMsg))
  245. if _, err := io.ReadFull(conn, buf); err != nil {
  246. t.Fatalf("reading early write: %v", err)
  247. }
  248. if string(buf) != earlyWriteMsg {
  249. t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
  250. }
  251. }
  252. }
  253. type serverResult struct {
  254. err error
  255. clientAddr string
  256. version int
  257. peer key.MachinePublic
  258. conn *controlbase.Conn
  259. }
  260. type proxy interface {
  261. Start(*testing.T) string
  262. Close()
  263. ConnIsFromProxy(string) bool
  264. }
  265. type socksProxy struct {
  266. sync.Mutex
  267. closed bool
  268. proxy socks5.Server
  269. ln net.Listener
  270. clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
  271. }
  272. func (s *socksProxy) Start(t *testing.T) (url string) {
  273. t.Helper()
  274. s.Lock()
  275. defer s.Unlock()
  276. ln, err := net.Listen("tcp", "127.0.0.1:0")
  277. if err != nil {
  278. t.Fatalf("listening for SOCKS server: %v", err)
  279. }
  280. s.ln = ln
  281. s.clientConnAddrs = map[string]bool{}
  282. s.proxy.Logf = func(format string, a ...any) {
  283. s.Lock()
  284. defer s.Unlock()
  285. if s.closed {
  286. return
  287. }
  288. t.Logf(format, a...)
  289. }
  290. s.proxy.Dialer = s.dialAndRecord
  291. go s.proxy.Serve(ln)
  292. return fmt.Sprintf("socks5://%s", ln.Addr().String())
  293. }
  294. func (s *socksProxy) Close() {
  295. s.Lock()
  296. defer s.Unlock()
  297. if s.closed {
  298. return
  299. }
  300. s.closed = true
  301. s.ln.Close()
  302. }
  303. func (s *socksProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
  304. var d net.Dialer
  305. conn, err := d.DialContext(ctx, network, addr)
  306. if err != nil {
  307. return nil, err
  308. }
  309. s.Lock()
  310. defer s.Unlock()
  311. s.clientConnAddrs[conn.LocalAddr().String()] = true
  312. return conn, nil
  313. }
  314. func (s *socksProxy) ConnIsFromProxy(addr string) bool {
  315. s.Lock()
  316. defer s.Unlock()
  317. return s.clientConnAddrs[addr]
  318. }
  319. type httpProxy struct {
  320. useTLS bool // take incoming connections over TLS
  321. allowConnect bool // allow CONNECT for TLS
  322. allowHTTP bool // allow plain HTTP proxying
  323. sync.Mutex
  324. ln net.Listener
  325. rp httputil.ReverseProxy
  326. s http.Server
  327. clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
  328. }
  329. func (h *httpProxy) Start(t *testing.T) (url string) {
  330. t.Helper()
  331. h.Lock()
  332. defer h.Unlock()
  333. ln, err := net.Listen("tcp", "127.0.0.1:0")
  334. if err != nil {
  335. t.Fatalf("listening for HTTP proxy: %v", err)
  336. }
  337. h.ln = ln
  338. h.rp = httputil.ReverseProxy{
  339. Director: func(*http.Request) {},
  340. Transport: &http.Transport{
  341. DialContext: h.dialAndRecord,
  342. TLSClientConfig: &tls.Config{
  343. InsecureSkipVerify: true,
  344. },
  345. TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{},
  346. },
  347. }
  348. h.clientConnAddrs = map[string]bool{}
  349. h.s.Handler = h
  350. if h.useTLS {
  351. h.s.TLSConfig = tlsConfig(t)
  352. go h.s.ServeTLS(h.ln, "", "")
  353. return fmt.Sprintf("https://%s", ln.Addr().String())
  354. } else {
  355. go h.s.Serve(h.ln)
  356. return fmt.Sprintf("http://%s", ln.Addr().String())
  357. }
  358. }
  359. func (h *httpProxy) Close() {
  360. h.Lock()
  361. defer h.Unlock()
  362. h.s.Close()
  363. }
  364. func (h *httpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  365. if r.Method != "CONNECT" {
  366. if !h.allowHTTP {
  367. http.Error(w, "http proxy not allowed", 500)
  368. return
  369. }
  370. h.rp.ServeHTTP(w, r)
  371. return
  372. }
  373. if !h.allowConnect {
  374. http.Error(w, "connect not allowed", 500)
  375. return
  376. }
  377. dst := r.RequestURI
  378. c, err := h.dialAndRecord(context.Background(), "tcp", dst)
  379. if err != nil {
  380. http.Error(w, err.Error(), 500)
  381. return
  382. }
  383. defer c.Close()
  384. cc, ccbuf, err := w.(http.Hijacker).Hijack()
  385. if err != nil {
  386. http.Error(w, err.Error(), 500)
  387. return
  388. }
  389. defer cc.Close()
  390. io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n")
  391. errc := make(chan error, 1)
  392. go func() {
  393. _, err := io.Copy(cc, c)
  394. errc <- err
  395. }()
  396. go func() {
  397. _, err := io.Copy(c, ccbuf)
  398. errc <- err
  399. }()
  400. <-errc
  401. }
  402. func (h *httpProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
  403. var d net.Dialer
  404. conn, err := d.DialContext(ctx, network, addr)
  405. if err != nil {
  406. return nil, err
  407. }
  408. h.Lock()
  409. defer h.Unlock()
  410. h.clientConnAddrs[conn.LocalAddr().String()] = true
  411. return conn, nil
  412. }
  413. func (h *httpProxy) ConnIsFromProxy(addr string) bool {
  414. h.Lock()
  415. defer h.Unlock()
  416. return h.clientConnAddrs[addr]
  417. }
  418. func tlsConfig(t *testing.T) *tls.Config {
  419. // Cert and key taken from the example code in the crypto/tls
  420. // package.
  421. certPem := []byte(`-----BEGIN CERTIFICATE-----
  422. MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
  423. DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
  424. EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
  425. 7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
  426. 5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
  427. BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
  428. NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
  429. Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
  430. 6MF9+Yw1Yy0t
  431. -----END CERTIFICATE-----`)
  432. keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
  433. MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
  434. AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
  435. EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
  436. -----END EC PRIVATE KEY-----`)
  437. cert, err := tls.X509KeyPair(certPem, keyPem)
  438. if err != nil {
  439. t.Fatal(err)
  440. }
  441. return &tls.Config{
  442. Certificates: []tls.Certificate{cert},
  443. }
  444. }
  445. func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
  446. return func(w http.ResponseWriter, r *http.Request) {
  447. w.Header().Set("Upgrade", upgradeHeaderValue)
  448. w.Header().Set("Connection", "upgrade")
  449. w.WriteHeader(http.StatusSwitchingProtocols)
  450. w.(http.Flusher).Flush()
  451. // Advance the clock to trigger HTTPs fallback.
  452. clock.Now()
  453. <-r.Context().Done()
  454. }
  455. }
  456. func TestDialPlan(t *testing.T) {
  457. if runtime.GOOS != "linux" {
  458. t.Skip("only works on Linux due to multiple localhost addresses")
  459. }
  460. client, server := key.NewMachine(), key.NewMachine()
  461. const (
  462. testProtocolVersion = 1
  463. )
  464. getRandomPort := func() string {
  465. ln, err := net.Listen("tcp", ":0")
  466. if err != nil {
  467. t.Fatalf("net.Listen: %v", err)
  468. }
  469. defer ln.Close()
  470. _, port, err := net.SplitHostPort(ln.Addr().String())
  471. if err != nil {
  472. t.Fatal(err)
  473. }
  474. return port
  475. }
  476. // We need consistent ports for each address; these are chosen
  477. // randomly and we hope that they won't conflict during this test.
  478. httpPort := getRandomPort()
  479. httpsPort := getRandomPort()
  480. makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
  481. done := make(chan struct{})
  482. t.Cleanup(func() {
  483. close(done)
  484. })
  485. var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  486. conn, err := AcceptHTTP(context.Background(), w, r, server, nil)
  487. if err != nil {
  488. log.Print(err)
  489. } else {
  490. defer conn.Close()
  491. }
  492. w.Header().Set("X-Handler-Name", name)
  493. <-done
  494. })
  495. if wrap != nil {
  496. handler = wrap(handler)
  497. }
  498. httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
  499. if err != nil {
  500. t.Fatalf("HTTP listen: %v", err)
  501. }
  502. httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
  503. if err != nil {
  504. t.Fatalf("HTTPS listen: %v", err)
  505. }
  506. httpServer := &http.Server{Handler: handler}
  507. go httpServer.Serve(httpLn)
  508. t.Cleanup(func() {
  509. httpServer.Close()
  510. })
  511. httpsServer := &http.Server{
  512. Handler: handler,
  513. TLSConfig: tlsConfig(t),
  514. ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
  515. }
  516. go httpsServer.ServeTLS(httpsLn, "", "")
  517. t.Cleanup(func() {
  518. httpsServer.Close()
  519. })
  520. return
  521. }
  522. fallbackAddr := netip.MustParseAddr("127.0.0.1")
  523. goodAddr := netip.MustParseAddr("127.0.0.2")
  524. otherAddr := netip.MustParseAddr("127.0.0.3")
  525. other2Addr := netip.MustParseAddr("127.0.0.4")
  526. brokenAddr := netip.MustParseAddr("127.0.0.10")
  527. testCases := []struct {
  528. name string
  529. plan *tailcfg.ControlDialPlan
  530. wrap func(http.Handler) http.Handler
  531. want netip.Addr
  532. allowFallback bool
  533. }{
  534. {
  535. name: "single",
  536. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  537. {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
  538. }},
  539. want: goodAddr,
  540. },
  541. {
  542. name: "broken-then-good",
  543. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  544. // Dials the broken one, which fails, and then
  545. // eventually dials the good one and succeeds
  546. {IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
  547. {IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
  548. }},
  549. want: goodAddr,
  550. },
  551. // TODO(#8442): fix this test
  552. // {
  553. // name: "multiple-priority-fast-path",
  554. // plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  555. // // Dials some good IPs and our bad one (which
  556. // // hangs forever), which then hits the fast
  557. // // path where we bail without waiting.
  558. // {IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
  559. // {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
  560. // {IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
  561. // {IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
  562. // }},
  563. // want: otherAddr,
  564. // },
  565. {
  566. name: "multiple-priority-slow-path",
  567. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  568. // Our broken address is the highest priority,
  569. // so we don't hit our fast path.
  570. {IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
  571. {IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
  572. {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
  573. }},
  574. want: otherAddr,
  575. },
  576. {
  577. name: "fallback",
  578. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  579. {IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
  580. }},
  581. want: fallbackAddr,
  582. allowFallback: true,
  583. },
  584. }
  585. for _, tt := range testCases {
  586. t.Run(tt.name, func(t *testing.T) {
  587. // TODO(awly): replace this with tstest.NewClock and update the
  588. // test to advance the clock correctly.
  589. clock := tstime.StdClock{}
  590. makeHandler(t, "fallback", fallbackAddr, nil)
  591. makeHandler(t, "good", goodAddr, nil)
  592. makeHandler(t, "other", otherAddr, nil)
  593. makeHandler(t, "other2", other2Addr, nil)
  594. makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
  595. return brokenMITMHandler(clock)
  596. })
  597. dialer := closeTrackDialer{
  598. t: t,
  599. inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial,
  600. conns: make(map[*closeTrackConn]bool),
  601. }
  602. defer dialer.Done()
  603. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  604. defer cancel()
  605. // By default, we intentionally point to something that
  606. // we know won't connect, since we want a fallback to
  607. // DNS to be an error.
  608. host := "example.com"
  609. if tt.allowFallback {
  610. host = "localhost"
  611. }
  612. drained := make(chan struct{})
  613. a := &Dialer{
  614. Hostname: host,
  615. HTTPPort: httpPort,
  616. HTTPSPort: httpsPort,
  617. MachineKey: client,
  618. ControlKey: server.Public(),
  619. ProtocolVersion: testProtocolVersion,
  620. Dialer: dialer.Dial,
  621. Logf: t.Logf,
  622. DialPlan: tt.plan,
  623. proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
  624. drainFinished: drained,
  625. omitCertErrorLogging: true,
  626. testFallbackDelay: 50 * time.Millisecond,
  627. Clock: clock,
  628. }
  629. conn, err := a.dial(ctx)
  630. if err != nil {
  631. t.Fatalf("dialing controlhttp: %v", err)
  632. }
  633. defer conn.Close()
  634. raddr := conn.RemoteAddr().(*net.TCPAddr)
  635. got, ok := netip.AddrFromSlice(raddr.IP)
  636. if !ok {
  637. t.Errorf("invalid remote IP: %v", raddr.IP)
  638. } else if got != tt.want {
  639. t.Errorf("got connection from %q; want %q", got, tt.want)
  640. } else {
  641. t.Logf("successfully connected to %q", raddr.String())
  642. }
  643. // Wait until our dialer drains so we can verify that
  644. // all connections are closed.
  645. <-drained
  646. })
  647. }
  648. }
  649. type closeTrackDialer struct {
  650. t testing.TB
  651. inner dnscache.DialContextFunc
  652. mu sync.Mutex
  653. conns map[*closeTrackConn]bool
  654. }
  655. func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
  656. c, err := d.inner(ctx, network, addr)
  657. if err != nil {
  658. return nil, err
  659. }
  660. ct := &closeTrackConn{Conn: c, d: d}
  661. d.mu.Lock()
  662. d.conns[ct] = true
  663. d.mu.Unlock()
  664. return ct, nil
  665. }
  666. func (d *closeTrackDialer) Done() {
  667. // Unfortunately, tsdial.Dialer.SystemDial closes connections
  668. // asynchronously in a goroutine, so we can't assume that everything is
  669. // closed by the time we get here.
  670. //
  671. // Sleep/wait a few times on the assumption that things will close
  672. // "eventually".
  673. const iters = 100
  674. for i := range iters {
  675. d.mu.Lock()
  676. if len(d.conns) == 0 {
  677. d.mu.Unlock()
  678. return
  679. }
  680. // Only error on last iteration
  681. if i != iters-1 {
  682. d.mu.Unlock()
  683. time.Sleep(100 * time.Millisecond)
  684. continue
  685. }
  686. for conn := range d.conns {
  687. d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
  688. }
  689. d.mu.Unlock()
  690. }
  691. }
  692. func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
  693. d.mu.Lock()
  694. delete(d.conns, c) // safe if already deleted
  695. d.mu.Unlock()
  696. }
  697. type closeTrackConn struct {
  698. net.Conn
  699. d *closeTrackDialer
  700. }
  701. func (c *closeTrackConn) Close() error {
  702. c.d.noteClose(c)
  703. return c.Conn.Close()
  704. }