server.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. package ssh
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "time"
  9. gossh "golang.org/x/crypto/ssh"
  10. )
  11. // ErrServerClosed is returned by the Server's Serve, ListenAndServe,
  12. // and ListenAndServeTLS methods after a call to Shutdown or Close.
  13. var ErrServerClosed = errors.New("ssh: Server closed")
  14. type SubsystemHandler func(s Session)
  15. var DefaultSubsystemHandlers = map[string]SubsystemHandler{}
  16. type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
  17. var DefaultRequestHandlers = map[string]RequestHandler{}
  18. type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
  19. var DefaultChannelHandlers = map[string]ChannelHandler{
  20. "session": DefaultSessionHandler,
  21. }
  22. // Server defines parameters for running an SSH server. The zero value for
  23. // Server is a valid configuration. When both PasswordHandler and
  24. // PublicKeyHandler are nil, no client authentication is performed.
  25. type Server struct {
  26. Addr string // TCP address to listen on, ":22" if empty
  27. Handler Handler // handler to invoke, ssh.DefaultHandler if nil
  28. HostSigners []Signer // private keys for the host key, must have at least one
  29. Version string // server version to be sent before the initial handshake
  30. KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler
  31. PasswordHandler PasswordHandler // password authentication handler
  32. PublicKeyHandler PublicKeyHandler // public key authentication handler
  33. NoClientAuthHandler NoClientAuthHandler // no client authentication handler
  34. PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
  35. ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
  36. LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
  37. ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
  38. ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
  39. SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
  40. ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures
  41. IdleTimeout time.Duration // connection timeout when no activity, none if empty
  42. MaxTimeout time.Duration // absolute connection timeout, none if empty
  43. // ChannelHandlers allow overriding the built-in session handlers or provide
  44. // extensions to the protocol, such as tcpip forwarding. By default only the
  45. // "session" handler is enabled.
  46. ChannelHandlers map[string]ChannelHandler
  47. // RequestHandlers allow overriding the server-level request handlers or
  48. // provide extensions to the protocol, such as tcpip forwarding. By default
  49. // no handlers are enabled.
  50. RequestHandlers map[string]RequestHandler
  51. // SubsystemHandlers are handlers which are similar to the usual SSH command
  52. // handlers, but handle named subsystems.
  53. SubsystemHandlers map[string]SubsystemHandler
  54. listenerWg sync.WaitGroup
  55. mu sync.RWMutex
  56. listeners map[net.Listener]struct{}
  57. conns map[*gossh.ServerConn]struct{}
  58. connWg sync.WaitGroup
  59. doneChan chan struct{}
  60. }
  61. func (srv *Server) ensureHostSigner() error {
  62. srv.mu.Lock()
  63. defer srv.mu.Unlock()
  64. if len(srv.HostSigners) == 0 {
  65. signer, err := generateSigner()
  66. if err != nil {
  67. return err
  68. }
  69. srv.HostSigners = append(srv.HostSigners, signer)
  70. }
  71. return nil
  72. }
  73. func (srv *Server) ensureHandlers() {
  74. srv.mu.Lock()
  75. defer srv.mu.Unlock()
  76. if srv.RequestHandlers == nil {
  77. srv.RequestHandlers = map[string]RequestHandler{}
  78. for k, v := range DefaultRequestHandlers {
  79. srv.RequestHandlers[k] = v
  80. }
  81. }
  82. if srv.ChannelHandlers == nil {
  83. srv.ChannelHandlers = map[string]ChannelHandler{}
  84. for k, v := range DefaultChannelHandlers {
  85. srv.ChannelHandlers[k] = v
  86. }
  87. }
  88. if srv.SubsystemHandlers == nil {
  89. srv.SubsystemHandlers = map[string]SubsystemHandler{}
  90. for k, v := range DefaultSubsystemHandlers {
  91. srv.SubsystemHandlers[k] = v
  92. }
  93. }
  94. }
  95. func (srv *Server) config(ctx Context) *gossh.ServerConfig {
  96. srv.mu.RLock()
  97. defer srv.mu.RUnlock()
  98. var config *gossh.ServerConfig
  99. if srv.ServerConfigCallback == nil {
  100. config = &gossh.ServerConfig{}
  101. } else {
  102. config = srv.ServerConfigCallback(ctx)
  103. }
  104. for _, signer := range srv.HostSigners {
  105. config.AddHostKey(signer)
  106. }
  107. if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil {
  108. config.NoClientAuth = true
  109. }
  110. if srv.Version != "" {
  111. config.ServerVersion = "SSH-2.0-" + srv.Version
  112. }
  113. if srv.PasswordHandler != nil {
  114. config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
  115. applyConnMetadata(ctx, conn)
  116. if ok := srv.PasswordHandler(ctx, string(password)); !ok {
  117. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  118. }
  119. return ctx.Permissions().Permissions, nil
  120. }
  121. }
  122. if srv.PublicKeyHandler != nil {
  123. config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
  124. applyConnMetadata(ctx, conn)
  125. if err := srv.PublicKeyHandler(ctx, key); err != nil {
  126. return ctx.Permissions().Permissions, err
  127. }
  128. ctx.SetValue(ContextKeyPublicKey, key)
  129. return ctx.Permissions().Permissions, nil
  130. }
  131. }
  132. if srv.KeyboardInteractiveHandler != nil {
  133. config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) {
  134. applyConnMetadata(ctx, conn)
  135. if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok {
  136. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  137. }
  138. return ctx.Permissions().Permissions, nil
  139. }
  140. }
  141. if srv.NoClientAuthHandler != nil {
  142. config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) {
  143. applyConnMetadata(ctx, conn)
  144. if err := srv.NoClientAuthHandler(ctx); err != nil {
  145. return ctx.Permissions().Permissions, err
  146. }
  147. return ctx.Permissions().Permissions, nil
  148. }
  149. }
  150. return config
  151. }
  152. // Handle sets the Handler for the server.
  153. func (srv *Server) Handle(fn Handler) {
  154. srv.mu.Lock()
  155. defer srv.mu.Unlock()
  156. srv.Handler = fn
  157. }
  158. // Close immediately closes all active listeners and all active
  159. // connections.
  160. //
  161. // Close returns any error returned from closing the Server's
  162. // underlying Listener(s).
  163. func (srv *Server) Close() error {
  164. srv.mu.Lock()
  165. defer srv.mu.Unlock()
  166. srv.closeDoneChanLocked()
  167. err := srv.closeListenersLocked()
  168. for c := range srv.conns {
  169. c.Close()
  170. delete(srv.conns, c)
  171. }
  172. return err
  173. }
  174. // Shutdown gracefully shuts down the server without interrupting any
  175. // active connections. Shutdown works by first closing all open
  176. // listeners, and then waiting indefinitely for connections to close.
  177. // If the provided context expires before the shutdown is complete,
  178. // then the context's error is returned.
  179. func (srv *Server) Shutdown(ctx context.Context) error {
  180. srv.mu.Lock()
  181. lnerr := srv.closeListenersLocked()
  182. srv.closeDoneChanLocked()
  183. srv.mu.Unlock()
  184. finished := make(chan struct{}, 1)
  185. go func() {
  186. srv.listenerWg.Wait()
  187. srv.connWg.Wait()
  188. finished <- struct{}{}
  189. }()
  190. select {
  191. case <-ctx.Done():
  192. return ctx.Err()
  193. case <-finished:
  194. return lnerr
  195. }
  196. }
  197. // Serve accepts incoming connections on the Listener l, creating a new
  198. // connection goroutine for each. The connection goroutines read requests and then
  199. // calls srv.Handler to handle sessions.
  200. //
  201. // Serve always returns a non-nil error.
  202. func (srv *Server) Serve(l net.Listener) error {
  203. srv.ensureHandlers()
  204. defer l.Close()
  205. if err := srv.ensureHostSigner(); err != nil {
  206. return err
  207. }
  208. if srv.Handler == nil {
  209. srv.Handler = DefaultHandler
  210. }
  211. var tempDelay time.Duration
  212. srv.trackListener(l, true)
  213. defer srv.trackListener(l, false)
  214. for {
  215. conn, e := l.Accept()
  216. if e != nil {
  217. select {
  218. case <-srv.getDoneChan():
  219. return ErrServerClosed
  220. default:
  221. }
  222. if ne, ok := e.(net.Error); ok && ne.Temporary() {
  223. if tempDelay == 0 {
  224. tempDelay = 5 * time.Millisecond
  225. } else {
  226. tempDelay *= 2
  227. }
  228. if max := 1 * time.Second; tempDelay > max {
  229. tempDelay = max
  230. }
  231. time.Sleep(tempDelay)
  232. continue
  233. }
  234. return e
  235. }
  236. go srv.HandleConn(conn)
  237. }
  238. }
  239. func (srv *Server) HandleConn(newConn net.Conn) {
  240. ctx, cancel := newContext(srv)
  241. if srv.ConnCallback != nil {
  242. cbConn := srv.ConnCallback(ctx, newConn)
  243. if cbConn == nil {
  244. newConn.Close()
  245. return
  246. }
  247. newConn = cbConn
  248. }
  249. conn := &serverConn{
  250. Conn: newConn,
  251. idleTimeout: srv.IdleTimeout,
  252. closeCanceler: cancel,
  253. }
  254. if srv.MaxTimeout > 0 {
  255. conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
  256. }
  257. defer conn.Close()
  258. sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
  259. if err != nil {
  260. if srv.ConnectionFailedCallback != nil {
  261. srv.ConnectionFailedCallback(conn, err)
  262. }
  263. return
  264. }
  265. srv.trackConn(sshConn, true)
  266. defer srv.trackConn(sshConn, false)
  267. ctx.SetValue(ContextKeyConn, sshConn)
  268. applyConnMetadata(ctx, sshConn)
  269. //go gossh.DiscardRequests(reqs)
  270. go srv.handleRequests(ctx, reqs)
  271. for ch := range chans {
  272. handler := srv.ChannelHandlers[ch.ChannelType()]
  273. if handler == nil {
  274. handler = srv.ChannelHandlers["default"]
  275. }
  276. if handler == nil {
  277. ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
  278. continue
  279. }
  280. go handler(srv, sshConn, ch, ctx)
  281. }
  282. }
  283. func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
  284. for req := range in {
  285. handler := srv.RequestHandlers[req.Type]
  286. if handler == nil {
  287. handler = srv.RequestHandlers["default"]
  288. }
  289. if handler == nil {
  290. req.Reply(false, nil)
  291. continue
  292. }
  293. /*reqCtx, cancel := context.WithCancel(ctx)
  294. defer cancel() */
  295. ret, payload := handler(ctx, srv, req)
  296. req.Reply(ret, payload)
  297. }
  298. }
  299. // ListenAndServe listens on the TCP network address srv.Addr and then calls
  300. // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used.
  301. // ListenAndServe always returns a non-nil error.
  302. func (srv *Server) ListenAndServe() error {
  303. addr := srv.Addr
  304. if addr == "" {
  305. addr = ":22"
  306. }
  307. ln, err := net.Listen("tcp", addr)
  308. if err != nil {
  309. return err
  310. }
  311. return srv.Serve(ln)
  312. }
  313. // AddHostKey adds a private key as a host key. If an existing host key exists
  314. // with the same algorithm, it is overwritten. Each server config must have at
  315. // least one host key.
  316. func (srv *Server) AddHostKey(key Signer) {
  317. srv.mu.Lock()
  318. defer srv.mu.Unlock()
  319. // these are later added via AddHostKey on ServerConfig, which performs the
  320. // check for one of every algorithm.
  321. // This check is based on the AddHostKey method from the x/crypto/ssh
  322. // library. This allows us to only keep one active key for each type on a
  323. // server at once. So, if you're dynamically updating keys at runtime, this
  324. // list will not keep growing.
  325. for i, k := range srv.HostSigners {
  326. if k.PublicKey().Type() == key.PublicKey().Type() {
  327. srv.HostSigners[i] = key
  328. return
  329. }
  330. }
  331. srv.HostSigners = append(srv.HostSigners, key)
  332. }
  333. // SetOption runs a functional option against the server.
  334. func (srv *Server) SetOption(option Option) error {
  335. // NOTE: there is a potential race here for any option that doesn't call an
  336. // internal method. We can't actually lock here because if something calls
  337. // (as an example) AddHostKey, it will deadlock.
  338. //srv.mu.Lock()
  339. //defer srv.mu.Unlock()
  340. return option(srv)
  341. }
  342. func (srv *Server) getDoneChan() <-chan struct{} {
  343. srv.mu.Lock()
  344. defer srv.mu.Unlock()
  345. return srv.getDoneChanLocked()
  346. }
  347. func (srv *Server) getDoneChanLocked() chan struct{} {
  348. if srv.doneChan == nil {
  349. srv.doneChan = make(chan struct{})
  350. }
  351. return srv.doneChan
  352. }
  353. func (srv *Server) closeDoneChanLocked() {
  354. ch := srv.getDoneChanLocked()
  355. select {
  356. case <-ch:
  357. // Already closed. Don't close again.
  358. default:
  359. // Safe to close here. We're the only closer, guarded
  360. // by srv.mu.
  361. close(ch)
  362. }
  363. }
  364. func (srv *Server) closeListenersLocked() error {
  365. var err error
  366. for ln := range srv.listeners {
  367. if cerr := ln.Close(); cerr != nil && err == nil {
  368. err = cerr
  369. }
  370. delete(srv.listeners, ln)
  371. }
  372. return err
  373. }
  374. func (srv *Server) trackListener(ln net.Listener, add bool) {
  375. srv.mu.Lock()
  376. defer srv.mu.Unlock()
  377. if srv.listeners == nil {
  378. srv.listeners = make(map[net.Listener]struct{})
  379. }
  380. if add {
  381. // If the *Server is being reused after a previous
  382. // Close or Shutdown, reset its doneChan:
  383. if len(srv.listeners) == 0 && len(srv.conns) == 0 {
  384. srv.doneChan = nil
  385. }
  386. srv.listeners[ln] = struct{}{}
  387. srv.listenerWg.Add(1)
  388. } else {
  389. delete(srv.listeners, ln)
  390. srv.listenerWg.Done()
  391. }
  392. }
  393. func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
  394. srv.mu.Lock()
  395. defer srv.mu.Unlock()
  396. if srv.conns == nil {
  397. srv.conns = make(map[*gossh.ServerConn]struct{})
  398. }
  399. if add {
  400. srv.conns[c] = struct{}{}
  401. srv.connWg.Add(1)
  402. } else {
  403. delete(srv.conns, c)
  404. srv.connWg.Done()
  405. }
  406. }