Просмотр исходного кода

sessionrecording: implement v2 recording endpoint support (#14105)

The v2 endpoint supports HTTP/2 bidirectional streaming and acks for
received bytes. This is used to detect when a recorder disappears to
more quickly terminate the session.

Updates https://github.com/tailscale/corp/issues/24023

Signed-off-by: Andrew Lytvynov <[email protected]>
Andrew Lytvynov 1 год назад
Родитель
Сommit
c2a7f17f2b

+ 1 - 1
k8s-operator/sessionrecording/hijacker.go

@@ -102,7 +102,7 @@ type Hijacker struct {
 // connection succeeds. In case of success, returns a list with a single
 // successful recording attempt and an error channel. If the connection errors
 // after having been established, an error is sent down the channel.
-type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error)
+type RecorderDialFn func(context.Context, []netip.AddrPort, sessionrecording.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error)
 
 // Hijack hijacks a 'kubectl exec' session and configures for the session
 // contents to be sent to a recorder.

+ 2 - 2
k8s-operator/sessionrecording/hijacker_test.go

@@ -10,7 +10,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"net"
 	"net/http"
 	"net/netip"
 	"net/url"
@@ -20,6 +19,7 @@ import (
 	"go.uber.org/zap"
 	"tailscale.com/client/tailscale/apitype"
 	"tailscale.com/k8s-operator/sessionrecording/fakes"
+	"tailscale.com/sessionrecording"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tsnet"
 	"tailscale.com/tstest"
@@ -80,7 +80,7 @@ func Test_Hijacker(t *testing.T) {
 			h := &Hijacker{
 				connectToRecorder: func(context.Context,
 					[]netip.AddrPort,
-					func(context.Context, string, string) (net.Conn, error),
+					sessionrecording.DialFunc,
 				) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) {
 					if tt.failRecorderConnect {
 						err = errors.New("test")

+ 260 - 60
sessionrecording/connect.go

@@ -7,6 +7,8 @@ package sessionrecording
 
 import (
 	"context"
+	"crypto/tls"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -14,12 +16,33 @@ import (
 	"net/http"
 	"net/http/httptrace"
 	"net/netip"
+	"sync/atomic"
 	"time"
 
+	"golang.org/x/net/http2"
 	"tailscale.com/tailcfg"
+	"tailscale.com/util/httpm"
 	"tailscale.com/util/multierr"
 )
 
+const (
+	// Timeout for an individual DialFunc call for a single recorder address.
+	perDialAttemptTimeout = 5 * time.Second
+	// Timeout for the V2 API HEAD probe request (supportsV2).
+	http2ProbeTimeout = 10 * time.Second
+	// Maximum timeout for trying all available recorders, including V2 API
+	// probes and dial attempts.
+	allDialAttemptsTimeout = 30 * time.Second
+)
+
+// uploadAckWindow is the period of time to wait for an ackFrame from recorder
+// before terminating the connection. This is a variable to allow overriding it
+// in tests.
+var uploadAckWindow = 30 * time.Second
+
+// DialFunc is a function for dialing the recorder.
+type DialFunc func(ctx context.Context, network, host string) (net.Conn, error)
+
 // ConnectToRecorder connects to the recorder at any of the provided addresses.
 // It returns the first successful response, or a multierr if all attempts fail.
 //
@@ -32,19 +55,15 @@ import (
 // attempts are in order the recorder(s) was attempted. If successful a
 // successful connection is made, the last attempt in the slice is the
 // attempt for connected recorder.
-func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) {
+func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) {
 	if len(recs) == 0 {
 		return nil, nil, nil, errors.New("no recorders configured")
 	}
 	// We use a special context for dialing the recorder, so that we can
 	// limit the time we spend dialing to 30 seconds and still have an
 	// unbounded context for the upload.
-	dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
+	dialCtx, dialCancel := context.WithTimeout(ctx, allDialAttemptsTimeout)
 	defer dialCancel()
-	hc, err := SessionRecordingClientForDialer(dialCtx, dial)
-	if err != nil {
-		return nil, nil, nil, err
-	}
 
 	var errs []error
 	var attempts []*tailcfg.SSHRecordingAttempt
@@ -54,74 +73,230 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con
 		}
 		attempts = append(attempts, attempt)
 
-		// We dial the recorder and wait for it to send a 100-continue
-		// response before returning from this function. This ensures that
-		// the recorder is ready to accept the recording.
-
-		// got100 is closed when we receive the 100-continue response.
-		got100 := make(chan struct{})
-		ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
-			Got100Continue: func() {
-				close(got100)
-			},
-		})
-
-		pr, pw := io.Pipe()
-		req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr)
+		var pw io.WriteCloser
+		var errChan <-chan error
+		var err error
+		hc := clientHTTP2(dialCtx, dial)
+		// We need to probe V2 support using a separate HEAD request. Sending
+		// an HTTP/2 POST request to a HTTP/1 server will just "hang" until the
+		// request body is closed (instead of returning a 404 as one would
+		// expect). Sending a HEAD request without a body does not have that
+		// problem.
+		if supportsV2(ctx, hc, ap) {
+			pw, errChan, err = connectV2(ctx, hc, ap)
+		} else {
+			pw, errChan, err = connectV1(ctx, clientHTTP1(dialCtx, dial), ap)
+		}
 		if err != nil {
-			err = fmt.Errorf("recording: error starting recording: %w", err)
+			err = fmt.Errorf("recording: error starting recording on %q: %w", ap, err)
 			attempt.FailureMessage = err.Error()
 			errs = append(errs, err)
 			continue
 		}
-		// We set the Expect header to 100-continue, so that the recorder
-		// will send a 100-continue response before it starts reading the
-		// request body.
-		req.Header.Set("Expect", "100-continue")
+		return pw, attempts, errChan, nil
+	}
+	return nil, attempts, nil, multierr.New(errs...)
+}
 
-		// errChan is used to indicate the result of the request.
-		errChan := make(chan error, 1)
-		go func() {
-			resp, err := hc.Do(req)
-			if err != nil {
-				errChan <- fmt.Errorf("recording: error starting recording: %w", err)
+// supportsV2 checks whether a recorder instance supports the /v2/record
+// endpoint.
+func supportsV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) bool {
+	ctx, cancel := context.WithTimeout(ctx, http2ProbeTimeout)
+	defer cancel()
+	req, err := http.NewRequestWithContext(ctx, httpm.HEAD, fmt.Sprintf("http://%s/v2/record", ap), nil)
+	if err != nil {
+		return false
+	}
+	resp, err := hc.Do(req)
+	if err != nil {
+		return false
+	}
+	defer resp.Body.Close()
+	return resp.StatusCode == http.StatusOK && resp.ProtoMajor > 1
+}
+
+// connectV1 connects to the legacy /record endpoint on the recorder. It is
+// used for backwards-compatibility with older tsrecorder instances.
+//
+// On success, it returns a WriteCloser that can be used to upload the
+// recording, and a channel that will be sent an error (or nil) when the upload
+// fails or completes.
+func connectV1(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) {
+	// We dial the recorder and wait for it to send a 100-continue
+	// response before returning from this function. This ensures that
+	// the recorder is ready to accept the recording.
+
+	// got100 is closed when we receive the 100-continue response.
+	got100 := make(chan struct{})
+	ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+		Got100Continue: func() {
+			close(got100)
+		},
+	})
+
+	pr, pw := io.Pipe()
+	req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/record", ap), pr)
+	if err != nil {
+		return nil, nil, err
+	}
+	// We set the Expect header to 100-continue, so that the recorder
+	// will send a 100-continue response before it starts reading the
+	// request body.
+	req.Header.Set("Expect", "100-continue")
+
+	// errChan is used to indicate the result of the request.
+	errChan := make(chan error, 1)
+	go func() {
+		defer close(errChan)
+		resp, err := hc.Do(req)
+		if err != nil {
+			errChan <- err
+			return
+		}
+		defer resp.Body.Close()
+		if resp.StatusCode != 200 {
+			errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
+			return
+		}
+	}()
+	select {
+	case <-got100:
+		return pw, errChan, nil
+	case err := <-errChan:
+		// If we get an error before we get the 100-continue response,
+		// we need to try another recorder.
+		if err == nil {
+			// If the error is nil, we got a 200 response, which
+			// is unexpected as we haven't sent any data yet.
+			err = errors.New("recording: unexpected EOF")
+		}
+		return nil, nil, err
+	}
+}
+
+// connectV2 connects to the /v2/record endpoint on the recorder over HTTP/2.
+// It explicitly tracks ack frames sent in the response and terminates the
+// connection if sent recording data is un-acked for uploadAckWindow.
+//
+// On success, it returns a WriteCloser that can be used to upload the
+// recording, and a channel that will be sent an error (or nil) when the upload
+// fails or completes.
+func connectV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) {
+	pr, pw := io.Pipe()
+	upload := &readCounter{r: pr}
+	req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/record", ap), upload)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	// With HTTP/2, hc.Do will not block while the request body is being sent.
+	// It will return immediately and allow us to consume the response body at
+	// the same time.
+	resp, err := hc.Do(req)
+	if err != nil {
+		return nil, nil, err
+	}
+	if resp.StatusCode != http.StatusOK {
+		resp.Body.Close()
+		return nil, nil, fmt.Errorf("recording: unexpected status: %v", resp.Status)
+	}
+
+	errChan := make(chan error, 1)
+	acks := make(chan int64)
+	// Read acks from the response and send them to the acks channel.
+	go func() {
+		defer close(errChan)
+		defer close(acks)
+		defer resp.Body.Close()
+		defer pw.Close()
+		dec := json.NewDecoder(resp.Body)
+		for {
+			var frame v2ResponseFrame
+			if err := dec.Decode(&frame); err != nil {
+				if !errors.Is(err, io.EOF) {
+					errChan <- fmt.Errorf("recording: unexpected error receiving acks: %w", err)
+				}
 				return
 			}
-			if resp.StatusCode != 200 {
-				errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
+			if frame.Error != "" {
+				errChan <- fmt.Errorf("recording: received error from the recorder: %q", frame.Error)
 				return
 			}
-			errChan <- nil
-		}()
-		select {
-		case <-got100:
-		case err := <-errChan:
-			// If we get an error before we get the 100-continue response,
-			// we need to try another recorder.
-			if err == nil {
-				// If the error is nil, we got a 200 response, which
-				// is unexpected as we haven't sent any data yet.
-				err = errors.New("recording: unexpected EOF")
+			select {
+			case acks <- frame.Ack:
+			case <-ctx.Done():
+				return
 			}
-			attempt.FailureMessage = err.Error()
-			errs = append(errs, err)
-			continue // try the next recorder
 		}
-		return pw, attempts, errChan, nil
-	}
-	return nil, attempts, nil, multierr.New(errs...)
+	}()
+	// Track acks from the acks channel.
+	go func() {
+		// Hack for tests: some tests modify uploadAckWindow and reset it when
+		// the test ends. This can race with t.Reset call below. Making a copy
+		// here is a lazy workaround to not wait for this goroutine to exit in
+		// the test cases.
+		uploadAckWindow := uploadAckWindow
+		// This timer fires if we didn't receive an ack for too long.
+		t := time.NewTimer(uploadAckWindow)
+		defer t.Stop()
+		for {
+			select {
+			case <-t.C:
+				// Close the pipe which terminates the connection and cleans up
+				// other goroutines. Note that tsrecorder will send us ack
+				// frames even if there is no new data to ack. This helps
+				// detect broken recorder connection if the session is idle.
+				pr.CloseWithError(errNoAcks)
+				resp.Body.Close()
+				return
+			case _, ok := <-acks:
+				if !ok {
+					// acks channel closed means that the goroutine reading them
+					// finished, which means that the request has ended.
+					return
+				}
+				// TODO(awly): limit how far behind the received acks can be. This
+				// should handle scenarios where a session suddenly dumps a lot of
+				// output.
+				t.Reset(uploadAckWindow)
+			case <-ctx.Done():
+				return
+			}
+		}
+	}()
+
+	return pw, errChan, nil
 }
 
-// SessionRecordingClientForDialer returns an http.Client that uses a clone of
-// the provided Dialer's PeerTransport to dial connections. This is used to make
-// requests to the session recording server to upload session recordings. It
-// uses the provided dialCtx to dial connections, and limits a single dial to 5
-// seconds.
-func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) {
-	tr := http.DefaultTransport.(*http.Transport).Clone()
+var errNoAcks = errors.New("did not receive ack frames from the recorder in 30s")
+
+type v2ResponseFrame struct {
+	// Ack is the number of bytes received from the client so far. The bytes
+	// are not guaranteed to be durably stored yet.
+	Ack int64 `json:"ack,omitempty"`
+	// Error is an error encountered while storing the recording. Error is only
+	// ever set as the last frame in the response.
+	Error string `json:"error,omitempty"`
+}
 
+// readCounter is an io.Reader that counts how many bytes were read.
+type readCounter struct {
+	r    io.Reader
+	sent atomic.Int64
+}
+
+func (u *readCounter) Read(buf []byte) (int, error) {
+	n, err := u.r.Read(buf)
+	u.sent.Add(int64(n))
+	return n, err
+}
+
+// clientHTTP1 returns a claassic http.Client with a per-dial context. It uses
+// dialCtx and adds a 5s timeout to it.
+func clientHTTP1(dialCtx context.Context, dial DialFunc) *http.Client {
+	tr := http.DefaultTransport.(*http.Transport).Clone()
 	tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
-		perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
+		perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout)
 		defer cancel()
 		go func() {
 			select {
@@ -132,7 +307,32 @@ func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.
 		}()
 		return dial(perAttemptCtx, network, addr)
 	}
+	return &http.Client{Transport: tr}
+}
+
+// clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c
+// requests (HTTP/2 over plaintext). Unfortunately the same client does not
+// work for HTTP/1 so we need to split these up.
+func clientHTTP2(dialCtx context.Context, dial DialFunc) *http.Client {
 	return &http.Client{
-		Transport: tr,
-	}, nil
+		Transport: &http2.Transport{
+			// Allow "http://" scheme in URLs.
+			AllowHTTP: true,
+			// Pretend like we're using TLS, but actually use the provided
+			// DialFunc underneath. This is necessary to convince the transport
+			// to actually dial.
+			DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
+				perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout)
+				defer cancel()
+				go func() {
+					select {
+					case <-perAttemptCtx.Done():
+					case <-dialCtx.Done():
+						cancel()
+					}
+				}()
+				return dial(perAttemptCtx, network, addr)
+			},
+		},
+	}
 }

