| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- package sessionrecording
- import (
- "bytes"
- "context"
- "crypto/rand"
- "crypto/sha256"
- "encoding/json"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/http/httptest"
- "net/netip"
- "strings"
- "testing"
- "time"
- "golang.org/x/net/http2"
- "golang.org/x/net/http2/h2c"
- "tailscale.com/net/memnet"
- )
- func TestConnectToRecorder(t *testing.T) {
- tests := []struct {
- desc string
- http2 bool
- // setup returns a recorder server mux, and a channel which sends the
- // hash of the recording uploaded to it. The channel is expected to
- // fire only once.
- setup func(t *testing.T) (*http.ServeMux, <-chan []byte)
- wantErr bool
- }{
- {
- desc: "v1 recorder",
- setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
- uploadHash := make(chan []byte, 1)
- mux := http.NewServeMux()
- mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
- hash := sha256.New()
- if _, err := io.Copy(hash, r.Body); err != nil {
- t.Error(err)
- }
- uploadHash <- hash.Sum(nil)
- })
- return mux, uploadHash
- },
- },
- {
- desc: "v2 recorder",
- http2: true,
- setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
- uploadHash := make(chan []byte, 1)
- mux := http.NewServeMux()
- mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
- t.Error("received request to v1 endpoint")
- http.Error(w, "not found", http.StatusNotFound)
- })
- mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
- // Force the status to send to unblock the client waiting
- // for it.
- w.WriteHeader(http.StatusOK)
- w.(http.Flusher).Flush()
- body := &readCounter{r: r.Body}
- hash := sha256.New()
- ctx, cancel := context.WithCancel(r.Context())
- go func() {
- defer cancel()
- if _, err := io.Copy(hash, body); err != nil {
- t.Error(err)
- }
- }()
- // Send acks for received bytes.
- tick := time.NewTicker(time.Millisecond)
- defer tick.Stop()
- enc := json.NewEncoder(w)
- outer:
- for {
- select {
- case <-ctx.Done():
- break outer
- case <-tick.C:
- if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil {
- t.Errorf("writing ack frame: %v", err)
- break outer
- }
- }
- }
- uploadHash <- hash.Sum(nil)
- })
- // Probing HEAD endpoint which always returns 200 OK.
- mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
- return mux, uploadHash
- },
- },
- {
- desc: "v2 recorder no acks",
- http2: true,
- wantErr: true,
- setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
- // Make the client no-ack timeout quick for the test.
- oldAckWindow := uploadAckWindow
- uploadAckWindow = 100 * time.Millisecond
- t.Cleanup(func() { uploadAckWindow = oldAckWindow })
- uploadHash := make(chan []byte, 1)
- mux := http.NewServeMux()
- mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
- t.Error("received request to v1 endpoint")
- http.Error(w, "not found", http.StatusNotFound)
- })
- mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
- // Force the status to send to unblock the client waiting
- // for it.
- w.WriteHeader(http.StatusOK)
- w.(http.Flusher).Flush()
- // Consume the whole request body but don't send any acks
- // back.
- hash := sha256.New()
- if _, err := io.Copy(hash, r.Body); err != nil {
- t.Error(err)
- }
- // Goes in the channel buffer, non-blocking.
- uploadHash <- hash.Sum(nil)
- // Block until the parent test case ends to prevent the
- // request termination. We want to exercise the ack
- // tracking logic specifically.
- ctx, cancel := context.WithCancel(r.Context())
- t.Cleanup(cancel)
- <-ctx.Done()
- })
- mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
- return mux, uploadHash
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.desc, func(t *testing.T) {
- mux, uploadHash := tt.setup(t)
- memNet := &memnet.Network{}
- ln := memNet.NewLocalTCPListener()
- srv := &httptest.Server{
- Config: &http.Server{Handler: mux},
- Listener: ln,
- }
- if tt.http2 {
- // Wire up h2c-compatible HTTP/2 server. This is optional
- // because the v1 recorder didn't support HTTP/2 and we try to
- // mimic that.
- s := &http2.Server{}
- srv.Config.Handler = h2c.NewHandler(mux, s)
- if err := http2.ConfigureServer(srv.Config, s); err != nil {
- t.Errorf("configuring HTTP/2 support in server: %v", err)
- }
- }
- srv.Start()
- t.Cleanup(srv.Close)
- ctx := context.Background()
- w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(ln.Addr().String())}, memNet.Dial)
- if err != nil {
- t.Fatalf("ConnectToRecorder: %v", err)
- }
- // Send some random data and hash it to compare with the recorded
- // data hash.
- hash := sha256.New()
- const numBytes = 1 << 20 // 1MB
- if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil {
- t.Fatalf("writing recording data: %v", err)
- }
- if err := w.Close(); err != nil {
- t.Fatalf("closing recording stream: %v", err)
- }
- if err := <-errc; err != nil && !tt.wantErr {
- t.Fatalf("error from the channel: %v", err)
- } else if err == nil && tt.wantErr {
- t.Fatalf("did not receive expected error from the channel")
- }
- if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) {
- t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv)
- }
- })
- }
- }
- func TestSendEvent(t *testing.T) {
- t.Run("supported", func(t *testing.T) {
- eventBody := `{"foo":"bar"}`
- eventRecieved := make(chan []byte, 1)
- mux := http.NewServeMux()
- mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- })
- mux.HandleFunc("POST /v2/event", func(w http.ResponseWriter, r *http.Request) {
- body, err := io.ReadAll(r.Body)
- if err != nil {
- t.Error(err)
- }
- eventRecieved <- body
- w.WriteHeader(http.StatusOK)
- })
- srv := httptest.NewUnstartedServer(mux)
- s := &http2.Server{}
- srv.Config.Handler = h2c.NewHandler(mux, s)
- if err := http2.ConfigureServer(srv.Config, s); err != nil {
- t.Fatalf("configuring HTTP/2 support in server: %v", err)
- }
- srv.Start()
- t.Cleanup(srv.Close)
- d := new(net.Dialer)
- addr := netip.MustParseAddrPort(srv.Listener.Addr().String())
- err := SendEvent(addr, bytes.NewBufferString(eventBody), d.DialContext)
- if err != nil {
- t.Fatalf("SendEvent: %v", err)
- }
- if recv := string(<-eventRecieved); recv != eventBody {
- t.Errorf("mismatch in event body, sent %q, received %q", eventBody, recv)
- }
- })
- t.Run("not_supported", func(t *testing.T) {
- mux := http.NewServeMux()
- mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNotFound)
- })
- srv := httptest.NewUnstartedServer(mux)
- s := &http2.Server{}
- srv.Config.Handler = h2c.NewHandler(mux, s)
- if err := http2.ConfigureServer(srv.Config, s); err != nil {
- t.Fatalf("configuring HTTP/2 support in server: %v", err)
- }
- srv.Start()
- t.Cleanup(srv.Close)
- d := new(net.Dialer)
- addr := netip.MustParseAddrPort(srv.Listener.Addr().String())
- err := SendEvent(addr, nil, d.DialContext)
- if err == nil {
- t.Fatal("expected an error, got nil")
- }
- if !strings.Contains(err.Error(), fmt.Sprintf(addressNotSupportEventv2, srv.Listener.Addr().String())) {
- t.Fatalf("unexpected error: %v", err)
- }
- })
- t.Run("server_error", func(t *testing.T) {
- mux := http.NewServeMux()
- mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- })
- mux.HandleFunc("POST /v2/event", func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusInternalServerError)
- })
- srv := httptest.NewUnstartedServer(mux)
- s := &http2.Server{}
- srv.Config.Handler = h2c.NewHandler(mux, s)
- if err := http2.ConfigureServer(srv.Config, s); err != nil {
- t.Fatalf("configuring HTTP/2 support in server: %v", err)
- }
- srv.Start()
- t.Cleanup(srv.Close)
- d := new(net.Dialer)
- addr := netip.MustParseAddrPort(srv.Listener.Addr().String())
- err := SendEvent(addr, nil, d.DialContext)
- if err == nil {
- t.Fatal("expected an error, got nil")
- }
- if !strings.Contains(err.Error(), "server returned non-OK status") {
- t.Fatalf("unexpected error: %v", err)
- }
- })
- }
|