Преглед изворни кода

ssh/tailssh: also handle recording upload failure during writes

Previously we would error out when the recording server disappeared after the in memory
buffer filled up for the io.Copy. This makes it so that we handle failing open correctly
in that path.

Updates tailscale/corp#9967

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali пре 2 година
родитељ
комит
1b8a0dfe5e
1 измењених фајлова са 30 додато и 15 уклоњено
  1. 30 15
      ssh/tailssh/tailssh.go

+ 30 - 15
ssh/tailssh/tailssh.go

@@ -1519,8 +1519,9 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
 
 	now := time.Now()
 	rec := &recording{
-		ss:    ss,
-		start: now,
+		ss:       ss,
+		start:    now,
+		failOpen: onFailure == nil || onFailure.TerminateSessionWithMessage == "",
 	}
 
 	// We want to use a background context for uploading and not ss.ctx.
@@ -1611,6 +1612,10 @@ type recording struct {
 	ss    *sshSession
 	start time.Time
 
+	// failOpen specifies whether the session should be allowed to
+	// continue if writing to the recording fails.
+	failOpen bool
+
 	mu  sync.Mutex // guards writes to, close of out
 	out io.WriteCloser
 }
@@ -1642,7 +1647,7 @@ func (r *recording) writer(dir string, w io.Writer) io.Writer {
 		// passwords.
 		return w
 	}
-	return &loggingWriter{r, dir, w}
+	return &loggingWriter{r: r, dir: dir, w: w}
 }
 
 // loggingWriter is an io.Writer wrapper that writes first an
@@ -1651,20 +1656,30 @@ type loggingWriter struct {
 	r   *recording
 	dir string    // "i" or "o" (input or output)
 	w   io.Writer // underlying Writer, after writing to r.out
+
+	// recordingFailedOpen specifies whether we've failed to write to
+	// r.out and should stop trying. It is set to true if we fail to write
+	// to r.out and r.failOpen is set.
+	recordingFailedOpen bool
 }
 
-func (w loggingWriter) Write(p []byte) (n int, err error) {
-	j, err := json.Marshal([]any{
-		time.Since(w.r.start).Seconds(),
-		w.dir,
-		string(p),
-	})
-	if err != nil {
-		return 0, err
-	}
-	j = append(j, '\n')
-	if err := w.writeCastLine(j); err != nil {
-		return 0, err
+func (w *loggingWriter) Write(p []byte) (n int, err error) {
+	if !w.recordingFailedOpen {
+		j, err := json.Marshal([]any{
+			time.Since(w.r.start).Seconds(),
+			w.dir,
+			string(p),
+		})
+		if err != nil {
+			return 0, err
+		}
+		j = append(j, '\n')
+		if err := w.writeCastLine(j); err != nil {
+			if !w.r.failOpen {
+				return 0, err
+			}
+			w.recordingFailedOpen = true
+		}
 	}
 	return w.w.Write(p)
 }