client.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package splithttp
  2. import (
  3. "bytes"
  4. "context"
  5. "io"
  6. gonet "net"
  7. "net/http"
  8. "net/http/httptrace"
  9. "sync"
  10. "github.com/xtls/xray-core/common"
  11. "github.com/xtls/xray-core/common/errors"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/signal/done"
  14. )
  15. // interface to abstract between use of browser dialer, vs net/http
  16. type DialerClient interface {
  17. // (ctx, baseURL, payload) -> err
  18. // baseURL already contains sessionId and seq
  19. SendUploadRequest(context.Context, string, io.ReadWriteCloser, int64) error
  20. // (ctx, baseURL) -> (downloadReader, remoteAddr, localAddr)
  21. // baseURL already contains sessionId
  22. OpenDownload(context.Context, string) (io.ReadCloser, net.Addr, net.Addr, error)
  23. }
  24. // implements splithttp.DialerClient in terms of direct network connections
  25. type DefaultDialerClient struct {
  26. transportConfig *Config
  27. download *http.Client
  28. upload *http.Client
  29. isH2 bool
  30. isH3 bool
  31. // pool of net.Conn, created using dialUploadConn
  32. uploadRawPool *sync.Pool
  33. dialUploadConn func(ctxInner context.Context) (net.Conn, error)
  34. }
  35. func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) (io.ReadCloser, gonet.Addr, gonet.Addr, error) {
  36. var remoteAddr gonet.Addr
  37. var localAddr gonet.Addr
  38. // this is done when the TCP/UDP connection to the server was established,
  39. // and we can unblock the Dial function and print correct net addresses in
  40. // logs
  41. gotConn := done.New()
  42. var downResponse io.ReadCloser
  43. gotDownResponse := done.New()
  44. go func() {
  45. trace := &httptrace.ClientTrace{
  46. GotConn: func(connInfo httptrace.GotConnInfo) {
  47. remoteAddr = connInfo.Conn.RemoteAddr()
  48. localAddr = connInfo.Conn.LocalAddr()
  49. gotConn.Close()
  50. },
  51. }
  52. // in case we hit an error, we want to unblock this part
  53. defer gotConn.Close()
  54. req, err := http.NewRequestWithContext(
  55. httptrace.WithClientTrace(ctx, trace),
  56. "GET",
  57. baseURL,
  58. nil,
  59. )
  60. if err != nil {
  61. errors.LogInfoInner(ctx, err, "failed to construct download http request")
  62. gotDownResponse.Close()
  63. return
  64. }
  65. req.Header = c.transportConfig.GetRequestHeader()
  66. response, err := c.download.Do(req)
  67. gotConn.Close()
  68. if err != nil {
  69. errors.LogInfoInner(ctx, err, "failed to send download http request")
  70. gotDownResponse.Close()
  71. return
  72. }
  73. if response.StatusCode != 200 {
  74. response.Body.Close()
  75. errors.LogInfo(ctx, "invalid status code on download:", response.Status)
  76. gotDownResponse.Close()
  77. return
  78. }
  79. downResponse = response.Body
  80. gotDownResponse.Close()
  81. }()
  82. // we want to block Dial until we know the remote address of the server,
  83. // for logging purposes
  84. <-gotConn.Wait()
  85. lazyDownload := &LazyReader{
  86. CreateReader: func() (io.ReadCloser, error) {
  87. <-gotDownResponse.Wait()
  88. if downResponse == nil {
  89. return nil, errors.New("downResponse failed")
  90. }
  91. return downResponse, nil
  92. },
  93. }
  94. return lazyDownload, remoteAddr, localAddr, nil
  95. }
  96. func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
  97. req, err := http.NewRequest("POST", url, payload)
  98. req.ContentLength = contentLength
  99. if err != nil {
  100. return err
  101. }
  102. req.Header = c.transportConfig.GetRequestHeader()
  103. if c.isH2 || c.isH3 {
  104. resp, err := c.upload.Do(req)
  105. if err != nil {
  106. return err
  107. }
  108. defer resp.Body.Close()
  109. if resp.StatusCode != 200 {
  110. return errors.New("bad status code:", resp.Status)
  111. }
  112. } else {
  113. // stringify the entire HTTP/1.1 request so it can be
  114. // safely retried. if instead req.Write is called multiple
  115. // times, the body is already drained after the first
  116. // request
  117. requestBytes := new(bytes.Buffer)
  118. common.Must(req.Write(requestBytes))
  119. var uploadConn any
  120. for {
  121. uploadConn = c.uploadRawPool.Get()
  122. newConnection := uploadConn == nil
  123. if newConnection {
  124. uploadConn, err = c.dialUploadConn(context.WithoutCancel(ctx))
  125. if err != nil {
  126. return err
  127. }
  128. }
  129. _, err = uploadConn.(net.Conn).Write(requestBytes.Bytes())
  130. // if the write failed, we try another connection from
  131. // the pool, until the write on a new connection fails.
  132. // failed writes to a pooled connection are normal when
  133. // the connection has been closed in the meantime.
  134. if err == nil {
  135. break
  136. } else if newConnection {
  137. return err
  138. }
  139. }
  140. c.uploadRawPool.Put(uploadConn)
  141. }
  142. return nil
  143. }