direct_test.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Copyright (c) Tailscale Inc & contributors
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlclient
  4. import (
  5. "encoding/json"
  6. "net/http"
  7. "net/http/httptest"
  8. "net/netip"
  9. "testing"
  10. "time"
  11. "tailscale.com/hostinfo"
  12. "tailscale.com/ipn/ipnstate"
  13. "tailscale.com/net/netmon"
  14. "tailscale.com/net/tsdial"
  15. "tailscale.com/tailcfg"
  16. "tailscale.com/types/key"
  17. "tailscale.com/util/eventbus/eventbustest"
  18. )
  19. func TestSetDiscoPublicKey(t *testing.T) {
  20. initialKey := key.NewDisco().Public()
  21. c := &Direct{
  22. discoPubKey: initialKey,
  23. }
  24. c.mu.Lock()
  25. if c.discoPubKey != initialKey {
  26. t.Fatalf("initial disco key mismatch: got %v, want %v", c.discoPubKey, initialKey)
  27. }
  28. c.mu.Unlock()
  29. newKey := key.NewDisco().Public()
  30. c.SetDiscoPublicKey(newKey)
  31. c.mu.Lock()
  32. if c.discoPubKey != newKey {
  33. t.Fatalf("disco key not updated: got %v, want %v", c.discoPubKey, newKey)
  34. }
  35. if c.discoPubKey == initialKey {
  36. t.Fatal("disco key should have changed")
  37. }
  38. c.mu.Unlock()
  39. }
  40. func TestNewDirect(t *testing.T) {
  41. hi := hostinfo.New()
  42. ni := tailcfg.NetInfo{LinkType: "wired"}
  43. hi.NetInfo = &ni
  44. bus := eventbustest.NewBus(t)
  45. k := key.NewMachine()
  46. dialer := tsdial.NewDialer(netmon.NewStatic())
  47. dialer.SetBus(bus)
  48. opts := Options{
  49. ServerURL: "https://example.com",
  50. Hostinfo: hi,
  51. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  52. return k, nil
  53. },
  54. Dialer: dialer,
  55. Bus: bus,
  56. }
  57. c, err := NewDirect(opts)
  58. if err != nil {
  59. t.Fatal(err)
  60. }
  61. if c.serverURL != opts.ServerURL {
  62. t.Errorf("c.serverURL got %v want %v", c.serverURL, opts.ServerURL)
  63. }
  64. // hi is stored without its NetInfo field.
  65. hiWithoutNi := *hi
  66. hiWithoutNi.NetInfo = nil
  67. if !hiWithoutNi.Equal(c.hostinfo) {
  68. t.Errorf("c.hostinfo got %v want %v", c.hostinfo, hi)
  69. }
  70. changed := c.SetNetInfo(&ni)
  71. if changed {
  72. t.Errorf("c.SetNetInfo(ni) want false got %v", changed)
  73. }
  74. ni = tailcfg.NetInfo{LinkType: "wifi"}
  75. changed = c.SetNetInfo(&ni)
  76. if !changed {
  77. t.Errorf("c.SetNetInfo(ni) want true got %v", changed)
  78. }
  79. changed = c.SetHostinfo(hi)
  80. if changed {
  81. t.Errorf("c.SetHostinfo(hi) want false got %v", changed)
  82. }
  83. hi = hostinfo.New()
  84. hi.Hostname = "different host name"
  85. changed = c.SetHostinfo(hi)
  86. if !changed {
  87. t.Errorf("c.SetHostinfo(hi) want true got %v", changed)
  88. }
  89. endpoints := fakeEndpoints(1, 2, 3)
  90. changed = c.newEndpoints(endpoints)
  91. if !changed {
  92. t.Errorf("c.newEndpoints want true got %v", changed)
  93. }
  94. changed = c.newEndpoints(endpoints)
  95. if changed {
  96. t.Errorf("c.newEndpoints want false got %v", changed)
  97. }
  98. endpoints = fakeEndpoints(4, 5, 6)
  99. changed = c.newEndpoints(endpoints)
  100. if !changed {
  101. t.Errorf("c.newEndpoints want true got %v", changed)
  102. }
  103. }
  104. func fakeEndpoints(ports ...uint16) (ret []tailcfg.Endpoint) {
  105. for _, port := range ports {
  106. ret = append(ret, tailcfg.Endpoint{
  107. Addr: netip.AddrPortFrom(netip.Addr{}, port),
  108. })
  109. }
  110. return
  111. }
  112. func TestTsmpPing(t *testing.T) {
  113. hi := hostinfo.New()
  114. ni := tailcfg.NetInfo{LinkType: "wired"}
  115. hi.NetInfo = &ni
  116. bus := eventbustest.NewBus(t)
  117. k := key.NewMachine()
  118. dialer := tsdial.NewDialer(netmon.NewStatic())
  119. dialer.SetBus(bus)
  120. opts := Options{
  121. ServerURL: "https://example.com",
  122. Hostinfo: hi,
  123. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  124. return k, nil
  125. },
  126. Dialer: dialer,
  127. Bus: bus,
  128. }
  129. c, err := NewDirect(opts)
  130. if err != nil {
  131. t.Fatal(err)
  132. }
  133. pingRes := &tailcfg.PingResponse{
  134. Type: "TSMP",
  135. IP: "123.456.7890",
  136. Err: "",
  137. NodeName: "testnode",
  138. }
  139. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  140. defer r.Body.Close()
  141. body := new(ipnstate.PingResult)
  142. if err := json.NewDecoder(r.Body).Decode(body); err != nil {
  143. t.Fatal(err)
  144. }
  145. if pingRes.IP != body.IP {
  146. t.Fatalf("PingResult did not have the correct IP : got %v, expected : %v", body.IP, pingRes.IP)
  147. }
  148. w.WriteHeader(200)
  149. }))
  150. defer ts.Close()
  151. now := time.Now()
  152. pr := &tailcfg.PingRequest{
  153. URL: ts.URL,
  154. }
  155. err = postPingResult(now, t.Logf, c.httpc, pr, pingRes)
  156. if err != nil {
  157. t.Fatal(err)
  158. }
  159. }