+ 189 - 0
sessionrecording/connect_test.go

@@ -0,0 +1,189 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package sessionrecording
+
+import (
+	"bytes"
+	"context"
+	"crypto/rand"
+	"crypto/sha256"
+	"encoding/json"
+	"io"
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"net/netip"
+	"testing"
+	"time"
+
+	"golang.org/x/net/http2"
+	"golang.org/x/net/http2/h2c"
+)
+
+func TestConnectToRecorder(t *testing.T) {
+	tests := []struct {
+		desc  string
+		http2 bool
+		// setup returns a recorder server mux, and a channel which sends the
+		// hash of the recording uploaded to it. The channel is expected to
+		// fire only once.
+		setup   func(t *testing.T) (*http.ServeMux, <-chan []byte)
+		wantErr bool
+	}{
+		{
+			desc: "v1 recorder",
+			setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
+				uploadHash := make(chan []byte, 1)
+				mux := http.NewServeMux()
+				mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
+					hash := sha256.New()
+					if _, err := io.Copy(hash, r.Body); err != nil {
+						t.Error(err)
+					}
+					uploadHash <- hash.Sum(nil)
+				})
+				return mux, uploadHash
+			},
+		},
+		{
+			desc:  "v2 recorder",
+			http2: true,
+			setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
+				uploadHash := make(chan []byte, 1)
+				mux := http.NewServeMux()
+				mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
+					t.Error("received request to v1 endpoint")
+					http.Error(w, "not found", http.StatusNotFound)
+				})
+				mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
+					// Force the status to send to unblock the client waiting
+					// for it.
+					w.WriteHeader(http.StatusOK)
+					w.(http.Flusher).Flush()
+
+					body := &readCounter{r: r.Body}
+					hash := sha256.New()
+					ctx, cancel := context.WithCancel(r.Context())
+					go func() {
+						defer cancel()
+						if _, err := io.Copy(hash, body); err != nil {
+							t.Error(err)
+						}
+					}()
+
+					// Send acks for received bytes.
+					tick := time.NewTicker(time.Millisecond)
+					defer tick.Stop()
+					enc := json.NewEncoder(w)
+				outer:
+					for {
+						select {
+						case <-ctx.Done():
+							break outer
+						case <-tick.C:
+							if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil {
+								t.Errorf("writing ack frame: %v", err)
+								break outer
+							}
+						}
+					}
+
+					uploadHash <- hash.Sum(nil)
+				})
+				// Probing HEAD endpoint which always returns 200 OK.
+				mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
+				return mux, uploadHash
+			},
+		},
+		{
+			desc:    "v2 recorder no acks",
+			http2:   true,
+			wantErr: true,
+			setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
+				// Make the client no-ack timeout quick for the test.
+				oldAckWindow := uploadAckWindow
+				uploadAckWindow = 100 * time.Millisecond
+				t.Cleanup(func() { uploadAckWindow = oldAckWindow })
+
+				uploadHash := make(chan []byte, 1)
+				mux := http.NewServeMux()
+				mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
+					t.Error("received request to v1 endpoint")
+					http.Error(w, "not found", http.StatusNotFound)
+				})
+				mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
+					// Force the status to send to unblock the client waiting
+					// for it.
+					w.WriteHeader(http.StatusOK)
+					w.(http.Flusher).Flush()
+
+					// Consume the whole request body but don't send any acks
+					// back.
+					hash := sha256.New()
+					if _, err := io.Copy(hash, r.Body); err != nil {
+						t.Error(err)
+					}
+					// Goes in the channel buffer, non-blocking.
+					uploadHash <- hash.Sum(nil)
+
+					// Block until the parent test case ends to prevent the
+					// request termination. We want to exercise the ack
+					// tracking logic specifically.
+					ctx, cancel := context.WithCancel(r.Context())
+					t.Cleanup(cancel)
+					<-ctx.Done()
+				})
+				mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
+				return mux, uploadHash
+			},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.desc, func(t *testing.T) {
+			mux, uploadHash := tt.setup(t)
+
+			srv := httptest.NewUnstartedServer(mux)
+			if tt.http2 {
+				// Wire up h2c-compatible HTTP/2 server. This is optional
+				// because the v1 recorder didn't support HTTP/2 and we try to
+				// mimic that.
+				h2s := &http2.Server{}
+				srv.Config.Handler = h2c.NewHandler(mux, h2s)
+				if err := http2.ConfigureServer(srv.Config, h2s); err != nil {
+					t.Errorf("configuring HTTP/2 support in server: %v", err)
+				}
+			}
+			srv.Start()
+			t.Cleanup(srv.Close)
+
+			d := new(net.Dialer)
+
+			ctx := context.Background()
+			w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(srv.Listener.Addr().String())}, d.DialContext)
+			if err != nil {
+				t.Fatalf("ConnectToRecorder: %v", err)
+			}
+
+			// Send some random data and hash it to compare with the recorded
+			// data hash.
+			hash := sha256.New()
+			const numBytes = 1 << 20 // 1MB
+			if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil {
+				t.Fatalf("writing recording data: %v", err)
+			}
+			if err := w.Close(); err != nil {
+				t.Fatalf("closing recording stream: %v", err)
+			}
+			if err := <-errc; err != nil && !tt.wantErr {
+				t.Fatalf("error from the channel: %v", err)
+			} else if err == nil && tt.wantErr {
+				t.Fatalf("did not receive expected error from the channel")
+			}
+
+			if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) {
+				t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv)
+			}
+		})
+	}
+}

