dialer.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. package splithttp
  2. import (
  3. "context"
  4. gotls "crypto/tls"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/http/httptrace"
  9. "net/url"
  10. "strconv"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. "github.com/quic-go/quic-go"
  15. "github.com/quic-go/quic-go/http3"
  16. "github.com/xtls/xray-core/common"
  17. "github.com/xtls/xray-core/common/buf"
  18. "github.com/xtls/xray-core/common/errors"
  19. "github.com/xtls/xray-core/common/net"
  20. "github.com/xtls/xray-core/common/signal/done"
  21. "github.com/xtls/xray-core/common/uuid"
  22. "github.com/xtls/xray-core/transport/internet"
  23. "github.com/xtls/xray-core/transport/internet/browser_dialer"
  24. "github.com/xtls/xray-core/transport/internet/reality"
  25. "github.com/xtls/xray-core/transport/internet/stat"
  26. "github.com/xtls/xray-core/transport/internet/tls"
  27. "github.com/xtls/xray-core/transport/pipe"
  28. "golang.org/x/net/http2"
  29. )
  30. // defines the maximum time an idle TCP session can survive in the tunnel, so
  31. // it should be consistent across HTTP versions and with other transports.
  32. const connIdleTimeout = 300 * time.Second
  33. // consistent with quic-go
  34. const quicgoH3KeepAlivePeriod = 10 * time.Second
  35. // consistent with chrome
  36. const chromeH2KeepAlivePeriod = 45 * time.Second
  37. type dialerConf struct {
  38. net.Destination
  39. *internet.MemoryStreamConfig
  40. }
  41. var (
  42. globalDialerMap map[dialerConf]*XmuxManager
  43. globalDialerAccess sync.Mutex
  44. )
  45. func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) {
  46. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  47. if browser_dialer.HasBrowserDialer() && realityConfig != nil {
  48. return &BrowserDialerClient{}, nil
  49. }
  50. globalDialerAccess.Lock()
  51. defer globalDialerAccess.Unlock()
  52. if globalDialerMap == nil {
  53. globalDialerMap = make(map[dialerConf]*XmuxManager)
  54. }
  55. key := dialerConf{dest, streamSettings}
  56. xmuxManager, found := globalDialerMap[key]
  57. if !found {
  58. transportConfig := streamSettings.ProtocolSettings.(*Config)
  59. var xmuxConfig XmuxConfig
  60. if transportConfig.Xmux != nil {
  61. xmuxConfig = *transportConfig.Xmux
  62. }
  63. xmuxManager = NewXmuxManager(xmuxConfig, func() XmuxConn {
  64. return createHTTPClient(dest, streamSettings)
  65. })
  66. globalDialerMap[key] = xmuxManager
  67. }
  68. xmuxClient := xmuxManager.GetXmuxClient(ctx)
  69. return xmuxClient.XmuxConn.(DialerClient), xmuxClient
  70. }
  71. func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) string {
  72. if realityConfig != nil {
  73. return "2"
  74. }
  75. if tlsConfig == nil {
  76. return "1.1"
  77. }
  78. if len(tlsConfig.NextProtocol) != 1 {
  79. return "2"
  80. }
  81. if tlsConfig.NextProtocol[0] == "http/1.1" {
  82. return "1.1"
  83. }
  84. if tlsConfig.NextProtocol[0] == "h3" {
  85. return "3"
  86. }
  87. return "2"
  88. }
  89. func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
  90. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  91. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  92. httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
  93. if httpVersion == "3" {
  94. dest.Network = net.Network_UDP // better to keep this line
  95. }
  96. var gotlsConfig *gotls.Config
  97. if tlsConfig != nil {
  98. gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
  99. }
  100. transportConfig := streamSettings.ProtocolSettings.(*Config)
  101. dialContext := func(ctxInner context.Context) (net.Conn, error) {
  102. conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings)
  103. if err != nil {
  104. return nil, err
  105. }
  106. if realityConfig != nil {
  107. return reality.UClient(conn, realityConfig, ctxInner, dest)
  108. }
  109. if gotlsConfig != nil {
  110. if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
  111. conn = tls.UClient(conn, gotlsConfig, fingerprint)
  112. if err := conn.(*tls.UConn).HandshakeContext(ctxInner); err != nil {
  113. return nil, err
  114. }
  115. } else {
  116. conn = tls.Client(conn, gotlsConfig)
  117. }
  118. }
  119. return conn, nil
  120. }
  121. var keepAlivePeriod time.Duration
  122. if streamSettings.ProtocolSettings.(*Config).Xmux != nil {
  123. keepAlivePeriod = time.Duration(streamSettings.ProtocolSettings.(*Config).Xmux.HKeepAlivePeriod) * time.Second
  124. }
  125. var transport http.RoundTripper
  126. if httpVersion == "3" {
  127. if keepAlivePeriod == 0 {
  128. keepAlivePeriod = quicgoH3KeepAlivePeriod
  129. }
  130. if keepAlivePeriod < 0 {
  131. keepAlivePeriod = 0
  132. }
  133. quicConfig := &quic.Config{
  134. MaxIdleTimeout: connIdleTimeout,
  135. // these two are defaults of quic-go/http3. the default of quic-go (no
  136. // http3) is different, so it is hardcoded here for clarity.
  137. // https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
  138. MaxIncomingStreams: -1,
  139. KeepAlivePeriod: keepAlivePeriod,
  140. }
  141. transport = &http3.RoundTripper{
  142. QUICConfig: quicConfig,
  143. TLSClientConfig: gotlsConfig,
  144. Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
  145. conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
  146. if err != nil {
  147. return nil, err
  148. }
  149. var udpConn net.PacketConn
  150. var udpAddr *net.UDPAddr
  151. switch c := conn.(type) {
  152. case *internet.PacketConnWrapper:
  153. var ok bool
  154. udpConn, ok = c.Conn.(*net.UDPConn)
  155. if !ok {
  156. return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
  157. }
  158. udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
  159. if err != nil {
  160. return nil, err
  161. }
  162. case *net.UDPConn:
  163. udpConn = c
  164. udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
  165. if err != nil {
  166. return nil, err
  167. }
  168. default:
  169. udpConn = &internet.FakePacketConn{c}
  170. udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
  171. if err != nil {
  172. return nil, err
  173. }
  174. }
  175. return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
  176. },
  177. }
  178. } else if httpVersion == "2" {
  179. if keepAlivePeriod == 0 {
  180. keepAlivePeriod = chromeH2KeepAlivePeriod
  181. }
  182. if keepAlivePeriod < 0 {
  183. keepAlivePeriod = 0
  184. }
  185. transport = &http2.Transport{
  186. DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
  187. return dialContext(ctxInner)
  188. },
  189. IdleConnTimeout: connIdleTimeout,
  190. ReadIdleTimeout: keepAlivePeriod,
  191. }
  192. } else {
  193. httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
  194. return dialContext(ctxInner)
  195. }
  196. transport = &http.Transport{
  197. DialTLSContext: httpDialContext,
  198. DialContext: httpDialContext,
  199. IdleConnTimeout: connIdleTimeout,
  200. // chunked transfer download with KeepAlives is buggy with
  201. // http.Client and our custom dial context.
  202. DisableKeepAlives: true,
  203. }
  204. }
  205. client := &DefaultDialerClient{
  206. transportConfig: transportConfig,
  207. client: &http.Client{
  208. Transport: transport,
  209. },
  210. httpVersion: httpVersion,
  211. uploadRawPool: &sync.Pool{},
  212. dialUploadConn: dialContext,
  213. }
  214. return client
  215. }
  216. func init() {
  217. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  218. }
  219. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
  220. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  221. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  222. httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
  223. if httpVersion == "3" {
  224. dest.Network = net.Network_UDP
  225. }
  226. transportConfiguration := streamSettings.ProtocolSettings.(*Config)
  227. var requestURL url.URL
  228. if tlsConfig != nil || realityConfig != nil {
  229. requestURL.Scheme = "https"
  230. } else {
  231. requestURL.Scheme = "http"
  232. }
  233. requestURL.Host = transportConfiguration.Host
  234. if requestURL.Host == "" && tlsConfig != nil {
  235. requestURL.Host = tlsConfig.ServerName
  236. }
  237. if requestURL.Host == "" && realityConfig != nil {
  238. requestURL.Host = realityConfig.ServerName
  239. }
  240. if requestURL.Host == "" {
  241. requestURL.Host = dest.Address.String()
  242. }
  243. sessionIdUuid := uuid.New()
  244. requestURL.Path = transportConfiguration.GetNormalizedPath() + sessionIdUuid.String()
  245. requestURL.RawQuery = transportConfiguration.GetNormalizedQuery()
  246. httpClient, xmuxClient := getHTTPClient(ctx, dest, streamSettings)
  247. mode := transportConfiguration.Mode
  248. if mode == "" || mode == "auto" {
  249. mode = "packet-up"
  250. if httpVersion == "2" {
  251. mode = "stream-up"
  252. }
  253. if realityConfig != nil && transportConfiguration.DownloadSettings == nil {
  254. mode = "stream-one"
  255. }
  256. }
  257. errors.LogInfo(ctx, fmt.Sprintf("XHTTP is dialing to %s, mode %s, HTTP version %s, host %s", dest, mode, httpVersion, requestURL.Host))
  258. requestURL2 := requestURL
  259. httpClient2 := httpClient
  260. xmuxClient2 := xmuxClient
  261. if transportConfiguration.DownloadSettings != nil {
  262. globalDialerAccess.Lock()
  263. if streamSettings.DownloadSettings == nil {
  264. streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig)
  265. if streamSettings.DownloadSettings.SocketSettings == nil {
  266. streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings
  267. }
  268. }
  269. globalDialerAccess.Unlock()
  270. memory2 := streamSettings.DownloadSettings
  271. dest2 := *memory2.Destination // just panic
  272. tlsConfig2 := tls.ConfigFromStreamSettings(memory2)
  273. realityConfig2 := reality.ConfigFromStreamSettings(memory2)
  274. httpVersion2 := decideHTTPVersion(tlsConfig2, realityConfig2)
  275. if httpVersion2 == "3" {
  276. dest2.Network = net.Network_UDP
  277. }
  278. if tlsConfig2 != nil || realityConfig2 != nil {
  279. requestURL2.Scheme = "https"
  280. } else {
  281. requestURL2.Scheme = "http"
  282. }
  283. config2 := memory2.ProtocolSettings.(*Config)
  284. requestURL2.Host = config2.Host
  285. if requestURL2.Host == "" && tlsConfig2 != nil {
  286. requestURL2.Host = tlsConfig2.ServerName
  287. }
  288. if requestURL2.Host == "" && realityConfig2 != nil {
  289. requestURL2.Host = realityConfig2.ServerName
  290. }
  291. if requestURL2.Host == "" {
  292. requestURL2.Host = dest2.Address.String()
  293. }
  294. requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String()
  295. requestURL2.RawQuery = config2.GetNormalizedQuery()
  296. httpClient2, xmuxClient2 = getHTTPClient(ctx, dest2, memory2)
  297. errors.LogInfo(ctx, fmt.Sprintf("XHTTP is downloading from %s, mode %s, HTTP version %s, host %s", dest2, "stream-down", httpVersion2, requestURL2.Host))
  298. }
  299. var writer io.WriteCloser
  300. var reader io.ReadCloser
  301. var remoteAddr, localAddr net.Addr
  302. var err error
  303. if mode == "stream-one" {
  304. requestURL.Path = transportConfiguration.GetNormalizedPath()
  305. if xmuxClient != nil {
  306. xmuxClient.LeftRequests.Add(-1)
  307. }
  308. writer, reader = httpClient.Open(context.WithoutCancel(ctx), requestURL.String())
  309. remoteAddr = &net.TCPAddr{}
  310. localAddr = &net.TCPAddr{}
  311. } else {
  312. if xmuxClient2 != nil {
  313. xmuxClient2.LeftRequests.Add(-1)
  314. }
  315. reader, remoteAddr, localAddr, err = httpClient2.OpenDownload(context.WithoutCancel(ctx), requestURL2.String())
  316. if err != nil {
  317. return nil, err
  318. }
  319. }
  320. if xmuxClient != nil {
  321. xmuxClient.OpenUsage.Add(1)
  322. }
  323. if xmuxClient2 != nil && xmuxClient2 != xmuxClient {
  324. xmuxClient2.OpenUsage.Add(1)
  325. }
  326. var closed atomic.Int32
  327. conn := splitConn{
  328. writer: writer,
  329. reader: reader,
  330. remoteAddr: remoteAddr,
  331. localAddr: localAddr,
  332. onClose: func() {
  333. if closed.Add(1) > 1 {
  334. return
  335. }
  336. if xmuxClient != nil {
  337. xmuxClient.OpenUsage.Add(-1)
  338. }
  339. if xmuxClient2 != nil && xmuxClient2 != xmuxClient {
  340. xmuxClient2.OpenUsage.Add(-1)
  341. }
  342. },
  343. }
  344. if mode == "stream-one" {
  345. if xmuxClient != nil {
  346. xmuxClient.LeftRequests.Add(-1)
  347. }
  348. return stat.Connection(&conn), nil
  349. }
  350. if mode == "stream-up" {
  351. if xmuxClient != nil {
  352. xmuxClient.LeftRequests.Add(-1)
  353. }
  354. conn.writer = httpClient.OpenUpload(ctx, requestURL.String())
  355. return stat.Connection(&conn), nil
  356. }
  357. scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes()
  358. scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs()
  359. if scMaxEachPostBytes.From <= buf.Size {
  360. panic("`scMaxEachPostBytes` should be bigger than " + strconv.Itoa(buf.Size))
  361. }
  362. maxUploadSize := scMaxEachPostBytes.rand()
  363. // WithSizeLimit(0) will still allow single bytes to pass, and a lot of
  364. // code relies on this behavior. Subtract 1 so that together with
  365. // uploadWriter wrapper, exact size limits can be enforced
  366. // uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - 1))
  367. uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - buf.Size))
  368. conn.writer = uploadWriter{
  369. uploadPipeWriter,
  370. maxUploadSize,
  371. }
  372. go func() {
  373. var seq int64
  374. var lastWrite time.Time
  375. for {
  376. wroteRequest := done.New()
  377. ctx := httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
  378. WroteRequest: func(httptrace.WroteRequestInfo) {
  379. wroteRequest.Close()
  380. },
  381. })
  382. // this intentionally makes a shallow-copy of the struct so we
  383. // can reassign Path (potentially concurrently)
  384. url := requestURL
  385. url.Path += "/" + strconv.FormatInt(seq, 10)
  386. // reassign query to get different padding
  387. url.RawQuery = transportConfiguration.GetNormalizedQuery()
  388. seq += 1
  389. if scMinPostsIntervalMs.From > 0 {
  390. time.Sleep(time.Duration(scMinPostsIntervalMs.rand())*time.Millisecond - time.Since(lastWrite))
  391. }
  392. // by offloading the uploads into a buffered pipe, multiple conn.Write
  393. // calls get automatically batched together into larger POST requests.
  394. // without batching, bandwidth is extremely limited.
  395. chunk, err := uploadPipeReader.ReadMultiBuffer()
  396. if err != nil {
  397. break
  398. }
  399. lastWrite = time.Now()
  400. if xmuxClient != nil && xmuxClient.LeftRequests.Add(-1) <= 0 {
  401. httpClient, xmuxClient = getHTTPClient(ctx, dest, streamSettings)
  402. }
  403. go func() {
  404. err := httpClient.SendUploadRequest(
  405. context.WithoutCancel(ctx),
  406. url.String(),
  407. &buf.MultiBufferContainer{MultiBuffer: chunk},
  408. int64(chunk.Len()),
  409. )
  410. wroteRequest.Close()
  411. if err != nil {
  412. errors.LogInfoInner(ctx, err, "failed to send upload")
  413. uploadPipeReader.Interrupt()
  414. }
  415. }()
  416. if _, ok := httpClient.(*DefaultDialerClient); ok {
  417. <-wroteRequest.Wait()
  418. }
  419. }
  420. }()
  421. return stat.Connection(&conn), nil
  422. }
  423. // A wrapper around pipe that ensures the size limit is exactly honored.
  424. //
  425. // The MultiBuffer pipe accepts any single WriteMultiBuffer call even if that
  426. // single MultiBuffer exceeds the size limit, and then starts blocking on the
  427. // next WriteMultiBuffer call. This means that ReadMultiBuffer can return more
  428. // bytes than the size limit. We work around this by splitting a potentially
  429. // too large write up into multiple.
  430. type uploadWriter struct {
  431. *pipe.Writer
  432. maxLen int32
  433. }
  434. func (w uploadWriter) Write(b []byte) (int, error) {
  435. /*
  436. capacity := int(w.maxLen - w.Len())
  437. if capacity > 0 && capacity < len(b) {
  438. b = b[:capacity]
  439. }
  440. */
  441. buffer := buf.New()
  442. n, err := buffer.Write(b)
  443. if err != nil {
  444. return 0, err
  445. }
  446. err = w.WriteMultiBuffer([]*buf.Buffer{buffer})
  447. if err != nil {
  448. return 0, err
  449. }
  450. return n, nil
  451. }