|
@@ -9,6 +9,7 @@
|
|
|
package tailssh
|
|
package tailssh
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
|
+ "context"
|
|
|
"encoding/json"
|
|
"encoding/json"
|
|
|
"errors"
|
|
"errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
@@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
|
|
|
func (srv *server) handleSSH(s ssh.Session) {
|
|
func (srv *server) handleSSH(s ssh.Session) {
|
|
|
lb := srv.lb
|
|
lb := srv.lb
|
|
|
logf := srv.logf
|
|
logf := srv.logf
|
|
|
-
|
|
|
|
|
sshUser := s.User()
|
|
sshUser := s.User()
|
|
|
addr := s.RemoteAddr()
|
|
addr := s.RemoteAddr()
|
|
|
logf("Handling SSH from %v for user %v", addr, sshUser)
|
|
logf("Handling SSH from %v for user %v", addr, sshUser)
|
|
@@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- ptyReq, winCh, isPty := s.Pty()
|
|
|
|
|
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
|
|
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
|
|
|
node, uprof, ok := lb.WhoIs(srcIPP)
|
|
node, uprof, ok := lb.WhoIs(srcIPP)
|
|
|
if !ok {
|
|
if !ok {
|
|
@@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) {
|
|
|
s.Exit(1)
|
|
s.Exit(1)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
|
|
|
|
|
|
|
+ var ctx context.Context = context.Background()
|
|
|
|
|
+ if action.SesssionDuration != 0 {
|
|
|
|
|
+ sctx := newSSHContext()
|
|
|
|
|
+ ctx = sctx
|
|
|
|
|
+ t := time.AfterFunc(action.SesssionDuration, func() {
|
|
|
|
|
+ sctx.CloseWithError(userVisibleError{
|
|
|
|
|
+ fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration),
|
|
|
|
|
+ context.DeadlineExceeded,
|
|
|
|
|
+ })
|
|
|
|
|
+ })
|
|
|
|
|
+ defer t.Stop()
|
|
|
|
|
+ }
|
|
|
|
|
+ srv.handleAcceptedSSH(ctx, s, ci, lu)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// handleAcceptedSSH handles s once it's been accepted and determined
|
|
|
|
|
+// that it should run as local system user lu.
|
|
|
|
|
+//
|
|
|
|
|
+// When ctx is done, the session is forcefully terminated. If its Err
|
|
|
|
|
+// is an SSHTerminationError, its SSHTerminationMessage is sent to the
|
|
|
|
|
+// user.
|
|
|
|
|
+func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) {
|
|
|
|
|
+ logf := srv.logf
|
|
|
|
|
+ localUser := lu.Username
|
|
|
|
|
+
|
|
|
|
|
+ var err error
|
|
|
|
|
+ ptyReq, winCh, isPty := s.Pty()
|
|
|
|
|
+ logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ())
|
|
|
var cmd *exec.Cmd
|
|
var cmd *exec.Cmd
|
|
|
if euid := os.Geteuid(); euid != 0 {
|
|
if euid := os.Geteuid(); euid != 0 {
|
|
|
if lu.Uid != fmt.Sprint(euid) {
|
|
if lu.Uid != fmt.Sprint(euid) {
|
|
@@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) {
|
|
|
go func() { io.Copy(s.Stderr(), stderr) }()
|
|
go func() { io.Copy(s.Stderr(), stderr) }()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if action.SesssionDuration != 0 {
|
|
|
|
|
- t := time.AfterFunc(action.SesssionDuration, func() {
|
|
|
|
|
- logf("terminating SSH session from %v after max duration", srcIP)
|
|
|
|
|
- cmd.Process.Kill()
|
|
|
|
|
- })
|
|
|
|
|
- defer t.Stop()
|
|
|
|
|
|
|
+ if ctx.Done() != nil {
|
|
|
|
|
+ done := make(chan struct{})
|
|
|
|
|
+ defer close(done)
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ select {
|
|
|
|
|
+ case <-done:
|
|
|
|
|
+ case <-ctx.Done():
|
|
|
|
|
+ err := ctx.Err()
|
|
|
|
|
+ if serr, ok := err.(SSHTerminationError); ok {
|
|
|
|
|
+ msg := serr.SSHTerminationMessage()
|
|
|
|
|
+ if msg != "" {
|
|
|
|
|
+ io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ logf("terminating SSH session from %v: %v", ci.srcIP, err)
|
|
|
|
|
+ cmd.Process.Kill()
|
|
|
|
|
+ }
|
|
|
|
|
+ }()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
go func() {
|
|
go func() {
|