dialer.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package http
  2. import (
  3. "context"
  4. gotls "crypto/tls"
  5. "net/http"
  6. "net/url"
  7. "sync"
  8. "time"
  9. "github.com/xtls/xray-core/transport/internet/stat"
  10. "github.com/xtls/xray-core/common"
  11. "github.com/xtls/xray-core/common/buf"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/net/cnc"
  14. "github.com/xtls/xray-core/common/session"
  15. "github.com/xtls/xray-core/transport/internet"
  16. "github.com/xtls/xray-core/transport/internet/tls"
  17. "github.com/xtls/xray-core/transport/pipe"
  18. "golang.org/x/net/http2"
  19. )
  20. type dialerConf struct {
  21. net.Destination
  22. *internet.MemoryStreamConfig
  23. }
  24. var (
  25. globalDialerMap map[dialerConf]*http.Client
  26. globalDialerAccess sync.Mutex
  27. )
  28. func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*http.Client, error) {
  29. globalDialerAccess.Lock()
  30. defer globalDialerAccess.Unlock()
  31. if globalDialerMap == nil {
  32. globalDialerMap = make(map[dialerConf]*http.Client)
  33. }
  34. httpSettings := streamSettings.ProtocolSettings.(*Config)
  35. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  36. if tlsConfig == nil {
  37. return nil, newError("TLS must be enabled for http transport.").AtWarning()
  38. }
  39. sockopt := streamSettings.SocketSettings
  40. if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
  41. return client, nil
  42. }
  43. transport := &http2.Transport{
  44. DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
  45. rawHost, rawPort, err := net.SplitHostPort(addr)
  46. if err != nil {
  47. return nil, err
  48. }
  49. if len(rawPort) == 0 {
  50. rawPort = "443"
  51. }
  52. port, err := net.PortFromString(rawPort)
  53. if err != nil {
  54. return nil, err
  55. }
  56. address := net.ParseAddress(rawHost)
  57. dctx := context.Background()
  58. dctx = session.ContextWithID(dctx, session.IDFromContext(ctx))
  59. dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx))
  60. pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt)
  61. if err != nil {
  62. newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
  63. return nil, err
  64. }
  65. cn := gotls.Client(pconn, tlsConfig)
  66. if err := cn.Handshake(); err != nil {
  67. newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
  68. return nil, err
  69. }
  70. if !tlsConfig.InsecureSkipVerify {
  71. if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
  72. newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
  73. return nil, err
  74. }
  75. }
  76. state := cn.ConnectionState()
  77. if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
  78. return nil, newError("http2: unexpected ALPN protocol " + p + "; want q" + http2.NextProtoTLS).AtError()
  79. }
  80. if !state.NegotiatedProtocolIsMutual {
  81. return nil, newError("http2: could not negotiate protocol mutually").AtError()
  82. }
  83. return cn, nil
  84. },
  85. TLSClientConfig: tlsConfig.GetTLSConfig(tls.WithDestination(dest)),
  86. }
  87. if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
  88. transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
  89. transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
  90. }
  91. client := &http.Client{
  92. Transport: transport,
  93. }
  94. globalDialerMap[dialerConf{dest, streamSettings}] = client
  95. return client, nil
  96. }
  97. // Dial dials a new TCP connection to the given destination.
  98. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
  99. httpSettings := streamSettings.ProtocolSettings.(*Config)
  100. client, err := getHTTPClient(ctx, dest, streamSettings)
  101. if err != nil {
  102. return nil, err
  103. }
  104. opts := pipe.OptionsFromContext(ctx)
  105. preader, pwriter := pipe.New(opts...)
  106. breader := &buf.BufferedReader{Reader: preader}
  107. request := &http.Request{
  108. Method: "PUT",
  109. Host: httpSettings.getRandomHost(),
  110. Body: breader,
  111. URL: &url.URL{
  112. Scheme: "https",
  113. Host: dest.NetAddr(),
  114. Path: httpSettings.getNormalizedPath(),
  115. },
  116. Proto: "HTTP/2",
  117. ProtoMajor: 2,
  118. ProtoMinor: 0,
  119. Header: make(http.Header),
  120. }
  121. // Disable any compression method from server.
  122. request.Header.Set("Accept-Encoding", "identity")
  123. response, err := client.Do(request)
  124. if err != nil {
  125. return nil, newError("failed to dial to ", dest).Base(err).AtWarning()
  126. }
  127. if response.StatusCode != 200 {
  128. return nil, newError("unexpected status", response.StatusCode).AtWarning()
  129. }
  130. bwriter := buf.NewBufferedWriter(pwriter)
  131. common.Must(bwriter.SetBuffered(false))
  132. return cnc.NewConnection(
  133. cnc.ConnectionOutput(response.Body),
  134. cnc.ConnectionInput(bwriter),
  135. cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, response.Body}),
  136. ), nil
  137. }
  138. func init() {
  139. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  140. }