dialer.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. package splithttp
  2. import (
  3. "context"
  4. gotls "crypto/tls"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/http/httptrace"
  9. "net/url"
  10. "strconv"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. "github.com/quic-go/quic-go"
  15. "github.com/quic-go/quic-go/http3"
  16. "github.com/xtls/xray-core/common"
  17. "github.com/xtls/xray-core/common/buf"
  18. "github.com/xtls/xray-core/common/errors"
  19. "github.com/xtls/xray-core/common/net"
  20. "github.com/xtls/xray-core/common/signal/done"
  21. "github.com/xtls/xray-core/common/uuid"
  22. "github.com/xtls/xray-core/transport/internet"
  23. "github.com/xtls/xray-core/transport/internet/browser_dialer"
  24. "github.com/xtls/xray-core/transport/internet/reality"
  25. "github.com/xtls/xray-core/transport/internet/stat"
  26. "github.com/xtls/xray-core/transport/internet/tls"
  27. "github.com/xtls/xray-core/transport/pipe"
  28. "golang.org/x/net/http2"
  29. )
  30. type dialerConf struct {
  31. net.Destination
  32. *internet.MemoryStreamConfig
  33. }
  34. var (
  35. globalDialerMap map[dialerConf]*XmuxManager
  36. globalDialerAccess sync.Mutex
  37. )
  38. func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) {
  39. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  40. if browser_dialer.HasBrowserDialer() && realityConfig == nil {
  41. return &BrowserDialerClient{transportConfig: streamSettings.ProtocolSettings.(*Config)}, nil
  42. }
  43. globalDialerAccess.Lock()
  44. defer globalDialerAccess.Unlock()
  45. if globalDialerMap == nil {
  46. globalDialerMap = make(map[dialerConf]*XmuxManager)
  47. }
  48. key := dialerConf{dest, streamSettings}
  49. xmuxManager, found := globalDialerMap[key]
  50. if !found {
  51. transportConfig := streamSettings.ProtocolSettings.(*Config)
  52. var xmuxConfig XmuxConfig
  53. if transportConfig.Xmux != nil {
  54. xmuxConfig = *transportConfig.Xmux
  55. }
  56. xmuxManager = NewXmuxManager(xmuxConfig, func() XmuxConn {
  57. return createHTTPClient(dest, streamSettings)
  58. })
  59. globalDialerMap[key] = xmuxManager
  60. }
  61. xmuxClient := xmuxManager.GetXmuxClient(ctx)
  62. return xmuxClient.XmuxConn.(DialerClient), xmuxClient
  63. }
  64. func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) string {
  65. if realityConfig != nil {
  66. return "2"
  67. }
  68. if tlsConfig == nil {
  69. return "1.1"
  70. }
  71. if len(tlsConfig.NextProtocol) != 1 {
  72. return "2"
  73. }
  74. if tlsConfig.NextProtocol[0] == "http/1.1" {
  75. return "1.1"
  76. }
  77. if tlsConfig.NextProtocol[0] == "h3" {
  78. return "3"
  79. }
  80. return "2"
  81. }
  82. func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
  83. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  84. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  85. httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
  86. if httpVersion == "3" {
  87. dest.Network = net.Network_UDP // better to keep this line
  88. }
  89. var gotlsConfig *gotls.Config
  90. if tlsConfig != nil {
  91. gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
  92. }
  93. transportConfig := streamSettings.ProtocolSettings.(*Config)
  94. dialContext := func(ctxInner context.Context) (net.Conn, error) {
  95. conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings)
  96. if err != nil {
  97. return nil, err
  98. }
  99. if realityConfig != nil {
  100. return reality.UClient(conn, realityConfig, ctxInner, dest)
  101. }
  102. if gotlsConfig != nil {
  103. if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
  104. conn = tls.UClient(conn, gotlsConfig, fingerprint)
  105. if err := conn.(*tls.UConn).HandshakeContext(ctxInner); err != nil {
  106. return nil, err
  107. }
  108. } else {
  109. conn = tls.Client(conn, gotlsConfig)
  110. }
  111. }
  112. return conn, nil
  113. }
  114. var keepAlivePeriod time.Duration
  115. if streamSettings.ProtocolSettings.(*Config).Xmux != nil {
  116. keepAlivePeriod = time.Duration(streamSettings.ProtocolSettings.(*Config).Xmux.HKeepAlivePeriod) * time.Second
  117. }
  118. var transport http.RoundTripper
  119. if httpVersion == "3" {
  120. if keepAlivePeriod == 0 {
  121. keepAlivePeriod = net.QuicgoH3KeepAlivePeriod
  122. }
  123. if keepAlivePeriod < 0 {
  124. keepAlivePeriod = 0
  125. }
  126. quicConfig := &quic.Config{
  127. MaxIdleTimeout: net.ConnIdleTimeout,
  128. // these two are defaults of quic-go/http3. the default of quic-go (no
  129. // http3) is different, so it is hardcoded here for clarity.
  130. // https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
  131. MaxIncomingStreams: -1,
  132. KeepAlivePeriod: keepAlivePeriod,
  133. }
  134. transport = &http3.Transport{
  135. QUICConfig: quicConfig,
  136. TLSClientConfig: gotlsConfig,
  137. Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (*quic.Conn, error) {
  138. conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
  139. if err != nil {
  140. return nil, err
  141. }
  142. var udpConn net.PacketConn
  143. var udpAddr *net.UDPAddr
  144. switch c := conn.(type) {
  145. case *internet.PacketConnWrapper:
  146. var ok bool
  147. udpConn, ok = c.Conn.(*net.UDPConn)
  148. if !ok {
  149. return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
  150. }
  151. udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
  152. if err != nil {
  153. return nil, err
  154. }
  155. case *net.UDPConn:
  156. udpConn = c
  157. udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
  158. if err != nil {
  159. return nil, err
  160. }
  161. default:
  162. udpConn = &internet.FakePacketConn{Conn: c}
  163. udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
  164. if err != nil {
  165. return nil, err
  166. }
  167. }
  168. return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
  169. },
  170. }
  171. } else if httpVersion == "2" {
  172. if keepAlivePeriod == 0 {
  173. keepAlivePeriod = net.ChromeH2KeepAlivePeriod
  174. }
  175. if keepAlivePeriod < 0 {
  176. keepAlivePeriod = 0
  177. }
  178. transport = &http2.Transport{
  179. DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
  180. return dialContext(ctxInner)
  181. },
  182. IdleConnTimeout: net.ConnIdleTimeout,
  183. ReadIdleTimeout: keepAlivePeriod,
  184. }
  185. } else {
  186. httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
  187. return dialContext(ctxInner)
  188. }
  189. transport = &http.Transport{
  190. DialTLSContext: httpDialContext,
  191. DialContext: httpDialContext,
  192. IdleConnTimeout: net.ConnIdleTimeout,
  193. // chunked transfer download with KeepAlives is buggy with
  194. // http.Client and our custom dial context.
  195. DisableKeepAlives: true,
  196. }
  197. }
  198. client := &DefaultDialerClient{
  199. transportConfig: transportConfig,
  200. client: &http.Client{
  201. Transport: transport,
  202. },
  203. httpVersion: httpVersion,
  204. uploadRawPool: &sync.Pool{},
  205. dialUploadConn: dialContext,
  206. }
  207. return client
  208. }
  209. func init() {
  210. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  211. }
  212. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
  213. tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
  214. realityConfig := reality.ConfigFromStreamSettings(streamSettings)
  215. httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
  216. if httpVersion == "3" {
  217. dest.Network = net.Network_UDP
  218. }
  219. transportConfiguration := streamSettings.ProtocolSettings.(*Config)
  220. var requestURL url.URL
  221. if tlsConfig != nil || realityConfig != nil {
  222. requestURL.Scheme = "https"
  223. } else {
  224. requestURL.Scheme = "http"
  225. }
  226. requestURL.Host = transportConfiguration.Host
  227. if requestURL.Host == "" && tlsConfig != nil {
  228. requestURL.Host = tlsConfig.ServerName
  229. }
  230. if requestURL.Host == "" && realityConfig != nil {
  231. requestURL.Host = realityConfig.ServerName
  232. }
  233. if requestURL.Host == "" {
  234. requestURL.Host = dest.Address.String()
  235. }
  236. sessionIdUuid := uuid.New()
  237. requestURL.Path = transportConfiguration.GetNormalizedPath() + sessionIdUuid.String()
  238. requestURL.RawQuery = transportConfiguration.GetNormalizedQuery()
  239. httpClient, xmuxClient := getHTTPClient(ctx, dest, streamSettings)
  240. mode := transportConfiguration.Mode
  241. if mode == "" || mode == "auto" {
  242. mode = "packet-up"
  243. if realityConfig != nil {
  244. mode = "stream-one"
  245. if transportConfiguration.DownloadSettings != nil {
  246. mode = "stream-up"
  247. }
  248. }
  249. }
  250. errors.LogInfo(ctx, fmt.Sprintf("XHTTP is dialing to %s, mode %s, HTTP version %s, host %s", dest, mode, httpVersion, requestURL.Host))
  251. requestURL2 := requestURL
  252. httpClient2 := httpClient
  253. xmuxClient2 := xmuxClient
  254. if transportConfiguration.DownloadSettings != nil {
  255. globalDialerAccess.Lock()
  256. if streamSettings.DownloadSettings == nil {
  257. streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig)
  258. if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.Penetrate {
  259. streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings
  260. }
  261. }
  262. globalDialerAccess.Unlock()
  263. memory2 := streamSettings.DownloadSettings
  264. dest2 := *memory2.Destination // just panic
  265. tlsConfig2 := tls.ConfigFromStreamSettings(memory2)
  266. realityConfig2 := reality.ConfigFromStreamSettings(memory2)
  267. httpVersion2 := decideHTTPVersion(tlsConfig2, realityConfig2)
  268. if httpVersion2 == "3" {
  269. dest2.Network = net.Network_UDP
  270. }
  271. if tlsConfig2 != nil || realityConfig2 != nil {
  272. requestURL2.Scheme = "https"
  273. } else {
  274. requestURL2.Scheme = "http"
  275. }
  276. config2 := memory2.ProtocolSettings.(*Config)
  277. requestURL2.Host = config2.Host
  278. if requestURL2.Host == "" && tlsConfig2 != nil {
  279. requestURL2.Host = tlsConfig2.ServerName
  280. }
  281. if requestURL2.Host == "" && realityConfig2 != nil {
  282. requestURL2.Host = realityConfig2.ServerName
  283. }
  284. if requestURL2.Host == "" {
  285. requestURL2.Host = dest2.Address.String()
  286. }
  287. requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String()
  288. requestURL2.RawQuery = config2.GetNormalizedQuery()
  289. httpClient2, xmuxClient2 = getHTTPClient(ctx, dest2, memory2)
  290. errors.LogInfo(ctx, fmt.Sprintf("XHTTP is downloading from %s, mode %s, HTTP version %s, host %s", dest2, "stream-down", httpVersion2, requestURL2.Host))
  291. }
  292. if xmuxClient != nil {
  293. xmuxClient.OpenUsage.Add(1)
  294. }
  295. if xmuxClient2 != nil && xmuxClient2 != xmuxClient {
  296. xmuxClient2.OpenUsage.Add(1)
  297. }
  298. var closed atomic.Int32
  299. reader, writer := io.Pipe()
  300. conn := splitConn{
  301. writer: writer,
  302. onClose: func() {
  303. if closed.Add(1) > 1 {
  304. return
  305. }
  306. if xmuxClient != nil {
  307. xmuxClient.OpenUsage.Add(-1)
  308. }
  309. if xmuxClient2 != nil && xmuxClient2 != xmuxClient {
  310. xmuxClient2.OpenUsage.Add(-1)
  311. }
  312. },
  313. }
  314. var err error
  315. if mode == "stream-one" {
  316. requestURL.Path = transportConfiguration.GetNormalizedPath()
  317. if xmuxClient != nil {
  318. xmuxClient.LeftRequests.Add(-1)
  319. }
  320. conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient.OpenStream(ctx, requestURL.String(), reader, false)
  321. if err != nil { // browser dialer only
  322. return nil, err
  323. }
  324. return stat.Connection(&conn), nil
  325. } else { // stream-down
  326. if xmuxClient2 != nil {
  327. xmuxClient2.LeftRequests.Add(-1)
  328. }
  329. conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient2.OpenStream(ctx, requestURL2.String(), nil, false)
  330. if err != nil { // browser dialer only
  331. return nil, err
  332. }
  333. }
  334. if mode == "stream-up" {
  335. if xmuxClient != nil {
  336. xmuxClient.LeftRequests.Add(-1)
  337. }
  338. _, _, _, err = httpClient.OpenStream(ctx, requestURL.String(), reader, true)
  339. if err != nil { // browser dialer only
  340. return nil, err
  341. }
  342. return stat.Connection(&conn), nil
  343. }
  344. scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes()
  345. scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs()
  346. if scMaxEachPostBytes.From <= buf.Size {
  347. panic("`scMaxEachPostBytes` should be bigger than " + strconv.Itoa(buf.Size))
  348. }
  349. maxUploadSize := scMaxEachPostBytes.rand()
  350. // WithSizeLimit(0) will still allow single bytes to pass, and a lot of
  351. // code relies on this behavior. Subtract 1 so that together with
  352. // uploadWriter wrapper, exact size limits can be enforced
  353. // uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - 1))
  354. uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - buf.Size))
  355. conn.writer = uploadWriter{
  356. uploadPipeWriter,
  357. maxUploadSize,
  358. }
  359. go func() {
  360. var seq int64
  361. var lastWrite time.Time
  362. for {
  363. wroteRequest := done.New()
  364. ctx := httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
  365. WroteRequest: func(httptrace.WroteRequestInfo) {
  366. wroteRequest.Close()
  367. },
  368. })
  369. // this intentionally makes a shallow-copy of the struct so we
  370. // can reassign Path (potentially concurrently)
  371. url := requestURL
  372. url.Path += "/" + strconv.FormatInt(seq, 10)
  373. seq += 1
  374. if scMinPostsIntervalMs.From > 0 {
  375. time.Sleep(time.Duration(scMinPostsIntervalMs.rand())*time.Millisecond - time.Since(lastWrite))
  376. }
  377. // by offloading the uploads into a buffered pipe, multiple conn.Write
  378. // calls get automatically batched together into larger POST requests.
  379. // without batching, bandwidth is extremely limited.
  380. chunk, err := uploadPipeReader.ReadMultiBuffer()
  381. if err != nil {
  382. break
  383. }
  384. lastWrite = time.Now()
  385. if xmuxClient != nil && (xmuxClient.LeftRequests.Add(-1) <= 0 ||
  386. (xmuxClient.UnreusableAt != time.Time{} && lastWrite.After(xmuxClient.UnreusableAt))) {
  387. httpClient, xmuxClient = getHTTPClient(ctx, dest, streamSettings)
  388. }
  389. go func() {
  390. err := httpClient.PostPacket(
  391. ctx,
  392. url.String(),
  393. &buf.MultiBufferContainer{MultiBuffer: chunk},
  394. int64(chunk.Len()),
  395. )
  396. wroteRequest.Close()
  397. if err != nil {
  398. errors.LogInfoInner(ctx, err, "failed to send upload")
  399. uploadPipeReader.Interrupt()
  400. }
  401. }()
  402. if _, ok := httpClient.(*DefaultDialerClient); ok {
  403. <-wroteRequest.Wait()
  404. }
  405. }
  406. }()
  407. return stat.Connection(&conn), nil
  408. }
  409. // A wrapper around pipe that ensures the size limit is exactly honored.
  410. //
  411. // The MultiBuffer pipe accepts any single WriteMultiBuffer call even if that
  412. // single MultiBuffer exceeds the size limit, and then starts blocking on the
  413. // next WriteMultiBuffer call. This means that ReadMultiBuffer can return more
  414. // bytes than the size limit. We work around this by splitting a potentially
  415. // too large write up into multiple.
  416. type uploadWriter struct {
  417. *pipe.Writer
  418. maxLen int32
  419. }
  420. func (w uploadWriter) Write(b []byte) (int, error) {
  421. /*
  422. capacity := int(w.maxLen - w.Len())
  423. if capacity > 0 && capacity < len(b) {
  424. b = b[:capacity]
  425. }
  426. */
  427. buffer := buf.New()
  428. n, err := buffer.Write(b)
  429. if err != nil {
  430. return 0, err
  431. }
  432. err = w.WriteMultiBuffer([]*buf.Buffer{buffer})
  433. if err != nil {
  434. return 0, err
  435. }
  436. return n, nil
  437. }