controlclient_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. // Copyright (c) Tailscale Inc & contributors
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlclient
  4. import (
  5. "context"
  6. "crypto/tls"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "net"
  12. "net/http"
  13. "net/netip"
  14. "net/url"
  15. "reflect"
  16. "sync/atomic"
  17. "testing"
  18. "time"
  19. "tailscale.com/control/controlknobs"
  20. "tailscale.com/health"
  21. "tailscale.com/net/bakedroots"
  22. "tailscale.com/net/connectproxy"
  23. "tailscale.com/net/netmon"
  24. "tailscale.com/net/tsdial"
  25. "tailscale.com/tailcfg"
  26. "tailscale.com/tstest"
  27. "tailscale.com/tstest/integration/testcontrol"
  28. "tailscale.com/tstest/tlstest"
  29. "tailscale.com/tstime"
  30. "tailscale.com/types/key"
  31. "tailscale.com/types/logger"
  32. "tailscale.com/types/netmap"
  33. "tailscale.com/types/persist"
  34. "tailscale.com/util/eventbus/eventbustest"
  35. )
  36. func fieldsOf(t reflect.Type) (fields []string) {
  37. for i := range t.NumField() {
  38. if name := t.Field(i).Name; name != "_" {
  39. fields = append(fields, name)
  40. }
  41. }
  42. return
  43. }
  44. func TestStatusEqual(t *testing.T) {
  45. // Verify that the Equal method stays in sync with reality
  46. equalHandles := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
  47. if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) {
  48. t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
  49. have, equalHandles)
  50. }
  51. tests := []struct {
  52. a, b *Status
  53. want bool
  54. }{
  55. {
  56. &Status{},
  57. nil,
  58. false,
  59. },
  60. {
  61. nil,
  62. &Status{},
  63. false,
  64. },
  65. {
  66. nil,
  67. nil,
  68. true,
  69. },
  70. {
  71. &Status{},
  72. &Status{},
  73. true,
  74. },
  75. {
  76. &Status{},
  77. &Status{LoggedIn: true, Persist: new(persist.Persist).View()},
  78. false,
  79. },
  80. }
  81. for i, tt := range tests {
  82. got := tt.a.Equal(tt.b)
  83. if got != tt.want {
  84. t.Errorf("%d. Equal = %v; want %v", i, got, tt.want)
  85. }
  86. }
  87. }
  88. // tests [canSkipStatus].
  89. func TestCanSkipStatus(t *testing.T) {
  90. st := new(Status)
  91. nm1 := &netmap.NetworkMap{}
  92. nm2 := &netmap.NetworkMap{}
  93. commonPersist := new(persist.Persist).View()
  94. tests := []struct {
  95. name string
  96. s1, s2 *Status
  97. want bool
  98. }{
  99. {
  100. name: "nil-s2",
  101. s1: st,
  102. s2: nil,
  103. want: false,
  104. },
  105. {
  106. name: "equal",
  107. s1: st,
  108. s2: st,
  109. want: false,
  110. },
  111. {
  112. name: "s1-error",
  113. s1: &Status{Err: io.EOF, NetMap: nm1},
  114. s2: &Status{NetMap: nm2},
  115. want: false,
  116. },
  117. {
  118. name: "s1-url",
  119. s1: &Status{URL: "foo", NetMap: nm1},
  120. s2: &Status{NetMap: nm2},
  121. want: false,
  122. },
  123. {
  124. name: "s1-persist-diff",
  125. s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1},
  126. s2: &Status{NetMap: nm2},
  127. want: false,
  128. },
  129. {
  130. name: "s1-login-finished-diff",
  131. s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
  132. s2: &Status{NetMap: nm2},
  133. want: false,
  134. },
  135. {
  136. name: "s1-login-finished",
  137. s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
  138. s2: &Status{NetMap: nm2},
  139. want: false,
  140. },
  141. {
  142. name: "s1-synced-diff",
  143. s1: &Status{InMapPoll: true, LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
  144. s2: &Status{NetMap: nm2},
  145. want: false,
  146. },
  147. {
  148. name: "s1-no-netmap1",
  149. s1: &Status{NetMap: nil},
  150. s2: &Status{NetMap: nm2},
  151. want: false,
  152. },
  153. {
  154. name: "s1-no-netmap2",
  155. s1: &Status{NetMap: nm1},
  156. s2: &Status{NetMap: nil},
  157. want: false,
  158. },
  159. {
  160. name: "skip",
  161. s1: &Status{NetMap: nm1, LoggedIn: true, InMapPoll: true, Persist: commonPersist},
  162. s2: &Status{NetMap: nm2, LoggedIn: true, InMapPoll: true, Persist: commonPersist},
  163. want: true,
  164. },
  165. }
  166. for _, tt := range tests {
  167. t.Run(tt.name, func(t *testing.T) {
  168. if got := canSkipStatus(tt.s1, tt.s2); got != tt.want {
  169. t.Errorf("canSkipStatus = %v, want %v", got, tt.want)
  170. }
  171. })
  172. }
  173. coveredFields := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
  174. if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, coveredFields) {
  175. t.Errorf("Status fields = %q; this code was only written to handle fields %q", have, coveredFields)
  176. }
  177. }
  178. func TestRetryableErrors(t *testing.T) {
  179. errorTests := []struct {
  180. err error
  181. want bool
  182. }{
  183. {errNoNoiseClient, true},
  184. {errNoNodeKey, true},
  185. {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true},
  186. {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true},
  187. {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true},
  188. {errBadHTTPResponse(429, "too may requests"), true},
  189. {errBadHTTPResponse(500, "internal server error"), true},
  190. {errBadHTTPResponse(502, "bad gateway"), true},
  191. {errBadHTTPResponse(503, "service unavailable"), true},
  192. {errBadHTTPResponse(504, "gateway timeout"), true},
  193. {errBadHTTPResponse(1234, "random error"), false},
  194. }
  195. for _, tt := range errorTests {
  196. t.Run(tt.err.Error(), func(t *testing.T) {
  197. if isRetryableErrorForTest(tt.err) != tt.want {
  198. t.Fatalf("retriable: got %v, want %v", tt.err, tt.want)
  199. }
  200. })
  201. }
  202. }
  203. type retryableForTest interface {
  204. Retryable() bool
  205. }
  206. func isRetryableErrorForTest(err error) bool {
  207. var ae retryableForTest
  208. if errors.As(err, &ae) {
  209. return ae.Retryable()
  210. }
  211. return false
  212. }
  213. var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
  214. func TestDirectProxyManual(t *testing.T) {
  215. if !*liveNetworkTest {
  216. t.Skip("skipping without --live-network-test")
  217. }
  218. bus := eventbustest.NewBus(t)
  219. dialer := &tsdial.Dialer{}
  220. dialer.SetNetMon(netmon.NewStatic())
  221. dialer.SetBus(bus)
  222. opts := Options{
  223. Persist: persist.Persist{},
  224. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  225. return key.NewMachine(), nil
  226. },
  227. ServerURL: "https://controlplane.tailscale.com",
  228. Clock: tstime.StdClock{},
  229. Hostinfo: &tailcfg.Hostinfo{
  230. BackendLogID: "test-backend-log-id",
  231. },
  232. DiscoPublicKey: key.NewDisco().Public(),
  233. Logf: t.Logf,
  234. HealthTracker: health.NewTracker(bus),
  235. PopBrowserURL: func(url string) {
  236. t.Logf("PopBrowserURL: %q", url)
  237. },
  238. Dialer: dialer,
  239. ControlKnobs: &controlknobs.Knobs{},
  240. Bus: bus,
  241. }
  242. d, err := NewDirect(opts)
  243. if err != nil {
  244. t.Fatalf("NewDirect: %v", err)
  245. }
  246. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  247. defer cancel()
  248. url, err := d.TryLogin(ctx, LoginEphemeral)
  249. if err != nil {
  250. t.Fatalf("TryLogin: %v", err)
  251. }
  252. t.Logf("URL: %q", url)
  253. }
  254. func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
  255. // TestTLSWithProxy verifies we can connect to the control plane via
  256. // an HTTPS proxy.
  257. func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
  258. func testHTTPS(t *testing.T, withProxy bool) {
  259. bakedroots.ResetForTest(t, tlstest.TestRootCA())
  260. bus := eventbustest.NewBus(t)
  261. controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
  262. if err != nil {
  263. t.Fatal(err)
  264. }
  265. defer controlLn.Close()
  266. proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServer.ServerTLSConfig())
  267. if err != nil {
  268. t.Fatal(err)
  269. }
  270. defer proxyLn.Close()
  271. const requiredAuthKey = "hunter2"
  272. const someUsername = "testuser"
  273. const somePassword = "testpass"
  274. testControl := &testcontrol.Server{
  275. Logf: tstest.WhileTestRunningLogger(t),
  276. RequireAuthKey: requiredAuthKey,
  277. }
  278. controlSrv := &http.Server{
  279. Handler: testControl,
  280. ErrorLog: logger.StdLogger(t.Logf),
  281. }
  282. go controlSrv.Serve(controlLn)
  283. const fakeControlIP = "1.2.3.4"
  284. const fakeProxyIP = "5.6.7.8"
  285. dialer := &tsdial.Dialer{}
  286. dialer.SetNetMon(netmon.NewStatic())
  287. dialer.SetBus(bus)
  288. dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
  289. host, _, err := net.SplitHostPort(addr)
  290. if err != nil {
  291. return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
  292. }
  293. var d net.Dialer
  294. if host == fakeControlIP {
  295. return d.DialContext(ctx, network, controlLn.Addr().String())
  296. }
  297. if host == fakeProxyIP {
  298. return d.DialContext(ctx, network, proxyLn.Addr().String())
  299. }
  300. return nil, fmt.Errorf("unexpected dial to %q", addr)
  301. })
  302. opts := Options{
  303. Persist: persist.Persist{},
  304. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  305. return key.NewMachine(), nil
  306. },
  307. AuthKey: requiredAuthKey,
  308. ServerURL: "https://controlplane.tstest",
  309. Clock: tstime.StdClock{},
  310. Hostinfo: &tailcfg.Hostinfo{
  311. BackendLogID: "test-backend-log-id",
  312. },
  313. DiscoPublicKey: key.NewDisco().Public(),
  314. Logf: t.Logf,
  315. HealthTracker: health.NewTracker(bus),
  316. PopBrowserURL: func(url string) {
  317. t.Logf("PopBrowserURL: %q", url)
  318. },
  319. Dialer: dialer,
  320. Bus: bus,
  321. }
  322. d, err := NewDirect(opts)
  323. if err != nil {
  324. t.Fatalf("NewDirect: %v", err)
  325. }
  326. d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
  327. switch host {
  328. case "controlplane.tstest":
  329. return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
  330. case "proxy.tstest":
  331. if !withProxy {
  332. t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
  333. return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
  334. }
  335. return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
  336. }
  337. t.Errorf("unexpected DNS query for %q", host)
  338. return []netip.Addr{}, nil
  339. }
  340. var proxyReqs atomic.Int64
  341. if withProxy {
  342. d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
  343. t.Logf("using proxy for %q", req.URL)
  344. u := &url.URL{
  345. Scheme: "https",
  346. Host: "proxy.tstest:443",
  347. User: url.UserPassword(someUsername, somePassword),
  348. }
  349. return u, nil
  350. }
  351. connectProxy := &http.Server{
  352. Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
  353. }
  354. go connectProxy.Serve(proxyLn)
  355. }
  356. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  357. defer cancel()
  358. url, err := d.TryLogin(ctx, LoginEphemeral)
  359. if err != nil {
  360. t.Fatalf("TryLogin: %v", err)
  361. }
  362. if url != "" {
  363. t.Errorf("got URL %q, want empty", url)
  364. }
  365. if withProxy {
  366. if got, want := proxyReqs.Load(), int64(1); got != want {
  367. t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
  368. }
  369. }
  370. }
  371. func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
  372. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  373. if r.RequestURI != target {
  374. t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
  375. http.Error(w, "bad target", http.StatusBadRequest)
  376. return
  377. }
  378. r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
  379. user, pass, ok := r.BasicAuth()
  380. if !ok || user != "testuser" || pass != "testpass" {
  381. t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
  382. http.Error(w, "bad auth", http.StatusUnauthorized)
  383. return
  384. }
  385. (&connectproxy.Handler{
  386. Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
  387. var d net.Dialer
  388. c, err := d.DialContext(ctx, network, backendAddrPort)
  389. if err == nil {
  390. reqs.Add(1)
  391. }
  392. return c, err
  393. },
  394. Logf: t.Logf,
  395. }).ServeHTTP(w, r)
  396. })
  397. }