Browse Source

ssh/tailssh: do the full auth flow during ssh auth

Fixes #5091

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 years ago
parent
commit
f16b77de5d

+ 0 - 112
ssh/tailssh/ctxreader.go

@@ -1,112 +0,0 @@
-// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package tailssh
-
-import (
-	"context"
-	"io"
-	"sync"
-
-	"tailscale.com/tempfork/gliderlabs/ssh"
-)
-
-// readResult is a result from a io.Reader.Read call,
-// as used by contextReader.
-type readResult struct {
-	buf []byte // ownership passed on chan send
-	err error
-}
-
-// contextReader wraps an io.Reader, providing a ReadContext method
-// that can be aborted before yielding bytes. If it's aborted, subsequent
-// reads can get those byte(s) later.
-type contextReader struct {
-	r io.Reader
-
-	// buffered is leftover data from a previous read call that wasn't entirely
-	// consumed.
-	buffered []byte
-	// readErr is a previous read error that was seen while filling buffered. It
-	// should be returned to the caller after buffered is consumed.
-	readErr error
-
-	mu sync.Mutex // guards ch only
-
-	// ch is non-nil if a goroutine had been started and has a result to be
-	// read. The goroutine may be either still running or done and has
-	// send to the channel.
-	ch chan readResult
-}
-
-// HasOutstandingRead reports whether there's an outstanding Read call that's
-// either currently blocked in a Read or whose result hasn't been consumed.
-func (w *contextReader) HasOutstandingRead() bool {
-	w.mu.Lock()
-	defer w.mu.Unlock()
-	return w.ch != nil
-}
-
-func (w *contextReader) setChan(c chan readResult) {
-	w.mu.Lock()
-	defer w.mu.Unlock()
-	w.ch = c
-}
-
-// ReadContext is like Read, but takes a context permitting the read to be canceled.
-//
-// If the context becomes done, the underlying Read call continues and its result
-// will be given to the next caller to ReadContext.
-func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) {
-	if len(p) == 0 {
-		return 0, nil
-	}
-
-	n = copy(p, w.buffered)
-	if n > 0 {
-		w.buffered = w.buffered[n:]
-		if len(w.buffered) == 0 {
-			err = w.readErr
-		}
-		return n, err
-	}
-
-	if w.ch == nil {
-		ch := make(chan readResult, 1)
-		w.setChan(ch)
-		go func() {
-			rbuf := make([]byte, len(p))
-			n, err := w.r.Read(rbuf)
-			ch <- readResult{rbuf[:n], err}
-		}()
-	}
-
-	select {
-	case <-ctx.Done():
-		return 0, ctx.Err()
-	case rr := <-w.ch:
-		w.setChan(nil)
-		n = copy(p, rr.buf)
-		w.buffered = rr.buf[n:]
-		w.readErr = rr.err
-		if len(w.buffered) == 0 {
-			err = rr.err
-		}
-		return n, err
-	}
-}
-
-// contextReaderSession implements ssh.Session, wrapping another
-// ssh.Session but changing its Read method to use contextReader.
-type contextReaderSession struct {
-	ssh.Session
-	cr *contextReader
-}
-
-func (a contextReaderSession) Read(p []byte) (n int, err error) {
-	if a.cr.HasOutstandingRead() {
-		return a.cr.ReadContext(context.Background(), p)
-	}
-	return a.Session.Read(p)
-}

+ 0 - 2
ssh/tailssh/incubator.go

@@ -86,11 +86,9 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
 		// TODO(maisem): this doesn't work with sftp
 		return exec.CommandContext(ss.ctx, name, args...)
 	}
-	ss.conn.mu.Lock()
 	lu := ss.conn.localUser
 	ci := ss.conn.info
 	gids := strings.Join(ss.conn.userGroupIDs, ",")
