controlclient_test.go 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlclient
  4. import (
  5. "context"
  6. "crypto/tls"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "net"
  12. "net/http"
  13. "net/netip"
  14. "net/url"
  15. "reflect"
  16. "slices"
  17. "sync/atomic"
  18. "testing"
  19. "time"
  20. "tailscale.com/control/controlknobs"
  21. "tailscale.com/health"
  22. "tailscale.com/net/bakedroots"
  23. "tailscale.com/net/connectproxy"
  24. "tailscale.com/net/netmon"
  25. "tailscale.com/net/tsdial"
  26. "tailscale.com/tailcfg"
  27. "tailscale.com/tstest"
  28. "tailscale.com/tstest/integration/testcontrol"
  29. "tailscale.com/tstest/tlstest"
  30. "tailscale.com/tstime"
  31. "tailscale.com/types/key"
  32. "tailscale.com/types/logger"
  33. "tailscale.com/types/netmap"
  34. "tailscale.com/types/persist"
  35. )
  36. func fieldsOf(t reflect.Type) (fields []string) {
  37. for i := range t.NumField() {
  38. if name := t.Field(i).Name; name != "_" {
  39. fields = append(fields, name)
  40. }
  41. }
  42. return
  43. }
  44. func TestStatusEqual(t *testing.T) {
  45. // Verify that the Equal method stays in sync with reality
  46. equalHandles := []string{"Err", "URL", "NetMap", "Persist", "state"}
  47. if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) {
  48. t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
  49. have, equalHandles)
  50. }
  51. tests := []struct {
  52. a, b *Status
  53. want bool
  54. }{
  55. {
  56. &Status{},
  57. nil,
  58. false,
  59. },
  60. {
  61. nil,
  62. &Status{},
  63. false,
  64. },
  65. {
  66. nil,
  67. nil,
  68. true,
  69. },
  70. {
  71. &Status{},
  72. &Status{},
  73. true,
  74. },
  75. {
  76. &Status{},
  77. &Status{state: StateAuthenticated},
  78. false,
  79. },
  80. }
  81. for i, tt := range tests {
  82. got := tt.a.Equal(tt.b)
  83. if got != tt.want {
  84. t.Errorf("%d. Equal = %v; want %v", i, got, tt.want)
  85. }
  86. }
  87. }
  88. // tests [canSkipStatus].
  89. func TestCanSkipStatus(t *testing.T) {
  90. st := new(Status)
  91. nm1 := &netmap.NetworkMap{}
  92. nm2 := &netmap.NetworkMap{}
  93. tests := []struct {
  94. name string
  95. s1, s2 *Status
  96. want bool
  97. }{
  98. {
  99. name: "nil-s2",
  100. s1: st,
  101. s2: nil,
  102. want: false,
  103. },
  104. {
  105. name: "equal",
  106. s1: st,
  107. s2: st,
  108. want: false,
  109. },
  110. {
  111. name: "s1-error",
  112. s1: &Status{Err: io.EOF, NetMap: nm1},
  113. s2: &Status{NetMap: nm2},
  114. want: false,
  115. },
  116. {
  117. name: "s1-url",
  118. s1: &Status{URL: "foo", NetMap: nm1},
  119. s2: &Status{NetMap: nm2},
  120. want: false,
  121. },
  122. {
  123. name: "s1-persist-diff",
  124. s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1},
  125. s2: &Status{NetMap: nm2},
  126. want: false,
  127. },
  128. {
  129. name: "s1-state-diff",
  130. s1: &Status{state: 123, NetMap: nm1},
  131. s2: &Status{NetMap: nm2},
  132. want: false,
  133. },
  134. {
  135. name: "s1-no-netmap1",
  136. s1: &Status{NetMap: nil},
  137. s2: &Status{NetMap: nm2},
  138. want: false,
  139. },
  140. {
  141. name: "s1-no-netmap2",
  142. s1: &Status{NetMap: nm1},
  143. s2: &Status{NetMap: nil},
  144. want: false,
  145. },
  146. {
  147. name: "skip",
  148. s1: &Status{NetMap: nm1},
  149. s2: &Status{NetMap: nm2},
  150. want: true,
  151. },
  152. }
  153. for _, tt := range tests {
  154. t.Run(tt.name, func(t *testing.T) {
  155. if got := canSkipStatus(tt.s1, tt.s2); got != tt.want {
  156. t.Errorf("canSkipStatus = %v, want %v", got, tt.want)
  157. }
  158. })
  159. }
  160. want := []string{"Err", "URL", "NetMap", "Persist", "state"}
  161. if f := fieldsOf(reflect.TypeFor[Status]()); !slices.Equal(f, want) {
  162. t.Errorf("Status fields = %q; this code was only written to handle fields %q", f, want)
  163. }
  164. }
  165. func TestRetryableErrors(t *testing.T) {
  166. errorTests := []struct {
  167. err error
  168. want bool
  169. }{
  170. {errNoNoiseClient, true},
  171. {errNoNodeKey, true},
  172. {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true},
  173. {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true},
  174. {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true},
  175. {errBadHTTPResponse(429, "too may requests"), true},
  176. {errBadHTTPResponse(500, "internal server eror"), true},
  177. {errBadHTTPResponse(502, "bad gateway"), true},
  178. {errBadHTTPResponse(503, "service unavailable"), true},
  179. {errBadHTTPResponse(504, "gateway timeout"), true},
  180. {errBadHTTPResponse(1234, "random error"), false},
  181. }
  182. for _, tt := range errorTests {
  183. t.Run(tt.err.Error(), func(t *testing.T) {
  184. if isRetryableErrorForTest(tt.err) != tt.want {
  185. t.Fatalf("retriable: got %v, want %v", tt.err, tt.want)
  186. }
  187. })
  188. }
  189. }
  190. type retryableForTest interface {
  191. Retryable() bool
  192. }
  193. func isRetryableErrorForTest(err error) bool {
  194. var ae retryableForTest
  195. if errors.As(err, &ae) {
  196. return ae.Retryable()
  197. }
  198. return false
  199. }
  200. var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
  201. func TestDirectProxyManual(t *testing.T) {
  202. if !*liveNetworkTest {
  203. t.Skip("skipping without --live-network-test")
  204. }
  205. dialer := &tsdial.Dialer{}
  206. dialer.SetNetMon(netmon.NewStatic())
  207. opts := Options{
  208. Persist: persist.Persist{},
  209. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  210. return key.NewMachine(), nil
  211. },
  212. ServerURL: "https://controlplane.tailscale.com",
  213. Clock: tstime.StdClock{},
  214. Hostinfo: &tailcfg.Hostinfo{
  215. BackendLogID: "test-backend-log-id",
  216. },
  217. DiscoPublicKey: key.NewDisco().Public(),
  218. Logf: t.Logf,
  219. HealthTracker: &health.Tracker{},
  220. PopBrowserURL: func(url string) {
  221. t.Logf("PopBrowserURL: %q", url)
  222. },
  223. Dialer: dialer,
  224. ControlKnobs: &controlknobs.Knobs{},
  225. }
  226. d, err := NewDirect(opts)
  227. if err != nil {
  228. t.Fatalf("NewDirect: %v", err)
  229. }
  230. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  231. defer cancel()
  232. url, err := d.TryLogin(ctx, LoginEphemeral)
  233. if err != nil {
  234. t.Fatalf("TryLogin: %v", err)
  235. }
  236. t.Logf("URL: %q", url)
  237. }
  238. func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
  239. // TestTLSWithProxy verifies we can connect to the control plane via
  240. // an HTTPS proxy.
  241. func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
  242. func testHTTPS(t *testing.T, withProxy bool) {
  243. bakedroots.ResetForTest(t, tlstest.TestRootCA())
  244. controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
  245. if err != nil {
  246. t.Fatal(err)
  247. }
  248. defer controlLn.Close()
  249. proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServer.ServerTLSConfig())
  250. if err != nil {
  251. t.Fatal(err)
  252. }
  253. defer proxyLn.Close()
  254. const requiredAuthKey = "hunter2"
  255. const someUsername = "testuser"
  256. const somePassword = "testpass"
  257. testControl := &testcontrol.Server{
  258. Logf: tstest.WhileTestRunningLogger(t),
  259. RequireAuthKey: requiredAuthKey,
  260. }
  261. controlSrv := &http.Server{
  262. Handler: testControl,
  263. ErrorLog: logger.StdLogger(t.Logf),
  264. }
  265. go controlSrv.Serve(controlLn)
  266. const fakeControlIP = "1.2.3.4"
  267. const fakeProxyIP = "5.6.7.8"
  268. dialer := &tsdial.Dialer{}
  269. dialer.SetNetMon(netmon.NewStatic())
  270. dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
  271. host, _, err := net.SplitHostPort(addr)
  272. if err != nil {
  273. return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
  274. }
  275. var d net.Dialer
  276. if host == fakeControlIP {
  277. return d.DialContext(ctx, network, controlLn.Addr().String())
  278. }
  279. if host == fakeProxyIP {
  280. return d.DialContext(ctx, network, proxyLn.Addr().String())
  281. }
  282. return nil, fmt.Errorf("unexpected dial to %q", addr)
  283. })
  284. opts := Options{
  285. Persist: persist.Persist{},
  286. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  287. return key.NewMachine(), nil
  288. },
  289. AuthKey: requiredAuthKey,
  290. ServerURL: "https://controlplane.tstest",
  291. Clock: tstime.StdClock{},
  292. Hostinfo: &tailcfg.Hostinfo{
  293. BackendLogID: "test-backend-log-id",
  294. },
  295. DiscoPublicKey: key.NewDisco().Public(),
  296. Logf: t.Logf,
  297. HealthTracker: &health.Tracker{},
  298. PopBrowserURL: func(url string) {
  299. t.Logf("PopBrowserURL: %q", url)
  300. },
  301. Dialer: dialer,
  302. }
  303. d, err := NewDirect(opts)
  304. if err != nil {
  305. t.Fatalf("NewDirect: %v", err)
  306. }
  307. d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
  308. switch host {
  309. case "controlplane.tstest":
  310. return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
  311. case "proxy.tstest":
  312. if !withProxy {
  313. t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
  314. return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
  315. }
  316. return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
  317. }
  318. t.Errorf("unexpected DNS query for %q", host)
  319. return []netip.Addr{}, nil
  320. }
  321. var proxyReqs atomic.Int64
  322. if withProxy {
  323. d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
  324. t.Logf("using proxy for %q", req.URL)
  325. u := &url.URL{
  326. Scheme: "https",
  327. Host: "proxy.tstest:443",
  328. User: url.UserPassword(someUsername, somePassword),
  329. }
  330. return u, nil
  331. }
  332. connectProxy := &http.Server{
  333. Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
  334. }
  335. go connectProxy.Serve(proxyLn)
  336. }
  337. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  338. defer cancel()
  339. url, err := d.TryLogin(ctx, LoginEphemeral)
  340. if err != nil {
  341. t.Fatalf("TryLogin: %v", err)
  342. }
  343. if url != "" {
  344. t.Errorf("got URL %q, want empty", url)
  345. }
  346. if withProxy {
  347. if got, want := proxyReqs.Load(), int64(1); got != want {
  348. t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
  349. }
  350. }
  351. }
  352. func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
  353. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  354. if r.RequestURI != target {
  355. t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
  356. http.Error(w, "bad target", http.StatusBadRequest)
  357. return
  358. }
  359. r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
  360. user, pass, ok := r.BasicAuth()
  361. if !ok || user != "testuser" || pass != "testpass" {
  362. t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
  363. http.Error(w, "bad auth", http.StatusUnauthorized)
  364. return
  365. }
  366. (&connectproxy.Handler{
  367. Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
  368. var d net.Dialer
  369. c, err := d.DialContext(ctx, network, backendAddrPort)
  370. if err == nil {
  371. reqs.Add(1)
  372. }
  373. return c, err
  374. },
  375. Logf: t.Logf,
  376. }).ServeHTTP(w, r)
  377. })
  378. }