| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- // Copyright (c) Tailscale Inc & contributors
- // SPDX-License-Identifier: BSD-3-Clause
- package controlclient
- import (
- "context"
- "crypto/tls"
- "errors"
- "flag"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/netip"
- "net/url"
- "reflect"
- "sync/atomic"
- "testing"
- "time"
- "tailscale.com/control/controlknobs"
- "tailscale.com/health"
- "tailscale.com/net/bakedroots"
- "tailscale.com/net/connectproxy"
- "tailscale.com/net/netmon"
- "tailscale.com/net/tsdial"
- "tailscale.com/tailcfg"
- "tailscale.com/tstest"
- "tailscale.com/tstest/integration/testcontrol"
- "tailscale.com/tstest/tlstest"
- "tailscale.com/tstime"
- "tailscale.com/types/key"
- "tailscale.com/types/logger"
- "tailscale.com/types/netmap"
- "tailscale.com/types/persist"
- "tailscale.com/util/eventbus/eventbustest"
- )
- func fieldsOf(t reflect.Type) (fields []string) {
- for i := range t.NumField() {
- if name := t.Field(i).Name; name != "_" {
- fields = append(fields, name)
- }
- }
- return
- }
- func TestStatusEqual(t *testing.T) {
- // Verify that the Equal method stays in sync with reality
- equalHandles := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
- if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) {
- t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
- have, equalHandles)
- }
- tests := []struct {
- a, b *Status
- want bool
- }{
- {
- &Status{},
- nil,
- false,
- },
- {
- nil,
- &Status{},
- false,
- },
- {
- nil,
- nil,
- true,
- },
- {
- &Status{},
- &Status{},
- true,
- },
- {
- &Status{},
- &Status{LoggedIn: true, Persist: new(persist.Persist).View()},
- false,
- },
- }
- for i, tt := range tests {
- got := tt.a.Equal(tt.b)
- if got != tt.want {
- t.Errorf("%d. Equal = %v; want %v", i, got, tt.want)
- }
- }
- }
- // tests [canSkipStatus].
- func TestCanSkipStatus(t *testing.T) {
- st := new(Status)
- nm1 := &netmap.NetworkMap{}
- nm2 := &netmap.NetworkMap{}
- commonPersist := new(persist.Persist).View()
- tests := []struct {
- name string
- s1, s2 *Status
- want bool
- }{
- {
- name: "nil-s2",
- s1: st,
- s2: nil,
- want: false,
- },
- {
- name: "equal",
- s1: st,
- s2: st,
- want: false,
- },
- {
- name: "s1-error",
- s1: &Status{Err: io.EOF, NetMap: nm1},
- s2: &Status{NetMap: nm2},
- want: false,
- },
- {
- name: "s1-url",
- s1: &Status{URL: "foo", NetMap: nm1},
- s2: &Status{NetMap: nm2},
- want: false,
- },
- {
- name: "s1-persist-diff",
- s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1},
- s2: &Status{NetMap: nm2},
- want: false,
- },
- {
- name: "s1-login-finished-diff",
- s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
- s2: &Status{NetMap: nm2},
- want: false,
- },
- {
- name: "s1-login-finished",
- s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
- s2: &Status{NetMap: nm2},
- want: false,
- },
- {
- name: "s1-synced-diff",
- s1: &Status{InMapPoll: true, LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
- s2: &Status{NetMap: nm2},
- want: false,
- },
- {
- name: "s1-no-netmap1",
- s1: &Status{NetMap: nil},
- s2: &Status{NetMap: nm2},
- want: false,
- },
- {
- name: "s1-no-netmap2",
- s1: &Status{NetMap: nm1},
- s2: &Status{NetMap: nil},
- want: false,
- },
- {
- name: "skip",
- s1: &Status{NetMap: nm1, LoggedIn: true, InMapPoll: true, Persist: commonPersist},
- s2: &Status{NetMap: nm2, LoggedIn: true, InMapPoll: true, Persist: commonPersist},
- want: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := canSkipStatus(tt.s1, tt.s2); got != tt.want {
- t.Errorf("canSkipStatus = %v, want %v", got, tt.want)
- }
- })
- }
- coveredFields := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
- if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, coveredFields) {
- t.Errorf("Status fields = %q; this code was only written to handle fields %q", have, coveredFields)
- }
- }
- func TestRetryableErrors(t *testing.T) {
- errorTests := []struct {
- err error
- want bool
- }{
- {errNoNoiseClient, true},
- {errNoNodeKey, true},
- {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true},
- {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true},
- {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true},
- {errBadHTTPResponse(429, "too may requests"), true},
- {errBadHTTPResponse(500, "internal server error"), true},
- {errBadHTTPResponse(502, "bad gateway"), true},
- {errBadHTTPResponse(503, "service unavailable"), true},
- {errBadHTTPResponse(504, "gateway timeout"), true},
- {errBadHTTPResponse(1234, "random error"), false},
- }
- for _, tt := range errorTests {
- t.Run(tt.err.Error(), func(t *testing.T) {
- if isRetryableErrorForTest(tt.err) != tt.want {
- t.Fatalf("retriable: got %v, want %v", tt.err, tt.want)
- }
- })
- }
- }
- type retryableForTest interface {
- Retryable() bool
- }
- func isRetryableErrorForTest(err error) bool {
- var ae retryableForTest
- if errors.As(err, &ae) {
- return ae.Retryable()
- }
- return false
- }
- var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
- func TestDirectProxyManual(t *testing.T) {
- if !*liveNetworkTest {
- t.Skip("skipping without --live-network-test")
- }
- bus := eventbustest.NewBus(t)
- dialer := &tsdial.Dialer{}
- dialer.SetNetMon(netmon.NewStatic())
- dialer.SetBus(bus)
- opts := Options{
- Persist: persist.Persist{},
- GetMachinePrivateKey: func() (key.MachinePrivate, error) {
- return key.NewMachine(), nil
- },
- ServerURL: "https://controlplane.tailscale.com",
- Clock: tstime.StdClock{},
- Hostinfo: &tailcfg.Hostinfo{
- BackendLogID: "test-backend-log-id",
- },
- DiscoPublicKey: key.NewDisco().Public(),
- Logf: t.Logf,
- HealthTracker: health.NewTracker(bus),
- PopBrowserURL: func(url string) {
- t.Logf("PopBrowserURL: %q", url)
- },
- Dialer: dialer,
- ControlKnobs: &controlknobs.Knobs{},
- Bus: bus,
- }
- d, err := NewDirect(opts)
- if err != nil {
- t.Fatalf("NewDirect: %v", err)
- }
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- url, err := d.TryLogin(ctx, LoginEphemeral)
- if err != nil {
- t.Fatalf("TryLogin: %v", err)
- }
- t.Logf("URL: %q", url)
- }
- func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
- // TestTLSWithProxy verifies we can connect to the control plane via
- // an HTTPS proxy.
- func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
- func testHTTPS(t *testing.T, withProxy bool) {
- bakedroots.ResetForTest(t, tlstest.TestRootCA())
- bus := eventbustest.NewBus(t)
- controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig())
- if err != nil {
- t.Fatal(err)
- }
- defer controlLn.Close()
- proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServer.ServerTLSConfig())
- if err != nil {
- t.Fatal(err)
- }
- defer proxyLn.Close()
- const requiredAuthKey = "hunter2"
- const someUsername = "testuser"
- const somePassword = "testpass"
- testControl := &testcontrol.Server{
- Logf: tstest.WhileTestRunningLogger(t),
- RequireAuthKey: requiredAuthKey,
- }
- controlSrv := &http.Server{
- Handler: testControl,
- ErrorLog: logger.StdLogger(t.Logf),
- }
- go controlSrv.Serve(controlLn)
- const fakeControlIP = "1.2.3.4"
- const fakeProxyIP = "5.6.7.8"
- dialer := &tsdial.Dialer{}
- dialer.SetNetMon(netmon.NewStatic())
- dialer.SetBus(bus)
- dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
- host, _, err := net.SplitHostPort(addr)
- if err != nil {
- return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
- }
- var d net.Dialer
- if host == fakeControlIP {
- return d.DialContext(ctx, network, controlLn.Addr().String())
- }
- if host == fakeProxyIP {
- return d.DialContext(ctx, network, proxyLn.Addr().String())
- }
- return nil, fmt.Errorf("unexpected dial to %q", addr)
- })
- opts := Options{
- Persist: persist.Persist{},
- GetMachinePrivateKey: func() (key.MachinePrivate, error) {
- return key.NewMachine(), nil
- },
- AuthKey: requiredAuthKey,
- ServerURL: "https://controlplane.tstest",
- Clock: tstime.StdClock{},
- Hostinfo: &tailcfg.Hostinfo{
- BackendLogID: "test-backend-log-id",
- },
- DiscoPublicKey: key.NewDisco().Public(),
- Logf: t.Logf,
- HealthTracker: health.NewTracker(bus),
- PopBrowserURL: func(url string) {
- t.Logf("PopBrowserURL: %q", url)
- },
- Dialer: dialer,
- Bus: bus,
- }
- d, err := NewDirect(opts)
- if err != nil {
- t.Fatalf("NewDirect: %v", err)
- }
- d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
- switch host {
- case "controlplane.tstest":
- return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
- case "proxy.tstest":
- if !withProxy {
- t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
- return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
- }
- return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
- }
- t.Errorf("unexpected DNS query for %q", host)
- return []netip.Addr{}, nil
- }
- var proxyReqs atomic.Int64
- if withProxy {
- d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
- t.Logf("using proxy for %q", req.URL)
- u := &url.URL{
- Scheme: "https",
- Host: "proxy.tstest:443",
- User: url.UserPassword(someUsername, somePassword),
- }
- return u, nil
- }
- connectProxy := &http.Server{
- Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
- }
- go connectProxy.Serve(proxyLn)
- }
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- url, err := d.TryLogin(ctx, LoginEphemeral)
- if err != nil {
- t.Fatalf("TryLogin: %v", err)
- }
- if url != "" {
- t.Errorf("got URL %q, want empty", url)
- }
- if withProxy {
- if got, want := proxyReqs.Load(), int64(1); got != want {
- t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
- }
- }
- }
- func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI != target {
- t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
- http.Error(w, "bad target", http.StatusBadRequest)
- return
- }
- r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
- user, pass, ok := r.BasicAuth()
- if !ok || user != "testuser" || pass != "testpass" {
- t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
- http.Error(w, "bad auth", http.StatusUnauthorized)
- return
- }
- (&connectproxy.Handler{
- Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
- var d net.Dialer
- c, err := d.DialContext(ctx, network, backendAddrPort)
- if err == nil {
- reqs.Add(1)
- }
- return c, err
- },
- Logf: t.Logf,
- }).ServeHTTP(w, r)
- })
- }
|