outbound.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. package ssh
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/base64"
  6. "math/rand"
  7. "net"
  8. "os"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "github.com/sagernet/sing-box/adapter"
  13. "github.com/sagernet/sing-box/adapter/outbound"
  14. "github.com/sagernet/sing-box/common/dialer"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/log"
  17. "github.com/sagernet/sing-box/option"
  18. "github.com/sagernet/sing/common"
  19. E "github.com/sagernet/sing/common/exceptions"
  20. "github.com/sagernet/sing/common/logger"
  21. M "github.com/sagernet/sing/common/metadata"
  22. N "github.com/sagernet/sing/common/network"
  23. "golang.org/x/crypto/ssh"
  24. )
  25. func RegisterOutbound(registry *outbound.Registry) {
  26. outbound.Register[option.SSHOutboundOptions](registry, C.TypeSSH, NewOutbound)
  27. }
  28. var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
  29. type Outbound struct {
  30. outbound.Adapter
  31. ctx context.Context
  32. logger logger.ContextLogger
  33. dialer N.Dialer
  34. serverAddr M.Socksaddr
  35. user string
  36. hostKey []ssh.PublicKey
  37. hostKeyAlgorithms []string
  38. clientVersion string
  39. authMethod []ssh.AuthMethod
  40. clientAccess sync.Mutex
  41. clientConn net.Conn
  42. client *ssh.Client
  43. }
  44. func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (adapter.Outbound, error) {
  45. outboundDialer, err := dialer.New(ctx, options.DialerOptions)
  46. if err != nil {
  47. return nil, err
  48. }
  49. outbound := &Outbound{
  50. Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, []string{N.NetworkTCP}, tag, options.DialerOptions),
  51. ctx: ctx,
  52. logger: logger,
  53. dialer: outboundDialer,
  54. serverAddr: options.ServerOptions.Build(),
  55. user: options.User,
  56. hostKeyAlgorithms: options.HostKeyAlgorithms,
  57. clientVersion: options.ClientVersion,
  58. }
  59. if outbound.serverAddr.Port == 0 {
  60. outbound.serverAddr.Port = 22
  61. }
  62. if outbound.user == "" {
  63. outbound.user = "root"
  64. }
  65. if outbound.clientVersion == "" {
  66. outbound.clientVersion = randomVersion()
  67. }
  68. if options.Password != "" {
  69. outbound.authMethod = append(outbound.authMethod, ssh.Password(options.Password))
  70. }
  71. if len(options.PrivateKey) > 0 || options.PrivateKeyPath != "" {
  72. var privateKey []byte
  73. if len(options.PrivateKey) > 0 {
  74. privateKey = []byte(strings.Join(options.PrivateKey, "\n"))
  75. } else {
  76. var err error
  77. privateKey, err = os.ReadFile(os.ExpandEnv(options.PrivateKeyPath))
  78. if err != nil {
  79. return nil, E.Cause(err, "read private key")
  80. }
  81. }
  82. var signer ssh.Signer
  83. var err error
  84. if options.PrivateKeyPassphrase == "" {
  85. signer, err = ssh.ParsePrivateKey(privateKey)
  86. } else {
  87. signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(options.PrivateKeyPassphrase))
  88. }
  89. if err != nil {
  90. return nil, E.Cause(err, "parse private key")
  91. }
  92. outbound.authMethod = append(outbound.authMethod, ssh.PublicKeys(signer))
  93. }
  94. if len(options.HostKey) > 0 {
  95. for _, hostKey := range options.HostKey {
  96. key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey))
  97. if err != nil {
  98. return nil, E.New("parse host key ", key)
  99. }
  100. outbound.hostKey = append(outbound.hostKey, key)
  101. }
  102. }
  103. return outbound, nil
  104. }
  105. func randomVersion() string {
  106. version := "SSH-2.0-OpenSSH_"
  107. if rand.Intn(2) == 0 {
  108. version += "7." + strconv.Itoa(rand.Intn(10))
  109. } else {
  110. version += "8." + strconv.Itoa(rand.Intn(9))
  111. }
  112. return version
  113. }
  114. func (s *Outbound) connect() (*ssh.Client, error) {
  115. if s.client != nil {
  116. return s.client, nil
  117. }
  118. s.clientAccess.Lock()
  119. defer s.clientAccess.Unlock()
  120. if s.client != nil {
  121. return s.client, nil
  122. }
  123. conn, err := s.dialer.DialContext(s.ctx, N.NetworkTCP, s.serverAddr)
  124. if err != nil {
  125. return nil, err
  126. }
  127. config := &ssh.ClientConfig{
  128. User: s.user,
  129. Auth: s.authMethod,
  130. ClientVersion: s.clientVersion,
  131. HostKeyAlgorithms: s.hostKeyAlgorithms,
  132. HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
  133. if len(s.hostKey) == 0 {
  134. return nil
  135. }
  136. serverKey := key.Marshal()
  137. for _, hostKey := range s.hostKey {
  138. if bytes.Equal(serverKey, hostKey.Marshal()) {
  139. return nil
  140. }
  141. }
  142. return E.New("host key mismatch, server send ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey))
  143. },
  144. }
  145. clientConn, chans, reqs, err := ssh.NewClientConn(conn, s.serverAddr.Addr.String(), config)
  146. if err != nil {
  147. conn.Close()
  148. return nil, E.Cause(err, "connect to ssh server")
  149. }
  150. client := ssh.NewClient(clientConn, chans, reqs)
  151. s.clientConn = conn
  152. s.client = client
  153. go func() {
  154. client.Wait()
  155. conn.Close()
  156. s.clientAccess.Lock()
  157. s.client = nil
  158. s.clientConn = nil
  159. s.clientAccess.Unlock()
  160. }()
  161. return client, nil
  162. }
  163. func (s *Outbound) InterfaceUpdated() {
  164. common.Close(s.clientConn)
  165. return
  166. }
  167. func (s *Outbound) Close() error {
  168. return common.Close(s.clientConn)
  169. }
  170. func (s *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  171. client, err := s.connect()
  172. if err != nil {
  173. return nil, err
  174. }
  175. return client.Dial(network, destination.String())
  176. }
  177. func (s *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  178. return nil, os.ErrInvalid
  179. }