tailssh.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build linux || (darwin && !ios)
  5. // +build linux darwin,!ios
  6. // Package tailssh is an SSH server integrated into Tailscale.
  7. package tailssh
  8. import (
  9. "context"
  10. "crypto/rand"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "os"
  18. "os/exec"
  19. "os/user"
  20. "path/filepath"
  21. "strconv"
  22. "strings"
  23. "sync"
  24. "time"
  25. "github.com/tailscale/ssh"
  26. "inet.af/netaddr"
  27. "tailscale.com/envknob"
  28. "tailscale.com/ipn/ipnlocal"
  29. "tailscale.com/logtail/backoff"
  30. "tailscale.com/net/tsaddr"
  31. "tailscale.com/tailcfg"
  32. "tailscale.com/types/logger"
  33. )
  34. // TODO(bradfitz): this is all very temporary as code is temporarily
  35. // being moved around; it will be restructured and documented in
  36. // following commits.
  37. // Handle handles an SSH connection from c.
  38. func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error {
  39. tsd, err := os.Executable()
  40. if err != nil {
  41. return err
  42. }
  43. srv := &server{
  44. lb: lb,
  45. logf: logf,
  46. tailscaledPath: tsd,
  47. }
  48. ss, err := srv.newSSHServer()
  49. if err != nil {
  50. return err
  51. }
  52. ss.HandleConn(c)
  53. return nil
  54. }
  55. func (srv *server) newSSHServer() (*ssh.Server, error) {
  56. ss := &ssh.Server{
  57. Handler: srv.handleSSH,
  58. RequestHandlers: map[string]ssh.RequestHandler{},
  59. SubsystemHandlers: map[string]ssh.SubsystemHandler{},
  60. // Note: the direct-tcpip channel handler and LocalPortForwardingCallback
  61. // only adds support for forwarding ports from the local machine.
  62. // TODO(maisem/bradfitz): add remote port forwarding support.
  63. ChannelHandlers: map[string]ssh.ChannelHandler{
  64. "direct-tcpip": ssh.DirectTCPIPHandler,
  65. },
  66. Version: "SSH-2.0-Tailscale",
  67. LocalPortForwardingCallback: srv.mayForwardLocalPortTo,
  68. }
  69. for k, v := range ssh.DefaultRequestHandlers {
  70. ss.RequestHandlers[k] = v
  71. }
  72. for k, v := range ssh.DefaultChannelHandlers {
  73. ss.ChannelHandlers[k] = v
  74. }
  75. for k, v := range ssh.DefaultSubsystemHandlers {
  76. ss.SubsystemHandlers[k] = v
  77. }
  78. keys, err := srv.lb.GetSSH_HostKeys()
  79. if err != nil {
  80. return nil, err
  81. }
  82. for _, signer := range keys {
  83. ss.AddHostKey(signer)
  84. }
  85. return ss, nil
  86. }
  87. type server struct {
  88. lb *ipnlocal.LocalBackend
  89. logf logger.Logf
  90. tailscaledPath string
  91. // mu protects activeSessions.
  92. mu sync.Mutex
  93. activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => that session
  94. activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
  95. }
  96. var debugPolicyFile = envknob.String("TS_DEBUG_SSH_POLICY_FILE")
  97. // mayForwardLocalPortTo reports whether the ctx should be allowed to port forward
  98. // to the specified host and port.
  99. // TODO(bradfitz/maisem): should we have more checks on host/port?
  100. func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
  101. ss, ok := srv.getSessionForContext(ctx)
  102. if !ok {
  103. return false
  104. }
  105. return ss.action.AllowLocalPortForwarding
  106. }
  107. // sshPolicy returns the SSHPolicy for current node.
  108. // If there is no SSHPolicy in the netmap, it returns a debugPolicy
  109. // if one is defined.
  110. func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
  111. lb := srv.lb
  112. nm := lb.NetMap()
  113. if nm == nil {
  114. return nil, false
  115. }
  116. if pol := nm.SSHPolicy; pol != nil {
  117. return pol, true
  118. }
  119. if debugPolicyFile != "" {
  120. f, err := os.ReadFile(debugPolicyFile)
  121. if err != nil {
  122. srv.logf("error reading debug SSH policy file: %v", err)
  123. return nil, false
  124. }
  125. p := new(tailcfg.SSHPolicy)
  126. if err := json.Unmarshal(f, p); err != nil {
  127. srv.logf("invalid JSON in %v: %v", debugPolicyFile, err)
  128. return nil, false
  129. }
  130. return p, true
  131. }
  132. return nil, false
  133. }
  134. func asTailscaleIPPort(a net.Addr) (netaddr.IPPort, error) {
  135. ta, ok := a.(*net.TCPAddr)
  136. if !ok {
  137. return netaddr.IPPort{}, fmt.Errorf("non-TCP addr %T %v", a, a)
  138. }
  139. tanetaddr, ok := netaddr.FromStdIP(ta.IP)
  140. if !ok {
  141. return netaddr.IPPort{}, fmt.Errorf("unparseable addr %v", ta.IP)
  142. }
  143. if !tsaddr.IsTailscaleIP(tanetaddr) {
  144. return netaddr.IPPort{}, fmt.Errorf("non-Tailscale addr %v", ta.IP)
  145. }
  146. return netaddr.IPPortFrom(tanetaddr, uint16(ta.Port)), nil
  147. }
  148. // evaluatePolicy returns the SSHAction, sshConnInfo and localUser
  149. // after evaluating the sshUser and remoteAddr against the SSHPolicy.
  150. // The remoteAddr and localAddr params must be Tailscale IPs.
  151. func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) {
  152. logf := srv.logf
  153. lb := srv.lb
  154. logf("Handling SSH from %v for user %v", remoteAddr, sshUser)
  155. pol, ok := srv.sshPolicy()
  156. if !ok {
  157. return nil, nil, "", fmt.Errorf("tsshd: rejecting connection; no SSH policy")
  158. }
  159. srcIPP, err := asTailscaleIPPort(remoteAddr)
  160. if err != nil {
  161. return nil, nil, "", fmt.Errorf("tsshd: rejecting: %w", err)
  162. }
  163. dstIPP, err := asTailscaleIPPort(localAddr)
  164. if err != nil {
  165. return nil, nil, "", err
  166. }
  167. node, uprof, ok := lb.WhoIs(srcIPP)
  168. if !ok {
  169. return nil, nil, "", fmt.Errorf("Hello, %v. I don't know who you are.\n", srcIPP)
  170. }
  171. ci := &sshConnInfo{
  172. now: time.Now(),
  173. sshUser: sshUser,
  174. src: srcIPP,
  175. dst: dstIPP,
  176. node: node,
  177. uprof: &uprof,
  178. }
  179. a, localUser, ok := evalSSHPolicy(pol, ci)
  180. if !ok {
  181. return nil, nil, "", fmt.Errorf("ssh: access denied for %q from %v", uprof.LoginName, ci.src.IP())
  182. }
  183. return a, ci, localUser, nil
  184. }
  185. // handleSSH is invoked when a new SSH connection attempt is made.
  186. func (srv *server) handleSSH(s ssh.Session) {
  187. logf := srv.logf
  188. sshUser := s.User()
  189. action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr())
  190. if err != nil {
  191. logf(err.Error())
  192. s.Exit(1)
  193. return
  194. }
  195. // Loop processing/fetching Actions until one reaches a
  196. // terminal state (Accept, Reject, or invalid Action), or
  197. // until fetchSSHAction times out due to the context being
  198. // done (client disconnect) or its 30 minute timeout passes.
  199. // (Which is a long time for somebody to see login
  200. // instructions and go to a URL to do something.)
  201. ProcessAction:
  202. for {
  203. if action.Message != "" {
  204. io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
  205. }
  206. if action.Reject {
  207. logf("ssh: access denied for %q from %v", ci.uprof.LoginName, ci.src.IP())
  208. s.Exit(1)
  209. return
  210. }
  211. if action.Accept {
  212. break ProcessAction
  213. }
  214. url := action.HoldAndDelegate
  215. if url == "" {
  216. logf("ssh: access denied; SSHAction has neither Reject, Accept, or next step URL")
  217. s.Exit(1)
  218. return
  219. }
  220. action, err = srv.fetchSSHAction(s.Context(), url)
  221. if err != nil {
  222. logf("ssh: fetching SSAction from %s: %v", url, err)
  223. s.Exit(1)
  224. return
  225. }
  226. }
  227. lu, err := user.Lookup(localUser)
  228. if err != nil {
  229. logf("ssh: user Lookup %q: %v", localUser, err)
  230. s.Exit(1)
  231. return
  232. }
  233. ss := srv.newSSHSession(s, ci, lu, action)
  234. ss.run()
  235. }
  236. // sshSession is an accepted Tailscale SSH session.
  237. type sshSession struct {
  238. ssh.Session
  239. idH string // the RFC4253 sec8 hash H; don't share outside process
  240. sharedID string // ID that's shared with control
  241. logf logger.Logf
  242. ctx *sshContext // implements context.Context
  243. srv *server
  244. connInfo *sshConnInfo
  245. action *tailcfg.SSHAction
  246. localUser *user.User
  247. agentListener net.Listener // non-nil if agent-forwarding requested+allowed
  248. // initialized by launchProcess:
  249. cmd *exec.Cmd
  250. stdin io.WriteCloser
  251. stdout io.Reader
  252. stderr io.Reader // nil for pty sessions
  253. ptyReq *ssh.Pty // non-nil for pty sessions
  254. // We use this sync.Once to ensure that we only terminate the process once,
  255. // either it exits itself or is terminated
  256. exitOnce sync.Once
  257. }
  258. func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User, action *tailcfg.SSHAction) *sshSession {
  259. sharedID := fmt.Sprintf("%s-%02x", ci.now.UTC().Format("20060102T150405"), randBytes(5))
  260. return &sshSession{
  261. Session: s,
  262. idH: s.Context().(ssh.Context).SessionID(),
  263. sharedID: sharedID,
  264. ctx: newSSHContext(),
  265. srv: srv,
  266. action: action,
  267. localUser: lu,
  268. connInfo: ci,
  269. logf: logger.WithPrefix(srv.logf, "ssh-session("+sharedID+"): "),
  270. }
  271. }
  272. func (srv *server) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) {
  273. ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
  274. defer cancel()
  275. bo := backoff.NewBackoff("fetch-ssh-action", srv.logf, 10*time.Second)
  276. for {
  277. if err := ctx.Err(); err != nil {
  278. return nil, err
  279. }
  280. req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  281. if err != nil {
  282. return nil, err
  283. }
  284. res, err := srv.lb.DoNoiseRequest(req)
  285. if err != nil {
  286. bo.BackOff(ctx, err)
  287. continue
  288. }
  289. if res.StatusCode != 200 {
  290. res.Body.Close()
  291. bo.BackOff(ctx, fmt.Errorf("unexpected status: %v", res.Status))
  292. continue
  293. }
  294. a := new(tailcfg.SSHAction)
  295. if err := json.NewDecoder(res.Body).Decode(a); err != nil {
  296. bo.BackOff(ctx, err)
  297. continue
  298. }
  299. return a, nil
  300. }
  301. }
  302. // killProcessOnContextDone waits for ss.ctx to be done and kills the process,
  303. // unless the process has already exited.
  304. func (ss *sshSession) killProcessOnContextDone() {
  305. <-ss.ctx.Done()
  306. // Either the process has already existed, in which case this does nothing.
  307. // Or, the process is still running in which case this will kill it.
  308. ss.exitOnce.Do(func() {
  309. err := ss.ctx.Err()
  310. if serr, ok := err.(SSHTerminationError); ok {
  311. msg := serr.SSHTerminationMessage()
  312. if msg != "" {
  313. io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
  314. }
  315. }
  316. ss.logf("terminating SSH session from %v: %v", ss.connInfo.src.IP(), err)
  317. ss.cmd.Process.Kill()
  318. })
  319. }
  320. // sessionAction returns the SSHAction associated with the session.
  321. func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bool) {
  322. srv.mu.Lock()
  323. defer srv.mu.Unlock()
  324. ss, ok = srv.activeSessionByH[sctx.SessionID()]
  325. return
  326. }
  327. // startSession registers ss as an active session.
  328. func (srv *server) startSession(ss *sshSession) {
  329. srv.mu.Lock()
  330. defer srv.mu.Unlock()
  331. if srv.activeSessionByH == nil {
  332. srv.activeSessionByH = make(map[string]*sshSession)
  333. }
  334. if srv.activeSessionBySharedID == nil {
  335. srv.activeSessionBySharedID = make(map[string]*sshSession)
  336. }
  337. if ss.idH == "" {
  338. panic("empty idH")
  339. }
  340. if _, dup := srv.activeSessionByH[ss.idH]; dup {
  341. panic("dup idH")
  342. }
  343. if ss.sharedID == "" {
  344. panic("empty sharedID")
  345. }
  346. if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
  347. panic("dup sharedID")
  348. }
  349. srv.activeSessionByH[ss.idH] = ss
  350. srv.activeSessionBySharedID[ss.sharedID] = ss
  351. }
  352. // endSession unregisters s from the list of active sessions.
  353. func (srv *server) endSession(ss *sshSession) {
  354. srv.mu.Lock()
  355. defer srv.mu.Unlock()
  356. delete(srv.activeSessionByH, ss.idH)
  357. delete(srv.activeSessionBySharedID, ss.sharedID)
  358. }
  359. var errSessionDone = errors.New("session is done")
  360. // handleSSHAgentForwarding starts a Unix socket listener and in the background
  361. // forwards agent connections between the listenr and the ssh.Session.
  362. // On success, it assigns ss.agentListener.
  363. func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) error {
  364. if !ssh.AgentRequested(ss) || !ss.action.AllowAgentForwarding {
  365. return nil
  366. }
  367. ss.logf("ssh: agent forwarding requested")
  368. ln, err := ssh.NewAgentListener()
  369. if err != nil {
  370. return err
  371. }
  372. defer func() {
  373. if err != nil && ln != nil {
  374. ln.Close()
  375. }
  376. }()
  377. uid, err := strconv.ParseUint(lu.Uid, 10, 32)
  378. if err != nil {
  379. return err
  380. }
  381. gid, err := strconv.ParseUint(lu.Gid, 10, 32)
  382. if err != nil {
  383. return err
  384. }
  385. socket := ln.Addr().String()
  386. dir := filepath.Dir(socket)
  387. // Make sure the socket is accessible by the user.
  388. if err := os.Chown(socket, int(uid), int(gid)); err != nil {
  389. return err
  390. }
  391. if err := os.Chmod(dir, 0755); err != nil {
  392. return err
  393. }
  394. go ssh.ForwardAgentConnections(ln, s)
  395. ss.agentListener = ln
  396. return nil
  397. }
  398. // run is the entrypoint for a newly accepted SSH session.
  399. //
  400. // When ctx is done, the session is forcefully terminated. If its Err
  401. // is an SSHTerminationError, its SSHTerminationMessage is sent to the
  402. // user.
  403. func (ss *sshSession) run() {
  404. srv := ss.srv
  405. srv.startSession(ss)
  406. defer srv.endSession(ss)
  407. defer ss.ctx.CloseWithError(errSessionDone)
  408. if ss.action.SesssionDuration != 0 {
  409. t := time.AfterFunc(ss.action.SesssionDuration, func() {
  410. ss.ctx.CloseWithError(userVisibleError{
  411. fmt.Sprintf("Session timeout of %v elapsed.", ss.action.SesssionDuration),
  412. context.DeadlineExceeded,
  413. })
  414. })
  415. defer t.Stop()
  416. }
  417. logf := srv.logf
  418. lu := ss.localUser
  419. localUser := lu.Username
  420. if euid := os.Geteuid(); euid != 0 {
  421. if lu.Uid != fmt.Sprint(euid) {
  422. logf("ssh: can't switch to user %q from process euid %v", localUser, euid)
  423. fmt.Fprintf(ss, "can't switch user\n")
  424. ss.Exit(1)
  425. return
  426. }
  427. }
  428. // Take control of the PTY so that we can configure it below.
  429. // See https://github.com/tailscale/tailscale/issues/4146
  430. ss.DisablePTYEmulation()
  431. if err := ss.handleSSHAgentForwarding(ss, lu); err != nil {
  432. logf("ssh: agent forwarding failed: %v", err)
  433. } else if ss.agentListener != nil {
  434. // TODO(maisem/bradfitz): add a way to close all session resources
  435. defer ss.agentListener.Close()
  436. }
  437. err := ss.launchProcess(ss.ctx)
  438. if err != nil {
  439. logf("start failed: %v", err.Error())
  440. ss.Exit(1)
  441. return
  442. }
  443. go ss.killProcessOnContextDone()
  444. go func() {
  445. _, err := io.Copy(ss.stdin, ss)
  446. if err != nil {
  447. // TODO: don't log in the success case.
  448. logf("ssh: stdin copy: %v", err)
  449. }
  450. ss.stdin.Close()
  451. }()
  452. go func() {
  453. _, err := io.Copy(ss, ss.stdout)
  454. if err != nil {
  455. // TODO: don't log in the success case.
  456. logf("ssh: stdout copy: %v", err)
  457. }
  458. }()
  459. // stderr is nil for ptys.
  460. if ss.stderr != nil {
  461. go func() {
  462. _, err := io.Copy(ss.Stderr(), ss.stderr)
  463. if err != nil {
  464. // TODO: don't log in the success case.
  465. logf("ssh: stderr copy: %v", err)
  466. }
  467. }()
  468. }
  469. err = ss.cmd.Wait()
  470. // This will either make the SSH Termination goroutine be a no-op,
  471. // or itself will be a no-op because the process was killed by the
  472. // aforementioned goroutine.
  473. ss.exitOnce.Do(func() {})
  474. if err == nil {
  475. logf("ssh: Wait: ok")
  476. ss.Exit(0)
  477. return
  478. }
  479. if ee, ok := err.(*exec.ExitError); ok {
  480. code := ee.ProcessState.ExitCode()
  481. logf("ssh: Wait: code=%v", code)
  482. ss.Exit(code)
  483. return
  484. }
  485. logf("ssh: Wait: %v", err)
  486. ss.Exit(1)
  487. return
  488. }
  489. type sshConnInfo struct {
  490. // now is the time to consider the present moment for the
  491. // purposes of rule evaluation.
  492. now time.Time
  493. // sshUser is the requested local SSH username ("root", "alice", etc).
  494. sshUser string
  495. // src is the Tailscale IP and port that the connection came from.
  496. src netaddr.IPPort
  497. // dst is the Tailscale IP and port that the connection came for.
  498. dst netaddr.IPPort
  499. // node is srcIP's node.
  500. node *tailcfg.Node
  501. // uprof is node's UserProfile.
  502. uprof *tailcfg.UserProfile
  503. }
  504. func evalSSHPolicy(pol *tailcfg.SSHPolicy, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, ok bool) {
  505. for _, r := range pol.Rules {
  506. if a, localUser, err := matchRule(r, ci); err == nil {
  507. return a, localUser, true
  508. }
  509. }
  510. return nil, "", false
  511. }
  512. // internal errors for testing; they don't escape to callers or logs.
  513. var (
  514. errNilRule = errors.New("nil rule")
  515. errNilAction = errors.New("nil action")
  516. errRuleExpired = errors.New("rule expired")
  517. errPrincipalMatch = errors.New("principal didn't match")
  518. errUserMatch = errors.New("user didn't match")
  519. )
  520. func matchRule(r *tailcfg.SSHRule, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, err error) {
  521. if r == nil {
  522. return nil, "", errNilRule
  523. }
  524. if r.Action == nil {
  525. return nil, "", errNilAction
  526. }
  527. if r.RuleExpires != nil && ci.now.After(*r.RuleExpires) {
  528. return nil, "", errRuleExpired
  529. }
  530. if !matchesPrincipal(r.Principals, ci) {
  531. return nil, "", errPrincipalMatch
  532. }
  533. if !r.Action.Reject || r.SSHUsers != nil {
  534. localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
  535. if localUser == "" {
  536. return nil, "", errUserMatch
  537. }
  538. }
  539. return r.Action, localUser, nil
  540. }
  541. func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser string) {
  542. if v, ok := ruleSSHUsers[reqSSHUser]; ok {
  543. return v
  544. }
  545. return ruleSSHUsers["*"]
  546. }
  547. func matchesPrincipal(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
  548. for _, p := range ps {
  549. if p == nil {
  550. continue
  551. }
  552. if p.Any {
  553. return true
  554. }
  555. if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID {
  556. return true
  557. }
  558. if p.NodeIP != "" {
  559. if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() {
  560. return true
  561. }
  562. }
  563. if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin {
  564. return true
  565. }
  566. }
  567. return false
  568. }
  569. func randBytes(n int) []byte {
  570. b := make([]byte, n)
  571. if _, err := rand.Read(b); err != nil {
  572. panic(err)
  573. }
  574. return b
  575. }