| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- package splithttp
- import (
- "bytes"
- "context"
- "fmt"
- "io"
- "net/http"
- "net/http/httptrace"
- "sync"
- "github.com/xtls/xray-core/common"
- "github.com/xtls/xray-core/common/errors"
- "github.com/xtls/xray-core/common/net"
- "github.com/xtls/xray-core/common/signal/done"
- )
- // interface to abstract between use of browser dialer, vs net/http
- type DialerClient interface {
- IsClosed() bool
- // ctx, url, body, uploadOnly
- OpenStream(context.Context, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error)
- // ctx, url, body, contentLength
- PostPacket(context.Context, string, io.Reader, int64) error
- }
- // implements splithttp.DialerClient in terms of direct network connections
- type DefaultDialerClient struct {
- transportConfig *Config
- client *http.Client
- closed bool
- httpVersion string
- // pool of net.Conn, created using dialUploadConn
- uploadRawPool *sync.Pool
- dialUploadConn func(ctxInner context.Context) (net.Conn, error)
- }
- func (c *DefaultDialerClient) IsClosed() bool {
- return c.closed
- }
- func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (wrc io.ReadCloser, remoteAddr, localAddr net.Addr, err error) {
- // this is done when the TCP/UDP connection to the server was established,
- // and we can unblock the Dial function and print correct net addresses in
- // logs
- gotConn := done.New()
- ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
- GotConn: func(connInfo httptrace.GotConnInfo) {
- remoteAddr = connInfo.Conn.RemoteAddr()
- localAddr = connInfo.Conn.LocalAddr()
- gotConn.Close()
- },
- })
- method := "GET" // stream-down
- if body != nil {
- method = "POST" // stream-up/one
- }
- req, _ := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body)
- req.Header = c.transportConfig.GetRequestHeader(url)
- if method == "POST" && !c.transportConfig.NoGRPCHeader {
- req.Header.Set("Content-Type", "application/grpc")
- }
- wrc = &WaitReadCloser{Wait: make(chan struct{})}
- go func() {
- resp, err := c.client.Do(req)
- if err != nil {
- if !uploadOnly { // stream-down is enough
- c.closed = true
- errors.LogInfoInner(ctx, err, "failed to "+method+" "+url)
- }
- gotConn.Close()
- wrc.Close()
- return
- }
- if resp.StatusCode != 200 && !uploadOnly {
- errors.LogInfo(ctx, "unexpected status ", resp.StatusCode)
- }
- if resp.StatusCode != 200 || uploadOnly { // stream-up
- io.Copy(io.Discard, resp.Body)
- resp.Body.Close() // if it is called immediately, the upload will be interrupted also
- wrc.Close()
- return
- }
- wrc.(*WaitReadCloser).Set(resp.Body)
- }()
- <-gotConn.Wait()
- return
- }
- func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body io.Reader, contentLength int64) error {
- req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), "POST", url, body)
- if err != nil {
- return err
- }
- req.ContentLength = contentLength
- req.Header = c.transportConfig.GetRequestHeader(url)
- if c.httpVersion != "1.1" {
- resp, err := c.client.Do(req)
- if err != nil {
- c.closed = true
- return err
- }
- io.Copy(io.Discard, resp.Body)
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- return errors.New("bad status code:", resp.Status)
- }
- } else {
- // stringify the entire HTTP/1.1 request so it can be
- // safely retried. if instead req.Write is called multiple
- // times, the body is already drained after the first
- // request
- requestBuff := new(bytes.Buffer)
- common.Must(req.Write(requestBuff))
- var uploadConn any
- var h1UploadConn *H1Conn
- for {
- uploadConn = c.uploadRawPool.Get()
- newConnection := uploadConn == nil
- if newConnection {
- newConn, err := c.dialUploadConn(context.WithoutCancel(ctx))
- if err != nil {
- return err
- }
- h1UploadConn = NewH1Conn(newConn)
- uploadConn = h1UploadConn
- } else {
- h1UploadConn = uploadConn.(*H1Conn)
- // TODO: Replace 0 here with a config value later
- // Or add some other condition for optimization purposes
- if h1UploadConn.UnreadedResponsesCount > 0 {
- resp, err := http.ReadResponse(h1UploadConn.RespBufReader, req)
- if err != nil {
- c.closed = true
- return fmt.Errorf("error while reading response: %s", err.Error())
- }
- io.Copy(io.Discard, resp.Body)
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- return fmt.Errorf("got non-200 error response code: %d", resp.StatusCode)
- }
- }
- }
- _, err := h1UploadConn.Write(requestBuff.Bytes())
- // if the write failed, we try another connection from
- // the pool, until the write on a new connection fails.
- // failed writes to a pooled connection are normal when
- // the connection has been closed in the meantime.
- if err == nil {
- break
- } else if newConnection {
- return err
- }
- }
- c.uploadRawPool.Put(uploadConn)
- }
- return nil
- }
- type WaitReadCloser struct {
- Wait chan struct{}
- io.ReadCloser
- }
- func (w *WaitReadCloser) Set(rc io.ReadCloser) {
- w.ReadCloser = rc
- defer func() {
- if recover() != nil {
- rc.Close()
- }
- }()
- close(w.Wait)
- }
- func (w *WaitReadCloser) Read(b []byte) (int, error) {
- if w.ReadCloser == nil {
- if <-w.Wait; w.ReadCloser == nil {
- return 0, io.ErrClosedPipe
- }
- }
- return w.ReadCloser.Read(b)
- }
- func (w *WaitReadCloser) Close() error {
- if w.ReadCloser != nil {
- return w.ReadCloser.Close()
- }
- defer func() {
- if recover() != nil && w.ReadCloser != nil {
- w.ReadCloser.Close()
- }
- }()
- close(w.Wait)
- return nil
- }
|