Browse Source

ssh/tailssh: break a method into half in prep for testing

And add a private context type in the process.

Updates #3802

Change-Id: I257187f4cfb0f2248d95b81c1dfe0911ef203b60
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 years ago
parent
commit
e2ed06c53c
2 changed files with 108 additions and 9 deletions
  1. 61 0
      ssh/tailssh/context.go
  2. 47 9
      ssh/tailssh/tailssh.go

+ 61 - 0
ssh/tailssh/context.go

@@ -0,0 +1,61 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tailssh
+
+import (
+	"sync"
+	"time"
+)
+
+// sshContext is the context.Context implementation we use for SSH
+// that adds a CloseWithError method. Otherwise it's just a normalish
+// Context.
+type sshContext struct {
+	mu     sync.Mutex
+	closed bool
+	done   chan struct{}
+	err    error
+}
+
+func newSSHContext() *sshContext {
+	return &sshContext{done: make(chan struct{})}
+}
+
+func (ctx *sshContext) CloseWithError(err error) {
+	ctx.mu.Lock()
+	defer ctx.mu.Unlock()
+	if ctx.closed {
+		return
+	}
+	ctx.closed = true
+	ctx.err = err
+	close(ctx.done)
+}
+
+func (ctx *sshContext) Err() error {
+	ctx.mu.Lock()
+	defer ctx.mu.Unlock()
+	return ctx.err
+}
+
+func (ctx *sshContext) Done() <-chan struct{}                   { return ctx.done }
+func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
+func (ctx *sshContext) Value(interface{}) interface{}           { return nil }
+
+// userVisibleError is a wrapper around an error that implements
+// SSHTerminationError, so msg is written to their session.
+type userVisibleError struct {
+	msg string
+	error
+}
+
+func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
+
+// SSHTerminationError is implemented by errors that terminate an SSH
+// session and should be written to user's sessions.
+type SSHTerminationError interface {
+	error
+	SSHTerminationMessage() string
+}

+ 47 - 9
ssh/tailssh/tailssh.go

@@ -9,6 +9,7 @@
 package tailssh
 package tailssh
 
 
 import (
 import (
+	"context"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
 func (srv *server) handleSSH(s ssh.Session) {
 func (srv *server) handleSSH(s ssh.Session) {
 	lb := srv.lb
 	lb := srv.lb
 	logf := srv.logf
 	logf := srv.logf
-
 	sshUser := s.User()
 	sshUser := s.User()
 	addr := s.RemoteAddr()
 	addr := s.RemoteAddr()
 	logf("Handling SSH from %v for user %v", addr, sshUser)
 	logf("Handling SSH from %v for user %v", addr, sshUser)
@@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) {
 		return
 		return
 	}
 	}
 
 
-	ptyReq, winCh, isPty := s.Pty()
 	srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
 	srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
 	node, uprof, ok := lb.WhoIs(srcIPP)
 	node, uprof, ok := lb.WhoIs(srcIPP)
 	if !ok {
 	if !ok {
@@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) {
 		s.Exit(1)
 		s.Exit(1)
 	}
 	}
 
 
-	logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
+	var ctx context.Context = context.Background()
+	if action.SesssionDuration != 0 {
+		sctx := newSSHContext()
+		ctx = sctx
+		t := time.AfterFunc(action.SesssionDuration, func() {
+			sctx.CloseWithError(userVisibleError{
+				fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration),
+				context.DeadlineExceeded,
+			})
+		})
+		defer t.Stop()
+	}
+	srv.handleAcceptedSSH(ctx, s, ci, lu)
+}
+
+// handleAcceptedSSH handles s once it's been accepted and determined
+// that it should run as local system user lu.
+//
+// When ctx is done, the session is forcefully terminated. If its Err
+// is an SSHTerminationError, its SSHTerminationMessage is sent to the
+// user.
+func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) {
+	logf := srv.logf
+	localUser := lu.Username
+
+	var err error
+	ptyReq, winCh, isPty := s.Pty()
+	logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ())
 	var cmd *exec.Cmd
 	var cmd *exec.Cmd
 	if euid := os.Geteuid(); euid != 0 {
 	if euid := os.Geteuid(); euid != 0 {
 		if lu.Uid != fmt.Sprint(euid) {
 		if lu.Uid != fmt.Sprint(euid) {
@@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) {
 		go func() { io.Copy(s.Stderr(), stderr) }()
 		go func() { io.Copy(s.Stderr(), stderr) }()
 	}
 	}
 
 
-	if action.SesssionDuration != 0 {
-		t := time.AfterFunc(action.SesssionDuration, func() {
-			logf("terminating SSH session from %v after max duration", srcIP)
-			cmd.Process.Kill()
-		})
-		defer t.Stop()
+	if ctx.Done() != nil {
+		done := make(chan struct{})
+		defer close(done)
+		go func() {
+			select {
+			case <-done:
+			case <-ctx.Done():
+				err := ctx.Err()
+				if serr, ok := err.(SSHTerminationError); ok {
+					msg := serr.SSHTerminationMessage()
+					if msg != "" {
+						io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
+					}
+				}
+				logf("terminating SSH session from %v: %v", ci.srcIP, err)
+				cmd.Process.Kill()
+			}
+		}()
 	}
 	}
 
 
 	go func() {
 	go func() {