dialer.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package http
  2. import (
  3. "context"
  4. gotls "crypto/tls"
  5. "net/http"
  6. "net/url"
  7. "sync"
  8. "time"
  9. "github.com/xtls/xray-core/common"
  10. "github.com/xtls/xray-core/common/buf"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/net/cnc"
  13. "github.com/xtls/xray-core/common/session"
  14. "github.com/xtls/xray-core/transport/internet"
  15. "github.com/xtls/xray-core/transport/internet/stat"
  16. "github.com/xtls/xray-core/transport/internet/tls"
  17. "github.com/xtls/xray-core/transport/pipe"
  18. "golang.org/x/net/http2"
  19. )
  20. type dialerConf struct {
  21. net.Destination
  22. *internet.MemoryStreamConfig
  23. }
  24. var (
  25. globalDialerMap map[dialerConf]*http.Client
  26. globalDialerAccess sync.Mutex
  27. )
  28. func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*http.Client, error) {
  29. globalDialerAccess.Lock()
  30. defer globalDialerAccess.Unlock()
  31. if globalDialerMap == nil {
  32. globalDialerMap = make(map[dialerConf]*http.Client)
  33. }
  34. httpSettings := streamSettings.ProtocolSettings.(*Config)
  35. tlsConfigs := tls.ConfigFromStreamSettings(streamSettings)
  36. if tlsConfigs == nil {
  37. return nil, newError("TLS must be enabled for http transport.").AtWarning()
  38. }
  39. sockopt := streamSettings.SocketSettings
  40. if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
  41. return client, nil
  42. }
  43. transport := &http2.Transport{
  44. DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
  45. rawHost, rawPort, err := net.SplitHostPort(addr)
  46. if err != nil {
  47. return nil, err
  48. }
  49. if len(rawPort) == 0 {
  50. rawPort = "443"
  51. }
  52. port, err := net.PortFromString(rawPort)
  53. if err != nil {
  54. return nil, err
  55. }
  56. address := net.ParseAddress(rawHost)
  57. dctx := context.Background()
  58. dctx = session.ContextWithID(dctx, session.IDFromContext(ctx))
  59. dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx))
  60. pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt)
  61. if err != nil {
  62. newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
  63. return nil, err
  64. }
  65. var cn tls.Interface
  66. if fingerprint, ok := tls.Fingerprints[tlsConfigs.Fingerprint]; ok {
  67. cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
  68. } else {
  69. cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
  70. }
  71. if err := cn.Handshake(); err != nil {
  72. newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
  73. return nil, err
  74. }
  75. if !tlsConfig.InsecureSkipVerify {
  76. if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
  77. newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
  78. return nil, err
  79. }
  80. }
  81. negotiatedProtocol, negotiatedProtocolIsMutual := cn.NegotiatedProtocol()
  82. if negotiatedProtocol != http2.NextProtoTLS {
  83. return nil, newError("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
  84. }
  85. if !negotiatedProtocolIsMutual {
  86. return nil, newError("http2: could not negotiate protocol mutually").AtError()
  87. }
  88. return cn, nil
  89. },
  90. TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)),
  91. }
  92. if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
  93. transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
  94. transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
  95. }
  96. client := &http.Client{
  97. Transport: transport,
  98. }
  99. globalDialerMap[dialerConf{dest, streamSettings}] = client
  100. return client, nil
  101. }
  102. // Dial dials a new TCP connection to the given destination.
  103. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
  104. httpSettings := streamSettings.ProtocolSettings.(*Config)
  105. client, err := getHTTPClient(ctx, dest, streamSettings)
  106. if err != nil {
  107. return nil, err
  108. }
  109. opts := pipe.OptionsFromContext(ctx)
  110. preader, pwriter := pipe.New(opts...)
  111. breader := &buf.BufferedReader{Reader: preader}
  112. httpMethod := "PUT"
  113. if httpSettings.Method != "" {
  114. httpMethod = httpSettings.Method
  115. }
  116. httpHeaders := make(http.Header)
  117. for _, httpHeader := range httpSettings.Header {
  118. for _, httpHeaderValue := range httpHeader.Value {
  119. httpHeaders.Set(httpHeader.Name, httpHeaderValue)
  120. }
  121. }
  122. request := &http.Request{
  123. Method: httpMethod,
  124. Host: httpSettings.getRandomHost(),
  125. Body: breader,
  126. URL: &url.URL{
  127. Scheme: "https",
  128. Host: dest.NetAddr(),
  129. Path: httpSettings.getNormalizedPath(),
  130. },
  131. Proto: "HTTP/2",
  132. ProtoMajor: 2,
  133. ProtoMinor: 0,
  134. Header: httpHeaders,
  135. }
  136. // Disable any compression method from server.
  137. request.Header.Set("Accept-Encoding", "identity")
  138. response, err := client.Do(request)
  139. if err != nil {
  140. return nil, newError("failed to dial to ", dest).Base(err).AtWarning()
  141. }
  142. if response.StatusCode != 200 {
  143. return nil, newError("unexpected status", response.StatusCode).AtWarning()
  144. }
  145. bwriter := buf.NewBufferedWriter(pwriter)
  146. common.Must(bwriter.SetBuffered(false))
  147. return cnc.NewConnection(
  148. cnc.ConnectionOutput(response.Body),
  149. cnc.ConnectionInput(bwriter),
  150. cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, response.Body}),
  151. ), nil
  152. }
  153. func init() {
  154. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  155. }