Kaynağa Gözat

ssh/tailssh: handle Control-C during hold-and-delegate prompt

Fixes #4549

Change-Id: Iafc61af5e08cd03564d39cf667e940b2417714cc
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 3 yıl önce
ebeveyn
işleme
c1445155ef
2 değiştirilmiş dosya ile 152 ekleme ve 3 silme
  1. 112 0
      ssh/tailssh/ctxreader.go
  2. 40 3
      ssh/tailssh/tailssh.go

+ 112 - 0
ssh/tailssh/ctxreader.go

@@ -0,0 +1,112 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tailssh
+
+import (
+	"context"
+	"io"
+	"sync"
+
+	"tailscale.com/tempfork/gliderlabs/ssh"
+)
+
+// readResult is a result from a io.Reader.Read call,
+// as used by contextReader.
+type readResult struct {
+	buf []byte // ownership passed on chan send
+	err error
+}
+
+// contextReader wraps an io.Reader, providing a ReadContext method
+// that can be aborted before yielding bytes. If it's aborted, subsequent
+// reads can get those byte(s) later.
+type contextReader struct {
+	r io.Reader
+
+	// buffered is leftover data from a previous read call that wasn't entirely
+	// consumed.
+	buffered []byte
+	// readErr is a previous read error that was seen while filling buffered. It
+	// should be returned to the caller after bufffered is consumed.
+	readErr error
+
+	mu sync.Mutex // guards ch only
+
+	// ch is non-nil if a goroutine had been started and has a result to be
+	// read. The goroutine may be either still running or done and has
+	// send to the channel.
+	ch chan readResult
+}
+
+// HasOutstandingRead reports whether there's an oustanding Read call that's
+// either currently blocked in a Read or whose result hasn't been consumed.
+func (w *contextReader) HasOutstandingRead() bool {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	return w.ch != nil
+}
+
+func (w *contextReader) setChan(c chan readResult) {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	w.ch = c
+}
+
+// ReadContext is like Read, but takes a context permitting the read to be canceled.
+//
+// If the context becomes done, the underlying Read call continues and its result
+// will be given to the next caller to ReadContext.
+func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) {
+	if len(p) == 0 {
+		return 0, nil
+	}
+
+	n = copy(p, w.buffered)
+	if n > 0 {
+		w.buffered = w.buffered[n:]
+		if len(w.buffered) == 0 {
+			err = w.readErr
+		}
+		return n, err
+	}
+
+	if w.ch == nil {
+		ch := make(chan readResult, 1)
+		w.setChan(ch)
+		go func() {
+			rbuf := make([]byte, len(p))
+			n, err := w.r.Read(rbuf)
+			ch <- readResult{rbuf[:n], err}
+		}()
+	}
+
+	select {
+	case <-ctx.Done():
+		return 0, ctx.Err()
+	case rr := <-w.ch:
+		w.setChan(nil)
+		n = copy(p, rr.buf)
+		w.buffered = rr.buf[n:]
+		w.readErr = rr.err
+		if len(w.buffered) == 0 {
+			err = rr.err
+		}
+		return n, err
+	}
+}
+
+// contextReaderSesssion implements ssh.Session, wrapping another
+// ssh.Session but changing its Read method to use contextReader.
+type contextReaderSesssion struct {
+	ssh.Session
+	cr *contextReader
+}
+
+func (a contextReaderSesssion) Read(p []byte) (n int, err error) {
+	if a.cr.HasOutstandingRead() {
+		return a.cr.ReadContext(context.Background(), p)
+	}
+	return a.Session.Read(p)
+}

+ 40 - 3
ssh/tailssh/tailssh.go

@@ -37,6 +37,7 @@ import (
 	"tailscale.com/ipn/ipnlocal"
 	"tailscale.com/logtail/backoff"
 	"tailscale.com/net/tsaddr"
+	"tailscale.com/syncs"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tempfork/gliderlabs/ssh"
 	"tailscale.com/types/logger"
@@ -488,7 +489,8 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
 // completed. It also handles SFTP requests.
 func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
 	sshUser := s.User()
-	action, err := c.resolveTerminalAction(s)
+	cr := &contextReader{r: s}
+	action, err := c.resolveTerminalAction(s, cr)
 	if err != nil {
 		c.logf("resolveTerminalAction: %v", err)
 		io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
@@ -501,6 +503,10 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
 		return
 	}
 
+	if cr.HasOutstandingRead() {
+		s = contextReaderSesssion{s, cr}
+	}
+
 	// Do this check after auth, but before starting the session.
 	switch s.Subsystem() {
 	case "sftp", "":
@@ -522,8 +528,17 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
 // Any action with a Message in the chain will be printed to s.
 //
 // The returned SSHAction will be either Reject or Accept.
-func (c *conn) resolveTerminalAction(s ssh.Session) (*tailcfg.SSHAction, error) {
+func (c *conn) resolveTerminalAction(s ssh.Session, cr *contextReader) (*tailcfg.SSHAction, error) {
 	action := c.action0
+
+	var awaitReadOnce sync.Once // to start Reads on cr
+	var sawInterrupt syncs.AtomicBool
+	var wg sync.WaitGroup
+	defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit
+
+	ctx, cancel := context.WithCancel(s.Context())
+	defer cancel()
+
 	// Loop processing/fetching Actions until one reaches a
 	// terminal state (Accept, Reject, or invalid Action), or
 	// until fetchSSHAction times out due to the context being
@@ -541,10 +556,32 @@ func (c *conn) resolveTerminalAction(s ssh.Session) (*tailcfg.SSHAction, error)
 		if url == "" {
 			return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate")
 		}
+		awaitReadOnce.Do(func() {
+			wg.Add(1)
+			go func() {
+				defer wg.Done()
+				buf := make([]byte, 1)
+				for {
+					n, err := cr.ReadContext(ctx, buf)
+					if err != nil {
+						return
+					}
+					if n > 0 && buf[0] == 0x03 { // Ctrl-C
+						sawInterrupt.Set(true)
+						s.Stderr().Write([]byte("Canceled.\r\n"))
+						s.Exit(1)
+						return
+					}
+				}
+			}()
+		})
 		url = c.expandDelegateURL(url)
 		var err error
-		action, err = c.fetchSSHAction(s.Context(), url)
+		action, err = c.fetchSSHAction(ctx, url)
 		if err != nil {
+			if sawInterrupt.Get() {
+				return nil, fmt.Errorf("aborted by user")
+			}
 			return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err)
 		}
 	}