ssh.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build linux
  5. // +build linux
  6. package netstack
  7. import (
  8. "encoding/json"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "log"
  13. "net"
  14. "os"
  15. "os/exec"
  16. "syscall"
  17. "unsafe"
  18. "github.com/creack/pty"
  19. "github.com/gliderlabs/ssh"
  20. gossh "golang.org/x/crypto/ssh"
  21. "inet.af/netaddr"
  22. "tailscale.com/envknob"
  23. "tailscale.com/net/tsaddr"
  24. )
  25. func init() {
  26. sshDemo = sshDemoImpl
  27. }
  28. func sshDemoImpl(ns *Impl, c net.Conn) error {
  29. hostKey, err := ioutil.ReadFile("/etc/ssh/ssh_host_ed25519_key")
  30. if err != nil {
  31. return err
  32. }
  33. signer, err := gossh.ParsePrivateKey(hostKey)
  34. if err != nil {
  35. return err
  36. }
  37. srv := &ssh.Server{
  38. Handler: ns.handleSSH,
  39. RequestHandlers: map[string]ssh.RequestHandler{},
  40. SubsystemHandlers: map[string]ssh.SubsystemHandler{},
  41. ChannelHandlers: map[string]ssh.ChannelHandler{},
  42. }
  43. for k, v := range ssh.DefaultRequestHandlers {
  44. srv.RequestHandlers[k] = v
  45. }
  46. for k, v := range ssh.DefaultChannelHandlers {
  47. srv.ChannelHandlers[k] = v
  48. }
  49. for k, v := range ssh.DefaultSubsystemHandlers {
  50. srv.SubsystemHandlers[k] = v
  51. }
  52. srv.AddHostKey(signer)
  53. srv.HandleConn(c)
  54. return nil
  55. }
  56. func (ns *Impl) handleSSH(s ssh.Session) {
  57. lb := ns.lb
  58. user := s.User()
  59. addr := s.RemoteAddr()
  60. log.Printf("Handling SSH from %v for user %v", addr, user)
  61. ta, ok := addr.(*net.TCPAddr)
  62. if !ok {
  63. log.Printf("tsshd: rejecting non-TCP addr %T %v", addr, addr)
  64. s.Exit(1)
  65. return
  66. }
  67. tanetaddr, ok := netaddr.FromStdIP(ta.IP)
  68. if !ok {
  69. log.Printf("tsshd: rejecting unparseable addr %v", ta.IP)
  70. s.Exit(1)
  71. return
  72. }
  73. if !tsaddr.IsTailscaleIP(tanetaddr) {
  74. log.Printf("tsshd: rejecting non-Tailscale addr %v", ta.IP)
  75. s.Exit(1)
  76. return
  77. }
  78. ptyReq, winCh, isPty := s.Pty()
  79. if !isPty {
  80. fmt.Fprintf(s, "TODO scp etc")
  81. s.Exit(1)
  82. return
  83. }
  84. srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
  85. node, uprof, ok := lb.WhoIs(srcIPP)
  86. if !ok {
  87. fmt.Fprintf(s, "Hello, %v. I don't know who you are.\n", srcIPP)
  88. s.Exit(0)
  89. return
  90. }
  91. allow := envknob.String("TS_SSH_ALLOW_LOGIN")
  92. if allow == "" || uprof.LoginName != allow {
  93. log.Printf("ssh: access denied for %q (only allowing %q)", uprof.LoginName, allow)
  94. jnode, _ := json.Marshal(node)
  95. jprof, _ := json.Marshal(uprof)
  96. fmt.Fprintf(s, "Access denied.\n\nYou are node: %s\n\nYour profile: %s\n\nYou wanted %+v\n", jnode, jprof, ptyReq)
  97. s.Exit(1)
  98. return
  99. }
  100. cmd := exec.Command("/bin/bash")
  101. cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
  102. f, err := pty.Start(cmd)
  103. if err != nil {
  104. log.Printf("running shell: %v", err)
  105. s.Exit(1)
  106. return
  107. }
  108. defer f.Close()
  109. go func() {
  110. for win := range winCh {
  111. setWinsize(f, win.Width, win.Height)
  112. }
  113. }()
  114. go func() {
  115. io.Copy(f, s) // stdin
  116. }()
  117. io.Copy(s, f) // stdout
  118. cmd.Process.Kill()
  119. if err := cmd.Wait(); err != nil {
  120. s.Exit(1)
  121. }
  122. s.Exit(0)
  123. return
  124. }
  125. func setWinsize(f *os.File, w, h int) {
  126. syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
  127. uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
  128. }