Browse Source

k8s-operator,sessionrecording: fixing race condition between resize (#16454)

messages and cast headers when recording `kubectl attach` sessions

Updates #16490

Signed-off-by: chaosinthecrd <[email protected]>
Tom Meadows 7 months ago
parent
commit
bcaea4f245

+ 34 - 17
k8s-operator/api-proxy/proxy.go

@@ -22,6 +22,7 @@ import (
 	"k8s.io/client-go/transport"
 	"tailscale.com/client/local"
 	"tailscale.com/client/tailscale/apitype"
+	"tailscale.com/k8s-operator/sessionrecording"
 	ksr "tailscale.com/k8s-operator/sessionrecording"
 	"tailscale.com/kube/kubetypes"
 	"tailscale.com/tailcfg"
@@ -49,6 +50,7 @@ func NewAPIServerProxy(zlog *zap.SugaredLogger, restConfig *rest.Config, ts *tsn
 	if !authMode {
 		restConfig = rest.AnonymousClientConfig(restConfig)
 	}
+
 	cfg, err := restConfig.TransportConfig()
 	if err != nil {
 		return nil, fmt.Errorf("could not get rest.TransportConfig(): %w", err)
@@ -111,6 +113,8 @@ func (ap *APIServerProxy) Run(ctx context.Context) error {
 	mux.HandleFunc("/", ap.serveDefault)
 	mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY)
 	mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS)
+	mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/attach", ap.serveAttachSPDY)
+	mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/attach", ap.serveAttachWS)
 
 	ap.hs = &http.Server{
 		// Kubernetes uses SPDY for exec and port-forward, however SPDY is
@@ -165,19 +169,31 @@ func (ap *APIServerProxy) serveDefault(w http.ResponseWriter, r *http.Request) {
 	ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
 }
 
-// serveExecSPDY serves 'kubectl exec' requests for sessions streamed over SPDY,
+// serveExecSPDY serves '/exec' requests for sessions streamed over SPDY,
 // optionally configuring the kubectl exec sessions to be recorded.
 func (ap *APIServerProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) {
-	ap.execForProto(w, r, ksr.SPDYProtocol)
+	ap.sessionForProto(w, r, ksr.ExecSessionType, ksr.SPDYProtocol)
 }
 
-// serveExecWS serves 'kubectl exec' requests for sessions streamed over WebSocket,
+// serveExecWS serves '/exec' requests for sessions streamed over WebSocket,
 // optionally configuring the kubectl exec sessions to be recorded.
 func (ap *APIServerProxy) serveExecWS(w http.ResponseWriter, r *http.Request) {
-	ap.execForProto(w, r, ksr.WSProtocol)
+	ap.sessionForProto(w, r, ksr.ExecSessionType, ksr.WSProtocol)
+}
+
+// serveExecSPDY serves '/attach' requests for sessions streamed over SPDY,
+// optionally configuring the kubectl exec sessions to be recorded.
+func (ap *APIServerProxy) serveAttachSPDY(w http.ResponseWriter, r *http.Request) {
+	ap.sessionForProto(w, r, ksr.AttachSessionType, ksr.SPDYProtocol)
+}
+
+// serveExecWS serves '/attach' requests for sessions streamed over WebSocket,
+// optionally configuring the kubectl exec sessions to be recorded.
+func (ap *APIServerProxy) serveAttachWS(w http.ResponseWriter, r *http.Request) {
+	ap.sessionForProto(w, r, ksr.AttachSessionType, ksr.WSProtocol)
 }
 
-func (ap *APIServerProxy) execForProto(w http.ResponseWriter, r *http.Request, proto ksr.Protocol) {
+func (ap *APIServerProxy) sessionForProto(w http.ResponseWriter, r *http.Request, sessionType sessionrecording.SessionType, proto ksr.Protocol) {
 	const (
 		podNameKey       = "pod"
 		namespaceNameKey = "namespace"
@@ -192,7 +208,7 @@ func (ap *APIServerProxy) execForProto(w http.ResponseWriter, r *http.Request, p
 	counterNumRequestsProxied.Add(1)
 	failOpen, addrs, err := determineRecorderConfig(who)
 	if err != nil {
-		ap.log.Errorf("error trying to determine whether the 'kubectl exec' session needs to be recorded: %v", err)
+		ap.log.Errorf("error trying to determine whether the 'kubectl %s' session needs to be recorded: %v", sessionType, err)
 		return
 	}
 	if failOpen && len(addrs) == 0 { // will not record
@@ -201,7 +217,7 @@ func (ap *APIServerProxy) execForProto(w http.ResponseWriter, r *http.Request, p
 	}
 	ksr.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded
 	if !failOpen && len(addrs) == 0 {
-		msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available."
+		msg := fmt.Sprintf("forbidden: 'kubectl %s' session must be recorded, but no recorders are available.", sessionType)
 		ap.log.Error(msg)
 		http.Error(w, msg, http.StatusForbidden)
 		return
@@ -223,16 +239,17 @@ func (ap *APIServerProxy) execForProto(w http.ResponseWriter, r *http.Request, p
 	}
 
 	opts := ksr.HijackerOpts{
-		Req:       r,
-		W:         w,
-		Proto:     proto,
-		TS:        ap.ts,
-		Who:       who,
-		Addrs:     addrs,
-		FailOpen:  failOpen,
-		Pod:       r.PathValue(podNameKey),
-		Namespace: r.PathValue(namespaceNameKey),
-		Log:       ap.log,
+		Req:         r,
+		W:           w,
+		Proto:       proto,
+		SessionType: sessionType,
+		TS:          ap.ts,
+		Who:         who,
+		Addrs:       addrs,
+		FailOpen:    failOpen,
+		Pod:         r.PathValue(podNameKey),
+		Namespace:   r.PathValue(namespaceNameKey),
+		Log:         ap.log,
 	}
 	h := ksr.New(opts)
 

+ 9 - 3
k8s-operator/sessionrecording/fakes/fakes.go

@@ -10,13 +10,13 @@ package fakes
 import (
 	"bytes"
 	"encoding/json"
+	"fmt"
+	"math/rand"
 	"net"
 	"sync"
 	"testing"
 	"time"
 
-	"math/rand"
-
 	"tailscale.com/sessionrecording"
 	"tailscale.com/tstime"
 )
@@ -107,7 +107,13 @@ func CastLine(t *testing.T, p []byte, clock tstime.Clock) []byte {
 	return append(j, '\n')
 }
 
-func AsciinemaResizeMsg(t *testing.T, width, height int) []byte {
+func AsciinemaCastResizeMsg(t *testing.T, width, height int) []byte {
+	msg := fmt.Sprintf(`[0,"r","%dx%d"]`, height, width)
+
+	return append([]byte(msg), '\n')
+}
+
+func AsciinemaCastHeaderMsg(t *testing.T, width, height int) []byte {
 	t.Helper()
 	ch := sessionrecording.CastHeader{
 		Width:  width,

+ 39 - 24
k8s-operator/sessionrecording/hijacker.go

@@ -4,7 +4,7 @@
 //go:build !plan9
 
 // Package sessionrecording contains functionality for recording Kubernetes API
-// server proxy 'kubectl exec' sessions.
+// server proxy 'kubectl exec/attach' sessions.
 package sessionrecording
 
 import (
@@ -35,14 +35,20 @@ import (
 )
 
 const (
-	SPDYProtocol Protocol = "SPDY"
-	WSProtocol   Protocol = "WebSocket"
+	SPDYProtocol      Protocol    = "SPDY"
+	WSProtocol        Protocol    = "WebSocket"
+	ExecSessionType   SessionType = "exec"
+	AttachSessionType SessionType = "attach"
 )
 
 // Protocol is the streaming protocol of the hijacked session. Supported
 // protocols are SPDY and WebSocket.
 type Protocol string
 
+// SessionType is the type of session initiated with `kubectl`
+// (`exec` or `attach`)
+type SessionType string
+
 var (
 	// CounterSessionRecordingsAttempted counts the number of session recording attempts.
 	CounterSessionRecordingsAttempted = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_attempted")
@@ -63,25 +69,27 @@ func New(opts HijackerOpts) *Hijacker {
 		failOpen:          opts.FailOpen,
 		proto:             opts.Proto,
 		log:               opts.Log,
+		sessionType:       opts.SessionType,
 		connectToRecorder: sessionrecording.ConnectToRecorder,
 	}
 }
 
 type HijackerOpts struct {
-	TS        *tsnet.Server
-	Req       *http.Request
-	W         http.ResponseWriter
-	Who       *apitype.WhoIsResponse
-	Addrs     []netip.AddrPort
-	Log       *zap.SugaredLogger
-	Pod       string
-	Namespace string
-	FailOpen  bool
-	Proto     Protocol
+	TS          *tsnet.Server
+	Req         *http.Request
+	W           http.ResponseWriter
+	Who         *apitype.WhoIsResponse
+	Addrs       []netip.AddrPort
+	Log         *zap.SugaredLogger
+	Pod         string
+	Namespace   string
+	FailOpen    bool
+	Proto       Protocol
+	SessionType SessionType
 }
 
 // Hijacker implements [net/http.Hijacker] interface.
-// It must be configured with an http request for a 'kubectl exec' session that
+// It must be configured with an http request for a 'kubectl exec/attach' session that
 // needs to be recorded. It knows how to hijack the connection and configure for
 // the session contents to be sent to a tsrecorder instance.
 type Hijacker struct {
@@ -90,12 +98,13 @@ type Hijacker struct {
 	req               *http.Request
 	who               *apitype.WhoIsResponse
 	log               *zap.SugaredLogger
-	pod               string           // pod being exec-d
-	ns                string           // namespace of the pod being exec-d
+	pod               string           // pod being exec/attach-d
+	ns                string           // namespace of the pod being exec/attach-d
 	addrs             []netip.AddrPort // tsrecorder addresses
 	failOpen          bool             // whether to fail open if recording fails
 	connectToRecorder RecorderDialFn
-	proto             Protocol // streaming protocol
+	proto             Protocol    // streaming protocol
+	sessionType       SessionType // subcommand, e.g., "exec, attach"
 }
 
 // RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder
@@ -105,7 +114,7 @@ type Hijacker struct {
 // after having been established, an error is sent down the channel.
 type RecorderDialFn func(context.Context, []netip.AddrPort, netx.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error)
 
-// Hijack hijacks a 'kubectl exec' session and configures for the session
+// Hijack hijacks a 'kubectl exec/attach' session and configures for the session
 // contents to be sent to a recorder.
 func (h *Hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
 	h.log.Infof("recorder addrs: %v, failOpen: %v", h.addrs, h.failOpen)
@@ -114,7 +123,7 @@ func (h *Hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
 		return nil, nil, fmt.Errorf("error hijacking connection: %w", err)
 	}
 
-	conn, err := h.setUpRecording(context.Background(), reqConn)
+	conn, err := h.setUpRecording(h.req.Context(), reqConn)
 	if err != nil {
 		return nil, nil, fmt.Errorf("error setting up session recording: %w", err)
 	}
@@ -138,7 +147,7 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
 		err     error
 		errChan <-chan error
 	)
-	h.log.Infof("kubectl exec session will be recorded, recorders: %v, fail open policy: %t", h.addrs, h.failOpen)
+	h.log.Infof("kubectl %s session will be recorded, recorders: %v, fail open policy: %t", h.sessionType, h.addrs, h.failOpen)
 	qp := h.req.URL.Query()
 	container := strings.Join(qp[containerKey], "")
 	var recorderAddr net.Addr
@@ -161,7 +170,7 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
 		}
 		return nil, errors.New(msg)
 	} else {
-		h.log.Infof("exec session to container %q in Pod %q namespace %q will be recorded, the recording will be sent to a tsrecorder instance at %q", container, h.pod, h.ns, recorderAddr)
+		h.log.Infof("%s session to container %q in Pod %q namespace %q will be recorded, the recording will be sent to a tsrecorder instance at %q", h.sessionType, container, h.pod, h.ns, recorderAddr)
 	}
 
 	cl := tstime.DefaultClock{}
@@ -190,9 +199,15 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
 	var lc net.Conn
 	switch h.proto {
 	case SPDYProtocol:
-		lc = spdy.New(conn, rec, ch, hasTerm, h.log)
+		lc, err = spdy.New(ctx, conn, rec, ch, hasTerm, h.log)
+		if err != nil {
+			return nil, fmt.Errorf("failed to initialize spdy connection: %w", err)
+		}
 	case WSProtocol:
-		lc = ws.New(conn, rec, ch, hasTerm, h.log)
+		lc, err = ws.New(ctx, conn, rec, ch, hasTerm, h.log)
+		if err != nil {
+			return nil, fmt.Errorf("failed to initialize websocket connection: %w", err)
+		}
 	default:
 		return nil, fmt.Errorf("unknown protocol: %s", h.proto)
 	}
@@ -209,7 +224,7 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
 			h.log.Info("finished uploading the recording")
 			return
 		}
-		msg := fmt.Sprintf("connection to the session recorder errorred: %v;", err)
+		msg := fmt.Sprintf("connection to the session recorder errored: %v;", err)
 		if h.failOpen {
 			msg += msg + "; failure mode is 'fail open'; continuing session without recording."
 			h.log.Info(msg)

+ 1 - 1
k8s-operator/sessionrecording/hijacker_test.go

@@ -91,7 +91,7 @@ func Test_Hijacker(t *testing.T) {
 				who:      &apitype.WhoIsResponse{Node: &tailcfg.Node{}, UserProfile: &tailcfg.UserProfile{}},
 				log:      zl.Sugar(),
 				ts:       &tsnet.Server{},
-				req:      &http.Request{URL: &url.URL{}},
+				req:      &http.Request{URL: &url.URL{RawQuery: "tty=true"}},
 				proto:    tt.proto,
 			}
 			ctx := context.Background()

+ 57 - 41
k8s-operator/sessionrecording/spdy/conn.go

@@ -4,11 +4,12 @@
 //go:build !plan9
 
 // Package spdy contains functionality for parsing SPDY streaming sessions. This
-// is used for 'kubectl exec' session recording.
+// is used for 'kubectl exec/attach' session recording.
 package spdy
 
 import (
 	"bytes"
+	"context"
 	"encoding/binary"
 	"encoding/json"
 	"fmt"
@@ -24,29 +25,50 @@ import (
 )
 
 // New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
-// The connection must be a hijacked connection for a 'kubectl exec' session using SPDY.
+// The connection must be a hijacked connection for a 'kubectl exec/attach' session using SPDY.
 // The hijacked connection is used to transmit SPDY streams between Kubernetes client ('kubectl') and the destination container.
 // Data read from the underlying network connection is data sent via one of the SPDY streams from the client to the container.
 // Data written to the underlying connection is data sent from the container to the client.
 // We parse the data and send everything for the stdout/stderr streams to the configured tsrecorder as an asciinema recording with the provided header.
 // https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#background-remotecommand-subprotocol
-func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) net.Conn {
-	return &conn{
-		Conn:               nc,
-		rec:                rec,
-		ch:                 ch,
-		log:                log,
-		hasTerm:            hasTerm,
-		initialTermSizeSet: make(chan struct{}),
+func New(ctx context.Context, nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) (net.Conn, error) {
+	lc := &conn{
+		Conn:                  nc,
+		ctx:                   ctx,
+		rec:                   rec,
+		ch:                    ch,
+		log:                   log,
+		hasTerm:               hasTerm,
+		initialCastHeaderSent: make(chan struct{}, 1),
 	}
+
+	// if there is no term, we don't need to wait for a resize message
+	if !hasTerm {
+		var err error
+		lc.writeCastHeaderOnce.Do(func() {
+			// If this is a session with a terminal attached,
+			// we must wait for the terminal width and
+			// height to be parsed from a resize message
+			// before sending CastHeader, else tsrecorder
+			// will not be able to play this recording.
+			err = lc.rec.WriteCastHeader(ch)
+			close(lc.initialCastHeaderSent)
+		})
+		if err != nil {
+			return nil, fmt.Errorf("error writing CastHeader: %w", err)
+		}
+	}
+
+	return lc, nil
 }
 
 // conn is a wrapper around net.Conn. It reads the bytestream for a 'kubectl
-// exec' session streamed using SPDY protocol, sends session recording data to
+// exec/attach' session streamed using SPDY protocol, sends session recording data to
 // the configured recorder and forwards the raw bytes to the original
 // destination.
 type conn struct {
 	net.Conn
+	ctx context.Context
 	// rec knows how to send data written to it to a tsrecorder instance.
 	rec *tsrecorder.Client
 
@@ -63,7 +85,7 @@ type conn struct {
 	// CastHeader must be sent before any payload. If the session has a
 	// terminal attached, the CastHeader must have '.Width' and '.Height'
 	// fields set for the tsrecorder UI to be able to play the recording.
-	// For 'kubectl exec' sessions, terminal width and height are sent as a
+	// For 'kubectl exec/attach' sessions, terminal width and height are sent as a
 	// resize message on resize stream from the client when the session
 	// starts as well as at any time the client detects a terminal change.
 	// We can intercept the resize message on Read calls. As there is no
@@ -79,15 +101,10 @@ type conn struct {
 	// writeCastHeaderOnce is used to ensure CastHeader gets sent to tsrecorder once.
 	writeCastHeaderOnce sync.Once
 	hasTerm             bool // whether the session had TTY attached
-	// initialTermSizeSet channel gets sent a value once, when the Read has
-	// received a resize message and set the initial terminal size. It must
-	// be set to a buffered channel to prevent Reads being blocked on the
-	// first stdout/stderr write reading from the channel.
-	initialTermSizeSet chan struct{}
-	// sendInitialTermSizeSetOnce is used to ensure that a value is sent to
-	// initialTermSizeSet channel only once, when the initial resize message
-	// is received.
-	sendinitialTermSizeSetOnce sync.Once
+	// initialCastHeaderSent is a channel to ensure that the cast
+	// header is the first thing that is streamed to the session recorder.
+	// Otherwise the stream will fail.
+	initialCastHeaderSent chan struct{}
 
 	zlibReqReader zlibReader
 	// writeBuf is used to store data written to the connection that has not
@@ -124,7 +141,7 @@ func (c *conn) Read(b []byte) (int, error) {
 	}
 	c.readBuf.Next(len(sf.Raw)) // advance buffer past the parsed frame
 
-	if !sf.Ctrl { // data frame
+	if !sf.Ctrl && c.hasTerm { // data frame
 		switch sf.StreamID {
 		case c.resizeStreamID.Load():
 
@@ -140,10 +157,19 @@ func (c *conn) Read(b []byte) (int, error) {
 			// subsequent resize message, we need to send asciinema
 			// resize message.
 			var isInitialResize bool
-			c.sendinitialTermSizeSetOnce.Do(func() {
+			c.writeCastHeaderOnce.Do(func() {
 				isInitialResize = true
-				close(c.initialTermSizeSet) // unblock sending of CastHeader
+				// If this is a session with a terminal attached,
+				// we must wait for the terminal width and
+				// height to be parsed from a resize message
+				// before sending CastHeader, else tsrecorder
+				// will not be able to play this recording.
+				err = c.rec.WriteCastHeader(c.ch)
+				close(c.initialCastHeaderSent)
 			})
+			if err != nil {
+				return 0, fmt.Errorf("error writing CastHeader: %w", err)
+			}
 			if !isInitialResize {
 				if err := c.rec.WriteResize(c.ch.Height, c.ch.Width); err != nil {
 					return 0, fmt.Errorf("error writing resize message: %w", err)
@@ -190,24 +216,14 @@ func (c *conn) Write(b []byte) (int, error) {
 	if !sf.Ctrl {
 		switch sf.StreamID {
 		case c.stdoutStreamID.Load(), c.stderrStreamID.Load():
-			var err error
-			c.writeCastHeaderOnce.Do(func() {
-				// If this is a session with a terminal attached,
-				// we must wait for the terminal width and
-				// height to be parsed from a resize message
-				// before sending CastHeader, else tsrecorder
-				// will not be able to play this recording.
-				if c.hasTerm {
-					c.log.Debugf("write: waiting for the initial terminal size to be set before proceeding with sending the first payload")
-					<-c.initialTermSizeSet
+			// we must wait for confirmation that the initial cast header was sent before proceeding with any more writes
+			select {
+			case <-c.ctx.Done():
+				return 0, c.ctx.Err()
+			case <-c.initialCastHeaderSent:
+				if err := c.rec.WriteOutput(sf.Payload); err != nil {
+					return 0, fmt.Errorf("error sending payload to session recorder: %w", err)
 				}
-				err = c.rec.WriteCastHeader(c.ch)
-			})
-			if err != nil {
-				return 0, fmt.Errorf("error writing CastHeader: %w", err)
-			}
-			if err := c.rec.WriteOutput(sf.Payload); err != nil {
-				return 0, fmt.Errorf("error sending payload to session recorder: %w", err)
 			}
 		}
 	}

+ 55 - 43
k8s-operator/sessionrecording/spdy/conn_test.go

@@ -6,10 +6,12 @@
 package spdy
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"reflect"
 	"testing"
+	"time"
 
 	"go.uber.org/zap"
 	"tailscale.com/k8s-operator/sessionrecording/fakes"
@@ -29,15 +31,11 @@ func Test_Writes(t *testing.T) {
 	}
 	cl := tstest.NewClock(tstest.ClockOpts{})
 	tests := []struct {
-		name              string
-		inputs            [][]byte
-		wantForwarded     []byte
-		wantRecorded      []byte
-		firstWrite        bool
-		width             int
-		height            int
-		sendInitialResize bool
-		hasTerm           bool
+		name          string
+		inputs        [][]byte
+		wantForwarded []byte
+		wantRecorded  []byte
+		hasTerm       bool
 	}{
 		{
 			name:          "single_write_control_frame_with_payload",
@@ -78,24 +76,17 @@ func Test_Writes(t *testing.T) {
 			wantRecorded:  fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
 		},
 		{
-			name:              "single_first_write_stdout_data_frame_with_payload_sess_has_terminal",
-			inputs:            [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}},
-			wantForwarded:     []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5},
-			wantRecorded:      append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...),
-			width:             10,
-			height:            20,
-			hasTerm:           true,
-			firstWrite:        true,
-			sendInitialResize: true,
+			name:          "single_first_write_stdout_data_frame_with_payload_sess_has_terminal",
+			inputs:        [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}},
+			wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5},
+			wantRecorded:  fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
+			hasTerm:       true,
 		},
 		{
 			name:          "single_first_write_stdout_data_frame_with_payload_sess_does_not_have_terminal",
 			inputs:        [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}},
 			wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5},
-			wantRecorded:  append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...),
-			width:         10,
-			height:        20,
-			firstWrite:    true,
+			wantRecorded:  fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl),
 		},
 	}
 	for _, tt := range tests {
@@ -104,29 +95,25 @@ func Test_Writes(t *testing.T) {
 			sr := &fakes.TestSessionRecorder{}
 			rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar())
 
+			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+			defer cancel()
 			c := &conn{
-				Conn: tc,
-				log:  zl.Sugar(),
-				rec:  rec,
-				ch: sessionrecording.CastHeader{
-					Width:  tt.width,
-					Height: tt.height,
-				},
-				initialTermSizeSet: make(chan struct{}),
-				hasTerm:            tt.hasTerm,
-			}
-			if !tt.firstWrite {
-				// this test case does not intend to test that cast header gets written once
-				c.writeCastHeaderOnce.Do(func() {})
-			}
-			if tt.sendInitialResize {
-				close(c.initialTermSizeSet)
+				ctx:                   ctx,
+				Conn:                  tc,
+				log:                   zl.Sugar(),
+				rec:                   rec,
+				ch:                    sessionrecording.CastHeader{},
+				initialCastHeaderSent: make(chan struct{}),
+				hasTerm:               tt.hasTerm,
 			}
 
+			c.writeCastHeaderOnce.Do(func() {
+				close(c.initialCastHeaderSent)
+			})
+
 			c.stdoutStreamID.Store(stdoutStreamID)
 			c.stderrStreamID.Store(stderrStreamID)
 			for i, input := range tt.inputs {
-				c.hasTerm = tt.hasTerm
 				if _, err := c.Write(input); err != nil {
 					t.Errorf("[%d] spdyRemoteConnRecorder.Write() unexpected error %v", i, err)
 				}
@@ -171,11 +158,25 @@ func Test_Reads(t *testing.T) {
 		wantResizeStreamID       uint32
 		wantWidth                int
 		wantHeight               int
+		wantRecorded             []byte
 		resizeStreamIDBeforeRead uint32
 	}{
 		{
 			name:                     "resize_data_frame_single_read",
 			inputs:                   [][]byte{append([]byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg...)},
+			wantRecorded:             fakes.AsciinemaCastHeaderMsg(t, 10, 20),
+			resizeStreamIDBeforeRead: 1,
+			wantWidth:                10,
+			wantHeight:               20,
+		},
+		{
+			name: "resize_data_frame_many",
+			inputs: [][]byte{
+				append([]byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg...),
+				append([]byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg...),
+			},
+			wantRecorded: append(fakes.AsciinemaCastHeaderMsg(t, 10, 20), fakes.AsciinemaCastResizeMsg(t, 10, 20)...),
+
 			resizeStreamIDBeforeRead: 1,
 			wantWidth:                10,
 			wantHeight:               20,
@@ -183,6 +184,7 @@ func Test_Reads(t *testing.T) {
 		{
 			name:                     "resize_data_frame_two_reads",
 			inputs:                   [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg},
+			wantRecorded:             fakes.AsciinemaCastHeaderMsg(t, 10, 20),
 			resizeStreamIDBeforeRead: 1,
 			wantWidth:                10,
 			wantHeight:               20,
@@ -215,11 +217,15 @@ func Test_Reads(t *testing.T) {
 			tc := &fakes.TestConn{}
 			sr := &fakes.TestSessionRecorder{}
 			rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar())
+			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+			defer cancel()
 			c := &conn{
-				Conn:               tc,
-				log:                zl.Sugar(),
-				rec:                rec,
-				initialTermSizeSet: make(chan struct{}),
+				ctx:                   ctx,
+				Conn:                  tc,
+				log:                   zl.Sugar(),
+				rec:                   rec,
+				initialCastHeaderSent: make(chan struct{}),
+				hasTerm:               true,
 			}
 			c.resizeStreamID.Store(tt.resizeStreamIDBeforeRead)
 
@@ -251,6 +257,12 @@ func Test_Reads(t *testing.T) {
 					t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height)
 				}
 			}
+
+			// Assert that the expected bytes have been forwarded to the session recorder.
+			gotRecorded := sr.Bytes()
+			if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) {
+				t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", tt.wantRecorded, gotRecorded)
+			}
 		})
 	}
 }

+ 1 - 0
k8s-operator/sessionrecording/tsrecorder/tsrecorder.go

@@ -25,6 +25,7 @@ func New(conn io.WriteCloser, clock tstime.Clock, start time.Time, failOpen bool
 		clock:    clock,
 		conn:     conn,
 		failOpen: failOpen,
+		logger:   logger,
 	}
 }
 

+ 68 - 47
k8s-operator/sessionrecording/ws/conn.go

@@ -3,12 +3,13 @@
 
 //go:build !plan9
 
-// package ws has functionality to parse 'kubectl exec' sessions streamed using
+// package ws has functionality to parse 'kubectl exec/attach' sessions streamed using
 // WebSocket protocol.
 package ws
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -24,31 +25,53 @@ import (
 )
 
 // New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
-// The connection must be a hijacked connection for a 'kubectl exec' session using WebSocket protocol and a *.channel.k8s.io subprotocol.
+// The connection must be a hijacked connection for a 'kubectl exec/attach' session using WebSocket protocol and a *.channel.k8s.io subprotocol.
 // The hijacked connection is used to transmit *.channel.k8s.io streams between Kubernetes client ('kubectl') and the destination proxy controlled by Kubernetes.
 // Data read from the underlying network connection is data sent via one of the streams from the client to the container.
 // Data written to the underlying connection is data sent from the container to the client.
 // We parse the data and send everything for the stdout/stderr streams to the configured tsrecorder as an asciinema recording with the provided header.
 // https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#proposal-new-remotecommand-sub-protocol-version---v5channelk8sio
-func New(c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) net.Conn {
-	return &conn{
-		Conn:               c,
-		rec:                rec,
-		ch:                 ch,
-		hasTerm:            hasTerm,
-		log:                log,
-		initialTermSizeSet: make(chan struct{}, 1),
+func New(ctx context.Context, c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) (net.Conn, error) {
+	lc := &conn{
+		Conn:                  c,
+		ctx:                   ctx,
+		rec:                   rec,
+		ch:                    ch,
+		hasTerm:               hasTerm,
+		log:                   log,
+		initialCastHeaderSent: make(chan struct{}, 1),
 	}
+
+	// if there is no term, we don't need to wait for a resize message
+	if !hasTerm {
+		var err error
+		lc.writeCastHeaderOnce.Do(func() {
+			// If this is a session with a terminal attached,
+			// we must wait for the terminal width and
+			// height to be parsed from a resize message
+			// before sending CastHeader, else tsrecorder
+			// will not be able to play this recording.
+			err = lc.rec.WriteCastHeader(ch)
+			close(lc.initialCastHeaderSent)
+		})
+		if err != nil {
+			return nil, fmt.Errorf("error writing CastHeader: %w", err)
+		}
+	}
+
+	return lc, nil
 }
 
 // conn is a wrapper around net.Conn. It reads the bytestream
-// for a 'kubectl exec' session, sends session recording data to the configured
+// for a 'kubectl exec/attach' session, sends session recording data to the configured
 // recorder and forwards the raw bytes to the original destination.
 // A new conn is created per session.
-// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol.
+// conn only knows to how to read a 'kubectl exec/attach' session that is streamed using WebSocket protocol.
 // https://www.rfc-editor.org/rfc/rfc6455
 type conn struct {
 	net.Conn
+
+	ctx context.Context
 	// rec knows how to send data to a tsrecorder instance.
 	rec *tsrecorder.Client
 
@@ -56,7 +79,7 @@ type conn struct {
 	// CastHeader must be sent before any payload. If the session has a
 	// terminal attached, the CastHeader must have '.Width' and '.Height'
 	// fields set for the tsrecorder UI to be able to play the recording.
-	// For 'kubectl exec' sessions, terminal width and height are sent as a
+	// For 'kubectl exec/attach' sessions, terminal width and height are sent as a
 	// resize message on resize stream from the client when the session
 	// starts as well as at any time the client detects a terminal change.
 	// We can intercept the resize message on Read calls. As there is no
@@ -72,15 +95,10 @@ type conn struct {
 	// writeCastHeaderOnce is used to ensure CastHeader gets sent to tsrecorder once.
 	writeCastHeaderOnce sync.Once
 	hasTerm             bool // whether the session has TTY attached
-	// initialTermSizeSet channel gets sent a value once, when the Read has
-	// received a resize message and set the initial terminal size. It must
-	// be set to a buffered channel to prevent Reads being blocked on the
-	// first stdout/stderr write reading from the channel.
-	initialTermSizeSet chan struct{}
-	// sendInitialTermSizeSetOnce is used to ensure that a value is sent to
-	// initialTermSizeSet channel only once, when the initial resize message
-	// is received.
-	sendInitialTermSizeSetOnce sync.Once
+	// initialCastHeaderSent is a boolean that is set to ensure that the cast
+	// header is the first thing that is streamed to the session recorder.
+	// Otherwise the stream will fail.
+	initialCastHeaderSent chan struct{}
 
 	log *zap.SugaredLogger
 
@@ -171,9 +189,10 @@ func (c *conn) Read(b []byte) (int, error) {
 	c.readBuf.Next(len(readMsg.raw))
 
 	if readMsg.isFinalized && !c.readMsgIsIncomplete() {
+		// we want to send stream resize messages for terminal sessions
 		// Stream IDs for websocket streams are static.
 		// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
-		if readMsg.streamID.Load() == remotecommand.StreamResize {
+		if readMsg.streamID.Load() == remotecommand.StreamResize && c.hasTerm {
 			var msg tsrecorder.ResizeMsg
 			if err = json.Unmarshal(readMsg.payload, &msg); err != nil {
 				return 0, fmt.Errorf("error umarshalling resize message: %w", err)
@@ -182,22 +201,29 @@ func (c *conn) Read(b []byte) (int, error) {
 			c.ch.Width = msg.Width
 			c.ch.Height = msg.Height
 
-			// If this is initial resize message, the width and
-			// height will be sent in the CastHeader. If this is a
-			// subsequent resize message, we need to send asciinema
-			// resize message.
 			var isInitialResize bool
-			c.sendInitialTermSizeSetOnce.Do(func() {
+			c.writeCastHeaderOnce.Do(func() {
 				isInitialResize = true
-				close(c.initialTermSizeSet) // unblock sending of CastHeader
+				// If this is a session with a terminal attached,
+				// we must wait for the terminal width and
+				// height to be parsed from a resize message
+				// before sending CastHeader, else tsrecorder
+				// will not be able to play this recording.
+				err = c.rec.WriteCastHeader(c.ch)
+				close(c.initialCastHeaderSent)
 			})
+			if err != nil {
+				return 0, fmt.Errorf("error writing CastHeader: %w", err)
+			}
+
 			if !isInitialResize {
-				if err := c.rec.WriteResize(c.ch.Height, c.ch.Width); err != nil {
+				if err := c.rec.WriteResize(msg.Height, msg.Width); err != nil {
 					return 0, fmt.Errorf("error writing resize message: %w", err)
 				}
 			}
 		}
 	}
+
 	c.currentReadMsg = readMsg
 	return n, nil
 }
@@ -244,39 +270,33 @@ func (c *conn) Write(b []byte) (int, error) {
 		c.log.Errorf("write: parsing a message errored: %v", err)
 		return 0, fmt.Errorf("write: error parsing message: %v", err)
 	}
+
 	c.currentWriteMsg = writeMsg
 	if !ok { // incomplete fragment
 		return len(b), nil
 	}
+
 	c.writeBuf.Next(len(writeMsg.raw)) // advance frame
 
 	if len(writeMsg.payload) != 0 && writeMsg.isFinalized {
 		if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr {
-			var err error
-			c.writeCastHeaderOnce.Do(func() {
-				// If this is a session with a terminal attached,
-				// we must wait for the terminal width and
-				// height to be parsed from a resize message
-				// before sending CastHeader, else tsrecorder
-				// will not be able to play this recording.
-				if c.hasTerm {
-					c.log.Debug("waiting for terminal size to be set before starting to send recorded data")
-					<-c.initialTermSizeSet
+			// we must wait for confirmation that the initial cast header was sent before proceeding with any more writes
+			select {
+			case <-c.ctx.Done():
+				return 0, c.ctx.Err()
+			case <-c.initialCastHeaderSent:
+				if err := c.rec.WriteOutput(writeMsg.payload); err != nil {
+					return 0, fmt.Errorf("error writing message to recorder: %w", err)
 				}
-				err = c.rec.WriteCastHeader(c.ch)
-			})
-			if err != nil {
-				return 0, fmt.Errorf("error writing CastHeader: %w", err)
-			}
-			if err := c.rec.WriteOutput(writeMsg.payload); err != nil {
-				return 0, fmt.Errorf("error writing message to recorder: %v", err)
 			}
 		}
 	}
+
 	_, err = c.Conn.Write(c.currentWriteMsg.raw)
 	if err != nil {
 		c.log.Errorf("write: error writing to conn: %v", err)
 	}
+
 	return len(b), nil
 }
 
@@ -321,6 +341,7 @@ func (c *conn) writeMsgIsIncomplete() bool {
 func (c *conn) readMsgIsIncomplete() bool {
 	return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized
 }
+
 func (c *conn) curReadMsgType() (messageType, error) {
 	if c.currentReadMsg != nil {
 		return c.currentReadMsg.typ, nil

+ 81 - 63
k8s-operator/sessionrecording/ws/conn_test.go

@@ -6,9 +6,11 @@
 package ws
 
 import (
+	"context"
 	"fmt"
 	"reflect"
 	"testing"
+	"time"
 
 	"go.uber.org/zap"
 	"k8s.io/apimachinery/pkg/util/remotecommand"
@@ -26,46 +28,69 @@ func Test_conn_Read(t *testing.T) {
 	// Resize stream ID + {"width": 10, "height": 20}
 	testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}
 	lenResizeMsgPayload := byte(len(testResizeMsg))
-
+	cl := tstest.NewClock(tstest.ClockOpts{})
 	tests := []struct {
-		name       string
-		inputs     [][]byte
-		wantWidth  int
-		wantHeight int
+		name                 string
+		inputs               [][]byte
+		wantCastHeaderWidth  int
+		wantCastHeaderHeight int
+		wantRecorded         []byte
 	}{
 		{
 			name:   "single_read_control_message",
 			inputs: [][]byte{{0x88, 0x0}},
 		},
 		{
-			name:       "single_read_resize_message",
-			inputs:     [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)},
-			wantWidth:  10,
-			wantHeight: 20,
+			name:                 "single_read_resize_message",
+			inputs:               [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)},
+			wantCastHeaderWidth:  10,
+			wantCastHeaderHeight: 20,
+			wantRecorded:         fakes.AsciinemaCastHeaderMsg(t, 10, 20),
 		},
 		{
-			name:       "two_reads_resize_message",
-			inputs:     [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}},
-			wantWidth:  10,
-			wantHeight: 20,
+			name: "resize_data_frame_many",
+			inputs: [][]byte{
+				append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...),
+				append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...),
+			},
+			wantRecorded:         append(fakes.AsciinemaCastHeaderMsg(t, 10, 20), fakes.AsciinemaCastResizeMsg(t, 10, 20)...),
+			wantCastHeaderWidth:  10,
+			wantCastHeaderHeight: 20,
 		},
 		{
-			name:       "three_reads_resize_message_with_split_fragment",
-			inputs:     [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}},
-			wantWidth:  10,
-			wantHeight: 20,
+			name:                 "two_reads_resize_message",
+			inputs:               [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}},
+			wantCastHeaderWidth:  10,
+			wantCastHeaderHeight: 20,
+			wantRecorded:         fakes.AsciinemaCastHeaderMsg(t, 10, 20),
+		},
+		{
+			name:                 "three_reads_resize_message_with_split_fragment",
+			inputs:               [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}},
+			wantCastHeaderWidth:  10,
+			wantCastHeaderHeight: 20,
+			wantRecorded:         fakes.AsciinemaCastHeaderMsg(t, 10, 20),
 		},
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
+			l := zl.Sugar()
 			tc := &fakes.TestConn{}
+			sr := &fakes.TestSessionRecorder{}
+			rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar())
 			tc.ResetReadBuf()
+
+			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+			defer cancel()
 			c := &conn{
-				Conn: tc,
-				log:  zl.Sugar(),
+				ctx:                   ctx,
+				Conn:                  tc,
+				log:                   l,
+				hasTerm:               true,
+				initialCastHeaderSent: make(chan struct{}),
+				rec:                   rec,
 			}
 			for i, input := range tt.inputs {
-				c.initialTermSizeSet = make(chan struct{})
 				if err := tc.WriteReadBufBytes(input); err != nil {
 					t.Fatalf("writing bytes to test conn: %v", err)
 				}
@@ -75,14 +100,20 @@ func Test_conn_Read(t *testing.T) {
 					return
 				}
 			}
-			if tt.wantHeight != 0 || tt.wantWidth != 0 {
-				if tt.wantWidth != c.ch.Width {
-					t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width)
+
+			if tt.wantCastHeaderHeight != 0 || tt.wantCastHeaderWidth != 0 {
+				if tt.wantCastHeaderWidth != c.ch.Width {
+					t.Errorf("wants width: %v, got %v", tt.wantCastHeaderWidth, c.ch.Width)
 				}
-				if tt.wantHeight != c.ch.Height {
-					t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height)
+				if tt.wantCastHeaderHeight != c.ch.Height {
+					t.Errorf("want height: %v, got %v", tt.wantCastHeaderHeight, c.ch.Height)
 				}
 			}
+
+			gotRecorded := sr.Bytes()
+			if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) {
+				t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", string(tt.wantRecorded), string(gotRecorded))
+			}
 		})
 	}
 }
@@ -94,15 +125,11 @@ func Test_conn_Write(t *testing.T) {
 	}
 	cl := tstest.NewClock(tstest.ClockOpts{})
 	tests := []struct {
-		name              string
-		inputs            [][]byte
-		wantForwarded     []byte
-		wantRecorded      []byte
-		firstWrite        bool
-		width             int
-		height            int
-		hasTerm           bool
-		sendInitialResize bool
+		name          string
+		inputs        [][]byte
+		wantForwarded []byte
+		wantRecorded  []byte
+		hasTerm       bool
 	}{
 		{
 			name:          "single_write_control_frame",
@@ -130,10 +157,7 @@ func Test_conn_Write(t *testing.T) {
 			name:          "single_write_stdout_data_message_with_cast_header",
 			inputs:        [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
 			wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
-			wantRecorded:  append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...),
-			width:         10,
-			height:        20,
-			firstWrite:    true,
+			wantRecorded:  fakes.CastLine(t, []byte{0x7, 0x8}, cl),
 		},
 		{
 			name:          "two_writes_stdout_data_message",
@@ -148,15 +172,11 @@ func Test_conn_Write(t *testing.T) {
 			wantRecorded:  fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
 		},
 		{
-			name:              "three_writes_stdout_data_message_with_split_fragment_cast_header_with_terminal",
-			inputs:            [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}},
-			wantForwarded:     []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
-			wantRecorded:      append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl)...),
-			height:            20,
-			width:             10,
-			hasTerm:           true,
-			firstWrite:        true,
-			sendInitialResize: true,
+			name:          "three_writes_stdout_data_message_with_split_fragment_cast_header_with_terminal",
+			inputs:        [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}},
+			wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
+			wantRecorded:  fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
+			hasTerm:       true,
 		},
 	}
 	for _, tt := range tests {
@@ -164,24 +184,22 @@ func Test_conn_Write(t *testing.T) {
 			tc := &fakes.TestConn{}
 			sr := &fakes.TestSessionRecorder{}
 			rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar())
+			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+			defer cancel()
 			c := &conn{
-				Conn: tc,
-				log:  zl.Sugar(),
-				ch: sessionrecording.CastHeader{
-					Width:  tt.width,
-					Height: tt.height,
-				},
-				rec:                rec,
-				initialTermSizeSet: make(chan struct{}),
-				hasTerm:            tt.hasTerm,
-			}
-			if !tt.firstWrite {
-				// This test case does not intend to test that cast header gets written once.
-				c.writeCastHeaderOnce.Do(func() {})
-			}
-			if tt.sendInitialResize {
-				close(c.initialTermSizeSet)
+				Conn:                  tc,
+				ctx:                   ctx,
+				log:                   zl.Sugar(),
+				ch:                    sessionrecording.CastHeader{},
+				rec:                   rec,
+				initialCastHeaderSent: make(chan struct{}),
+				hasTerm:               tt.hasTerm,
 			}
+
+			c.writeCastHeaderOnce.Do(func() {
+				close(c.initialCastHeaderSent)
+			})
+
 			for i, input := range tt.inputs {
 				_, err := c.Write(input)
 				if err != nil {

+ 6 - 4
sessionrecording/header.go

@@ -66,13 +66,15 @@ type CastHeader struct {
 	Kubernetes *Kubernetes `json:"kubernetes,omitempty"`
 }
 
-// Kubernetes contains 'kubectl exec' session specific information for
+// Kubernetes contains 'kubectl exec/attach' session specific information for
 // tsrecorder.
 type Kubernetes struct {
-	// PodName is the name of the Pod being exec-ed.
+	// PodName is the name of the Pod the session was recorded for.
 	PodName string
-	// Namespace is the namespace in which is the Pod that is being exec-ed.
+	// Namespace is the namespace in which the Pod the session was recorded for exists in.
 	Namespace string
-	// Container is the container being exec-ed.
+	// Container is the container the session was recorded for.
 	Container string
+	// SessionType is the type of session that was executed (e.g., exec, attach)
+	SessionType string
 }