Răsfoiți Sursa

XHTTP client: Add decideHTTPVersion() and more logs

https://github.com/XTLS/Xray-core/pull/4150#issuecomment-2537981368
RPRX 10 luni în urmă
părinte
comite
7463561856

+ 2 - 3
transport/internet/splithttp/client.go

@@ -39,8 +39,7 @@ type DialerClient interface {
 type DefaultDialerClient struct {
 	transportConfig *Config
 	client          *http.Client
-	isH2            bool
-	isH3            bool
+	httpVersion     string
 	// pool of net.Conn, created using dialUploadConn
 	uploadRawPool  *sync.Pool
 	dialUploadConn func(ctxInner context.Context) (net.Conn, error)
@@ -172,7 +171,7 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string,
 	req.ContentLength = contentLength
 	req.Header = c.transportConfig.GetRequestHeader()
 
-	if c.isH2 || c.isH3 {
+	if c.httpVersion != "1.1" {
 		resp, err := c.client.Do(req)
 		if err != nil {
 			return err

+ 57 - 42
transport/internet/splithttp/dialer.go

@@ -3,6 +3,7 @@ package splithttp
 import (
 	"context"
 	gotls "crypto/tls"
+	"fmt"
 	"io"
 	"net/http"
 	"net/http/httptrace"
@@ -83,23 +84,32 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
 	return res.Resource.(DialerClient), res
 }
 
+func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) string {
+	if realityConfig != nil {
+		return "2"
+	}
+	if tlsConfig == nil {
+		return "1.1"
+	}
+	if len(tlsConfig.NextProtocol) != 1 {
+		return "2"
+	}
+	if tlsConfig.NextProtocol[0] == "http/1.1" {
+		return "1.1"
+	}
+	if tlsConfig.NextProtocol[0] == "h3" {
+		return "3"
+	}
+	return "2"
+}
+
 func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
 	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
 	realityConfig := reality.ConfigFromStreamSettings(streamSettings)
 
-	isH2 := false
-	isH3 := false
-
-	if tlsConfig != nil {
-		isH2 = !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
-		isH3 = len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3"
-	} else if realityConfig != nil {
-		isH2 = true
-		isH3 = false
-	}
-
-	if isH3 {
-		dest.Network = net.Network_UDP
+	httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
+	if httpVersion == "3" {
+		dest.Network = net.Network_UDP // better to keep this line
 	}
 
 	var gotlsConfig *gotls.Config
@@ -138,7 +148,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
 
 	var transport http.RoundTripper
 
-	if isH3 {
+	if httpVersion == "3" {
 		if keepAlivePeriod == 0 {
 			keepAlivePeriod = quicgoH3KeepAlivePeriod
 		}
@@ -194,7 +204,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
 				return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
 			},
 		}
-	} else if isH2 {
+	} else if httpVersion == "2" {
 		if keepAlivePeriod == 0 {
 			keepAlivePeriod = chromeH2KeepAlivePeriod
 		}
@@ -228,8 +238,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
 		client: &http.Client{
 			Transport: transport,
 		},
-		isH2:           isH2,
-		isH3:           isH3,
+		httpVersion:    httpVersion,
 		uploadRawPool:  &sync.Pool{},
 		dialUploadConn: dialContext,
 	}
@@ -242,16 +251,16 @@ func init() {
 }
 
 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
-	errors.LogInfo(ctx, "dialing splithttp to ", dest)
-
-	var requestURL url.URL
-
-	transportConfiguration := streamSettings.ProtocolSettings.(*Config)
 	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
 	realityConfig := reality.ConfigFromStreamSettings(streamSettings)
 
-	scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes()
-	scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs()
+	httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
+	if httpVersion == "3" {
+		dest.Network = net.Network_UDP
+	}
+
+	transportConfiguration := streamSettings.ProtocolSettings.(*Config)
+	var requestURL url.URL
 
 	if tlsConfig != nil || realityConfig != nil {
 		requestURL.Scheme = "https"
@@ -275,8 +284,21 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 
 	httpClient, muxRes := getHTTPClient(ctx, dest, streamSettings)
 
-	httpClient2 := httpClient
+	mode := transportConfiguration.Mode
+	if mode == "" || mode == "auto" {
+		mode = "packet-up"
+		if httpVersion == "2" {
+			mode = "stream-up"
+		}
+		if realityConfig != nil && transportConfiguration.DownloadSettings == nil {
+			mode = "stream-one"
+		}
+	}
+
+	errors.LogInfo(ctx, fmt.Sprintf("XHTTP is dialing to %s, mode %s, HTTP version %s, host %s", dest, mode, httpVersion, requestURL.Host))
+
 	requestURL2 := requestURL
+	httpClient2 := httpClient
 	var muxRes2 *muxResource
 	if transportConfiguration.DownloadSettings != nil {
 		globalDialerAccess.Lock()
@@ -286,9 +308,12 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		globalDialerAccess.Unlock()
 		memory2 := streamSettings.DownloadSettings
 		dest2 := *memory2.Destination // just panic
-		httpClient2, muxRes2 = getHTTPClient(ctx, dest2, memory2)
 		tlsConfig2 := tls.ConfigFromStreamSettings(memory2)
 		realityConfig2 := reality.ConfigFromStreamSettings(memory2)
+		httpVersion2 := decideHTTPVersion(tlsConfig2, realityConfig2)
+		if httpVersion2 == "3" {
+			dest2.Network = net.Network_UDP
+		}
 		if tlsConfig2 != nil || realityConfig2 != nil {
 			requestURL2.Scheme = "https"
 		} else {
@@ -307,20 +332,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		}
 		requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String()
 		requestURL2.RawQuery = config2.GetNormalizedQuery()
+		httpClient2, muxRes2 = getHTTPClient(ctx, dest2, memory2)
+		errors.LogInfo(ctx, fmt.Sprintf("XHTTP is downloading from %s, mode %s, HTTP version %s, host %s", dest2, "stream-down", httpVersion2, requestURL2.Host))
 	}
 
-	mode := transportConfiguration.Mode
-	if mode == "" || mode == "auto" {
-		mode = "packet-up"
-		if (tlsConfig != nil && (len(tlsConfig.NextProtocol) != 1 || tlsConfig.NextProtocol[0] == "h2")) || realityConfig != nil {
-			mode = "stream-up"
-		}
-		if realityConfig != nil && transportConfiguration.DownloadSettings == nil {
-			mode = "stream-one"
-		}
-	}
-	errors.LogInfo(ctx, "XHTTP is using mode: "+mode)
-
 	var writer io.WriteCloser
 	var reader io.ReadCloser
 	var remoteAddr, localAddr net.Addr
@@ -373,6 +388,9 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		return stat.Connection(&conn), nil
 	}
 
+	scMaxEachPostBytes := transportConfiguration.GetNormalizedScMaxEachPostBytes()
+	scMinPostsIntervalMs := transportConfiguration.GetNormalizedScMinPostsIntervalMs()
+
 	maxUploadSize := scMaxEachPostBytes.roll()
 	// WithSizeLimit(0) will still allow single bytes to pass, and a lot of
 	// code relies on this behavior. Subtract 1 so that together with
@@ -408,10 +426,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 			seq += 1
 
 			if scMinPostsIntervalMs.From > 0 {
-				sleep := time.Duration(scMinPostsIntervalMs.roll())*time.Millisecond - time.Since(lastWrite)
-				if sleep > 0 {
-					time.Sleep(sleep)
-				}
+				time.Sleep(time.Duration(scMinPostsIntervalMs.roll())*time.Millisecond - time.Since(lastWrite))
 			}
 
 			// by offloading the uploads into a buffered pipe, multiple conn.Write

+ 9 - 9
transport/internet/splithttp/hub.go

@@ -333,30 +333,30 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 			Net:  "unix",
 		}, streamSettings.SocketSettings)
 		if err != nil {
-			return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
+			return nil, errors.New("failed to listen UNIX domain socket for XHTTP on ", address).Base(err)
 		}
-		errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
+		errors.LogInfo(ctx, "listening UNIX domain socket for XHTTP on ", address)
 	} else if l.isH3 { // quic
 		Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
 			IP:   address.IP(),
 			Port: int(port),
 		}, streamSettings.SocketSettings)
 		if err != nil {
-			return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
+			return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err)
 		}
 		h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
 		if err != nil {
-			return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
+			return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err)
 		}
 		l.h3listener = h3listener
