1
0

ssh.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package outbound
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/base64"
  6. "math/rand"
  7. "net"
  8. "os"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "github.com/sagernet/sing-box/adapter"
  13. "github.com/sagernet/sing-box/common/dialer"
  14. C "github.com/sagernet/sing-box/constant"
  15. "github.com/sagernet/sing-box/log"
  16. "github.com/sagernet/sing-box/option"
  17. "github.com/sagernet/sing/common"
  18. E "github.com/sagernet/sing/common/exceptions"
  19. M "github.com/sagernet/sing/common/metadata"
  20. N "github.com/sagernet/sing/common/network"
  21. "golang.org/x/crypto/ssh"
  22. )
  23. var (
  24. _ adapter.Outbound = (*SSH)(nil)
  25. _ adapter.InterfaceUpdateListener = (*SSH)(nil)
  26. )
  27. type SSH struct {
  28. myOutboundAdapter
  29. ctx context.Context
  30. dialer N.Dialer
  31. serverAddr M.Socksaddr
  32. user string
  33. hostKey []ssh.PublicKey
  34. hostKeyAlgorithms []string
  35. clientVersion string
  36. authMethod []ssh.AuthMethod
  37. clientAccess sync.Mutex
  38. clientConn net.Conn
  39. client *ssh.Client
  40. }
  41. func NewSSH(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (*SSH, error) {
  42. outboundDialer, err := dialer.New(router, options.DialerOptions)
  43. if err != nil {
  44. return nil, err
  45. }
  46. outbound := &SSH{
  47. myOutboundAdapter: myOutboundAdapter{
  48. protocol: C.TypeSSH,
  49. network: []string{N.NetworkTCP},
  50. router: router,
  51. logger: logger,
  52. tag: tag,
  53. dependencies: withDialerDependency(options.DialerOptions),
  54. },
  55. ctx: ctx,
  56. dialer: outboundDialer,
  57. serverAddr: options.ServerOptions.Build(),
  58. user: options.User,
  59. hostKeyAlgorithms: options.HostKeyAlgorithms,
  60. clientVersion: options.ClientVersion,
  61. }
  62. if outbound.serverAddr.Port == 0 {
  63. outbound.serverAddr.Port = 22
  64. }
  65. if outbound.user == "" {
  66. outbound.user = "root"
  67. }
  68. if outbound.clientVersion == "" {
  69. outbound.clientVersion = randomVersion()
  70. }
  71. if options.Password != "" {
  72. outbound.authMethod = append(outbound.authMethod, ssh.Password(options.Password))
  73. }
  74. if len(options.PrivateKey) > 0 || options.PrivateKeyPath != "" {
  75. var privateKey []byte
  76. if len(options.PrivateKey) > 0 {
  77. privateKey = []byte(strings.Join(options.PrivateKey, "\n"))
  78. } else {
  79. var err error
  80. privateKey, err = os.ReadFile(os.ExpandEnv(options.PrivateKeyPath))
  81. if err != nil {
  82. return nil, E.Cause(err, "read private key")
  83. }
  84. }
  85. var signer ssh.Signer
  86. var err error
  87. if options.PrivateKeyPassphrase == "" {
  88. signer, err = ssh.ParsePrivateKey(privateKey)
  89. } else {
  90. signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(options.PrivateKeyPassphrase))
  91. }
  92. if err != nil {
  93. return nil, E.Cause(err, "parse private key")
  94. }
  95. outbound.authMethod = append(outbound.authMethod, ssh.PublicKeys(signer))
  96. }
  97. if len(options.HostKey) > 0 {
  98. for _, hostKey := range options.HostKey {
  99. key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey))
  100. if err != nil {
  101. return nil, E.New("parse host key ", key)
  102. }
  103. outbound.hostKey = append(outbound.hostKey, key)
  104. }
  105. }
  106. return outbound, nil
  107. }
  108. func randomVersion() string {
  109. version := "SSH-2.0-OpenSSH_"
  110. if rand.Intn(2) == 0 {
  111. version += "7." + strconv.Itoa(rand.Intn(10))
  112. } else {
  113. version += "8." + strconv.Itoa(rand.Intn(9))
  114. }
  115. return version
  116. }
  117. func (s *SSH) connect() (*ssh.Client, error) {
  118. if s.client != nil {
  119. return s.client, nil
  120. }
  121. s.clientAccess.Lock()
  122. defer s.clientAccess.Unlock()
  123. if s.client != nil {
  124. return s.client, nil
  125. }
  126. conn, err := s.dialer.DialContext(s.ctx, N.NetworkTCP, s.serverAddr)
  127. if err != nil {
  128. return nil, err
  129. }
  130. config := &ssh.ClientConfig{
  131. User: s.user,
  132. Auth: s.authMethod,
  133. ClientVersion: s.clientVersion,
  134. HostKeyAlgorithms: s.hostKeyAlgorithms,
  135. HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
  136. if len(s.hostKey) == 0 {
  137. return nil
  138. }
  139. serverKey := key.Marshal()
  140. for _, hostKey := range s.hostKey {
  141. if bytes.Equal(serverKey, hostKey.Marshal()) {
  142. return nil
  143. }
  144. }
  145. return E.New("host key mismatch, server send ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey))
  146. },
  147. }
  148. clientConn, chans, reqs, err := ssh.NewClientConn(conn, s.serverAddr.Addr.String(), config)
  149. if err != nil {
  150. conn.Close()
  151. return nil, E.Cause(err, "connect to ssh server")
  152. }
  153. client := ssh.NewClient(clientConn, chans, reqs)
  154. s.clientConn = conn
  155. s.client = client
  156. go func() {
  157. client.Wait()
  158. conn.Close()
  159. s.clientAccess.Lock()
  160. s.client = nil
  161. s.clientConn = nil
  162. s.clientAccess.Unlock()
  163. }()
  164. return client, nil
  165. }
  166. func (s *SSH) InterfaceUpdated() {
  167. common.Close(s.clientConn)
  168. return
  169. }
  170. func (s *SSH) Close() error {
  171. return common.Close(s.clientConn)
  172. }
  173. func (s *SSH) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  174. client, err := s.connect()
  175. if err != nil {
  176. return nil, err
  177. }
  178. return client.Dial(network, destination.String())
  179. }
  180. func (s *SSH) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  181. return nil, os.ErrInvalid
  182. }