direct_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package controlclient
  4. import (
  5. "encoding/json"
  6. "net/http"
  7. "net/http/httptest"
  8. "net/netip"
  9. "testing"
  10. "time"
  11. "tailscale.com/hostinfo"
  12. "tailscale.com/ipn/ipnstate"
  13. "tailscale.com/net/netmon"
  14. "tailscale.com/net/tsdial"
  15. "tailscale.com/tailcfg"
  16. "tailscale.com/types/key"
  17. )
  18. func TestNewDirect(t *testing.T) {
  19. hi := hostinfo.New()
  20. ni := tailcfg.NetInfo{LinkType: "wired"}
  21. hi.NetInfo = &ni
  22. k := key.NewMachine()
  23. opts := Options{
  24. ServerURL: "https://example.com",
  25. Hostinfo: hi,
  26. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  27. return k, nil
  28. },
  29. Dialer: tsdial.NewDialer(netmon.NewStatic()),
  30. }
  31. c, err := NewDirect(opts)
  32. if err != nil {
  33. t.Fatal(err)
  34. }
  35. if c.serverURL != opts.ServerURL {
  36. t.Errorf("c.serverURL got %v want %v", c.serverURL, opts.ServerURL)
  37. }
  38. // hi is stored without its NetInfo field.
  39. hiWithoutNi := *hi
  40. hiWithoutNi.NetInfo = nil
  41. if !hiWithoutNi.Equal(c.hostinfo) {
  42. t.Errorf("c.hostinfo got %v want %v", c.hostinfo, hi)
  43. }
  44. changed := c.SetNetInfo(&ni)
  45. if changed {
  46. t.Errorf("c.SetNetInfo(ni) want false got %v", changed)
  47. }
  48. ni = tailcfg.NetInfo{LinkType: "wifi"}
  49. changed = c.SetNetInfo(&ni)
  50. if !changed {
  51. t.Errorf("c.SetNetInfo(ni) want true got %v", changed)
  52. }
  53. changed = c.SetHostinfo(hi)
  54. if changed {
  55. t.Errorf("c.SetHostinfo(hi) want false got %v", changed)
  56. }
  57. hi = hostinfo.New()
  58. hi.Hostname = "different host name"
  59. changed = c.SetHostinfo(hi)
  60. if !changed {
  61. t.Errorf("c.SetHostinfo(hi) want true got %v", changed)
  62. }
  63. endpoints := fakeEndpoints(1, 2, 3)
  64. changed = c.newEndpoints(endpoints)
  65. if !changed {
  66. t.Errorf("c.newEndpoints want true got %v", changed)
  67. }
  68. changed = c.newEndpoints(endpoints)
  69. if changed {
  70. t.Errorf("c.newEndpoints want false got %v", changed)
  71. }
  72. endpoints = fakeEndpoints(4, 5, 6)
  73. changed = c.newEndpoints(endpoints)
  74. if !changed {
  75. t.Errorf("c.newEndpoints want true got %v", changed)
  76. }
  77. }
  78. func fakeEndpoints(ports ...uint16) (ret []tailcfg.Endpoint) {
  79. for _, port := range ports {
  80. ret = append(ret, tailcfg.Endpoint{
  81. Addr: netip.AddrPortFrom(netip.Addr{}, port),
  82. })
  83. }
  84. return
  85. }
  86. func TestTsmpPing(t *testing.T) {
  87. hi := hostinfo.New()
  88. ni := tailcfg.NetInfo{LinkType: "wired"}
  89. hi.NetInfo = &ni
  90. k := key.NewMachine()
  91. opts := Options{
  92. ServerURL: "https://example.com",
  93. Hostinfo: hi,
  94. GetMachinePrivateKey: func() (key.MachinePrivate, error) {
  95. return k, nil
  96. },
  97. Dialer: tsdial.NewDialer(netmon.NewStatic()),
  98. }
  99. c, err := NewDirect(opts)
  100. if err != nil {
  101. t.Fatal(err)
  102. }
  103. pingRes := &tailcfg.PingResponse{
  104. Type: "TSMP",
  105. IP: "123.456.7890",
  106. Err: "",
  107. NodeName: "testnode",
  108. }
  109. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  110. defer r.Body.Close()
  111. body := new(ipnstate.PingResult)
  112. if err := json.NewDecoder(r.Body).Decode(body); err != nil {
  113. t.Fatal(err)
  114. }
  115. if pingRes.IP != body.IP {
  116. t.Fatalf("PingResult did not have the correct IP : got %v, expected : %v", body.IP, pingRes.IP)
  117. }
  118. w.WriteHeader(200)
  119. }))
  120. defer ts.Close()
  121. now := time.Now()
  122. pr := &tailcfg.PingRequest{
  123. URL: ts.URL,
  124. }
  125. err = postPingResult(now, t.Logf, c.httpc, pr, pingRes)
  126. if err != nil {
  127. t.Fatal(err)
  128. }
  129. }