dialer.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package splithttp
  2. import (
  3. "context"
  4. gotls "crypto/tls"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "sync"
  10. "time"
  11. "github.com/quic-go/quic-go"
  12. "github.com/quic-go/quic-go/http3"
  13. "github.com/xtls/xray-core/common"
  14. "github.com/xtls/xray-core/common/buf"
  15. "github.com/xtls/xray-core/common/errors"
  16. "github.com/xtls/xray-core/common/net"
  17. "github.com/xtls/xray-core/common/signal/semaphore"
  18. "github.com/xtls/xray-core/common/uuid"
  19. "github.com/xtls/xray-core/transport/internet"
  20. "github.com/xtls/xray-core/transport/internet/browser_dialer"
  21. "github.com/xtls/xray-core/transport/internet/stat"
  22. "github.com/xtls/xray-core/transport/internet/tls"
  23. "github.com/xtls/xray-core/transport/pipe"
  24. "golang.org/x/net/http2"
  25. )
  26. type dialerConf struct {
  27. net.Destination
  28. *internet.MemoryStreamConfig
  29. }
  30. var (
  31. globalDialerMap map[dialerConf]DialerClient
  32. globalDialerAccess sync.Mutex
  33. )
  34. func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
  35. if browser_dialer.HasBrowserDialer() {
  36. return &BrowserDialerClient{}
  37. }
  38. globalDialerAccess.Lock()
  39. defer globalDialerAccess.Unlock()
  40. if globalDialerMap == nil {
  41. globalDialerMap = make(map[dialerConf]DialerClient)
  42. }
  43. if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
  44. return client
  45. }
  46. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  47. isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
  48. isH3 := tlsConfig != nil && (len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3")
  49. var gotlsConfig *gotls.Config
  50. if tlsConfig != nil {
  51. gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
  52. }
  53. dialContext := func(ctxInner context.Context) (net.Conn, error) {
  54. conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings)
  55. if err != nil {
  56. return nil, err
  57. }
  58. if gotlsConfig != nil {
  59. if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
  60. conn = tls.UClient(conn, gotlsConfig, fingerprint)
  61. if err := conn.(*tls.UConn).HandshakeContext(ctxInner); err != nil {
  62. return nil, err
  63. }
  64. } else {
  65. conn = tls.Client(conn, gotlsConfig)
  66. }
  67. }
  68. return conn, nil
  69. }
  70. var downloadTransport http.RoundTripper
  71. var uploadTransport http.RoundTripper
  72. if isH3 {
  73. dest.Network = net.Network_UDP
  74. quicConfig := &quic.Config{
  75. HandshakeIdleTimeout: 10 * time.Second,
  76. MaxIdleTimeout: 90 * time.Second,
  77. KeepAlivePeriod: 3 * time.Second,
  78. Allow0RTT: true,
  79. }
  80. roundTripper := &http3.RoundTripper{
  81. TLSClientConfig: gotlsConfig,
  82. QUICConfig: quicConfig,
  83. Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
  84. conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
  85. if err != nil {
  86. return nil, err
  87. }
  88. udpAddr, err := net.ResolveUDPAddr("udp", conn.RemoteAddr().String())
  89. if err != nil {
  90. return nil, err
  91. }
  92. return quic.DialEarly(ctx, conn.(*internet.PacketConnWrapper).Conn.(*net.UDPConn), udpAddr, tlsCfg, cfg)
  93. },
  94. }
  95. downloadTransport = roundTripper
  96. uploadTransport = roundTripper
  97. } else if isH2 {
  98. downloadTransport = &http2.Transport{
  99. DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
  100. return dialContext(ctxInner)
  101. },
  102. IdleConnTimeout: 90 * time.Second,
  103. }
  104. uploadTransport = downloadTransport
  105. } else {
  106. httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
  107. return dialContext(ctxInner)
  108. }
  109. downloadTransport = &http.Transport{
  110. DialTLSContext: httpDialContext,
  111. DialContext: httpDialContext,
  112. IdleConnTimeout: 90 * time.Second,
  113. // chunked transfer download with keepalives is buggy with
  114. // http.Client and our custom dial context.
  115. DisableKeepAlives: true,
  116. }
  117. // we use uploadRawPool for that
  118. uploadTransport = nil
  119. }
  120. client := &DefaultDialerClient{
  121. transportConfig: streamSettings.ProtocolSettings.(*Config),
  122. download: &http.Client{
  123. Transport: downloadTransport,
  124. },
  125. upload: &http.Client{
  126. Transport: uploadTransport,
  127. },
  128. isH2: isH2,
  129. isH3: isH3,
  130. uploadRawPool: &sync.Pool{},
  131. dialUploadConn: dialContext,
  132. }
  133. globalDialerMap[dialerConf{dest, streamSettings}] = client
  134. return client
  135. }
  136. func init() {
  137. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  138. }
  139. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
  140. errors.LogInfo(ctx, "dialing splithttp to ", dest)
  141. var requestURL url.URL
  142. transportConfiguration := streamSettings.ProtocolSettings.(*Config)
  143. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  144. maxConcurrentUploads := transportConfiguration.GetNormalizedMaxConcurrentUploads()
  145. maxUploadSize := transportConfiguration.GetNormalizedMaxUploadSize()
  146. if tlsConfig != nil {
  147. requestURL.Scheme = "https"
  148. } else {
  149. requestURL.Scheme = "http"
  150. }
  151. requestURL.Host = transportConfiguration.Host
  152. if requestURL.Host == "" {
  153. requestURL.Host = dest.NetAddr()
  154. }
  155. requestURL.Path = transportConfiguration.GetNormalizedPath()
  156. httpClient := getHTTPClient(ctx, dest, streamSettings)
  157. sessionIdUuid := uuid.New()
  158. sessionId := sessionIdUuid.String()
  159. baseURL := requestURL.String() + sessionId
  160. uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize))
  161. go func() {
  162. requestsLimiter := semaphore.New(int(maxConcurrentUploads))
  163. var requestCounter int64
  164. // by offloading the uploads into a buffered pipe, multiple conn.Write
  165. // calls get automatically batched together into larger POST requests.
  166. // without batching, bandwidth is extremely limited.
  167. for {
  168. chunk, err := uploadPipeReader.ReadMultiBuffer()
  169. if err != nil {
  170. break
  171. }
  172. <-requestsLimiter.Wait()
  173. seq := requestCounter
  174. requestCounter += 1
  175. go func() {
  176. defer requestsLimiter.Signal()
  177. err := httpClient.SendUploadRequest(
  178. context.WithoutCancel(ctx),
  179. baseURL+"/"+strconv.FormatInt(seq, 10),
  180. &buf.MultiBufferContainer{MultiBuffer: chunk},
  181. int64(chunk.Len()),
  182. )
  183. if err != nil {
  184. errors.LogInfoInner(ctx, err, "failed to send upload")
  185. uploadPipeReader.Interrupt()
  186. }
  187. }()
  188. }
  189. }()
  190. lazyRawDownload, remoteAddr, localAddr, err := httpClient.OpenDownload(context.WithoutCancel(ctx), baseURL)
  191. if err != nil {
  192. return nil, err
  193. }
  194. lazyDownload := &LazyReader{
  195. CreateReader: func() (io.ReadCloser, error) {
  196. // skip "ooooooooook" response
  197. trashHeader := []byte{0}
  198. for {
  199. _, err := io.ReadFull(lazyRawDownload, trashHeader)
  200. if err != nil {
  201. return nil, errors.New("failed to read initial response").Base(err)
  202. }
  203. if trashHeader[0] == 'k' {
  204. break
  205. }
  206. }
  207. return lazyRawDownload, nil
  208. },
  209. }
  210. // necessary in order to send larger chunks in upload
  211. bufferedUploadPipeWriter := buf.NewBufferedWriter(uploadPipeWriter)
  212. bufferedUploadPipeWriter.SetBuffered(false)
  213. conn := splitConn{
  214. writer: bufferedUploadPipeWriter,
  215. reader: lazyDownload,
  216. remoteAddr: remoteAddr,
  217. localAddr: localAddr,
  218. }
  219. return stat.Connection(&conn), nil
  220. }