+ 9 - 4
ssh/tailssh/tailssh.go

@@ -1170,7 +1170,7 @@ func (ss *sshSession) run() {
 		if err != nil && !errors.Is(err, io.EOF) {
 			isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
 			if !isErrBecauseProcessExited {
-				logf("stdout copy: %v, %T", err)
+				logf("stdout copy: %v", err)
 				ss.cancelCtx(err)
 			}
 		}
@@ -1520,9 +1520,14 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
 		go func() {
 			err := <-errChan
 			if err == nil {
-				// Success.
-				ss.logf("recording: finished uploading recording")
-				return
+				select {
+				case <-ss.ctx.Done():
+					// Success.
+					ss.logf("recording: finished uploading recording")
+					return
+				default:
+					err = errors.New("recording upload ended before the SSH session")
+				}
 			}
 			if onFailure != nil && onFailure.NotifyURL != "" && len(attempts) > 0 {
 				lastAttempt := attempts[len(attempts)-1]

+ 39 - 22
ssh/tailssh/tailssh_test.go

@@ -33,6 +33,8 @@ import (
 	"time"
 
 	gossh "github.com/tailscale/golang-x-crypto/ssh"
+	"golang.org/x/net/http2"
+	"golang.org/x/net/http2/h2c"
 	"tailscale.com/ipn/ipnlocal"
 	"tailscale.com/ipn/store/mem"
 	"tailscale.com/net/memnet"
@@ -481,10 +483,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
 	}
 
 	var handler http.HandlerFunc
-	recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
 		handler(w, r)
-	}))
-	defer recordingServer.Close()
+	})
 
 	s := &server{
 		logf: t.Logf,
@@ -533,9 +534,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
 		{
 			name: "upload-fails-after-starting",
 			handler: func(w http.ResponseWriter, r *http.Request) {
+				w.WriteHeader(http.StatusOK)
+				w.(http.Flusher).Flush()
 				r.Body.Read(make([]byte, 1))
 				time.Sleep(100 * time.Millisecond)
-				w.WriteHeader(http.StatusInternalServerError)
 			},
 			sshCommand:       "echo hello && sleep 1 && echo world",
 			wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
@@ -548,6 +550,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
+			s.logf = t.Logf
 			tstest.Replace(t, &handler, tt.handler)
 			sc, dc := memnet.NewTCPConn(src, dst, 1024)
 			var wg sync.WaitGroup
@@ -597,12 +600,12 @@ func TestMultipleRecorders(t *testing.T) {
 		t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
 	}
 	done := make(chan struct{})
-	recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
 		defer close(done)
-		io.ReadAll(r.Body)
 		w.WriteHeader(http.StatusOK)
-	}))
-	defer recordingServer.Close()
+		w.(http.Flusher).Flush()
+		io.ReadAll(r.Body)
+	})
 	badRecorder, err := net.Listen("tcp", ":0")
 	if err != nil {
 		t.Fatal(err)
@@ -610,15 +613,9 @@ func TestMultipleRecorders(t *testing.T) {
 	badRecorderAddr := badRecorder.Addr().String()
 	badRecorder.Close()
 
-	badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.WriteHeader(500)
-	}))
-	defer badRecordingServer500.Close()
-
-	badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.WriteHeader(200)
-	}))
-	defer badRecordingServer200.Close()
+	badRecordingServer500 := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusInternalServerError)
+	})
 
 	s := &server{
 		logf: t.Logf,
@@ -630,7 +627,6 @@ func TestMultipleRecorders(t *testing.T) {
 					Recorders: []netip.AddrPort{
 						netip.MustParseAddrPort(badRecorderAddr),
 						netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
-						netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
 						netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
 					},
 					OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
@@ -701,19 +697,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
 	}
 	var recording []byte
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
-	recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
 		defer cancel()
+		w.WriteHeader(http.StatusOK)
+		w.(http.Flusher).Flush()
+
 		var err error
 		recording, err = io.ReadAll(r.Body)
 		if err != nil {
 			t.Error(err)
 			return
 		}
-	}))
-	defer recordingServer.Close()
+	})
 
 	s := &server{
-		logf: logger.Discard,
+		logf: t.Logf,
 		lb: &localState{
 			sshEnabled: true,
 			matchingRule: newSSHRule(
@@ -1299,3 +1297,22 @@ func TestStdOsUserUserAssumptions(t *testing.T) {
 		t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
 	}
 }
+
+func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {
+	t.Helper()
+	mux := http.NewServeMux()
+	mux.HandleFunc("POST /record", func(http.ResponseWriter, *http.Request) {
+		t.Errorf("v1 recording endpoint called")
+	})
+	mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
+	mux.HandleFunc("POST /v2/record", handleRecord)
+
+	h2s := &http2.Server{}
+	srv := httptest.NewUnstartedServer(h2c.NewHandler(mux, h2s))
+	if err := http2.ConfigureServer(srv.Config, h2s); err != nil {
+		t.Errorf("configuring HTTP/2 support in recording server: %v", err)
+	}
+	srv.Start()
+	t.Cleanup(srv.Close)
+	return srv
+}