http_test.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlhttp
  4. import (
  5. "context"
  6. "crypto/tls"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httputil"
  13. "net/netip"
  14. "net/url"
  15. "runtime"
  16. "strconv"
  17. "sync"
  18. "testing"
  19. "time"
  20. "tailscale.com/control/controlbase"
  21. "tailscale.com/net/dnscache"
  22. "tailscale.com/net/socks5"
  23. "tailscale.com/net/tsdial"
  24. "tailscale.com/tailcfg"
  25. "tailscale.com/tstest"
  26. "tailscale.com/tstime"
  27. "tailscale.com/types/key"
  28. "tailscale.com/types/logger"
  29. )
  30. type httpTestParam struct {
  31. name string
  32. proxy proxy
  33. // makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
  34. // 101 switching protocols.
  35. makeHTTPHangAfterUpgrade bool
  36. doEarlyWrite bool
  37. }
  38. func TestControlHTTP(t *testing.T) {
  39. tests := []httpTestParam{
  40. // direct connection
  41. {
  42. name: "no_proxy",
  43. proxy: nil,
  44. },
  45. // direct connection but port 80 is MITM'ed and broken
  46. {
  47. name: "port80_broken_mitm",
  48. proxy: nil,
  49. makeHTTPHangAfterUpgrade: true,
  50. },
  51. // SOCKS5
  52. {
  53. name: "socks5",
  54. proxy: &socksProxy{},
  55. },
  56. // HTTP->HTTP
  57. {
  58. name: "http_to_http",
  59. proxy: &httpProxy{
  60. useTLS: false,
  61. allowConnect: false,
  62. allowHTTP: true,
  63. },
  64. },
  65. // HTTP->HTTPS
  66. {
  67. name: "http_to_https",
  68. proxy: &httpProxy{
  69. useTLS: false,
  70. allowConnect: true,
  71. allowHTTP: false,
  72. },
  73. },
  74. // HTTP->any (will pick HTTP)
  75. {
  76. name: "http_to_any",
  77. proxy: &httpProxy{
  78. useTLS: false,
  79. allowConnect: true,
  80. allowHTTP: true,
  81. },
  82. },
  83. // HTTPS->HTTP
  84. {
  85. name: "https_to_http",
  86. proxy: &httpProxy{
  87. useTLS: true,
  88. allowConnect: false,
  89. allowHTTP: true,
  90. },
  91. },
  92. // HTTPS->HTTPS
  93. {
  94. name: "https_to_https",
  95. proxy: &httpProxy{
  96. useTLS: true,
  97. allowConnect: true,
  98. allowHTTP: false,
  99. },
  100. },
  101. // HTTPS->any (will pick HTTP)
  102. {
  103. name: "https_to_any",
  104. proxy: &httpProxy{
  105. useTLS: true,
  106. allowConnect: true,
  107. allowHTTP: true,
  108. },
  109. },
  110. // Early write
  111. {
  112. name: "early_write",
  113. doEarlyWrite: true,
  114. },
  115. }
  116. for _, test := range tests {
  117. t.Run(test.name, func(t *testing.T) {
  118. testControlHTTP(t, test)
  119. })
  120. }
  121. }
  122. func testControlHTTP(t *testing.T, param httpTestParam) {
  123. proxy := param.proxy
  124. client, server := key.NewMachine(), key.NewMachine()
  125. const testProtocolVersion = 1
  126. const earlyWriteMsg = "Hello, world!"
  127. sch := make(chan serverResult, 1)
  128. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  129. var earlyWriteFn func(protocolVersion int, w io.Writer) error
  130. if param.doEarlyWrite {
  131. earlyWriteFn = func(protocolVersion int, w io.Writer) error {
  132. if protocolVersion != testProtocolVersion {
  133. t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
  134. return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
  135. }
  136. _, err := io.WriteString(w, earlyWriteMsg)
  137. return err
  138. }
  139. }
  140. conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn)
  141. if err != nil {
  142. log.Print(err)
  143. }
  144. res := serverResult{
  145. err: err,
  146. }
  147. if conn != nil {
  148. res.clientAddr = conn.RemoteAddr().String()
  149. res.version = conn.ProtocolVersion()
  150. res.peer = conn.Peer()
  151. res.conn = conn
  152. }
  153. sch <- res
  154. })
  155. httpLn, err := net.Listen("tcp", "127.0.0.1:0")
  156. if err != nil {
  157. t.Fatalf("HTTP listen: %v", err)
  158. }
  159. httpsLn, err := net.Listen("tcp", "127.0.0.1:0")
  160. if err != nil {
  161. t.Fatalf("HTTPS listen: %v", err)
  162. }
  163. var httpHandler http.Handler = handler
  164. const fallbackDelay = 50 * time.Millisecond
  165. clock := tstest.NewClock(tstest.ClockOpts{Step: 2 * fallbackDelay})
  166. // Advance once to init the clock.
  167. clock.Now()
  168. if param.makeHTTPHangAfterUpgrade {
  169. httpHandler = brokenMITMHandler(clock)
  170. }
  171. httpServer := &http.Server{Handler: httpHandler}
  172. go httpServer.Serve(httpLn)
  173. defer httpServer.Close()
  174. httpsServer := &http.Server{
  175. Handler: handler,
  176. TLSConfig: tlsConfig(t),
  177. }
  178. go httpsServer.ServeTLS(httpsLn, "", "")
  179. defer httpsServer.Close()
  180. ctx := context.Background()
  181. const debugTimeout = false
  182. if debugTimeout {
  183. var cancel context.CancelFunc
  184. ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
  185. defer cancel()
  186. }
  187. a := &Dialer{
  188. Hostname: "localhost",
  189. HTTPPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
  190. HTTPSPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
  191. MachineKey: client,
  192. ControlKey: server.Public(),
  193. ProtocolVersion: testProtocolVersion,
  194. Dialer: new(tsdial.Dialer).SystemDial,
  195. Logf: t.Logf,
  196. omitCertErrorLogging: true,
  197. testFallbackDelay: fallbackDelay,
  198. Clock: clock,
  199. }
  200. if proxy != nil {
  201. proxyEnv := proxy.Start(t)
  202. defer proxy.Close()
  203. proxyURL, err := url.Parse(proxyEnv)
  204. if err != nil {
  205. t.Fatal(err)
  206. }
  207. a.proxyFunc = func(*http.Request) (*url.URL, error) {
  208. return proxyURL, nil
  209. }
  210. } else {
  211. a.proxyFunc = func(*http.Request) (*url.URL, error) {
  212. return nil, nil
  213. }
  214. }
  215. conn, err := a.dial(ctx)
  216. if err != nil {
  217. t.Fatalf("dialing controlhttp: %v", err)
  218. }
  219. defer conn.Close()
  220. si := <-sch
  221. if si.conn != nil {
  222. defer si.conn.Close()
  223. }
  224. if si.err != nil {
  225. t.Fatalf("controlhttp server got error: %v", err)
  226. }
  227. if clientVersion := conn.ProtocolVersion(); si.version != clientVersion {
  228. t.Fatalf("client and server don't agree on protocol version: %d vs %d", clientVersion, si.version)
  229. }
  230. if si.peer != client.Public() {
  231. t.Fatalf("server got peer pubkey %s, want %s", si.peer, client.Public())
  232. }
  233. if spub := conn.Peer(); spub != server.Public() {
  234. t.Fatalf("client got peer pubkey %s, want %s", spub, server.Public())
  235. }
  236. if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
  237. t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
  238. }
  239. if param.doEarlyWrite {
  240. buf := make([]byte, len(earlyWriteMsg))
  241. if _, err := io.ReadFull(conn, buf); err != nil {
  242. t.Fatalf("reading early write: %v", err)
  243. }
  244. if string(buf) != earlyWriteMsg {
  245. t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
  246. }
  247. }
  248. }
  249. type serverResult struct {
  250. err error
  251. clientAddr string
  252. version int
  253. peer key.MachinePublic
  254. conn *controlbase.Conn
  255. }
  256. type proxy interface {
  257. Start(*testing.T) string
  258. Close()
  259. ConnIsFromProxy(string) bool
  260. }
  261. type socksProxy struct {
  262. sync.Mutex
  263. closed bool
  264. proxy socks5.Server
  265. ln net.Listener
  266. clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
  267. }
  268. func (s *socksProxy) Start(t *testing.T) (url string) {
  269. t.Helper()
  270. s.Lock()
  271. defer s.Unlock()
  272. ln, err := net.Listen("tcp", "127.0.0.1:0")
  273. if err != nil {
  274. t.Fatalf("listening for SOCKS server: %v", err)
  275. }
  276. s.ln = ln
  277. s.clientConnAddrs = map[string]bool{}
  278. s.proxy.Logf = func(format string, a ...any) {
  279. s.Lock()
  280. defer s.Unlock()
  281. if s.closed {
  282. return
  283. }
  284. t.Logf(format, a...)
  285. }
  286. s.proxy.Dialer = s.dialAndRecord
  287. go s.proxy.Serve(ln)
  288. return fmt.Sprintf("socks5://%s", ln.Addr().String())
  289. }
  290. func (s *socksProxy) Close() {
  291. s.Lock()
  292. defer s.Unlock()
  293. if s.closed {
  294. return
  295. }
  296. s.closed = true
  297. s.ln.Close()
  298. }
  299. func (s *socksProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
  300. var d net.Dialer
  301. conn, err := d.DialContext(ctx, network, addr)
  302. if err != nil {
  303. return nil, err
  304. }
  305. s.Lock()
  306. defer s.Unlock()
  307. s.clientConnAddrs[conn.LocalAddr().String()] = true
  308. return conn, nil
  309. }
  310. func (s *socksProxy) ConnIsFromProxy(addr string) bool {
  311. s.Lock()
  312. defer s.Unlock()
  313. return s.clientConnAddrs[addr]
  314. }
  315. type httpProxy struct {
  316. useTLS bool // take incoming connections over TLS
  317. allowConnect bool // allow CONNECT for TLS
  318. allowHTTP bool // allow plain HTTP proxying
  319. sync.Mutex
  320. ln net.Listener
  321. rp httputil.ReverseProxy
  322. s http.Server
  323. clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
  324. }
  325. func (h *httpProxy) Start(t *testing.T) (url string) {
  326. t.Helper()
  327. h.Lock()
  328. defer h.Unlock()
  329. ln, err := net.Listen("tcp", "127.0.0.1:0")
  330. if err != nil {
  331. t.Fatalf("listening for HTTP proxy: %v", err)
  332. }
  333. h.ln = ln
  334. h.rp = httputil.ReverseProxy{
  335. Director: func(*http.Request) {},
  336. Transport: &http.Transport{
  337. DialContext: h.dialAndRecord,
  338. TLSClientConfig: &tls.Config{
  339. InsecureSkipVerify: true,
  340. },
  341. TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{},
  342. },
  343. }
  344. h.clientConnAddrs = map[string]bool{}
  345. h.s.Handler = h
  346. if h.useTLS {
  347. h.s.TLSConfig = tlsConfig(t)
  348. go h.s.ServeTLS(h.ln, "", "")
  349. return fmt.Sprintf("https://%s", ln.Addr().String())
  350. } else {
  351. go h.s.Serve(h.ln)
  352. return fmt.Sprintf("http://%s", ln.Addr().String())
  353. }
  354. }
  355. func (h *httpProxy) Close() {
  356. h.Lock()
  357. defer h.Unlock()
  358. h.s.Close()
  359. }
  360. func (h *httpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  361. if r.Method != "CONNECT" {
  362. if !h.allowHTTP {
  363. http.Error(w, "http proxy not allowed", 500)
  364. return
  365. }
  366. h.rp.ServeHTTP(w, r)
  367. return
  368. }
  369. if !h.allowConnect {
  370. http.Error(w, "connect not allowed", 500)
  371. return
  372. }
  373. dst := r.RequestURI
  374. c, err := h.dialAndRecord(context.Background(), "tcp", dst)
  375. if err != nil {
  376. http.Error(w, err.Error(), 500)
  377. return
  378. }
  379. defer c.Close()
  380. cc, ccbuf, err := w.(http.Hijacker).Hijack()
  381. if err != nil {
  382. http.Error(w, err.Error(), 500)
  383. return
  384. }
  385. defer cc.Close()
  386. io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n")
  387. errc := make(chan error, 1)
  388. go func() {
  389. _, err := io.Copy(cc, c)
  390. errc <- err
  391. }()
  392. go func() {
  393. _, err := io.Copy(c, ccbuf)
  394. errc <- err
  395. }()
  396. <-errc
  397. }
  398. func (h *httpProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
  399. var d net.Dialer
  400. conn, err := d.DialContext(ctx, network, addr)
  401. if err != nil {
  402. return nil, err
  403. }
  404. h.Lock()
  405. defer h.Unlock()
  406. h.clientConnAddrs[conn.LocalAddr().String()] = true
  407. return conn, nil
  408. }
  409. func (h *httpProxy) ConnIsFromProxy(addr string) bool {
  410. h.Lock()
  411. defer h.Unlock()
  412. return h.clientConnAddrs[addr]
  413. }
  414. func tlsConfig(t *testing.T) *tls.Config {
  415. // Cert and key taken from the example code in the crypto/tls
  416. // package.
  417. certPem := []byte(`-----BEGIN CERTIFICATE-----
  418. MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
  419. DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
  420. EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
  421. 7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
  422. 5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
  423. BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
  424. NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
  425. Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
  426. 6MF9+Yw1Yy0t
  427. -----END CERTIFICATE-----`)
  428. keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
  429. MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
  430. AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
  431. EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
  432. -----END EC PRIVATE KEY-----`)
  433. cert, err := tls.X509KeyPair(certPem, keyPem)
  434. if err != nil {
  435. t.Fatal(err)
  436. }
  437. return &tls.Config{
  438. Certificates: []tls.Certificate{cert},
  439. }
  440. }
  441. func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
  442. return func(w http.ResponseWriter, r *http.Request) {
  443. w.Header().Set("Upgrade", upgradeHeaderValue)
  444. w.Header().Set("Connection", "upgrade")
  445. w.WriteHeader(http.StatusSwitchingProtocols)
  446. w.(http.Flusher).Flush()
  447. // Advance the clock to trigger HTTPs fallback.
  448. clock.Now()
  449. <-r.Context().Done()
  450. }
  451. }
  452. func TestDialPlan(t *testing.T) {
  453. if runtime.GOOS != "linux" {
  454. t.Skip("only works on Linux due to multiple localhost addresses")
  455. }
  456. client, server := key.NewMachine(), key.NewMachine()
  457. const (
  458. testProtocolVersion = 1
  459. )
  460. getRandomPort := func() string {
  461. ln, err := net.Listen("tcp", ":0")
  462. if err != nil {
  463. t.Fatalf("net.Listen: %v", err)
  464. }
  465. defer ln.Close()
  466. _, port, err := net.SplitHostPort(ln.Addr().String())
  467. if err != nil {
  468. t.Fatal(err)
  469. }
  470. return port
  471. }
  472. // We need consistent ports for each address; these are chosen
  473. // randomly and we hope that they won't conflict during this test.
  474. httpPort := getRandomPort()
  475. httpsPort := getRandomPort()
  476. makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
  477. done := make(chan struct{})
  478. t.Cleanup(func() {
  479. close(done)
  480. })
  481. var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  482. conn, err := AcceptHTTP(context.Background(), w, r, server, nil)
  483. if err != nil {
  484. log.Print(err)
  485. } else {
  486. defer conn.Close()
  487. }
  488. w.Header().Set("X-Handler-Name", name)
  489. <-done
  490. })
  491. if wrap != nil {
  492. handler = wrap(handler)
  493. }
  494. httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
  495. if err != nil {
  496. t.Fatalf("HTTP listen: %v", err)
  497. }
  498. httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
  499. if err != nil {
  500. t.Fatalf("HTTPS listen: %v", err)
  501. }
  502. httpServer := &http.Server{Handler: handler}
  503. go httpServer.Serve(httpLn)
  504. t.Cleanup(func() {
  505. httpServer.Close()
  506. })
  507. httpsServer := &http.Server{
  508. Handler: handler,
  509. TLSConfig: tlsConfig(t),
  510. ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
  511. }
  512. go httpsServer.ServeTLS(httpsLn, "", "")
  513. t.Cleanup(func() {
  514. httpsServer.Close()
  515. })
  516. return
  517. }
  518. fallbackAddr := netip.MustParseAddr("127.0.0.1")
  519. goodAddr := netip.MustParseAddr("127.0.0.2")
  520. otherAddr := netip.MustParseAddr("127.0.0.3")
  521. other2Addr := netip.MustParseAddr("127.0.0.4")
  522. brokenAddr := netip.MustParseAddr("127.0.0.10")
  523. testCases := []struct {
  524. name string
  525. plan *tailcfg.ControlDialPlan
  526. wrap func(http.Handler) http.Handler
  527. want netip.Addr
  528. allowFallback bool
  529. }{
  530. {
  531. name: "single",
  532. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  533. {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
  534. }},
  535. want: goodAddr,
  536. },
  537. {
  538. name: "broken-then-good",
  539. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  540. // Dials the broken one, which fails, and then
  541. // eventually dials the good one and succeeds
  542. {IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
  543. {IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
  544. }},
  545. want: goodAddr,
  546. },
  547. // TODO(#8442): fix this test
  548. // {
  549. // name: "multiple-priority-fast-path",
  550. // plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  551. // // Dials some good IPs and our bad one (which
  552. // // hangs forever), which then hits the fast
  553. // // path where we bail without waiting.
  554. // {IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
  555. // {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
  556. // {IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
  557. // {IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
  558. // }},
  559. // want: otherAddr,
  560. // },
  561. {
  562. name: "multiple-priority-slow-path",
  563. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  564. // Our broken address is the highest priority,
  565. // so we don't hit our fast path.
  566. {IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
  567. {IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
  568. {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
  569. }},
  570. want: otherAddr,
  571. },
  572. {
  573. name: "fallback",
  574. plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
  575. {IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
  576. }},
  577. want: fallbackAddr,
  578. allowFallback: true,
  579. },
  580. }
  581. for _, tt := range testCases {
  582. t.Run(tt.name, func(t *testing.T) {
  583. // TODO(awly): replace this with tstest.NewClock and update the
  584. // test to advance the clock correctly.
  585. clock := tstime.StdClock{}
  586. makeHandler(t, "fallback", fallbackAddr, nil)
  587. makeHandler(t, "good", goodAddr, nil)
  588. makeHandler(t, "other", otherAddr, nil)
  589. makeHandler(t, "other2", other2Addr, nil)
  590. makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
  591. return brokenMITMHandler(clock)
  592. })
  593. dialer := closeTrackDialer{
  594. t: t,
  595. inner: new(tsdial.Dialer).SystemDial,
  596. conns: make(map[*closeTrackConn]bool),
  597. }
  598. defer dialer.Done()
  599. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  600. defer cancel()
  601. // By default, we intentionally point to something that
  602. // we know won't connect, since we want a fallback to
  603. // DNS to be an error.
  604. host := "example.com"
  605. if tt.allowFallback {
  606. host = "localhost"
  607. }
  608. drained := make(chan struct{})
  609. a := &Dialer{
  610. Hostname: host,
  611. HTTPPort: httpPort,
  612. HTTPSPort: httpsPort,
  613. MachineKey: client,
  614. ControlKey: server.Public(),
  615. ProtocolVersion: testProtocolVersion,
  616. Dialer: dialer.Dial,
  617. Logf: t.Logf,
  618. DialPlan: tt.plan,
  619. proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
  620. drainFinished: drained,
  621. omitCertErrorLogging: true,
  622. testFallbackDelay: 50 * time.Millisecond,
  623. Clock: clock,
  624. }
  625. conn, err := a.dial(ctx)
  626. if err != nil {
  627. t.Fatalf("dialing controlhttp: %v", err)
  628. }
  629. defer conn.Close()
  630. raddr := conn.RemoteAddr().(*net.TCPAddr)
  631. got, ok := netip.AddrFromSlice(raddr.IP)
  632. if !ok {
  633. t.Errorf("invalid remote IP: %v", raddr.IP)
  634. } else if got != tt.want {
  635. t.Errorf("got connection from %q; want %q", got, tt.want)
  636. } else {
  637. t.Logf("successfully connected to %q", raddr.String())
  638. }
  639. // Wait until our dialer drains so we can verify that
  640. // all connections are closed.
  641. <-drained
  642. })
  643. }
  644. }
  645. type closeTrackDialer struct {
  646. t testing.TB
  647. inner dnscache.DialContextFunc
  648. mu sync.Mutex
  649. conns map[*closeTrackConn]bool
  650. }
  651. func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
  652. c, err := d.inner(ctx, network, addr)
  653. if err != nil {
  654. return nil, err
  655. }
  656. ct := &closeTrackConn{Conn: c, d: d}
  657. d.mu.Lock()
  658. d.conns[ct] = true
  659. d.mu.Unlock()
  660. return ct, nil
  661. }
  662. func (d *closeTrackDialer) Done() {
  663. // Unfortunately, tsdial.Dialer.SystemDial closes connections
  664. // asynchronously in a goroutine, so we can't assume that everything is
  665. // closed by the time we get here.
  666. //
  667. // Sleep/wait a few times on the assumption that things will close
  668. // "eventually".
  669. const iters = 100
  670. for i := 0; i < iters; i++ {
  671. d.mu.Lock()
  672. if len(d.conns) == 0 {
  673. d.mu.Unlock()
  674. return
  675. }
  676. // Only error on last iteration
  677. if i != iters-1 {
  678. d.mu.Unlock()
  679. time.Sleep(100 * time.Millisecond)
  680. continue
  681. }
  682. for conn := range d.conns {
  683. d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
  684. }
  685. d.mu.Unlock()
  686. }
  687. }
  688. func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
  689. d.mu.Lock()
  690. delete(d.conns, c) // safe if already deleted
  691. d.mu.Unlock()
  692. }
  693. type closeTrackConn struct {
  694. net.Conn
  695. d *closeTrackDialer
  696. }
  697. func (c *closeTrackConn) Close() error {
  698. c.d.noteClose(c)
  699. return c.Conn.Close()
  700. }