session.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package sshd
  2. import (
  3. "fmt"
  4. "sort"
  5. "strings"
  6. "github.com/anmitsu/go-shlex"
  7. "github.com/armon/go-radix"
  8. "github.com/sirupsen/logrus"
  9. "golang.org/x/crypto/ssh"
  10. "golang.org/x/term"
  11. )
  12. type session struct {
  13. l *logrus.Entry
  14. c *ssh.ServerConn
  15. term *term.Terminal
  16. commands *radix.Tree
  17. exitChan chan bool
  18. }
  19. func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session {
  20. s := &session{
  21. commands: radix.NewFromMap(commands.ToMap()),
  22. l: l,
  23. c: conn,
  24. exitChan: make(chan bool),
  25. }
  26. s.commands.Insert("logout", &Command{
  27. Name: "logout",
  28. ShortDescription: "Ends the current session",
  29. Callback: func(a any, args []string, w StringWriter) error {
  30. s.Close()
  31. return nil
  32. },
  33. })
  34. go s.handleChannels(chans)
  35. return s
  36. }
  37. func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
  38. for newChannel := range chans {
  39. if newChannel.ChannelType() != "session" {
  40. s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
  41. newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
  42. continue
  43. }
  44. channel, requests, err := newChannel.Accept()
  45. if err != nil {
  46. s.l.WithError(err).Warn("could not accept channel")
  47. continue
  48. }
  49. go s.handleRequests(requests, channel)
  50. }
  51. }
  52. func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
  53. defer s.Close()
  54. for req := range in {
  55. var err error
  56. switch req.Type {
  57. case "shell":
  58. if s.term == nil {
  59. s.term = s.createTerm(channel)
  60. err = req.Reply(true, nil)
  61. } else {
  62. err = req.Reply(false, nil)
  63. }
  64. case "pty-req":
  65. err = req.Reply(true, nil)
  66. case "window-change":
  67. err = req.Reply(true, nil)
  68. case "exec":
  69. var payload = struct{ Value string }{}
  70. cErr := ssh.Unmarshal(req.Payload, &payload)
  71. if cErr != nil {
  72. req.Reply(false, nil)
  73. return
  74. }
  75. req.Reply(true, nil)
  76. s.dispatchCommand(payload.Value, &stringWriter{channel})
  77. status := struct{ Status uint32 }{uint32(0)}
  78. channel.SendRequest("exit-status", false, ssh.Marshal(status))
  79. channel.Close()
  80. return
  81. default:
  82. s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
  83. err = req.Reply(false, nil)
  84. }
  85. if err != nil {
  86. s.l.WithError(err).Info("Error handling ssh session requests")
  87. s.Close()
  88. return
  89. }
  90. }
  91. }
  92. func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
  93. term := term.NewTerminal(channel, s.c.User()+"@nebula > ")
  94. term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
  95. // key 9 is tab
  96. if key == 9 {
  97. cmds := matchCommand(s.commands, line)
  98. if len(cmds) == 1 {
  99. return cmds[0] + " ", len(cmds[0]) + 1, true
  100. }
  101. sort.Strings(cmds)
  102. term.Write([]byte(strings.Join(cmds, "\n") + "\n\n"))
  103. }
  104. return "", 0, false
  105. }
  106. go s.handleInput(channel)
  107. return term
  108. }
  109. func (s *session) handleInput(channel ssh.Channel) {
  110. defer s.Close()
  111. w := &stringWriter{w: s.term}
  112. for {
  113. line, err := s.term.ReadLine()
  114. if err != nil {
  115. break
  116. }
  117. s.dispatchCommand(line, w)
  118. }
  119. }
  120. func (s *session) dispatchCommand(line string, w StringWriter) {
  121. args, err := shlex.Split(line, true)
  122. if err != nil {
  123. return
  124. }
  125. if len(args) == 0 {
  126. dumpCommands(s.commands, w)
  127. return
  128. }
  129. c, err := lookupCommand(s.commands, args[0])
  130. if err != nil {
  131. return
  132. }
  133. if c == nil {
  134. err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
  135. _ = err
  136. dumpCommands(s.commands, w)
  137. return
  138. }
  139. if checkHelpArgs(args) {
  140. s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w)
  141. return
  142. }
  143. _ = execCommand(c, args[1:], w)
  144. return
  145. }
  146. func (s *session) Close() {
  147. s.c.Close()
  148. s.exitChan <- true
  149. }