Selaa lähdekoodia

XHTTP client: Move `x_padding` into `Referer` header (#4298)

""Breaking"": Update the server side first, then client
rPDmYQ 9 kuukautta sitten
vanhempi
sitoutus
14a6636a41

+ 69 - 11
transport/internet/browser_dialer/dialer.go

@@ -5,6 +5,7 @@ import (
 	"context"
 	"context"
 	_ "embed"
 	_ "embed"
 	"encoding/base64"
 	"encoding/base64"
+	"encoding/json"
 	"net/http"
 	"net/http"
 	"time"
 	"time"
 
 
@@ -17,6 +18,12 @@ import (
 //go:embed dialer.html
 //go:embed dialer.html
 var webpage []byte
 var webpage []byte
 
 
+type task struct {
+	Method string `json:"method"`
+	URL    string `json:"url"`
+	Extra  any    `json:"extra,omitempty"`
+}
+
 var conns chan *websocket.Conn
 var conns chan *websocket.Conn
 
 
 var upgrader = &websocket.Upgrader{
 var upgrader = &websocket.Upgrader{
@@ -55,23 +62,69 @@ func HasBrowserDialer() bool {
 	return conns != nil
 	return conns != nil
 }
 }
 
 
+type webSocketExtra struct {
+	Protocol string `json:"protocol,omitempty"`
+}
+
 func DialWS(uri string, ed []byte) (*websocket.Conn, error) {
 func DialWS(uri string, ed []byte) (*websocket.Conn, error) {
-	data := []byte("WS " + uri)
+	task := task{
+		Method: "WS",
+		URL:    uri,
+	}
+
 	if ed != nil {
 	if ed != nil {
-		data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...)
+		task.Extra = webSocketExtra{
+			Protocol: base64.RawURLEncoding.EncodeToString(ed),
+		}
 	}
 	}
 
 
-	return dialRaw(data)
+	return dialTask(task)
 }
 }
 
 
-func DialGet(uri string) (*websocket.Conn, error) {
-	data := []byte("GET " + uri)
-	return dialRaw(data)
+type httpExtra struct {
+	Referrer string            `json:"referrer,omitempty"`
+	Headers  map[string]string `json:"headers,omitempty"`
 }
 }
 
 
-func DialPost(uri string, payload []byte) error {
-	data := []byte("POST " + uri)
-	conn, err := dialRaw(data)
+func httpExtraFromHeaders(headers http.Header) *httpExtra {
+	if len(headers) == 0 {
+		return nil
+	}
+
+	extra := httpExtra{}
+	if referrer := headers.Get("Referer"); referrer != "" {
+		extra.Referrer = referrer
+		headers.Del("Referer")
+	}
+
+	if len(headers) > 0 {
+		extra.Headers = make(map[string]string)
+		for header := range headers {
+			extra.Headers[header] = headers.Get(header)
+		}
+	}
+
+	return &extra
+}
+
+func DialGet(uri string, headers http.Header) (*websocket.Conn, error) {
+	task := task{
+		Method: "GET",
+		URL:    uri,
+		Extra:  httpExtraFromHeaders(headers),
+	}
+
+	return dialTask(task)
+}
+
+func DialPost(uri string, headers http.Header, payload []byte) error {
+	task := task{
+		Method: "POST",
+		URL:    uri,
+		Extra:  httpExtraFromHeaders(headers),
+	}
+
+	conn, err := dialTask(task)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -90,7 +143,12 @@ func DialPost(uri string, payload []byte) error {
 	return nil
 	return nil
 }
 }
 
 
-func dialRaw(data []byte) (*websocket.Conn, error) {
+func dialTask(task task) (*websocket.Conn, error) {
+	data, err := json.Marshal(task)
+	if err != nil {
+		return nil, err
+	}
+
 	var conn *websocket.Conn
 	var conn *websocket.Conn
 	for {
 	for {
 		conn = <-conns
 		conn = <-conns
@@ -100,7 +158,7 @@ func dialRaw(data []byte) (*websocket.Conn, error) {
 			break
 			break
 		}
 		}
 	}
 	}
-	err := CheckOK(conn)
+	err = CheckOK(conn)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 47 - 23
transport/internet/browser_dialer/dialer.html

@@ -14,10 +14,28 @@
 		let upstreamGetCount = 0;
 		let upstreamGetCount = 0;
 		let upstreamWsCount = 0;
 		let upstreamWsCount = 0;
 		let upstreamPostCount = 0;
 		let upstreamPostCount = 0;
+
+		function prepareRequestInit(extra) {
+			const requestInit = {};
+			if (extra.referrer) {
+				// note: we have to strip the protocol and host part.
+				// Browsers disallow that, and will reset the value to current page if attempted.
+				const referrer = URL.parse(extra.referrer);
+				requestInit.referrer = referrer.pathname + referrer.search + referrer.hash;
+				requestInit.referrerPolicy = "unsafe-url";
+			}
+
+			if (extra.headers) {
+				requestInit.headers = extra.headers;
+			}
+
+			return requestInit;
+		}
+
 		let check = function () {
 		let check = function () {
 			if (clientIdleCount > 0) {
 			if (clientIdleCount > 0) {
 				return;
 				return;
-			};
+			}
 			clientIdleCount += 1;
 			clientIdleCount += 1;
 			console.log("Prepare", url);
 			console.log("Prepare", url);
 			let ws = new WebSocket(url);
 			let ws = new WebSocket(url);
@@ -29,12 +47,12 @@
 			// double-checking that this continues to work
 			// double-checking that this continues to work
 			ws.onmessage = function (event) {
 			ws.onmessage = function (event) {
 				clientIdleCount -= 1;
 				clientIdleCount -= 1;
-				let [method, url, protocol] = event.data.split(" ");
-				switch (method) {
+				let task = JSON.parse(event.data);
+				switch (task.method) {
 					case "WS": {
 					case "WS": {
 						upstreamWsCount += 1;
 						upstreamWsCount += 1;
-						console.log("Dial WS", url, protocol);
-						const wss = new WebSocket(url, protocol);
+						console.log("Dial WS", task.url, task.extra.protocol);
+						const wss = new WebSocket(task.url, task.extra.protocol);
 						wss.binaryType = "arraybuffer";
 						wss.binaryType = "arraybuffer";
 						let opened = false;
 						let opened = false;
 						ws.onmessage = function (event) {
 						ws.onmessage = function (event) {
@@ -60,10 +78,12 @@
 							wss.close()
 							wss.close()
 						};
 						};
 						break;
 						break;
-					};
+					}
 					case "GET": {
 					case "GET": {
 						(async () => {
 						(async () => {
-							console.log("Dial GET", url);
+							const requestInit = prepareRequestInit(task.extra);
+
+							console.log("Dial GET", task.url);
 							ws.send("ok");
 							ws.send("ok");
 							const controller = new AbortController();
 							const controller = new AbortController();
 
 
@@ -83,58 +103,62 @@
 							ws.onclose = (event) => {
 							ws.onclose = (event) => {
 								try {
 								try {
 									reader && reader.cancel();
 									reader && reader.cancel();
-								} catch(e) {};
+								} catch(e) {}
 
 
 								try {
 								try {
 									controller.abort();
 									controller.abort();
-								} catch(e) {};
+								} catch(e) {}
 							};
 							};
 
 
 							try {
 							try {
 								upstreamGetCount += 1;
 								upstreamGetCount += 1;
-								const response = await fetch(url, {signal: controller.signal});
+
+								requestInit.signal = controller.signal;
+								const response = await fetch(task.url, requestInit);
 
 
 								const body = await response.body;
 								const body = await response.body;
 								reader = body.getReader();
 								reader = body.getReader();
 
 
 								while (true) {
 								while (true) {
 									const { done, value } = await reader.read();
 									const { done, value } = await reader.read();
-									ws.send(value);
+									if (value) ws.send(value);  // don't send back "undefined" string when received nothing
 									if (done) break;
 									if (done) break;
-								};
+								}
 							} finally {
 							} finally {
 								upstreamGetCount -= 1;
 								upstreamGetCount -= 1;
 								console.log("Dial GET DONE, remaining: ", upstreamGetCount);
 								console.log("Dial GET DONE, remaining: ", upstreamGetCount);
 								ws.close();
 								ws.close();
-							};
+							}
 						})();
 						})();
 						break;
 						break;
-					};
+					}
 					case "POST": {
 					case "POST": {
 						upstreamPostCount += 1;
 						upstreamPostCount += 1;
-						console.log("Dial POST", url);
+
+						const requestInit = prepareRequestInit(task.extra);
+						requestInit.method = "POST";
+
+						console.log("Dial POST", task.url);
 						ws.send("ok");
 						ws.send("ok");
 						ws.onmessage = async (event) => {
 						ws.onmessage = async (event) => {
 							try {
 							try {
-								const response = await fetch(
-									url,
-									{method: "POST", body: event.data}
-								);
+								requestInit.body = event.data;
+								const response = await fetch(task.url, requestInit);
 								if (response.ok) {
 								if (response.ok) {
 									ws.send("ok");
 									ws.send("ok");
 								} else {
 								} else {
 									console.error("bad status code");
 									console.error("bad status code");
 									ws.send("fail");
 									ws.send("fail");
-								};
+								}
 							} finally {
 							} finally {
 								upstreamPostCount -= 1;
 								upstreamPostCount -= 1;
 								console.log("Dial POST DONE, remaining: ", upstreamPostCount);
 								console.log("Dial POST DONE, remaining: ", upstreamPostCount);
 								ws.close();
 								ws.close();
-							};
+							}
 						};
 						};
 						break;
 						break;
-					};
-				};
+					}
+				}
 
 
 				check();
 				check();
 			};
 			};

+ 8 - 6
transport/internet/splithttp/browser_client.go

@@ -5,13 +5,15 @@ import (
 	"io"
 	"io"
 	gonet "net"
 	gonet "net"
 
 
+	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/transport/internet/browser_dialer"
 	"github.com/xtls/xray-core/transport/internet/browser_dialer"
 	"github.com/xtls/xray-core/transport/internet/websocket"
 	"github.com/xtls/xray-core/transport/internet/websocket"
 )
 )
 
 
-// implements splithttp.DialerClient in terms of browser dialer
-// has no fields because everything is global state :O)
-type BrowserDialerClient struct{}
+// BrowserDialerClient implements splithttp.DialerClient in terms of browser dialer
+type BrowserDialerClient struct {
+	transportConfig *Config
+}
 
 
 func (c *BrowserDialerClient) IsClosed() bool {
 func (c *BrowserDialerClient) IsClosed() bool {
 	panic("not implemented yet")
 	panic("not implemented yet")
@@ -19,10 +21,10 @@ func (c *BrowserDialerClient) IsClosed() bool {
 
 
 func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (io.ReadCloser, gonet.Addr, gonet.Addr, error) {
 func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (io.ReadCloser, gonet.Addr, gonet.Addr, error) {
 	if body != nil {
 	if body != nil {
-		panic("not implemented yet")
+		return nil, nil, nil, errors.New("bidirectional streaming for browser dialer not implemented yet")
 	}
 	}
 
 
-	conn, err := browser_dialer.DialGet(url)
+	conn, err := browser_dialer.DialGet(url, c.transportConfig.GetRequestHeader())
 	dummyAddr := &gonet.IPAddr{}
 	dummyAddr := &gonet.IPAddr{}
 	if err != nil {
 	if err != nil {
 		return nil, dummyAddr, dummyAddr, err
 		return nil, dummyAddr, dummyAddr, err
@@ -37,7 +39,7 @@ func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, body i
 		return err
 		return err
 	}
 	}
 
 
-	err = browser_dialer.DialPost(url, bytes)
+	err = browser_dialer.DialPost(url, c.transportConfig.GetRequestHeader(), bytes)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 26 - 6
transport/internet/splithttp/config.go

@@ -4,6 +4,7 @@ import (
 	"crypto/rand"
 	"crypto/rand"
 	"math/big"
 	"math/big"
 	"net/http"
 	"net/http"
+	"net/url"
 	"strings"
 	"strings"
 
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common"
@@ -11,6 +12,8 @@ import (
 	"github.com/xtls/xray-core/transport/internet"
 	"github.com/xtls/xray-core/transport/internet"
 )
 )
 
 
+const paddingQuery = "x_padding"
+
 func (c *Config) GetNormalizedPath() string {
 func (c *Config) GetNormalizedPath() string {
 	pathAndQuery := strings.SplitN(c.Path, "?", 2)
 	pathAndQuery := strings.SplitN(c.Path, "?", 2)
 	path := pathAndQuery[0]
 	path := pathAndQuery[0]
@@ -39,11 +42,6 @@ func (c *Config) GetNormalizedQuery() string {
 	}
 	}
 	query += "x_version=" + core.Version()
 	query += "x_version=" + core.Version()
 
 
-	paddingLen := c.GetNormalizedXPaddingBytes().rand()
-	if paddingLen > 0 {
-		query += "&x_padding=" + strings.Repeat("0", int(paddingLen))
-	}
-
 	return query
 	return query
 }
 }
 
 
@@ -53,6 +51,28 @@ func (c *Config) GetRequestHeader() http.Header {
 		header.Add(k, v)
 		header.Add(k, v)
 	}
 	}
 
 
+	paddingLen := c.GetNormalizedXPaddingBytes().rand()
+	if paddingLen > 0 {
+		query, err := url.ParseQuery(c.GetNormalizedQuery())
+		if err != nil {
+			query = url.Values{}
+		}
+		// https://www.rfc-editor.org/rfc/rfc7541.html#appendix-B
+		// h2's HPACK Header Compression feature employs a huffman encoding using a static table.
+		// 'X' is assigned an 8 bit code, so HPACK compression won't change actual padding length on the wire.
+		// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2-2
+		// h3's similar QPACK feature uses the same huffman table.
+		query.Set(paddingQuery, strings.Repeat("X", int(paddingLen)))
+
+		referrer := url.URL{
+			Scheme:   "https", // maybe http actually, but this part is not being checked
+			Host:     c.Host,
+			Path:     c.GetNormalizedPath(),
+			RawQuery: query.Encode(),
+		}
+
+		header.Set("Referer", referrer.String())
+	}
 	return header
 	return header
 }
 }
 
 
@@ -63,7 +83,7 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter) {
 	writer.Header().Set("X-Version", core.Version())
 	writer.Header().Set("X-Version", core.Version())
 	paddingLen := c.GetNormalizedXPaddingBytes().rand()
 	paddingLen := c.GetNormalizedXPaddingBytes().rand()
 	if paddingLen > 0 {
 	if paddingLen > 0 {
-		writer.Header().Set("X-Padding", strings.Repeat("0", int(paddingLen)))
+		writer.Header().Set("X-Padding", strings.Repeat("X", int(paddingLen)))
 	}
 	}
 }
 }
 
 

+ 11 - 7
transport/internet/splithttp/dialer.go

@@ -53,8 +53,8 @@ var (
 func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) {
 func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *XmuxClient) {
 	realityConfig := reality.ConfigFromStreamSettings(streamSettings)
 	realityConfig := reality.ConfigFromStreamSettings(streamSettings)
 
 
-	if browser_dialer.HasBrowserDialer() && realityConfig != nil {
-		return &BrowserDialerClient{}, nil
+	if browser_dialer.HasBrowserDialer() && realityConfig == nil {
+		return &BrowserDialerClient{transportConfig: streamSettings.ProtocolSettings.(*Config)}, nil
 	}
 	}
 
 
 	globalDialerAccess.Lock()
 	globalDialerAccess.Lock()
@@ -367,15 +367,18 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		},
 		},
 	}
 	}
 
 
+	var err error
 	if mode == "stream-one" {
 	if mode == "stream-one" {
 		requestURL.Path = transportConfiguration.GetNormalizedPath()
 		requestURL.Path = transportConfiguration.GetNormalizedPath()
 		if xmuxClient != nil {
 		if xmuxClient != nil {
 			xmuxClient.LeftRequests.Add(-1)
 			xmuxClient.LeftRequests.Add(-1)
 		}
 		}
-		conn.reader, conn.remoteAddr, conn.localAddr, _ = httpClient.OpenStream(ctx, requestURL.String(), reader, false)
+		conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient.OpenStream(ctx, requestURL.String(), reader, false)
+		if err != nil { // browser dialer only
+			return nil, err
+		}
 		return stat.Connection(&conn), nil
 		return stat.Connection(&conn), nil
 	} else { // stream-down
 	} else { // stream-down
-		var err error
 		if xmuxClient2 != nil {
 		if xmuxClient2 != nil {
 			xmuxClient2.LeftRequests.Add(-1)
 			xmuxClient2.LeftRequests.Add(-1)
 		}
 		}
@@ -388,7 +391,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		if xmuxClient != nil {
 		if xmuxClient != nil {
 			xmuxClient.LeftRequests.Add(-1)
 			xmuxClient.LeftRequests.Add(-1)
 		}
 		}
-		httpClient.OpenStream(ctx, requestURL.String(), reader, true)
+		_, _, _, err = httpClient.OpenStream(ctx, requestURL.String(), reader, true)
+		if err != nil { // browser dialer only
+			return nil, err
+		}
 		return stat.Connection(&conn), nil
 		return stat.Connection(&conn), nil
 	}
 	}
 
 
@@ -428,8 +434,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 			// can reassign Path (potentially concurrently)
 			// can reassign Path (potentially concurrently)
 			url := requestURL
 			url := requestURL
 			url.Path += "/" + strconv.FormatInt(seq, 10)
 			url.Path += "/" + strconv.FormatInt(seq, 10)
-			// reassign query to get different padding
-			url.RawQuery = transportConfiguration.GetNormalizedQuery()
 
 
 			seq += 1
 			seq += 1
 
 

+ 20 - 5
transport/internet/splithttp/hub.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"io"
 	gonet "net"
 	gonet "net"
 	"net/http"
 	"net/http"
+	"net/url"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -110,9 +111,23 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 	}
 	}
 
 
 	validRange := h.config.GetNormalizedXPaddingBytes()
 	validRange := h.config.GetNormalizedXPaddingBytes()