-		errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
+		errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port)
 
 		l.h3server = &http3.Server{
 			Handler: handler,
 		}
 		go func() {
 			if err := l.h3server.ServeListener(l.h3listener); err != nil {
-				errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
+				errors.LogWarningInner(ctx, err, "failed to serve HTTP/3 for XHTTP/3")
 			}
 		}()
 	} else { // tcp
@@ -369,9 +369,9 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 			Port: int(port),
 		}, streamSettings.SocketSettings)
 		if err != nil {
-			return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
+			return nil, errors.New("failed to listen TCP for XHTTP on ", address, ":", port).Base(err)
 		}
-		errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
+		errors.LogInfo(ctx, "listening TCP for XHTTP on ", address, ":", port)
 	}
 
 	// tcp/unix (h1/h2)
@@ -397,7 +397,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 
 		go func() {
 			if err := l.server.Serve(l.listener); err != nil {
-				errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
+				errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP")
 			}
 		}()
 	}

+ 1 - 1
transport/internet/splithttp/upload_queue.go

@@ -52,7 +52,7 @@ func (h *uploadQueue) Push(p Packet) error {
 		if p.Reader != nil {
 			p.Reader.Close()
 		}
-		return errors.New("splithttp packet queue closed")
+		return errors.New("packet queue closed")
 	}
 
 	h.pushedPackets <- p