-	ss.conn.mu.Unlock()
 	remoteUser := ci.uprof.LoginName
 	if len(ci.node.Tags) > 0 {
 		remoteUser = strings.Join(ci.node.Tags, ",")

+ 216 - 217
ssh/tailssh/tailssh.go

@@ -29,7 +29,6 @@ import (
 	"strconv"
 	"strings"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	gossh "github.com/tailscale/golang-x-crypto/ssh"
@@ -87,6 +86,21 @@ func init() {
 	})
 }
 
+// attachSessionToConnIfNotShutdown ensures that srv is not shutdown before
+// attaching the session to the conn. This ensures that once Shutdown is called,
+// new sessions are not allowed and existing ones are cleaned up.
+// It reports whether ss was attached to the conn.
+func (srv *server) attachSessionToConnIfNotShutdown(ss *sshSession) bool {
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	if srv.shutdownCalled {
+		// Do not start any new sessions.
+		return false
+	}
+	ss.conn.attachSession(ss)
+	return true
+}
+
 func (srv *server) trackActiveConn(c *conn, add bool) {
 	srv.mu.Lock()
 	defer srv.mu.Unlock()
@@ -121,12 +135,7 @@ func (srv *server) Shutdown() {
 	srv.mu.Lock()
 	srv.shutdownCalled = true
 	for c := range srv.activeConns {
-		for _, s := range c.sessions {
-			s.ctx.CloseWithError(userVisibleError{
-				fmt.Sprintf("Tailscale SSH is shutting down.\r\n"),
-				context.Canceled,
-			})
-		}
+		c.Close()
 	}
 	srv.mu.Unlock()
 	srv.sessionWaitGroup.Wait()
@@ -138,10 +147,7 @@ func (srv *server) OnPolicyChange() {
 	srv.mu.Lock()
 	defer srv.mu.Unlock()
 	for c := range srv.activeConns {
-		c.mu.Lock()
-		ci := c.info
-		c.mu.Unlock()
-		if ci == nil {
+		if c.info == nil {
 			// c.info is nil when the connection hasn't been authenticated yet.
 			// In that case, the connection will be terminated when it is.
 			continue
@@ -152,28 +158,53 @@ func (srv *server) OnPolicyChange() {
 
 // conn represents a single SSH connection and its associated
 // ssh.Server.
+//
+// During the lifecycle of a connection, the following are called in order:
+// Setup and discover server info
+//   - ServerConfigCallback
+//
+// Do the user auth
+//   - BannerHandler
+//   - NoClientAuthHandler
+//   - PublicKeyHandler (only if NoClientAuthHandler returns errPubKeyRequired)
+//
+// Once auth is done, the conn can be multiplexed with multiple sessions and
+// channels concurrently. At which point any of the following can be called
+// in any order.
+//   - c.handleSessionPostSSHAuth
+//   - c.mayForwardLocalPortTo followed by ssh.DirectTCPIPHandler
 type conn struct {
 	*ssh.Server
+	srv *server
 
 	insecureSkipTailscaleAuth bool // used by tests.
 
-	connID  string             // ID that's shared with control
-	action0 *tailcfg.SSHAction // first matching action
-	srv     *server
-
-	mu           sync.Mutex   // protects the following
-	localUser    *user.User   // set by checkAuth
-	userGroupIDs []string     // set by checkAuth
-	info         *sshConnInfo // set by setInfo
 	// idH is the RFC4253 sec8 hash H. It is used to identify the connection,
 	// and is shared among all sessions. It should not be shared outside
 	// process. It is confusingly referred to as SessionID by the gliderlabs/ssh
 	// library.
-	idH            string
-	pubKey         gossh.PublicKey    // set by authorizeSession
-	finalAction    *tailcfg.SSHAction // set by authorizeSession
-	finalActionErr error              // set by authorizeSession
-	sessions       []*sshSession
+	idH    string
+	connID string // ID that's shared with control
+
+	noPubKeyPolicyAuthError error // set by BannerCallback
+
+	action0        *tailcfg.SSHAction // set by doPolicyAuth; first matching action
+	currentAction  *tailcfg.SSHAction // set by doPolicyAuth, updated by resolveNextAction
+	finalAction    *tailcfg.SSHAction // set by doPolicyAuth or resolveNextAction
+	finalActionErr error              // set by doPolicyAuth or resolveNextAction
+
+	info         *sshConnInfo    // set by setInfo
+	localUser    *user.User      // set by doPolicyAuth
+	userGroupIDs []string        // set by doPolicyAuth
+	pubKey       gossh.PublicKey // set by doPolicyAuth
+
+	// mu protects the following fields.
+	//
+	// srv.mu should be acquired prior to mu.
+	// It is safe to just acquire mu, but unsafe to
+	// acquire mu and then srv.mu.
+	mu       sync.Mutex // protects the following
+	sessions []*sshSession
 }
 
 func (c *conn) logf(format string, args ...any) {
@@ -181,49 +212,108 @@ func (c *conn) logf(format string, args ...any) {
 	c.srv.logf(format, args...)
 }
 
-// PublicKeyHandler implements ssh.PublicKeyHandler is called by the
-// ssh.Server when the client presents a public key.
-func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error {
-	c.mu.Lock()
-	ci := c.info
-	c.mu.Unlock()
-	if ci == nil {
-		return gossh.ErrDenied
-	}
-
-	if err := c.checkAuth(pubKey); err != nil {
-		// TODO(maisem/bradfitz): surface the error here.
-		c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
-		return err
+// isAuthorized returns nil if the connection is authorized to proceed.
+func (c *conn) isAuthorized(ctx ssh.Context) error {
+	action := c.currentAction
+	for {
+		if action.Accept {
+			if c.pubKey != nil {
+				metricPublicKeyAccepts.Add(1)
+			}
+			return nil
+		}
+		if action.Reject || action.HoldAndDelegate == "" {
+			return gossh.ErrDenied
+		}
+		var err error
+		action, err = c.resolveNextAction(ctx)
+		if err != nil {
+			return err
+		}
 	}
-	c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
-	return nil
 }
 
 // errPubKeyRequired is returned by NoClientAuthCallback to make the client
 // resort to public-key auth; not user visible.
 var errPubKeyRequired = errors.New("ssh publickey required")
 
+// BannerCallback implements ssh.BannerCallback.
+// It is responsible for starting the policy evaluation, and returns
+// the first message found in the action chain. It stops the evaluation
+// on the first "accept" or "reject" action, and returns the message
+// associated with that action (if any).
+func (c *conn) BannerCallback(ctx ssh.Context) string {
+	if err := c.setInfo(ctx); err != nil {
+		c.logf("failed to get conninfo: %v", err)
+		return gossh.ErrDenied.Error()
+	}
+	if err := c.doPolicyAuth(ctx, nil /* no pub key */); err != nil {
+		// Stash the error for NoClientAuthCallback to return it.
+		c.noPubKeyPolicyAuthError = err
+		return ""
+	}
+	action := c.currentAction
+	for {
+		if action.Reject || action.Accept || action.Message != "" {
+			return action.Message
+		}
+		if action.HoldAndDelegate == "" {
+			// Do not send user-visible messages to the user.
+			// Let the SSH level authentication fail instead.
+			return ""
+		}
+		var err error
+		action, err = c.resolveNextAction(ctx)
+		if err != nil {
+			return ""
+		}
+	}
+}
+
 // NoClientAuthCallback implements gossh.NoClientAuthCallback and is called by
 // the ssh.Server when the client first connects with the "none"
 // authentication method.
-func (c *conn) NoClientAuthCallback(cm gossh.ConnMetadata) (*gossh.Permissions, error) {
+//
+// It is responsible for continuing policy evaluation from BannerCallback (or
+// starting it afresh). It returns an error if the policy evaluation fails, or
+// if the decision is "reject"
+//
+// It either returns nil (accept) or errPubKeyRequired or gossh.ErrDenied
+// (reject). The errors may be wrapped.
+func (c *conn) NoClientAuthCallback(ctx ssh.Context) error {
 	if c.insecureSkipTailscaleAuth {
-		return nil, nil
+		return nil
 	}
-	if err := c.setInfo(cm); err != nil {
-		c.logf("failed to get conninfo: %v", err)
-		return nil, gossh.ErrDenied
+	if c.noPubKeyPolicyAuthError != nil {
+		return c.noPubKeyPolicyAuthError
+	} else if c.currentAction == nil {
+		// This should never happen, but if it does, we want to know.
+		panic("no current action")
 	}
-	return nil, c.checkAuth(nil /* no pub key */)
+	return c.isAuthorized(ctx)
 }
 
-// checkAuth verifies that conn can proceed with the specified (optional)
+// PublicKeyHandler implements ssh.PublicKeyHandler is called by the
+// ssh.Server when the client presents a public key.
+func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error {
+	if err := c.doPolicyAuth(ctx, pubKey); err != nil {
+		// TODO(maisem/bradfitz): surface the error here.
+		c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
+		return err
+	}
+	if err := c.isAuthorized(ctx); err != nil {
+		return err
+	}
+	c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
+	return nil
+}
+
+// doPolicyAuth verifies that conn can proceed with the specified (optional)
 // pubKey. It returns nil if the matching policy action is Accept or
 // HoldAndDelegate. If pubKey is nil, there was no policy match but there is a
 // policy that might match a public key it returns errPubKeyRequired. Otherwise,
 // it returns gossh.ErrDenied possibly wrapped in gossh.WithBannerError.
-func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
+func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
 	a, localUser, err := c.evaluatePolicy(pubKey)
 	if err != nil {
 		if pubKey == nil && c.havePubKeyPolicy() {
@@ -232,7 +322,12 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
 		return fmt.Errorf("%w: %v", gossh.ErrDenied, err)
 	}
 	c.action0 = a
+	c.currentAction = a
+	c.pubKey = pubKey
 	if a.Accept || a.HoldAndDelegate != "" {
+		if a.Accept {
+			c.finalAction = a
+		}
 		lu, err := user.Lookup(localUser)
 		if err != nil {
 			c.logf("failed to lookup %v: %v", localUser, err)
@@ -245,13 +340,12 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
 		if err != nil {
 			return err
 		}
-		c.mu.Lock()
-		defer c.mu.Unlock()
 		c.userGroupIDs = gids
 		c.localUser = lu
 		return nil
 	}
 	if a.Reject {
+		c.finalAction = a
 		err := gossh.ErrDenied
 		if a.Message != "" {
 			err = gossh.WithBannerError{
@@ -269,9 +363,8 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
 func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig {
 	return &gossh.ServerConfig{
 		// OpenSSH presents this on failure as `Permission denied (tailscale).`
-		ImplicitAuthMethod:   "tailscale",
-		NoClientAuth:         true, // required for the NoClientAuthCallback to run
-		NoClientAuthCallback: c.NoClientAuthCallback,
+		ImplicitAuthMethod: "tailscale",
+		NoClientAuth:       true, // required for the NoClientAuthCallback to run
 	}
 }
 
@@ -289,23 +382,25 @@ func (srv *server) newConn() (*conn, error) {
 	now := srv.now()
 	c.connID = fmt.Sprintf("ssh-conn-%s-%02x", now.UTC().Format("20060102T150405"), randBytes(5))
 	c.Server = &ssh.Server{
-		Version:         "Tailscale",
-		Handler:         c.handleSessionPostSSHAuth,
-		RequestHandlers: map[string]ssh.RequestHandler{},
+		Version:              "Tailscale",
+		ServerConfigCallback: c.ServerConfig,
+
+		BannerHandler:       c.BannerCallback,
+		NoClientAuthHandler: c.NoClientAuthCallback,
+		PublicKeyHandler:    c.PublicKeyHandler,
+
+		Handler:                     c.handleSessionPostSSHAuth,
+		LocalPortForwardingCallback: c.mayForwardLocalPortTo,
 		SubsystemHandlers: map[string]ssh.SubsystemHandler{
 			"sftp": c.handleSessionPostSSHAuth,
 		},
-
 		// Note: the direct-tcpip channel handler and LocalPortForwardingCallback
 		// only adds support for forwarding ports from the local machine.
 		// TODO(maisem/bradfitz): add remote port forwarding support.
 		ChannelHandlers: map[string]ssh.ChannelHandler{
 			"direct-tcpip": ssh.DirectTCPIPHandler,
 		},
-		LocalPortForwardingCallback: c.mayForwardLocalPortTo,
-
-		PublicKeyHandler:     c.PublicKeyHandler,
-		ServerConfigCallback: c.ServerConfig,
+		RequestHandlers: map[string]ssh.RequestHandler{},
 	}
 	ss := c.Server
 	for k, v := range ssh.DefaultRequestHandlers {
@@ -341,10 +436,7 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de
 // havePubKeyPolicy reports whether any policy rule may provide access by means
 // of a ssh.PublicKey.
 func (c *conn) havePubKeyPolicy() bool {
-	c.mu.Lock()
-	ci := c.info
-	c.mu.Unlock()
-	if ci == nil {
+	if c.info == nil {
 		panic("havePubKeyPolicy called before setInfo")
 	}
 	// Is there any rule that looks like it'd require a public key for this
@@ -357,7 +449,7 @@ func (c *conn) havePubKeyPolicy() bool {
 		if c.ruleExpired(r) {
 			continue
 		}
-		if mapLocalUser(r.SSHUsers, ci.sshUser) == "" {
+		if mapLocalUser(r.SSHUsers, c.info.sshUser) == "" {
 			continue
 		}
 		for _, p := range r.Principals {
@@ -416,11 +508,11 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) {
 
 // connInfo returns a populated sshConnInfo from the provided arguments,
 // validating only that they represent a known Tailscale identity.
-func (c *conn) setInfo(cm gossh.ConnMetadata) error {
+func (c *conn) setInfo(ctx ssh.Context) error {
 	ci := &sshConnInfo{
-		sshUser: cm.User(),
-		src:     toIPPort(cm.RemoteAddr()),
-		dst:     toIPPort(cm.LocalAddr()),
+		sshUser: ctx.User(),
+		src:     toIPPort(ctx.RemoteAddr()),
+		dst:     toIPPort(ctx.LocalAddr()),
 	}
 	if !tsaddr.IsTailscaleIP(ci.dst.Addr()) {
 		return fmt.Errorf("tailssh: rejecting non-Tailscale local address %v", ci.dst)
@@ -432,11 +524,10 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
 	if !ok {
 		return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
 	}
-	c.mu.Lock()
-	defer c.mu.Unlock()
 	ci.node = node
 	ci.uprof = &uprof
 
+	c.idH = ctx.SessionID()
 	c.info = ci
 	c.logf("handling conn: %v", ci.String())
 	return nil
@@ -554,50 +645,10 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
 	return lines, err
 }
 
-func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	idH := s.Context().(ssh.Context).SessionID()
-	if c.idH == "" {
-		c.idH = idH
-	} else if c.idH != idH {
-		c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH)
-		s.Exit(1)
-		return nil, false
-	}
-	cr := &contextReader{r: s}
-	action, err := c.resolveTerminalActionLocked(s, cr)
-	if err != nil {
-		c.logf("resolveTerminalAction: %v", err)
-		io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
-		s.Exit(1)
-		return nil, false
-	}
-	if action.Reject || !action.Accept {
-		c.logf("access denied for %v", c.info.uprof.LoginName)
-		s.Exit(1)
-		return nil, false
-	}
-	return cr, true
-}
-
 // handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication,
 // but not necessarily before all the Tailscale-level extra verification has
 // completed. It also handles SFTP requests.
 func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
-	// Now that we have passed the SSH-level authentication, we can start the
-	// Tailscale-level extra verification. This means that we are going to
-	// evaluate the policy provided by control against the incoming SSH session.
-	cr, ok := c.authorizeSession(s)
-	if !ok {
-		return
-	}
-	if cr.HasOutstandingRead() {
-		// There was some buffered input while we were waiting for the policy
-		// decision.
-		s = contextReaderSession{s, cr}
-	}
-
 	// Do this check after auth, but before starting the session.
 	switch s.Subsystem() {
 	case "sftp", "":
@@ -609,45 +660,35 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
 	}
 
 	ss := c.newSSHSession(s)
-	c.mu.Lock()
 	ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username)
 	ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
-	c.mu.Unlock()
 	ss.run()
 }
 
-// resolveTerminalActionLocked either returns action0 (if it's Accept or Reject) or
-// else loops, fetching new SSHActions from the control plane.
-//
-// Any action with a Message in the chain will be printed to s.
-//
-// The returned SSHAction will be either Reject or Accept.
-//
-// c.mu must be held.
-func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) {
+// resolveNextAction starts at c.currentAction and makes it way through the
+// action chain one step at a time. An action without a HoldAndDelegate is
+// considered the final action. Once a final action is reached, this function
+// will keep returning that action. It updates c.currentAction to the next
+// action in the chain. When the final action is reached, it also sets
+// c.finalAction to the final action.
+func (c *conn) resolveNextAction(sctx ssh.Context) (action *tailcfg.SSHAction, err error) {
 	if c.finalAction != nil || c.finalActionErr != nil {
 		return c.finalAction, c.finalActionErr
 	}
 
-	if s.PublicKey() != nil {
-		metricPublicKeyConnections.Add(1)
-	}
 	defer func() {
-		c.finalAction = action
-		c.finalActionErr = err
-		c.pubKey = s.PublicKey()
-		if c.pubKey != nil && action.Accept {
-			metricPublicKeyAccepts.Add(1)
+		if action != nil {
+			c.currentAction = action
+			if action.Accept || action.Reject {
+				c.finalAction = action
+			}
+		}
+		if err != nil {
+			c.finalActionErr = err
 		}
 	}()
-	action = c.action0
-
-	var awaitReadOnce sync.Once // to start Reads on cr
-	var sawInterrupt atomic.Bool
-	var wg sync.WaitGroup
-	defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit
 
-	ctx, cancel := context.WithCancel(s.Context())
+	ctx, cancel := context.WithCancel(sctx)
 	defer cancel()
 
 	// Loop processing/fetching Actions until one reaches a
@@ -656,56 +697,28 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
 	// done (client disconnect) or its 30 minute timeout passes.
 	// (Which is a long time for somebody to see login
 	// instructions and go to a URL to do something.)
-	for {
-		if action.Message != "" {
-			io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
-		}
-		if action.Accept || action.Reject {
-			if action.Reject {
-				metricTerminalReject.Add(1)
-			} else {
-				metricTerminalAccept.Add(1)
-			}
-			return action, nil
-		}
-		url := action.HoldAndDelegate
-		if url == "" {
-			metricTerminalMalformed.Add(1)
-			return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate")
-		}
-		metricHolds.Add(1)
-		awaitReadOnce.Do(func() {
-			wg.Add(1)
-			go func() {
-				defer wg.Done()
-				buf := make([]byte, 1)
-				for {
-					n, err := cr.ReadContext(ctx, buf)
-					if err != nil {
-						return
-					}
-					if n > 0 && buf[0] == 0x03 { // Ctrl-C
-						sawInterrupt.Store(true)
-						s.Stderr().Write([]byte("Canceled.\r\n"))
-						s.Exit(1)
-						return
-					}
-				}
-			}()
-		})
-		url = c.expandDelegateURLLocked(url)
-		var err error
-		action, err = c.fetchSSHAction(ctx, url)
-		if err != nil {
-			if sawInterrupt.Load() {
-				metricTerminalInterrupt.Add(1)
-				return nil, fmt.Errorf("aborted by user")
-			} else {
-				metricTerminalFetchError.Add(1)
-			}
-			return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err)
+	action = c.currentAction
+	if action.Accept || action.Reject {
+		if action.Reject {
+			metricTerminalReject.Add(1)
+		} else {
+			metricTerminalAccept.Add(1)
 		}
+		return action, nil
+	}
+	url := action.HoldAndDelegate
+	if url == "" {
+		metricTerminalMalformed.Add(1)
+		return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate")
 	}
+	metricHolds.Add(1)
+	url = c.expandDelegateURLLocked(url)
+	nextAction, err := c.fetchSSHAction(ctx, url)
+	if err != nil {
+		metricTerminalFetchError.Add(1)
+		return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err)
+	}
+	return nextAction, nil
 }
 
 func (c *conn) expandDelegateURLLocked(actionURL string) string {
@@ -732,12 +745,10 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
 	}
 	var localPart string
 	var loginName string
-	c.mu.Lock()
 	if c.info.uprof != nil {
 		loginName = c.info.uprof.LoginName
 		localPart, _, _ = strings.Cut(loginName, "@")
 	}
-	c.mu.Unlock()
 	return strings.NewReplacer(
 		"$LOGINNAME_EMAIL", loginName,
 		"$LOGINNAME_LOCALPART", localPart,
@@ -793,8 +804,6 @@ func (c *conn) isStillValid() bool {
 	if !a.Accept && a.HoldAndDelegate == "" {
 		return false
 	}
-	c.mu.Lock()
-	defer c.mu.Unlock()
 	return c.localUser.Username == localUser
 }
 
@@ -806,6 +815,8 @@ func (c *conn) checkStillValid() {
 	}
 	metricPolicyChangeKick.Add(1)
 	c.logf("session no longer valid per new SSH policy; closing")
+	c.mu.Lock()
+	defer c.mu.Unlock()
 	for _, s := range c.sessions {
 		s.ctx.CloseWithError(userVisibleError{
 			fmt.Sprintf("Access revoked.\r\n"),
@@ -876,21 +887,22 @@ func (ss *sshSession) killProcessOnContextDone() {
 	})
 }
 
-// startSessionLocked registers ss as an active session.
-// It must be called with srv.mu held.
-func (c *conn) startSessionLocked(ss *sshSession) {
+// attachSession registers ss as an active session.
+func (c *conn) attachSession(ss *sshSession) {
 	c.srv.sessionWaitGroup.Add(1)
 	if ss.sharedID == "" {
 		panic("empty sharedID")
 	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
 	c.sessions = append(c.sessions, ss)
 }
 
-// endSession unregisters s from the list of active sessions.
-func (c *conn) endSession(ss *sshSession) {
+// detachSession unregisters s from the list of active sessions.
+func (c *conn) detachSession(ss *sshSession) {
 	defer c.srv.sessionWaitGroup.Done()
-	c.srv.mu.Lock()
-	defer c.srv.mu.Unlock()
+	c.mu.Lock()
+	defer c.mu.Unlock()
 	for i, s := range c.sessions {
 		if s == ss {
 			c.sessions = append(c.sessions[:i], c.sessions[i+1:]...)
@@ -960,22 +972,16 @@ func (ss *sshSession) run() {
 	metricActiveSessions.Add(1)
 	defer metricActiveSessions.Add(-1)
 	defer ss.ctx.CloseWithError(errSessionDone)
-	srv := ss.conn.srv
 
-	srv.mu.Lock()
-	if srv.shutdownCalled {
-		srv.mu.Unlock()
-		// Do not start any new sessions.
+	if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
 		fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
 		ss.Exit(1)
 		return
 	}
-	ss.conn.startSessionLocked(ss)
-	lu := ss.conn.localUser
-	localUser := lu.Username
-	srv.mu.Unlock()
+	defer ss.conn.detachSession(ss)
 
-	defer ss.conn.endSession(ss)
+	lu := ss.conn.localUser
+	logf := ss.logf
 
 	if ss.conn.finalAction.SessionDuration != 0 {
 		t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
@@ -987,11 +993,9 @@ func (ss *sshSession) run() {
 		defer t.Stop()
 	}
 
-	logf := ss.logf
-
 	if euid := os.Geteuid(); euid != 0 {
 		if lu.Uid != fmt.Sprint(euid) {
-			ss.logf("can't switch to user %q from process euid %v", localUser, euid)
+			ss.logf("can't switch to user %q from process euid %v", lu.Username, euid)
 			fmt.Fprintf(ss, "can't switch user\r\n")
 			ss.Exit(1)
 			return
@@ -1141,10 +1145,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
 	if c == nil {
 		return nil, "", errInvalidConn
 	}
-	c.mu.Lock()
-	ci := c.info
-	c.mu.Unlock()
-	if ci == nil {
+	if c.info == nil {
 		c.logf("invalid connection state")
 		return nil, "", errInvalidConn
 	}
@@ -1161,7 +1162,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
 		// For all but Reject rules, SSHUsers is required.
 		// If SSHUsers is nil or empty, mapLocalUser will return an
 		// empty string anyway.
-		localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
+		localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
 		if localUser == "" {
 			return nil, "", errUserMatch
 		}
@@ -1210,9 +1211,7 @@ func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey)
 // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
 // This function does not consider PubKeys.
 func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
-	c.mu.Lock()
 	ci := c.info
-	c.mu.Unlock()
 	if p.Any {
 		return true
 	}

+ 18 - 1
tempfork/gliderlabs/ssh/server.go

@@ -38,9 +38,11 @@ type Server struct {
 	HostSigners []Signer // private keys for the host key, must have at least one
 	Version     string   // server version to be sent before the initial handshake
 
-	KeyboardInteractiveHandler    KeyboardInteractiveHandler    // keyboard-interactive authentication handler
+	KeyboardInteractiveHandler    KeyboardInteractiveHandler // keyboard-interactive authentication handler
+	BannerHandler                 BannerHandler
 	PasswordHandler               PasswordHandler               // password authentication handler
 	PublicKeyHandler              PublicKeyHandler              // public key authentication handler
+	NoClientAuthHandler           NoClientAuthHandler           // no client authentication handler
 	PtyCallback                   PtyCallback                   // callback for allowing PTY sessions, allows all if nil
 	ConnCallback                  ConnCallback                  // optional callback for wrapping net.Conn before handling
 	LocalPortForwardingCallback   LocalPortForwardingCallback   // callback for allowing local port forwarding, denies all if nil
@@ -160,6 +162,21 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {
 			return ctx.Permissions().Permissions, nil
 		}
 	}
+	if srv.NoClientAuthHandler != nil {
+		config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) {
+			applyConnMetadata(ctx, conn)
+			if err := srv.NoClientAuthHandler(ctx); err != nil {
+				return ctx.Permissions().Permissions, err
+			}
+			return ctx.Permissions().Permissions, nil
+		}
+	}
+	if srv.BannerHandler != nil {
+		config.BannerCallback = func(conn gossh.ConnMetadata) string {
+			applyConnMetadata(ctx, conn)
+			return srv.BannerHandler(ctx)
+		}
+	}
 	return config
 }
 

+ 4 - 0
tempfork/gliderlabs/ssh/ssh.go

@@ -38,6 +38,10 @@ type Handler func(Session)
 // PublicKeyHandler is a callback for performing public key authentication.
 type PublicKeyHandler func(ctx Context, key PublicKey) error
 
+type NoClientAuthHandler func(ctx Context) error
+
+type BannerHandler func(ctx Context) string
+
 // PasswordHandler is a callback for performing password authentication.
 type PasswordHandler func(ctx Context, password string) bool