-	x_padding := int32(len(request.URL.Query().Get("x_padding")))
-	if validRange.To > 0 && (x_padding < validRange.From || x_padding > validRange.To) {
-		errors.LogInfo(context.Background(), "invalid x_padding length:", x_padding)
+	paddingLength := -1
+
+	if referrerPadding := request.Header.Get("Referer"); referrerPadding != "" {
+		// Browser dialer cannot control the host part of referrer header, so only check the query
+		if referrerURL, err := url.Parse(referrerPadding); err == nil {
+			if query := referrerURL.Query(); query.Has(paddingQuery) {
+				paddingLength = len(query.Get(paddingQuery))
+			}
+		}
+	}
+
+	if paddingLength == -1 {
+		paddingLength = len(request.URL.Query().Get(paddingQuery))
+	}
+
+	if validRange.To > 0 && (int32(paddingLength) < validRange.From || int32(paddingLength) > validRange.To) {
+		errors.LogInfo(context.Background(), "invalid x_padding length:", int32(paddingLength))
 		writer.WriteHeader(http.StatusBadRequest)
 		writer.WriteHeader(http.StatusBadRequest)
 		return
 		return
 	}
 	}
@@ -185,10 +200,10 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 			return
 			return
 		}
 		}
 
 
-		payload, err := io.ReadAll(request.Body)
+		payload, err := io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
 
 
 		if len(payload) > scMaxEachPostBytes {
 		if len(payload) > scMaxEachPostBytes {
-			errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request had size ", len(payload), ". Adjust scMaxEachPostBytes on the server to be at least as large as client.")
+			errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")
 			writer.WriteHeader(http.StatusRequestEntityTooLarge)
 			writer.WriteHeader(http.StatusRequestEntityTooLarge)
 			return
 			return
 		}
 		}