controlclient_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlclient
  4. import (
  5. "context"
  6. "crypto/tls"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "net"
  12. "net/http"
  13. "net/netip"
  14. "net/url"
  15. "reflect"
  16. "slices"
  17. "sync/atomic"
  18. "testing"
  19. "time"
  20. "tailscale.com/control/controlknobs"
  21. "tailscale.com/health"
  22. "tailscale.com/net/bakedroots"
  23. "tailscale.com/net/connectproxy"
  24. "tailscale.com/net/netmon"
  25. "tailscale.com/net/tsdial"
  26. "tailscale.com/tailcfg"
  27. "tailscale.com/tstest"
  28. "tailscale.com/tstest/integration/testcontrol"
  29. "tailscale.com/tstest/tlstest"
  30. "tailscale.com/tstime"
  31. "tailscale.com/types/key"
  32. "tailscale.com/types/logger"
  33. "tailscale.com/types/netmap"
  34. "tailscale.com/types/persist"
  35. "tailscale.com/util/eventbus/eventbustest"
  36. )
  37. func fieldsOf(t reflect.Type) (fields []string) {
  38. for i := range t.NumField() {
  39. if name := t.Field(i).Name; name != "_" {
  40. fields = append(fields, name)
  41. }
  42. }
  43. return
  44. }
  45. func TestStatusEqual(t *testing.T) {
  46. // Verify that the Equal method stays in sync with reality
  47. equalHandles := []string{"Err", "URL", "NetMap", "Persist", "state"}
  48. if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) {
  49. t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
  50. have, equalHandles)
  51. }
  52. tests := []struct {
  53. a, b *Status
  54. want bool
  55. }{
  56. {
  57. &Status{},
  58. nil,
  59. false,
  60. },
  61. {
  62. nil,
  63. &Status{},
  64. false,
  65. },
  66. {
  67. nil,
  68. nil,
  69. true,
  70. },
  71. {
  72. &Status{},
  73. &Status{},
  74. true,
  75. },
  76. {
  77. &Status{},
  78. &Status{state: StateAuthenticated},
  79. false,
  80. },
  81. }
  82. for i, tt := range tests {
  83. got := tt.a.Equal(tt.b)
  84. if got != tt.want {
  85. t.Errorf("%d. Equal = %v; want %v", i, got, tt.want)
  86. }
  87. }
  88. }
  89. // tests [canSkipStatus].
  90. func TestCanSkipStatus(t *testing.T) {
  91. st := new(Status)
  92. nm1 := &netmap.NetworkMap{}
  93. nm2 := &netmap.NetworkMap{}
  94. tests := []struct {
  95. name string
  96. s1, s2 *Status
  97. want bool
  98. }{
  99. {
  100. name: "nil-s2",
  101. s1: st,
  102. s2: nil,
  103. want: false,
  104. },
  105. {
  106. name: "equal",
  107. s1: st,
  108. s2: st,
  109. want: false,
  110. },
  111. {
  112. name: "s1-error",
  113. s1: &Status{Err: io.EOF, NetMap: nm1},
  114. s2: &Status{NetMap: nm2},
  115. want: false,
  116. },
  117. {
  118. name: "s1-url",
  119. s1: &Status{URL: "foo", NetMap: nm1},
  120. s2: &Status{NetMap: nm2},
  121. want: false,
  122. },
  123. {
  124. name: "s1-persist-diff",
  125. s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1},
  126. s2: &Status{NetMap: nm2},
  127. want: false,
  128. },
  129. {
  130. name: "s1-state-diff",
  131. s1: &Status{state: 123, NetMap: nm1},
  132. s2: &Status{NetMap: nm2},
  133. want: false,
  134. },
  135. {
  136. name: "s1-no-netmap1",
  137. s1: &Status{NetMap: nil},
  138. s2: &Status{NetMap: nm2},
  139. want: false,
  140. },
  141. {
  142. name: "s1-no-netmap2",
  143. s1: &Status{NetMap: nm1},
  144. s2: &Status{NetMap: nil},
  145. want: false,
  146. },
  147. {
  148. name: "skip",
  149. s1: &Status{NetMap: nm1},
  150. s2: &Status{NetMap: nm2},
  151. want: true,
  152. },
  153. }
  154. for _, tt := range tests {
  155. t.Run(tt.name, func(t *testing.T) {
  156. if got := canSkipStatus(tt.s1, tt.s2); got != tt.want {
  157. t.Errorf("canSkipStatus = %v, want %v", got, tt.want)
  158. }
  159. })
  160. }
  161. want := []string{"Err", "URL", "NetMap", "Persist", "state"}
  162. if f := fieldsOf(reflect.TypeFor[Status]()); !slices.Equal(f, want) {
  163. t.Errorf("Status fields = %q; this code was only written to handle fields %q", f, want)
  164. }
  165. }
  166. func TestRetryableErrors(t *testing.T) {
  167. errorTests := []struct {
  168. err error
  169. want bool
  170. }{
  171. {errNoNoiseClient, true},
  172. {errNoNodeKey, true},
  173. {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true},
  174. {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true},
  175. {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true},
  176. {errBadHTTPResponse(429, "too may requests"), true},
  177. {errBadHTTPResponse(500, "internal server eror"), true},
  178. {errBadHTTPResponse(502, "bad gateway"), true},
  179. {errBadHTTPResponse(503, "service unavailable"), true},
  180. {errBadHTTPResponse(504, "gateway timeout"), true},
  181. {errBadHTTPResponse(1234, "random error"), false},
  182. }
  183. for _, tt := range errorTests {
  184. t.Run(tt.err.Error(), func(t *testing.T) {
  185. if isRetryableErrorForTest(tt.err) != tt.want {
  186. t.Fatalf("retriable: got %v, want %v", tt.err, tt.want)
  187. }
  188. })
  189. }
  190. }
  191. type retryableForTest interface {
  192. Retryable() bool
  193. }
  194. func isRetryableErrorForTest(err error) bool {
  195. var ae retryableForTest
  196. if errors.As(err, &ae) {
  197. return ae.Retryable()
  198. }
  199. return false
  200. }
  201. var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
  202. func TestDirectProxyManual(t *testing.T) {
  203. if !*liveNetworkTest {
  204. t.Skip("skipping without --live-network-test")
  205. }
  206. bus := eventbustest.NewBus(t)
  207. dialer := &tsdial.Dialer{}
  208. dialer.SetNetMon(netmon.NewStatic())
  209. opts := Options{
  210. Persist: persist.Persist{},
  211. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  212. return key.NewMachine(), nil
  213. },
  214. ServerURL: "https://controlplane.tailscale.com",
  215. Clock: tstime.StdClock{},
  216. Hostinfo: &tailcfg.Hostinfo{
  217. BackendLogID: "test-backend-log-id",
  218. },
  219. DiscoPublicKey: key.NewDisco().Public(),
  220. Logf: t.Logf,
  221. HealthTracker: health.NewTracker(bus),
  222. PopBrowserURL: func(url string) {
  223. t.Logf("PopBrowserURL: %q", url)
  224. },
  225. Dialer: dialer,
  226. ControlKnobs: &controlknobs.Knobs{},
  227. Bus: bus,
  228. }
  229. d, err := NewDirect(opts)
  230. if err != nil {
  231. t.Fatalf("NewDirect: %v", err)
  232. }
  233. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  234. defer cancel()
  235. url, err := d.TryLogin(ctx, LoginEphemeral)
  236. if err != nil {
  237. t.Fatalf("TryLogin: %v", err)
  238. }
  239. t.Logf("URL: %q", url)
  240. }
  241. func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
  242. // TestTLSWithProxy verifies we can connect to the control plane via
  243. // an HTTPS proxy.
  244. func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
  245. func testHTTPS(t *testing.T, withProxy bool) {
  246. bakedroots.ResetForTest(t, tlstest.TestRootCA())
  247. bus := eventbustest.NewBus(t)
  248. controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
  249. if err != nil {
  250. t.Fatal(err)
  251. }
  252. defer controlLn.Close()
  253. proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServer.ServerTLSConfig())
  254. if err != nil {
  255. t.Fatal(err)
  256. }
  257. defer proxyLn.Close()
  258. const requiredAuthKey = "hunter2"
  259. const someUsername = "testuser"
  260. const somePassword = "testpass"
  261. testControl := &testcontrol.Server{
  262. Logf: tstest.WhileTestRunningLogger(t),
  263. RequireAuthKey: requiredAuthKey,
  264. }
  265. controlSrv := &http.Server{
  266. Handler: testControl,
  267. ErrorLog: logger.StdLogger(t.Logf),
  268. }
  269. go controlSrv.Serve(controlLn)
  270. const fakeControlIP = "1.2.3.4"
  271. const fakeProxyIP = "5.6.7.8"
  272. dialer := &tsdial.Dialer{}
  273. dialer.SetNetMon(netmon.NewStatic())
  274. dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
  275. host, _, err := net.SplitHostPort(addr)
  276. if err != nil {
  277. return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
  278. }
  279. var d net.Dialer
  280. if host == fakeControlIP {
  281. return d.DialContext(ctx, network, controlLn.Addr().String())
  282. }
  283. if host == fakeProxyIP {
  284. return d.DialContext(ctx, network, proxyLn.Addr().String())
  285. }
  286. return nil, fmt.Errorf("unexpected dial to %q", addr)
  287. })
  288. opts := Options{
  289. Persist: persist.Persist{},
  290. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  291. return key.NewMachine(), nil
  292. },
  293. AuthKey: requiredAuthKey,
  294. ServerURL: "https://controlplane.tstest",
  295. Clock: tstime.StdClock{},
  296. Hostinfo: &tailcfg.Hostinfo{
  297. BackendLogID: "test-backend-log-id",
  298. },
  299. DiscoPublicKey: key.NewDisco().Public(),
  300. Logf: t.Logf,
  301. HealthTracker: health.NewTracker(bus),
  302. PopBrowserURL: func(url string) {
  303. t.Logf("PopBrowserURL: %q", url)
  304. },
  305. Dialer: dialer,
  306. Bus: bus,
  307. }
  308. d, err := NewDirect(opts)
  309. if err != nil {
  310. t.Fatalf("NewDirect: %v", err)
  311. }
  312. d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
  313. switch host {
  314. case "controlplane.tstest":
  315. return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
  316. case "proxy.tstest":
  317. if !withProxy {
  318. t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
  319. return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
  320. }
  321. return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
  322. }
  323. t.Errorf("unexpected DNS query for %q", host)
  324. return []netip.Addr{}, nil
  325. }
  326. var proxyReqs atomic.Int64
  327. if withProxy {
  328. d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
  329. t.Logf("using proxy for %q", req.URL)
  330. u := &url.URL{
  331. Scheme: "https",
  332. Host: "proxy.tstest:443",
  333. User: url.UserPassword(someUsername, somePassword),
  334. }
  335. return u, nil
  336. }
  337. connectProxy := &http.Server{
  338. Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
  339. }
  340. go connectProxy.Serve(proxyLn)
  341. }
  342. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  343. defer cancel()
  344. url, err := d.TryLogin(ctx, LoginEphemeral)
  345. if err != nil {
  346. t.Fatalf("TryLogin: %v", err)
  347. }
  348. if url != "" {
  349. t.Errorf("got URL %q, want empty", url)
  350. }
  351. if withProxy {
  352. if got, want := proxyReqs.Load(), int64(1); got != want {
  353. t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
  354. }
  355. }
  356. }
  357. func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
  358. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  359. if r.RequestURI != target {
  360. t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
  361. http.Error(w, "bad target", http.StatusBadRequest)
  362. return
  363. }
  364. r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
  365. user, pass, ok := r.BasicAuth()
  366. if !ok || user != "testuser" || pass != "testpass" {
  367. t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
  368. http.Error(w, "bad auth", http.StatusUnauthorized)
  369. return
  370. }
  371. (&connectproxy.Handler{
  372. Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
  373. var d net.Dialer
  374. c, err := d.DialContext(ctx, network, backendAddrPort)
  375. if err == nil {
  376. reqs.Add(1)
  377. }
  378. return c, err
  379. },
  380. Logf: t.Logf,
  381. }).ServeHTTP(w, r)
  382. })
  383. }