| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- package splithttp
- import (
- "bytes"
- "context"
- "io"
- gonet "net"
- "net/http"
- "net/http/httptrace"
- "sync"
- "github.com/xtls/xray-core/common"
- "github.com/xtls/xray-core/common/errors"
- "github.com/xtls/xray-core/common/net"
- "github.com/xtls/xray-core/common/signal/done"
- )
- // interface to abstract between use of browser dialer, vs net/http
- type DialerClient interface {
- // (ctx, baseURL, payload) -> err
- // baseURL already contains sessionId and seq
- SendUploadRequest(context.Context, string, io.ReadWriteCloser, int64) error
- // (ctx, baseURL) -> (downloadReader, remoteAddr, localAddr)
- // baseURL already contains sessionId
- OpenDownload(context.Context, string) (io.ReadCloser, net.Addr, net.Addr, error)
- }
- // implements splithttp.DialerClient in terms of direct network connections
- type DefaultDialerClient struct {
- transportConfig *Config
- download *http.Client
- upload *http.Client
- isH2 bool
- isH3 bool
- // pool of net.Conn, created using dialUploadConn
- uploadRawPool *sync.Pool
- dialUploadConn func(ctxInner context.Context) (net.Conn, error)
- }
- func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) (io.ReadCloser, gonet.Addr, gonet.Addr, error) {
- var remoteAddr gonet.Addr
- var localAddr gonet.Addr
- // this is done when the TCP/UDP connection to the server was established,
- // and we can unblock the Dial function and print correct net addresses in
- // logs
- gotConn := done.New()
- var downResponse io.ReadCloser
- gotDownResponse := done.New()
- go func() {
- trace := &httptrace.ClientTrace{
- GotConn: func(connInfo httptrace.GotConnInfo) {
- remoteAddr = connInfo.Conn.RemoteAddr()
- localAddr = connInfo.Conn.LocalAddr()
- gotConn.Close()
- },
- }
- // in case we hit an error, we want to unblock this part
- defer gotConn.Close()
- req, err := http.NewRequestWithContext(
- httptrace.WithClientTrace(ctx, trace),
- "GET",
- baseURL,
- nil,
- )
- if err != nil {
- errors.LogInfoInner(ctx, err, "failed to construct download http request")
- gotDownResponse.Close()
- return
- }
- req.Header = c.transportConfig.GetRequestHeader()
- response, err := c.download.Do(req)
- gotConn.Close()
- if err != nil {
- errors.LogInfoInner(ctx, err, "failed to send download http request")
- gotDownResponse.Close()
- return
- }
- if response.StatusCode != 200 {
- response.Body.Close()
- errors.LogInfo(ctx, "invalid status code on download:", response.Status)
- gotDownResponse.Close()
- return
- }
- downResponse = response.Body
- gotDownResponse.Close()
- }()
- // we want to block Dial until we know the remote address of the server,
- // for logging purposes
- <-gotConn.Wait()
- lazyDownload := &LazyReader{
- CreateReader: func() (io.ReadCloser, error) {
- <-gotDownResponse.Wait()
- if downResponse == nil {
- return nil, errors.New("downResponse failed")
- }
- return downResponse, nil
- },
- }
- return lazyDownload, remoteAddr, localAddr, nil
- }
- func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
- req, err := http.NewRequest("POST", url, payload)
- req.ContentLength = contentLength
- if err != nil {
- return err
- }
- req.Header = c.transportConfig.GetRequestHeader()
- if c.isH2 || c.isH3 {
- resp, err := c.upload.Do(req)
- if err != nil {
- return err
- }
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- return errors.New("bad status code:", resp.Status)
- }
- } else {
- // stringify the entire HTTP/1.1 request so it can be
- // safely retried. if instead req.Write is called multiple
- // times, the body is already drained after the first
- // request
- requestBytes := new(bytes.Buffer)
- common.Must(req.Write(requestBytes))
- var uploadConn any
- for {
- uploadConn = c.uploadRawPool.Get()
- newConnection := uploadConn == nil
- if newConnection {
- uploadConn, err = c.dialUploadConn(context.WithoutCancel(ctx))
- if err != nil {
- return err
- }
- }
- _, err = uploadConn.(net.Conn).Write(requestBytes.Bytes())
- // if the write failed, we try another connection from
- // the pool, until the write on a new connection fails.
- // failed writes to a pooled connection are normal when
- // the connection has been closed in the meantime.
- if err == nil {
- break
- } else if newConnection {
- return err
- }
- }
- c.uploadRawPool.Put(uploadConn)
- }
- return nil
- }
|