connect_test.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package sessionrecording
  4. import (
  5. "bytes"
  6. "context"
  7. "crypto/rand"
  8. "crypto/sha256"
  9. "encoding/json"
  10. "fmt"
  11. "io"
  12. "net"
  13. "net/http"
  14. "net/http/httptest"
  15. "net/netip"
  16. "strings"
  17. "testing"
  18. "time"
  19. "golang.org/x/net/http2"
  20. "golang.org/x/net/http2/h2c"
  21. "tailscale.com/net/memnet"
  22. )
  23. func TestConnectToRecorder(t *testing.T) {
  24. tests := []struct {
  25. desc string
  26. http2 bool
  27. // setup returns a recorder server mux, and a channel which sends the
  28. // hash of the recording uploaded to it. The channel is expected to
  29. // fire only once.
  30. setup func(t *testing.T) (*http.ServeMux, <-chan []byte)
  31. wantErr bool
  32. }{
  33. {
  34. desc: "v1 recorder",
  35. setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
  36. uploadHash := make(chan []byte, 1)
  37. mux := http.NewServeMux()
  38. mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
  39. hash := sha256.New()
  40. if _, err := io.Copy(hash, r.Body); err != nil {
  41. t.Error(err)
  42. }
  43. uploadHash <- hash.Sum(nil)
  44. })
  45. return mux, uploadHash
  46. },
  47. },
  48. {
  49. desc: "v2 recorder",
  50. http2: true,
  51. setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
  52. uploadHash := make(chan []byte, 1)
  53. mux := http.NewServeMux()
  54. mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
  55. t.Error("received request to v1 endpoint")
  56. http.Error(w, "not found", http.StatusNotFound)
  57. })
  58. mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
  59. // Force the status to send to unblock the client waiting
  60. // for it.
  61. w.WriteHeader(http.StatusOK)
  62. w.(http.Flusher).Flush()
  63. body := &readCounter{r: r.Body}
  64. hash := sha256.New()
  65. ctx, cancel := context.WithCancel(r.Context())
  66. go func() {
  67. defer cancel()
  68. if _, err := io.Copy(hash, body); err != nil {
  69. t.Error(err)
  70. }
  71. }()
  72. // Send acks for received bytes.
  73. tick := time.NewTicker(time.Millisecond)
  74. defer tick.Stop()
  75. enc := json.NewEncoder(w)
  76. outer:
  77. for {
  78. select {
  79. case <-ctx.Done():
  80. break outer
  81. case <-tick.C:
  82. if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil {
  83. t.Errorf("writing ack frame: %v", err)
  84. break outer
  85. }
  86. }
  87. }
  88. uploadHash <- hash.Sum(nil)
  89. })
  90. // Probing HEAD endpoint which always returns 200 OK.
  91. mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
  92. return mux, uploadHash
  93. },
  94. },
  95. {
  96. desc: "v2 recorder no acks",
  97. http2: true,
  98. wantErr: true,
  99. setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
  100. // Make the client no-ack timeout quick for the test.
  101. oldAckWindow := uploadAckWindow
  102. uploadAckWindow = 100 * time.Millisecond
  103. t.Cleanup(func() { uploadAckWindow = oldAckWindow })
  104. uploadHash := make(chan []byte, 1)
  105. mux := http.NewServeMux()
  106. mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
  107. t.Error("received request to v1 endpoint")
  108. http.Error(w, "not found", http.StatusNotFound)
  109. })
  110. mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
  111. // Force the status to send to unblock the client waiting
  112. // for it.
  113. w.WriteHeader(http.StatusOK)
  114. w.(http.Flusher).Flush()
  115. // Consume the whole request body but don't send any acks
  116. // back.
  117. hash := sha256.New()
  118. if _, err := io.Copy(hash, r.Body); err != nil {
  119. t.Error(err)
  120. }
  121. // Goes in the channel buffer, non-blocking.
  122. uploadHash <- hash.Sum(nil)
  123. // Block until the parent test case ends to prevent the
  124. // request termination. We want to exercise the ack
  125. // tracking logic specifically.
  126. ctx, cancel := context.WithCancel(r.Context())
  127. t.Cleanup(cancel)
  128. <-ctx.Done()
  129. })
  130. mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
  131. return mux, uploadHash
  132. },
  133. },
  134. }
  135. for _, tt := range tests {
  136. t.Run(tt.desc, func(t *testing.T) {
  137. mux, uploadHash := tt.setup(t)
  138. memNet := &memnet.Network{}
  139. ln := memNet.NewLocalTCPListener()
  140. srv := &httptest.Server{
  141. Config: &http.Server{Handler: mux},
  142. Listener: ln,
  143. }
  144. if tt.http2 {
  145. // Wire up h2c-compatible HTTP/2 server. This is optional
  146. // because the v1 recorder didn't support HTTP/2 and we try to
  147. // mimic that.
  148. s := &http2.Server{}
  149. srv.Config.Handler = h2c.NewHandler(mux, s)
  150. if err := http2.ConfigureServer(srv.Config, s); err != nil {
  151. t.Errorf("configuring HTTP/2 support in server: %v", err)
  152. }
  153. }
  154. srv.Start()
  155. t.Cleanup(srv.Close)
  156. ctx := context.Background()
  157. w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(ln.Addr().String())}, memNet.Dial)
  158. if err != nil {
  159. t.Fatalf("ConnectToRecorder: %v", err)
  160. }
  161. // Send some random data and hash it to compare with the recorded
  162. // data hash.
  163. hash := sha256.New()
  164. const numBytes = 1 << 20 // 1MB
  165. if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil {
  166. t.Fatalf("writing recording data: %v", err)
  167. }
  168. if err := w.Close(); err != nil {
  169. t.Fatalf("closing recording stream: %v", err)
  170. }
  171. if err := <-errc; err != nil && !tt.wantErr {
  172. t.Fatalf("error from the channel: %v", err)
  173. } else if err == nil && tt.wantErr {
  174. t.Fatalf("did not receive expected error from the channel")
  175. }
  176. if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) {
  177. t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv)
  178. }
  179. })
  180. }
  181. }
  182. func TestSendEvent(t *testing.T) {
  183. t.Run("supported", func(t *testing.T) {
  184. eventBody := `{"foo":"bar"}`
  185. eventRecieved := make(chan []byte, 1)
  186. mux := http.NewServeMux()
  187. mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) {
  188. w.WriteHeader(http.StatusOK)
  189. })
  190. mux.HandleFunc("POST /v2/event", func(w http.ResponseWriter, r *http.Request) {
  191. body, err := io.ReadAll(r.Body)
  192. if err != nil {
  193. t.Error(err)
  194. }
  195. eventRecieved <- body
  196. w.WriteHeader(http.StatusOK)
  197. })
  198. srv := httptest.NewUnstartedServer(mux)
  199. s := &http2.Server{}
  200. srv.Config.Handler = h2c.NewHandler(mux, s)
  201. if err := http2.ConfigureServer(srv.Config, s); err != nil {
  202. t.Fatalf("configuring HTTP/2 support in server: %v", err)
  203. }
  204. srv.Start()
  205. t.Cleanup(srv.Close)
  206. d := new(net.Dialer)
  207. addr := netip.MustParseAddrPort(srv.Listener.Addr().String())
  208. err := SendEvent(addr, bytes.NewBufferString(eventBody), d.DialContext)
  209. if err != nil {
  210. t.Fatalf("SendEvent: %v", err)
  211. }
  212. if recv := string(<-eventRecieved); recv != eventBody {
  213. t.Errorf("mismatch in event body, sent %q, received %q", eventBody, recv)
  214. }
  215. })
  216. t.Run("not_supported", func(t *testing.T) {
  217. mux := http.NewServeMux()
  218. mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) {
  219. w.WriteHeader(http.StatusNotFound)
  220. })
  221. srv := httptest.NewUnstartedServer(mux)
  222. s := &http2.Server{}
  223. srv.Config.Handler = h2c.NewHandler(mux, s)
  224. if err := http2.ConfigureServer(srv.Config, s); err != nil {
  225. t.Fatalf("configuring HTTP/2 support in server: %v", err)
  226. }
  227. srv.Start()
  228. t.Cleanup(srv.Close)
  229. d := new(net.Dialer)
  230. addr := netip.MustParseAddrPort(srv.Listener.Addr().String())
  231. err := SendEvent(addr, nil, d.DialContext)
  232. if err == nil {
  233. t.Fatal("expected an error, got nil")
  234. }
  235. if !strings.Contains(err.Error(), fmt.Sprintf(addressNotSupportEventv2, srv.Listener.Addr().String())) {
  236. t.Fatalf("unexpected error: %v", err)
  237. }
  238. })
  239. t.Run("server_error", func(t *testing.T) {
  240. mux := http.NewServeMux()
  241. mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) {
  242. w.WriteHeader(http.StatusOK)
  243. })
  244. mux.HandleFunc("POST /v2/event", func(w http.ResponseWriter, r *http.Request) {
  245. w.WriteHeader(http.StatusInternalServerError)
  246. })
  247. srv := httptest.NewUnstartedServer(mux)
  248. s := &http2.Server{}
  249. srv.Config.Handler = h2c.NewHandler(mux, s)
  250. if err := http2.ConfigureServer(srv.Config, s); err != nil {
  251. t.Fatalf("configuring HTTP/2 support in server: %v", err)
  252. }
  253. srv.Start()
  254. t.Cleanup(srv.Close)
  255. d := new(net.Dialer)
  256. addr := netip.MustParseAddrPort(srv.Listener.Addr().String())
  257. err := SendEvent(addr, nil, d.DialContext)
  258. if err == nil {
  259. t.Fatal("expected an error, got nil")
  260. }
  261. if !strings.Contains(err.Error(), "server returned non-OK status") {
  262. t.Fatalf("unexpected error: %v", err)
  263. }
  264. })
  265. }