outbound.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package ssh
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/base64"
  6. "math/rand"
  7. "net"
  8. "os"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. "github.com/sagernet/sing-box/adapter"
  14. "github.com/sagernet/sing-box/adapter/outbound"
  15. "github.com/sagernet/sing-box/common/dialer"
  16. C "github.com/sagernet/sing-box/constant"
  17. "github.com/sagernet/sing-box/log"
  18. "github.com/sagernet/sing-box/option"
  19. "github.com/sagernet/sing/common"
  20. E "github.com/sagernet/sing/common/exceptions"
  21. "github.com/sagernet/sing/common/logger"
  22. M "github.com/sagernet/sing/common/metadata"
  23. N "github.com/sagernet/sing/common/network"
  24. "golang.org/x/crypto/ssh"
  25. )
  26. func RegisterOutbound(registry *outbound.Registry) {
  27. outbound.Register[option.SSHOutboundOptions](registry, C.TypeSSH, NewOutbound)
  28. }
  29. var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
  30. type Outbound struct {
  31. outbound.Adapter
  32. ctx context.Context
  33. logger logger.ContextLogger
  34. dialer N.Dialer
  35. serverAddr M.Socksaddr
  36. user string
  37. hostKey []ssh.PublicKey
  38. hostKeyAlgorithms []string
  39. clientVersion string
  40. authMethod []ssh.AuthMethod
  41. clientAccess sync.Mutex
  42. clientConn net.Conn
  43. client *ssh.Client
  44. }
  45. func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (adapter.Outbound, error) {
  46. outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
  47. if err != nil {
  48. return nil, err
  49. }
  50. outbound := &Outbound{
  51. Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, tag, []string{N.NetworkTCP}, options.DialerOptions),
  52. ctx: ctx,
  53. logger: logger,
  54. dialer: outboundDialer,
  55. serverAddr: options.ServerOptions.Build(),
  56. user: options.User,
  57. hostKeyAlgorithms: options.HostKeyAlgorithms,
  58. clientVersion: options.ClientVersion,
  59. }
  60. if outbound.serverAddr.Port == 0 {
  61. outbound.serverAddr.Port = 22
  62. }
  63. if outbound.user == "" {
  64. outbound.user = "root"
  65. }
  66. if outbound.clientVersion == "" {
  67. outbound.clientVersion = randomVersion()
  68. }
  69. if options.Password != "" {
  70. outbound.authMethod = append(outbound.authMethod, ssh.Password(options.Password))
  71. }
  72. if len(options.PrivateKey) > 0 || options.PrivateKeyPath != "" {
  73. var privateKey []byte
  74. if len(options.PrivateKey) > 0 {
  75. privateKey = []byte(strings.Join(options.PrivateKey, "\n"))
  76. } else {
  77. var err error
  78. privateKey, err = os.ReadFile(os.ExpandEnv(options.PrivateKeyPath))
  79. if err != nil {
  80. return nil, E.Cause(err, "read private key")
  81. }
  82. }
  83. var signer ssh.Signer
  84. var err error
  85. if options.PrivateKeyPassphrase == "" {
  86. signer, err = ssh.ParsePrivateKey(privateKey)
  87. } else {
  88. signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(options.PrivateKeyPassphrase))
  89. }
  90. if err != nil {
  91. return nil, E.Cause(err, "parse private key")
  92. }
  93. outbound.authMethod = append(outbound.authMethod, ssh.PublicKeys(signer))
  94. }
  95. if len(options.HostKey) > 0 {
  96. for _, hostKey := range options.HostKey {
  97. key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey))
  98. if err != nil {
  99. return nil, E.New("parse host key ", key)
  100. }
  101. outbound.hostKey = append(outbound.hostKey, key)
  102. }
  103. }
  104. return outbound, nil
  105. }
  106. func randomVersion() string {
  107. version := "SSH-2.0-OpenSSH_"
  108. if rand.Intn(2) == 0 {
  109. version += "7." + strconv.Itoa(rand.Intn(10))
  110. } else {
  111. version += "8." + strconv.Itoa(rand.Intn(9))
  112. }
  113. return version
  114. }
  115. func (s *Outbound) connect() (*ssh.Client, error) {
  116. if s.client != nil {
  117. return s.client, nil
  118. }
  119. s.clientAccess.Lock()
  120. defer s.clientAccess.Unlock()
  121. if s.client != nil {
  122. return s.client, nil
  123. }
  124. conn, err := s.dialer.DialContext(s.ctx, N.NetworkTCP, s.serverAddr)
  125. if err != nil {
  126. return nil, err
  127. }
  128. config := &ssh.ClientConfig{
  129. User: s.user,
  130. Auth: s.authMethod,
  131. ClientVersion: s.clientVersion,
  132. HostKeyAlgorithms: s.hostKeyAlgorithms,
  133. HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
  134. if len(s.hostKey) == 0 {
  135. return nil
  136. }
  137. serverKey := key.Marshal()
  138. for _, hostKey := range s.hostKey {
  139. if bytes.Equal(serverKey, hostKey.Marshal()) {
  140. return nil
  141. }
  142. }
  143. return E.New("host key mismatch, server send ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey))
  144. },
  145. }
  146. clientConn, chans, reqs, err := ssh.NewClientConn(conn, s.serverAddr.Addr.String(), config)
  147. if err != nil {
  148. conn.Close()
  149. return nil, E.Cause(err, "connect to ssh server")
  150. }
  151. client := ssh.NewClient(clientConn, chans, reqs)
  152. s.clientConn = conn
  153. s.client = client
  154. go func() {
  155. client.Wait()
  156. conn.Close()
  157. s.clientAccess.Lock()
  158. s.client = nil
  159. s.clientConn = nil
  160. s.clientAccess.Unlock()
  161. }()
  162. return client, nil
  163. }
  164. func (s *Outbound) InterfaceUpdated() {
  165. common.Close(s.clientConn)
  166. }
  167. func (s *Outbound) Close() error {
  168. return common.Close(s.clientConn)
  169. }
  170. func (s *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  171. client, err := s.connect()
  172. if err != nil {
  173. return nil, err
  174. }
  175. conn, err := client.Dial(network, destination.String())
  176. if err != nil {
  177. return nil, err
  178. }
  179. return &chanConnWrapper{Conn: conn}, nil
  180. }
  181. func (s *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  182. return nil, os.ErrInvalid
  183. }
  184. type chanConnWrapper struct {
  185. net.Conn
  186. }
  187. func (c *chanConnWrapper) SetDeadline(t time.Time) error {
  188. return os.ErrInvalid
  189. }
  190. func (c *chanConnWrapper) SetReadDeadline(t time.Time) error {
  191. return os.ErrInvalid
  192. }
  193. func (c *chanConnWrapper) SetWriteDeadline(t time.Time) error {
  194. return os.ErrInvalid
  195. }