controlclient_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlclient
  4. import (
  5. "context"
  6. "crypto/tls"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "net"
  12. "net/http"
  13. "net/netip"
  14. "net/url"
  15. "reflect"
  16. "sync/atomic"
  17. "testing"
  18. "time"
  19. "tailscale.com/control/controlknobs"
  20. "tailscale.com/health"
  21. "tailscale.com/net/bakedroots"
  22. "tailscale.com/net/connectproxy"
  23. "tailscale.com/net/netmon"
  24. "tailscale.com/net/tsdial"
  25. "tailscale.com/tailcfg"
  26. "tailscale.com/tstest"
  27. "tailscale.com/tstest/integration/testcontrol"
  28. "tailscale.com/tstest/tlstest"
  29. "tailscale.com/tstime"
  30. "tailscale.com/types/key"
  31. "tailscale.com/types/logger"
  32. "tailscale.com/types/netmap"
  33. "tailscale.com/types/persist"
  34. "tailscale.com/util/eventbus/eventbustest"
  35. )
  36. func fieldsOf(t reflect.Type) (fields []string) {
  37. for i := range t.NumField() {
  38. if name := t.Field(i).Name; name != "_" {
  39. fields = append(fields, name)
  40. }
  41. }
  42. return
  43. }
  44. func TestStatusEqual(t *testing.T) {
  45. // Verify that the Equal method stays in sync with reality
  46. equalHandles := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
  47. if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) {
  48. t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
  49. have, equalHandles)
  50. }
  51. tests := []struct {
  52. a, b *Status
  53. want bool
  54. }{
  55. {
  56. &Status{},
  57. nil,
  58. false,
  59. },
  60. {
  61. nil,
  62. &Status{},
  63. false,
  64. },
  65. {
  66. nil,
  67. nil,
  68. true,
  69. },
  70. {
  71. &Status{},
  72. &Status{},
  73. true,
  74. },
  75. {
  76. &Status{},
  77. &Status{LoggedIn: true, Persist: new(persist.Persist).View()},
  78. false,
  79. },
  80. }
  81. for i, tt := range tests {
  82. got := tt.a.Equal(tt.b)
  83. if got != tt.want {
  84. t.Errorf("%d. Equal = %v; want %v", i, got, tt.want)
  85. }
  86. }
  87. }
  88. // tests [canSkipStatus].
  89. func TestCanSkipStatus(t *testing.T) {
  90. st := new(Status)
  91. nm1 := &netmap.NetworkMap{}
  92. nm2 := &netmap.NetworkMap{}
  93. tests := []struct {
  94. name string
  95. s1, s2 *Status
  96. want bool
  97. }{
  98. {
  99. name: "nil-s2",
  100. s1: st,
  101. s2: nil,
  102. want: false,
  103. },
  104. {
  105. name: "equal",
  106. s1: st,
  107. s2: st,
  108. want: false,
  109. },
  110. {
  111. name: "s1-error",
  112. s1: &Status{Err: io.EOF, NetMap: nm1},
  113. s2: &Status{NetMap: nm2},
  114. want: false,
  115. },
  116. {
  117. name: "s1-url",
  118. s1: &Status{URL: "foo", NetMap: nm1},
  119. s2: &Status{NetMap: nm2},
  120. want: false,
  121. },
  122. {
  123. name: "s1-persist-diff",
  124. s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1},
  125. s2: &Status{NetMap: nm2},
  126. want: false,
  127. },
  128. {
  129. name: "s1-login-finished-diff",
  130. s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
  131. s2: &Status{NetMap: nm2},
  132. want: false,
  133. },
  134. {
  135. name: "s1-login-finished",
  136. s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
  137. s2: &Status{NetMap: nm2},
  138. want: false,
  139. },
  140. {
  141. name: "s1-synced-diff",
  142. s1: &Status{InMapPoll: true, LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
  143. s2: &Status{NetMap: nm2},
  144. want: false,
  145. },
  146. {
  147. name: "s1-no-netmap1",
  148. s1: &Status{NetMap: nil},
  149. s2: &Status{NetMap: nm2},
  150. want: false,
  151. },
  152. {
  153. name: "s1-no-netmap2",
  154. s1: &Status{NetMap: nm1},
  155. s2: &Status{NetMap: nil},
  156. want: false,
  157. },
  158. {
  159. name: "skip",
  160. s1: &Status{NetMap: nm1},
  161. s2: &Status{NetMap: nm2},
  162. want: true,
  163. },
  164. }
  165. for _, tt := range tests {
  166. t.Run(tt.name, func(t *testing.T) {
  167. if got := canSkipStatus(tt.s1, tt.s2); got != tt.want {
  168. t.Errorf("canSkipStatus = %v, want %v", got, tt.want)
  169. }
  170. })
  171. }
  172. coveredFields := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
  173. if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, coveredFields) {
  174. t.Errorf("Status fields = %q; this code was only written to handle fields %q", have, coveredFields)
  175. }
  176. }
  177. func TestRetryableErrors(t *testing.T) {
  178. errorTests := []struct {
  179. err error
  180. want bool
  181. }{
  182. {errNoNoiseClient, true},
  183. {errNoNodeKey, true},
  184. {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true},
  185. {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true},
  186. {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true},
  187. {errBadHTTPResponse(429, "too may requests"), true},
  188. {errBadHTTPResponse(500, "internal server eror"), true},
  189. {errBadHTTPResponse(502, "bad gateway"), true},
  190. {errBadHTTPResponse(503, "service unavailable"), true},
  191. {errBadHTTPResponse(504, "gateway timeout"), true},
  192. {errBadHTTPResponse(1234, "random error"), false},
  193. }
  194. for _, tt := range errorTests {
  195. t.Run(tt.err.Error(), func(t *testing.T) {
  196. if isRetryableErrorForTest(tt.err) != tt.want {
  197. t.Fatalf("retriable: got %v, want %v", tt.err, tt.want)
  198. }
  199. })
  200. }
  201. }
  202. type retryableForTest interface {
  203. Retryable() bool
  204. }
  205. func isRetryableErrorForTest(err error) bool {
  206. var ae retryableForTest
  207. if errors.As(err, &ae) {
  208. return ae.Retryable()
  209. }
  210. return false
  211. }
  212. var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
  213. func TestDirectProxyManual(t *testing.T) {
  214. if !*liveNetworkTest {
  215. t.Skip("skipping without --live-network-test")
  216. }
  217. bus := eventbustest.NewBus(t)
  218. dialer := &tsdial.Dialer{}
  219. dialer.SetNetMon(netmon.NewStatic())
  220. dialer.SetBus(bus)
  221. opts := Options{
  222. Persist: persist.Persist{},
  223. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  224. return key.NewMachine(), nil
  225. },
  226. ServerURL: "https://controlplane.tailscale.com",
  227. Clock: tstime.StdClock{},
  228. Hostinfo: &tailcfg.Hostinfo{
  229. BackendLogID: "test-backend-log-id",
  230. },
  231. DiscoPublicKey: key.NewDisco().Public(),
  232. Logf: t.Logf,
  233. HealthTracker: health.NewTracker(bus),
  234. PopBrowserURL: func(url string) {
  235. t.Logf("PopBrowserURL: %q", url)
  236. },
  237. Dialer: dialer,
  238. ControlKnobs: &controlknobs.Knobs{},
  239. Bus: bus,
  240. }
  241. d, err := NewDirect(opts)
  242. if err != nil {
  243. t.Fatalf("NewDirect: %v", err)
  244. }
  245. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  246. defer cancel()
  247. url, err := d.TryLogin(ctx, LoginEphemeral)
  248. if err != nil {
  249. t.Fatalf("TryLogin: %v", err)
  250. }
  251. t.Logf("URL: %q", url)
  252. }
  253. func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
  254. // TestTLSWithProxy verifies we can connect to the control plane via
  255. // an HTTPS proxy.
  256. func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
  257. func testHTTPS(t *testing.T, withProxy bool) {
  258. bakedroots.ResetForTest(t, tlstest.TestRootCA())
  259. bus := eventbustest.NewBus(t)
  260. controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
  261. if err != nil {
  262. t.Fatal(err)
  263. }
  264. defer controlLn.Close()
  265. proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServer.ServerTLSConfig())
  266. if err != nil {
  267. t.Fatal(err)
  268. }
  269. defer proxyLn.Close()
  270. const requiredAuthKey = "hunter2"
  271. const someUsername = "testuser"
  272. const somePassword = "testpass"
  273. testControl := &testcontrol.Server{
  274. Logf: tstest.WhileTestRunningLogger(t),
  275. RequireAuthKey: requiredAuthKey,
  276. }
  277. controlSrv := &http.Server{
  278. Handler: testControl,
  279. ErrorLog: logger.StdLogger(t.Logf),
  280. }
  281. go controlSrv.Serve(controlLn)
  282. const fakeControlIP = "1.2.3.4"
  283. const fakeProxyIP = "5.6.7.8"
  284. dialer := &tsdial.Dialer{}
  285. dialer.SetNetMon(netmon.NewStatic())
  286. dialer.SetBus(bus)
  287. dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
  288. host, _, err := net.SplitHostPort(addr)
  289. if err != nil {
  290. return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
  291. }
  292. var d net.Dialer
  293. if host == fakeControlIP {
  294. return d.DialContext(ctx, network, controlLn.Addr().String())
  295. }
  296. if host == fakeProxyIP {
  297. return d.DialContext(ctx, network, proxyLn.Addr().String())
  298. }
  299. return nil, fmt.Errorf("unexpected dial to %q", addr)
  300. })
  301. opts := Options{
  302. Persist: persist.Persist{},
  303. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  304. return key.NewMachine(), nil
  305. },
  306. AuthKey: requiredAuthKey,
  307. ServerURL: "https://controlplane.tstest",
  308. Clock: tstime.StdClock{},
  309. Hostinfo: &tailcfg.Hostinfo{
  310. BackendLogID: "test-backend-log-id",
  311. },
  312. DiscoPublicKey: key.NewDisco().Public(),
  313. Logf: t.Logf,
  314. HealthTracker: health.NewTracker(bus),
  315. PopBrowserURL: func(url string) {
  316. t.Logf("PopBrowserURL: %q", url)
  317. },
  318. Dialer: dialer,
  319. Bus: bus,
  320. }
  321. d, err := NewDirect(opts)
  322. if err != nil {
  323. t.Fatalf("NewDirect: %v", err)
  324. }
  325. d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
  326. switch host {
  327. case "controlplane.tstest":
  328. return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
  329. case "proxy.tstest":
  330. if !withProxy {
  331. t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
  332. return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
  333. }
  334. return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
  335. }
  336. t.Errorf("unexpected DNS query for %q", host)
  337. return []netip.Addr{}, nil
  338. }
  339. var proxyReqs atomic.Int64
  340. if withProxy {
  341. d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
  342. t.Logf("using proxy for %q", req.URL)
  343. u := &url.URL{
  344. Scheme: "https",
  345. Host: "proxy.tstest:443",
  346. User: url.UserPassword(someUsername, somePassword),
  347. }
  348. return u, nil
  349. }
  350. connectProxy := &http.Server{
  351. Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
  352. }
  353. go connectProxy.Serve(proxyLn)
  354. }
  355. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  356. defer cancel()
  357. url, err := d.TryLogin(ctx, LoginEphemeral)
  358. if err != nil {
  359. t.Fatalf("TryLogin: %v", err)
  360. }
  361. if url != "" {
  362. t.Errorf("got URL %q, want empty", url)
  363. }
  364. if withProxy {
  365. if got, want := proxyReqs.Load(), int64(1); got != want {
  366. t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
  367. }
  368. }
  369. }
  370. func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
  371. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  372. if r.RequestURI != target {
  373. t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
  374. http.Error(w, "bad target", http.StatusBadRequest)
  375. return
  376. }
  377. r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
  378. user, pass, ok := r.BasicAuth()
  379. if !ok || user != "testuser" || pass != "testpass" {
  380. t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
  381. http.Error(w, "bad auth", http.StatusUnauthorized)
  382. return
  383. }
  384. (&connectproxy.Handler{
  385. Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
  386. var d net.Dialer
  387. c, err := d.DialContext(ctx, network, backendAddrPort)
  388. if err == nil {
  389. reqs.Add(1)
  390. }
  391. return c, err
  392. },
  393. Logf: t.Logf,
  394. }).ServeHTTP(w, r)
  395. })
  396. }