dialer.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. package splithttp
  2. import (
  3. "context"
  4. gotls "crypto/tls"
  5. "net/http"
  6. "net/url"
  7. "strconv"
  8. "sync"
  9. "time"
  10. "github.com/quic-go/quic-go"
  11. "github.com/quic-go/quic-go/http3"
  12. "github.com/xtls/xray-core/common"
  13. "github.com/xtls/xray-core/common/buf"
  14. "github.com/xtls/xray-core/common/errors"
  15. "github.com/xtls/xray-core/common/net"
  16. "github.com/xtls/xray-core/common/signal/semaphore"
  17. "github.com/xtls/xray-core/common/uuid"
  18. "github.com/xtls/xray-core/transport/internet"
  19. "github.com/xtls/xray-core/transport/internet/browser_dialer"
  20. "github.com/xtls/xray-core/transport/internet/reality"
  21. "github.com/xtls/xray-core/transport/internet/stat"
  22. "github.com/xtls/xray-core/transport/internet/tls"
  23. "github.com/xtls/xray-core/transport/pipe"
  24. "golang.org/x/net/http2"
  25. )
  26. // defines the maximum time an idle TCP session can survive in the tunnel, so
  27. // it should be consistent across HTTP versions and with other transports.
  28. const connIdleTimeout = 300 * time.Second
  29. // consistent with quic-go
  30. const h3KeepalivePeriod = 10 * time.Second
  31. // consistent with chrome
  32. const h2KeepalivePeriod = 45 * time.Second
  33. type dialerConf struct {
  34. net.Destination
  35. *internet.MemoryStreamConfig
  36. }
  37. var (
  38. globalDialerMap map[dialerConf]*muxManager
  39. globalDialerAccess sync.Mutex
  40. )
  41. func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *muxResource) {
  42. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  43. if browser_dialer.HasBrowserDialer() && realityConfig != nil {
  44. return &BrowserDialerClient{}, nil
  45. }
  46. globalDialerAccess.Lock()
  47. defer globalDialerAccess.Unlock()
  48. if globalDialerMap == nil {
  49. globalDialerMap = make(map[dialerConf]*muxManager)
  50. }
  51. key := dialerConf{dest, streamSettings}
  52. muxManager, found := globalDialerMap[key]
  53. if !found {
  54. transportConfig := streamSettings.ProtocolSettings.(*Config)
  55. var mux Multiplexing
  56. if transportConfig.Xmux != nil {
  57. mux = *transportConfig.Xmux
  58. }
  59. muxManager = NewMuxManager(mux, func() interface{} {
  60. return createHTTPClient(dest, streamSettings)
  61. })
  62. globalDialerMap[key] = muxManager
  63. }
  64. res := muxManager.GetResource(ctx)
  65. return res.Resource.(DialerClient), res
  66. }
  67. func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
  68. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  69. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  70. isH2 := false
  71. isH3 := false
  72. if tlsConfig != nil {
  73. isH2 = !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
  74. isH3 = len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3"
  75. } else if realityConfig != nil {
  76. isH2 = true
  77. isH3 = false
  78. }
  79. if isH3 {
  80. dest.Network = net.Network_UDP
  81. }
  82. var gotlsConfig *gotls.Config
  83. if tlsConfig != nil {
  84. gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
  85. }
  86. transportConfig := streamSettings.ProtocolSettings.(*Config)
  87. dialContext := func(ctxInner context.Context) (net.Conn, error) {
  88. conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings)
  89. if err != nil {
  90. return nil, err
  91. }
  92. if realityConfig != nil {
  93. return reality.UClient(conn, realityConfig, ctxInner, dest)
  94. }
  95. if gotlsConfig != nil {
  96. if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
  97. conn = tls.UClient(conn, gotlsConfig, fingerprint)
  98. if err := conn.(*tls.UConn).HandshakeContext(ctxInner); err != nil {
  99. return nil, err
  100. }
  101. } else {
  102. conn = tls.Client(conn, gotlsConfig)
  103. }
  104. }
  105. return conn, nil
  106. }
  107. var transport http.RoundTripper
  108. if isH3 {
  109. quicConfig := &quic.Config{
  110. MaxIdleTimeout: connIdleTimeout,
  111. // these two are defaults of quic-go/http3. the default of quic-go (no
  112. // http3) is different, so it is hardcoded here for clarity.
  113. // https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
  114. MaxIncomingStreams: -1,
  115. KeepAlivePeriod: h3KeepalivePeriod,
  116. }
  117. transport = &http3.RoundTripper{
  118. QUICConfig: quicConfig,
  119. TLSClientConfig: gotlsConfig,
  120. Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
  121. conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
  122. if err != nil {
  123. return nil, err
  124. }
  125. var udpConn net.PacketConn
  126. var udpAddr *net.UDPAddr
  127. switch c := conn.(type) {
  128. case *internet.PacketConnWrapper:
  129. var ok bool
  130. udpConn, ok = c.Conn.(*net.UDPConn)
  131. if !ok {
  132. return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
  133. }
  134. udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
  135. if err != nil {
  136. return nil, err
  137. }
  138. case *net.UDPConn:
  139. udpConn = c
  140. udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
  141. if err != nil {
  142. return nil, err
  143. }
  144. default:
  145. udpConn = &internet.FakePacketConn{c}
  146. udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
  147. if err != nil {
  148. return nil, err
  149. }
  150. }
  151. return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
  152. },
  153. }
  154. } else if isH2 {
  155. transport = &http2.Transport{
  156. DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
  157. return dialContext(ctxInner)
  158. },
  159. IdleConnTimeout: connIdleTimeout,
  160. ReadIdleTimeout: h2KeepalivePeriod,
  161. }
  162. } else {
  163. httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
  164. return dialContext(ctxInner)
  165. }
  166. transport = &http.Transport{
  167. DialTLSContext: httpDialContext,
  168. DialContext: httpDialContext,
  169. IdleConnTimeout: connIdleTimeout,
  170. // chunked transfer download with keepalives is buggy with
  171. // http.Client and our custom dial context.
  172. DisableKeepAlives: true,
  173. }
  174. }
  175. client := &DefaultDialerClient{
  176. transportConfig: transportConfig,
  177. client: &http.Client{
  178. Transport: transport,
  179. },
  180. isH2: isH2,
  181. isH3: isH3,
  182. uploadRawPool: &sync.Pool{},
  183. dialUploadConn: dialContext,
  184. }
  185. return client
  186. }
  187. func init() {
  188. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  189. }
  190. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
  191. errors.LogInfo(ctx, "dialing splithttp to ", dest)
  192. var requestURL url.URL
  193. transportConfiguration := streamSettings.ProtocolSettings.(*Config)
  194. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  195. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  196. scMaxConcurrentPosts := transportConfiguration.GetNormalizedScMaxConcurrentPosts()
  197. scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes()
  198. scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs()
  199. if tlsConfig != nil || realityConfig != nil {
  200. requestURL.Scheme = "https"
  201. } else {
  202. requestURL.Scheme = "http"
  203. }
  204. requestURL.Host = transportConfiguration.Host
  205. if requestURL.Host == "" {
  206. requestURL.Host = dest.NetAddr()
  207. }
  208. sessionIdUuid := uuid.New()
  209. requestURL.Path = transportConfiguration.GetNormalizedPath() + sessionIdUuid.String()
  210. requestURL.RawQuery = transportConfiguration.GetNormalizedQuery()
  211. httpClient, muxRes := getHTTPClient(ctx, dest, streamSettings)
  212. httpClient2 := httpClient
  213. requestURL2 := requestURL
  214. var muxRes2 *muxResource
  215. if transportConfiguration.DownloadSettings != nil {
  216. globalDialerAccess.Lock()
  217. if streamSettings.DownloadSettings == nil {
  218. streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig)
  219. }
  220. globalDialerAccess.Unlock()
  221. memory2 := streamSettings.DownloadSettings
  222. httpClient2, muxRes2 = getHTTPClient(ctx, *memory2.Destination, memory2) // just panic
  223. if tls.ConfigFromStreamSettings(memory2) != nil || reality.ConfigFromStreamSettings(memory2) != nil {
  224. requestURL2.Scheme = "https"
  225. } else {
  226. requestURL2.Scheme = "http"
  227. }
  228. config2 := memory2.ProtocolSettings.(*Config)
  229. requestURL2.Host = config2.Host
  230. if requestURL2.Host == "" {
  231. requestURL2.Host = memory2.Destination.NetAddr()
  232. }
  233. requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String()
  234. requestURL2.RawQuery = config2.GetNormalizedQuery()
  235. }
  236. reader, remoteAddr, localAddr, err := httpClient2.OpenDownload(context.WithoutCancel(ctx), requestURL2.String())
  237. if err != nil {
  238. return nil, err
  239. }
  240. if muxRes != nil {
  241. muxRes.OpenRequests.Add(1)
  242. }
  243. if muxRes2 != nil {
  244. muxRes2.OpenRequests.Add(1)
  245. }
  246. closed := false
  247. conn := splitConn{
  248. writer: nil,
  249. reader: reader,
  250. remoteAddr: remoteAddr,
  251. localAddr: localAddr,
  252. onClose: func() {
  253. if closed {
  254. return
  255. }
  256. closed = true
  257. if muxRes != nil {
  258. muxRes.OpenRequests.Add(-1)
  259. }
  260. if muxRes2 != nil {
  261. muxRes2.OpenRequests.Add(-1)
  262. }
  263. },
  264. }
  265. mode := transportConfiguration.Mode
  266. if mode == "auto" {
  267. mode = "packet-up"
  268. if (tlsConfig != nil && len(tlsConfig.NextProtocol) != 1) || realityConfig != nil {
  269. mode = "stream-up"
  270. }
  271. }
  272. errors.LogInfo(ctx, "XHTTP is using mode: "+mode)
  273. if mode == "stream-up" {
  274. conn.writer = httpClient.OpenUpload(ctx, requestURL.String())
  275. return stat.Connection(&conn), nil
  276. }
  277. maxUploadSize := scMaxEachPostBytes.roll()
  278. // WithSizeLimit(0) will still allow single bytes to pass, and a lot of
  279. // code relies on this behavior. Subtract 1 so that together with
  280. // uploadWriter wrapper, exact size limits can be enforced
  281. // uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - 1))
  282. uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - buf.Size))
  283. conn.writer = uploadWriter{
  284. uploadPipeWriter,
  285. maxUploadSize,
  286. }
  287. go func() {
  288. requestsLimiter := semaphore.New(int(scMaxConcurrentPosts.roll()))
  289. var requestCounter int64
  290. lastWrite := time.Now()
  291. // by offloading the uploads into a buffered pipe, multiple conn.Write
  292. // calls get automatically batched together into larger POST requests.
  293. // without batching, bandwidth is extremely limited.
  294. for {
  295. chunk, err := uploadPipeReader.ReadMultiBuffer()
  296. if err != nil {
  297. break
  298. }
  299. <-requestsLimiter.Wait()
  300. seq := requestCounter
  301. requestCounter += 1
  302. go func() {
  303. defer requestsLimiter.Signal()
  304. // this intentionally makes a shallow-copy of the struct so we
  305. // can reassign Path (potentially concurrently)
  306. url := requestURL
  307. url.Path += "/" + strconv.FormatInt(seq, 10)
  308. // reassign query to get different padding
  309. url.RawQuery = transportConfiguration.GetNormalizedQuery()
  310. err := httpClient.SendUploadRequest(
  311. context.WithoutCancel(ctx),
  312. url.String(),
  313. &buf.MultiBufferContainer{MultiBuffer: chunk},
  314. int64(chunk.Len()),
  315. )
  316. if err != nil {
  317. errors.LogInfoInner(ctx, err, "failed to send upload")
  318. uploadPipeReader.Interrupt()
  319. }
  320. }()
  321. if scMinPostsIntervalMs.From > 0 {
  322. roll := time.Duration(scMinPostsIntervalMs.roll()) * time.Millisecond
  323. if time.Since(lastWrite) < roll {
  324. time.Sleep(roll)
  325. }
  326. lastWrite = time.Now()
  327. }
  328. }
  329. }()
  330. return stat.Connection(&conn), nil
  331. }
  332. // A wrapper around pipe that ensures the size limit is exactly honored.
  333. //
  334. // The MultiBuffer pipe accepts any single WriteMultiBuffer call even if that
  335. // single MultiBuffer exceeds the size limit, and then starts blocking on the
  336. // next WriteMultiBuffer call. This means that ReadMultiBuffer can return more
  337. // bytes than the size limit. We work around this by splitting a potentially
  338. // too large write up into multiple.
  339. type uploadWriter struct {
  340. *pipe.Writer
  341. maxLen int32
  342. }
  343. func (w uploadWriter) Write(b []byte) (int, error) {
  344. /*
  345. capacity := int(w.maxLen - w.Len())
  346. if capacity > 0 && capacity < len(b) {
  347. b = b[:capacity]
  348. }
  349. */
  350. buffer := buf.New()
  351. n, err := buffer.Write(b)
  352. if err != nil {
  353. return 0, err
  354. }
  355. err = w.WriteMultiBuffer([]*buf.Buffer{buffer})
  356. if err != nil {
  357. return 0, err
  358. }
  359. return n, nil
  360. }