tailssh.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build linux || (darwin && !ios)
  5. // +build linux darwin,!ios
  6. // Package tailssh is an SSH server integrated into Tailscale.
  7. package tailssh
  8. import (
  9. "context"
  10. "crypto/rand"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "os"
  18. "os/exec"
  19. "os/user"
  20. "path/filepath"
  21. "strconv"
  22. "strings"
  23. "sync"
  24. "time"
  25. "github.com/tailscale/ssh"
  26. "inet.af/netaddr"
  27. "tailscale.com/envknob"
  28. "tailscale.com/ipn/ipnlocal"
  29. "tailscale.com/logtail/backoff"
  30. "tailscale.com/net/tsaddr"
  31. "tailscale.com/tailcfg"
  32. "tailscale.com/types/logger"
  33. )
  34. // TODO(bradfitz): this is all very temporary as code is temporarily
  35. // being moved around; it will be restructured and documented in
  36. // following commits.
  37. // Handle handles an SSH connection from c.
  38. func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error {
  39. tsd, err := os.Executable()
  40. if err != nil {
  41. return err
  42. }
  43. srv := &server{
  44. lb: lb,
  45. logf: logf,
  46. tailscaledPath: tsd,
  47. }
  48. ss, err := srv.newSSHServer()
  49. if err != nil {
  50. return err
  51. }
  52. ss.HandleConn(c)
  53. return nil
  54. }
  55. func (srv *server) newSSHServer() (*ssh.Server, error) {
  56. ss := &ssh.Server{
  57. Handler: srv.handleSSH,
  58. RequestHandlers: map[string]ssh.RequestHandler{},
  59. SubsystemHandlers: map[string]ssh.SubsystemHandler{},
  60. // Note: the direct-tcpip channel handler and LocalPortForwardingCallback
  61. // only adds support for forwarding ports from the local machine.
  62. // TODO(maisem/bradfitz): add remote port forwarding support.
  63. ChannelHandlers: map[string]ssh.ChannelHandler{
  64. "direct-tcpip": ssh.DirectTCPIPHandler,
  65. },
  66. Version: "SSH-2.0-Tailscale",
  67. LocalPortForwardingCallback: srv.portForward,
  68. }
  69. for k, v := range ssh.DefaultRequestHandlers {
  70. ss.RequestHandlers[k] = v
  71. }
  72. for k, v := range ssh.DefaultChannelHandlers {
  73. ss.ChannelHandlers[k] = v
  74. }
  75. for k, v := range ssh.DefaultSubsystemHandlers {
  76. ss.SubsystemHandlers[k] = v
  77. }
  78. keys, err := srv.lb.GetSSH_HostKeys()
  79. if err != nil {
  80. return nil, err
  81. }
  82. for _, signer := range keys {
  83. ss.AddHostKey(signer)
  84. }
  85. return ss, nil
  86. }
  87. type server struct {
  88. lb *ipnlocal.LocalBackend
  89. logf logger.Logf
  90. tailscaledPath string
  91. // mu protects activeSessions.
  92. mu sync.Mutex
  93. activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => that session
  94. activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
  95. }
  96. var debugPolicyFile = envknob.String("TS_DEBUG_SSH_POLICY_FILE")
  97. // portForward reports whether the ctx should be allowed to port forward
  98. // to the specified host and port.
  99. // TODO(bradfitz/maisem): should we have more checks on host/port?
  100. func (srv *server) portForward(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
  101. return srv.isActiveSession(ctx)
  102. }
  103. // sshPolicy returns the SSHPolicy for current node.
  104. // If there is no SSHPolicy in the netmap, it returns a debugPolicy
  105. // if one is defined.
  106. func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
  107. lb := srv.lb
  108. nm := lb.NetMap()
  109. if nm == nil {
  110. return nil, false
  111. }
  112. if pol := nm.SSHPolicy; pol != nil {
  113. return pol, true
  114. }
  115. if debugPolicyFile != "" {
  116. f, err := os.ReadFile(debugPolicyFile)
  117. if err != nil {
  118. srv.logf("error reading debug SSH policy file: %v", err)
  119. return nil, false
  120. }
  121. p := new(tailcfg.SSHPolicy)
  122. if err := json.Unmarshal(f, p); err != nil {
  123. srv.logf("invalid JSON in %v: %v", debugPolicyFile, err)
  124. return nil, false
  125. }
  126. return p, true
  127. }
  128. return nil, false
  129. }
  130. func asTailscaleIPPort(a net.Addr) (netaddr.IPPort, error) {
  131. ta, ok := a.(*net.TCPAddr)
  132. if !ok {
  133. return netaddr.IPPort{}, fmt.Errorf("non-TCP addr %T %v", a, a)
  134. }
  135. tanetaddr, ok := netaddr.FromStdIP(ta.IP)
  136. if !ok {
  137. return netaddr.IPPort{}, fmt.Errorf("unparseable addr %v", ta.IP)
  138. }
  139. if !tsaddr.IsTailscaleIP(tanetaddr) {
  140. return netaddr.IPPort{}, fmt.Errorf("non-Tailscale addr %v", ta.IP)
  141. }
  142. return netaddr.IPPortFrom(tanetaddr, uint16(ta.Port)), nil
  143. }
  144. // evaluatePolicy returns the SSHAction, sshConnInfo and localUser
  145. // after evaluating the sshUser and remoteAddr against the SSHPolicy.
  146. // The remoteAddr and localAddr params must be Tailscale IPs.
  147. func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) {
  148. logf := srv.logf
  149. lb := srv.lb
  150. logf("Handling SSH from %v for user %v", remoteAddr, sshUser)
  151. pol, ok := srv.sshPolicy()
  152. if !ok {
  153. return nil, nil, "", fmt.Errorf("tsshd: rejecting connection; no SSH policy")
  154. }
  155. srcIPP, err := asTailscaleIPPort(remoteAddr)
  156. if err != nil {
  157. return nil, nil, "", fmt.Errorf("tsshd: rejecting: %w", err)
  158. }
  159. dstIPP, err := asTailscaleIPPort(localAddr)
  160. if err != nil {
  161. return nil, nil, "", err
  162. }
  163. node, uprof, ok := lb.WhoIs(srcIPP)
  164. if !ok {
  165. return nil, nil, "", fmt.Errorf("Hello, %v. I don't know who you are.\n", srcIPP)
  166. }
  167. ci := &sshConnInfo{
  168. now: time.Now(),
  169. sshUser: sshUser,
  170. src: srcIPP,
  171. dst: dstIPP,
  172. node: node,
  173. uprof: &uprof,
  174. }
  175. a, localUser, ok := evalSSHPolicy(pol, ci)
  176. if !ok {
  177. return nil, nil, "", fmt.Errorf("ssh: access denied for %q from %v", uprof.LoginName, ci.src.IP())
  178. }
  179. return a, ci, localUser, nil
  180. }
  181. // handleSSH is invoked when a new SSH connection attempt is made.
  182. func (srv *server) handleSSH(s ssh.Session) {
  183. logf := srv.logf
  184. sshUser := s.User()
  185. action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr())
  186. if err != nil {
  187. logf(err.Error())
  188. s.Exit(1)
  189. return
  190. }
  191. // Loop processing/fetching Actions until one reaches a
  192. // terminal state (Accept, Reject, or invalid Action), or
  193. // until fetchSSHAction times out due to the context being
  194. // done (client disconnect) or its 30 minute timeout passes.
  195. // (Which is a long time for somebody to see login
  196. // instructions and go to a URL to do something.)
  197. ProcessAction:
  198. for {
  199. if action.Message != "" {
  200. io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
  201. }
  202. if action.Reject {
  203. logf("ssh: access denied for %q from %v", ci.uprof.LoginName, ci.src.IP())
  204. s.Exit(1)
  205. return
  206. }
  207. if action.Accept {
  208. break ProcessAction
  209. }
  210. url := action.HoldAndDelegate
  211. if url == "" {
  212. logf("ssh: access denied; SSHAction has neither Reject, Accept, or next step URL")
  213. s.Exit(1)
  214. return
  215. }
  216. action, err = srv.fetchSSHAction(s.Context(), url)
  217. if err != nil {
  218. logf("ssh: fetching SSAction from %s: %v", url, err)
  219. s.Exit(1)
  220. return
  221. }
  222. }
  223. lu, err := user.Lookup(localUser)
  224. if err != nil {
  225. logf("ssh: user Lookup %q: %v", localUser, err)
  226. s.Exit(1)
  227. return
  228. }
  229. ss := srv.newSSHSession(s, ci, lu, action)
  230. ss.run()
  231. }
  232. // sshSession is an accepted Tailscale SSH session.
  233. type sshSession struct {
  234. ssh.Session
  235. idH string // the RFC4253 sec8 hash H; don't share outside process
  236. sharedID string // ID that's shared with control
  237. logf logger.Logf
  238. ctx *sshContext // implements context.Context
  239. srv *server
  240. connInfo *sshConnInfo
  241. action *tailcfg.SSHAction
  242. localUser *user.User
  243. agentListener net.Listener // non-nil if agent-forwarding requested+allowed
  244. // initialized by launchProcess:
  245. cmd *exec.Cmd
  246. stdin io.WriteCloser
  247. stdout io.Reader
  248. stderr io.Reader // nil for pty sessions
  249. ptyReq *ssh.Pty // non-nil for pty sessions
  250. // We use this sync.Once to ensure that we only terminate the process once,
  251. // either it exits itself or is terminated
  252. exitOnce sync.Once
  253. }
  254. func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User, action *tailcfg.SSHAction) *sshSession {
  255. sharedID := fmt.Sprintf("%s-%02x", ci.now.UTC().Format("20060102T150405"), randBytes(5))
  256. return &sshSession{
  257. Session: s,
  258. idH: s.Context().(ssh.Context).SessionID(),
  259. sharedID: sharedID,
  260. ctx: newSSHContext(),
  261. srv: srv,
  262. action: action,
  263. localUser: lu,
  264. connInfo: ci,
  265. logf: logger.WithPrefix(srv.logf, "ssh-session("+sharedID+"): "),
  266. }
  267. }
  268. func (srv *server) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) {
  269. ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
  270. defer cancel()
  271. bo := backoff.NewBackoff("fetch-ssh-action", srv.logf, 10*time.Second)
  272. for {
  273. if err := ctx.Err(); err != nil {
  274. return nil, err
  275. }
  276. req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  277. if err != nil {
  278. return nil, err
  279. }
  280. res, err := srv.lb.DoNoiseRequest(req)
  281. if err != nil {
  282. bo.BackOff(ctx, err)
  283. continue
  284. }
  285. if res.StatusCode != 200 {
  286. res.Body.Close()
  287. bo.BackOff(ctx, fmt.Errorf("unexpected status: %v", res.Status))
  288. continue
  289. }
  290. a := new(tailcfg.SSHAction)
  291. if err := json.NewDecoder(res.Body).Decode(a); err != nil {
  292. bo.BackOff(ctx, err)
  293. continue
  294. }
  295. return a, nil
  296. }
  297. }
  298. // killProcessOnContextDone waits for ss.ctx to be done and kills the process,
  299. // unless the process has already exited.
  300. func (ss *sshSession) killProcessOnContextDone() {
  301. <-ss.ctx.Done()
  302. // Either the process has already existed, in which case this does nothing.
  303. // Or, the process is still running in which case this will kill it.
  304. ss.exitOnce.Do(func() {
  305. err := ss.ctx.Err()
  306. if serr, ok := err.(SSHTerminationError); ok {
  307. msg := serr.SSHTerminationMessage()
  308. if msg != "" {
  309. io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
  310. }
  311. }
  312. ss.logf("terminating SSH session from %v: %v", ss.connInfo.src.IP(), err)
  313. ss.cmd.Process.Kill()
  314. })
  315. }
  316. // isActiveSession reports whether the ssh.Context corresponds
  317. // to an active session.
  318. func (srv *server) isActiveSession(sctx ssh.Context) bool {
  319. srv.mu.Lock()
  320. defer srv.mu.Unlock()
  321. _, ok := srv.activeSessionByH[sctx.SessionID()]
  322. return ok
  323. }
  324. // startSession registers ss as an active session.
  325. func (srv *server) startSession(ss *sshSession) {
  326. srv.mu.Lock()
  327. defer srv.mu.Unlock()
  328. if srv.activeSessionByH == nil {
  329. srv.activeSessionByH = make(map[string]*sshSession)
  330. }
  331. if srv.activeSessionBySharedID == nil {
  332. srv.activeSessionBySharedID = make(map[string]*sshSession)
  333. }
  334. if ss.idH == "" {
  335. panic("empty idH")
  336. }
  337. if _, dup := srv.activeSessionByH[ss.idH]; dup {
  338. panic("dup idH")
  339. }
  340. if ss.sharedID == "" {
  341. panic("empty sharedID")
  342. }
  343. if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
  344. panic("dup sharedID")
  345. }
  346. srv.activeSessionByH[ss.idH] = ss
  347. srv.activeSessionBySharedID[ss.sharedID] = ss
  348. }
  349. // endSession unregisters s from the list of active sessions.
  350. func (srv *server) endSession(ss *sshSession) {
  351. srv.mu.Lock()
  352. defer srv.mu.Unlock()
  353. delete(srv.activeSessionByH, ss.idH)
  354. delete(srv.activeSessionBySharedID, ss.sharedID)
  355. }
  356. var errSessionDone = errors.New("session is done")
  357. // handleSSHAgentForwarding starts a Unix socket listener and in the background
  358. // forwards agent connections between the listenr and the ssh.Session.
  359. // On success, it assigns ss.agentListener.
  360. func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) error {
  361. if !ssh.AgentRequested(ss) || !ss.action.AllowAgentForwarding {
  362. return nil
  363. }
  364. ss.logf("ssh: agent forwarding requested")
  365. ln, err := ssh.NewAgentListener()
  366. if err != nil {
  367. return err
  368. }
  369. defer func() {
  370. if err != nil && ln != nil {
  371. ln.Close()
  372. }
  373. }()
  374. uid, err := strconv.ParseUint(lu.Uid, 10, 32)
  375. if err != nil {
  376. return err
  377. }
  378. gid, err := strconv.ParseUint(lu.Gid, 10, 32)
  379. if err != nil {
  380. return err
  381. }
  382. socket := ln.Addr().String()
  383. dir := filepath.Dir(socket)
  384. // Make sure the socket is accessible by the user.
  385. if err := os.Chown(socket, int(uid), int(gid)); err != nil {
  386. return err
  387. }
  388. if err := os.Chmod(dir, 0755); err != nil {
  389. return err
  390. }
  391. go ssh.ForwardAgentConnections(ln, s)
  392. ss.agentListener = ln
  393. return nil
  394. }
  395. // run is the entrypoint for a newly accepted SSH session.
  396. //
  397. // When ctx is done, the session is forcefully terminated. If its Err
  398. // is an SSHTerminationError, its SSHTerminationMessage is sent to the
  399. // user.
  400. func (ss *sshSession) run() {
  401. srv := ss.srv
  402. srv.startSession(ss)
  403. defer srv.endSession(ss)
  404. defer ss.ctx.CloseWithError(errSessionDone)
  405. if ss.action.SesssionDuration != 0 {
  406. t := time.AfterFunc(ss.action.SesssionDuration, func() {
  407. ss.ctx.CloseWithError(userVisibleError{
  408. fmt.Sprintf("Session timeout of %v elapsed.", ss.action.SesssionDuration),
  409. context.DeadlineExceeded,
  410. })
  411. })
  412. defer t.Stop()
  413. }
  414. logf := srv.logf
  415. lu := ss.localUser
  416. localUser := lu.Username
  417. if euid := os.Geteuid(); euid != 0 {
  418. if lu.Uid != fmt.Sprint(euid) {
  419. logf("ssh: can't switch to user %q from process euid %v", localUser, euid)
  420. fmt.Fprintf(ss, "can't switch user\n")
  421. ss.Exit(1)
  422. return
  423. }
  424. }
  425. // Take control of the PTY so that we can configure it below.
  426. // See https://github.com/tailscale/tailscale/issues/4146
  427. ss.DisablePTYEmulation()
  428. if err := ss.handleSSHAgentForwarding(ss, lu); err != nil {
  429. logf("ssh: agent forwarding failed: %v", err)
  430. } else if ss.agentListener != nil {
  431. // TODO(maisem/bradfitz): add a way to close all session resources
  432. defer ss.agentListener.Close()
  433. }
  434. err := ss.launchProcess(ss.ctx)
  435. if err != nil {
  436. logf("start failed: %v", err.Error())
  437. ss.Exit(1)
  438. return
  439. }
  440. go ss.killProcessOnContextDone()
  441. go func() {
  442. _, err := io.Copy(ss.stdin, ss)
  443. if err != nil {
  444. // TODO: don't log in the success case.
  445. logf("ssh: stdin copy: %v", err)
  446. }
  447. ss.stdin.Close()
  448. }()
  449. go func() {
  450. _, err := io.Copy(ss, ss.stdout)
  451. if err != nil {
  452. // TODO: don't log in the success case.
  453. logf("ssh: stdout copy: %v", err)
  454. }
  455. }()
  456. // stderr is nil for ptys.
  457. if ss.stderr != nil {
  458. go func() {
  459. _, err := io.Copy(ss.Stderr(), ss.stderr)
  460. if err != nil {
  461. // TODO: don't log in the success case.
  462. logf("ssh: stderr copy: %v", err)
  463. }
  464. }()
  465. }
  466. err = ss.cmd.Wait()
  467. // This will either make the SSH Termination goroutine be a no-op,
  468. // or itself will be a no-op because the process was killed by the
  469. // aforementioned goroutine.
  470. ss.exitOnce.Do(func() {})
  471. if err == nil {
  472. logf("ssh: Wait: ok")
  473. ss.Exit(0)
  474. return
  475. }
  476. if ee, ok := err.(*exec.ExitError); ok {
  477. code := ee.ProcessState.ExitCode()
  478. logf("ssh: Wait: code=%v", code)
  479. ss.Exit(code)
  480. return
  481. }
  482. logf("ssh: Wait: %v", err)
  483. ss.Exit(1)
  484. return
  485. }
  486. type sshConnInfo struct {
  487. // now is the time to consider the present moment for the
  488. // purposes of rule evaluation.
  489. now time.Time
  490. // sshUser is the requested local SSH username ("root", "alice", etc).
  491. sshUser string
  492. // src is the Tailscale IP and port that the connection came from.
  493. src netaddr.IPPort
  494. // dst is the Tailscale IP and port that the connection came for.
  495. dst netaddr.IPPort
  496. // node is srcIP's node.
  497. node *tailcfg.Node
  498. // uprof is node's UserProfile.
  499. uprof *tailcfg.UserProfile
  500. }
  501. func evalSSHPolicy(pol *tailcfg.SSHPolicy, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, ok bool) {
  502. for _, r := range pol.Rules {
  503. if a, localUser, err := matchRule(r, ci); err == nil {
  504. return a, localUser, true
  505. }
  506. }
  507. return nil, "", false
  508. }
  509. // internal errors for testing; they don't escape to callers or logs.
  510. var (
  511. errNilRule = errors.New("nil rule")
  512. errNilAction = errors.New("nil action")
  513. errRuleExpired = errors.New("rule expired")
  514. errPrincipalMatch = errors.New("principal didn't match")
  515. errUserMatch = errors.New("user didn't match")
  516. )
  517. func matchRule(r *tailcfg.SSHRule, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, err error) {
  518. if r == nil {
  519. return nil, "", errNilRule
  520. }
  521. if r.Action == nil {
  522. return nil, "", errNilAction
  523. }
  524. if r.RuleExpires != nil && ci.now.After(*r.RuleExpires) {
  525. return nil, "", errRuleExpired
  526. }
  527. if !matchesPrincipal(r.Principals, ci) {
  528. return nil, "", errPrincipalMatch
  529. }
  530. if !r.Action.Reject || r.SSHUsers != nil {
  531. localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
  532. if localUser == "" {
  533. return nil, "", errUserMatch
  534. }
  535. }
  536. return r.Action, localUser, nil
  537. }
  538. func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser string) {
  539. if v, ok := ruleSSHUsers[reqSSHUser]; ok {
  540. return v
  541. }
  542. return ruleSSHUsers["*"]
  543. }
  544. func matchesPrincipal(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
  545. for _, p := range ps {
  546. if p == nil {
  547. continue
  548. }
  549. if p.Any {
  550. return true
  551. }
  552. if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID {
  553. return true
  554. }
  555. if p.NodeIP != "" {
  556. if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() {
  557. return true
  558. }
  559. }
  560. if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin {
  561. return true
  562. }
  563. }
  564. return false
  565. }
  566. func randBytes(n int) []byte {
  567. b := make([]byte, n)
  568. if _, err := rand.Read(b); err != nil {
  569. panic(err)
  570. }
  571. return b
  572. }