ソースを参照

Upgrade SplitHTTP Transport (#3462)

* move to paths instead of querystrings

* permit early data on serverside

* early data for the client, fix context cancellation
mmmray 1 年間 前
コミット
8fe976d7ee

+ 77 - 58
transport/internet/splithttp/dialer.go

@@ -16,6 +16,7 @@ import (
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal/done"
 	"github.com/xtls/xray-core/common/signal/semaphore"
 	"github.com/xtls/xray-core/common/uuid"
 	"github.com/xtls/xray-core/transport/internet"
@@ -44,18 +45,6 @@ var (
 	globalDialerAccess sync.Mutex
 )
 
-func destroyHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) {
-	globalDialerAccess.Lock()
-	defer globalDialerAccess.Unlock()
-
-	if globalDialerMap == nil {
-		globalDialerMap = make(map[dialerConf]reusedClient)
-	}
-
-	delete(globalDialerMap, dialerConf{dest, streamSettings})
-
-}
-
 func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) reusedClient {
 	globalDialerAccess.Lock()
 	defer globalDialerAccess.Unlock()
@@ -77,7 +66,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
 	}
 
 	dialContext := func(ctxInner context.Context) (net.Conn, error) {
-		conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
+		conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings)
 		if err != nil {
 			return nil, err
 		}
@@ -85,7 +74,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
 		if gotlsConfig != nil {
 			if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
 				conn = tls.UClient(conn, gotlsConfig, fingerprint)
-				if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil {
+				if err := conn.(*tls.UConn).HandshakeContext(ctxInner); err != nil {
 					return nil, err
 				}
 			} else {
@@ -171,49 +160,73 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 
 	var remoteAddr gonet.Addr
 	var localAddr gonet.Addr
+	// this is done when the TCP/UDP connection to the server was established,
+	// and we can unblock the Dial function and print correct net addresses in
+	// logs
+	gotConn := done.New()
 
-	trace := &httptrace.ClientTrace{
-		GotConn: func(connInfo httptrace.GotConnInfo) {
-			remoteAddr = connInfo.Conn.RemoteAddr()
-			localAddr = connInfo.Conn.LocalAddr()
-		},
-	}
+	var downResponse io.ReadCloser
+	gotDownResponse := done.New()
 
 	sessionIdUuid := uuid.New()
 	sessionId := sessionIdUuid.String()
 
-	req, err := http.NewRequestWithContext(
-		httptrace.WithClientTrace(ctx, trace),
-		"GET",
-		requestURL.String()+"?session="+sessionId,
-		nil,
-	)
-	if err != nil {
-		return nil, err
-	}
+	go func() {
+		trace := &httptrace.ClientTrace{
+			GotConn: func(connInfo httptrace.GotConnInfo) {
+				remoteAddr = connInfo.Conn.RemoteAddr()
+				localAddr = connInfo.Conn.LocalAddr()
+				gotConn.Close()
+			},
+		}
 
-	req.Header = transportConfiguration.GetRequestHeader()
-
-	downResponse, err := httpClient.download.Do(req)
-	if err != nil {
-		// workaround for various connection pool related issues, mostly around
-		// HTTP/1.1. if the http client ever fails to send a request, we simply
-		// delete it entirely.
-		// in HTTP/1.1, it was observed that pool connections would immediately
-		// fail with "context canceled" if the previous http response body was
-		// not explicitly BOTH drained and closed. at the same time, sometimes
-		// the draining itself takes forever and causes more problems.
-		// see also https://github.com/golang/go/issues/60240
-		destroyHTTPClient(ctx, dest, streamSettings)
-		return nil, newError("failed to send download http request, destroying client").Base(err)
-	}
+		// in case we hit an error, we want to unblock this part
+		defer gotConn.Close()
 
-	if downResponse.StatusCode != 200 {
-		downResponse.Body.Close()
-		return nil, newError("invalid status code on download:", downResponse.Status)
-	}
+		req, err := http.NewRequestWithContext(
+			httptrace.WithClientTrace(context.WithoutCancel(ctx), trace),
+			"GET",
+			requestURL.String()+sessionId,
+			nil,
+		)
+		if err != nil {
+			newError("failed to construct download http request").Base(err).WriteToLog()
+			gotDownResponse.Close()
+			return
+		}
+
+		req.Header = transportConfiguration.GetRequestHeader()
+
+		response, err := httpClient.download.Do(req)
+		gotConn.Close()
+		if err != nil {
+			newError("failed to send download http request").Base(err).WriteToLog()
+			gotDownResponse.Close()
+			return
+		}
+
+		if response.StatusCode != 200 {
+			response.Body.Close()
+			newError("invalid status code on download:", response.Status).WriteToLog()
+			gotDownResponse.Close()
+			return
+		}
+
+		// skip "ok" response
+		trashHeader := []byte{0, 0}
+		_, err = io.ReadFull(response.Body, trashHeader)
+		if err != nil {
+			response.Body.Close()
+			newError("failed to read initial response").Base(err).WriteToLog()
+			gotDownResponse.Close()
+			return
+		}
 
-	uploadUrl := requestURL.String() + "?session=" + sessionId + "&seq="
+		downResponse = response.Body
+		gotDownResponse.Close()
+	}()
+
+	uploadUrl := requestURL.String() + sessionId + "/"
 
 	uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize))
 
@@ -266,7 +279,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 					for i := 0; i < 5; i++ {
 						uploadConn = httpClient.uploadRawPool.Get()
 						if uploadConn == nil {
-							uploadConn, err = httpClient.dialUploadConn(ctx)
+							uploadConn, err = httpClient.dialUploadConn(context.WithoutCancel(ctx))
 							if err != nil {
 								newError("failed to connect upload").Base(err).WriteToLog()
 								uploadPipeReader.Interrupt()
@@ -293,21 +306,27 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		}
 	}()
 
-	// skip "ok" response
-	trashHeader := []byte{0, 0}
-	_, err = io.ReadFull(downResponse.Body, trashHeader)
-	if err != nil {
-		downResponse.Body.Close()
-		return nil, newError("failed to read initial response")
-	}
+	// we want to block Dial until we know the remote address of the server,
+	// for logging purposes
+	<-gotConn.Wait()
 
 	// necessary in order to send larger chunks in upload
 	bufferedUploadPipeWriter := buf.NewBufferedWriter(uploadPipeWriter)
 	bufferedUploadPipeWriter.SetBuffered(false)
 
+	lazyDownload := &LazyReader{
+		CreateReader: func() (io.ReadCloser, error) {
+			<-gotDownResponse.Wait()
+			if downResponse == nil {
+				return nil, newError("downResponse failed")
+			}
+			return downResponse, nil
+		},
+	}
+
 	conn := splitConn{
 		writer:     bufferedUploadPipeWriter,
-		reader:     downResponse.Body,
+		reader:     lazyDownload,
 		remoteAddr: remoteAddr,
 		localAddr:  localAddr,
 	}

+ 59 - 15
transport/internet/splithttp/hub.go

@@ -7,6 +7,7 @@ import (
 	gonet "net"
 	"net/http"
 	"strconv"
+	"strings"
 	"sync"
 	"time"
 
@@ -28,20 +29,65 @@ type requestHandler struct {
 	localAddr gonet.TCPAddr
 }
 
+type httpSession struct {
+	uploadQueue *UploadQueue
+	// for as long as the GET request is not opened by the client, this will be
+	// open ("undone"), and the session may be expired within a certain TTL.
+	// after the client connects, this becomes "done" and the session lives as
+	// long as the GET request.
+	isFullyConnected *done.Instance
+}
+
+func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
+	shouldReap := done.New()
+	go func() {
+		time.Sleep(30 * time.Second)
+		shouldReap.Close()
+	}()
+
+	select {
+	case <-isFullyConnected.Wait():
+		return
+	case <-shouldReap.Wait():
+		h.sessions.Delete(sessionId)
+	}
+}
+
+func (h *requestHandler) upsertSession(sessionId string) *httpSession {
+	currentSessionAny, ok := h.sessions.Load(sessionId)
+	if ok {
+		return currentSessionAny.(*httpSession)
+	}
+
+	s := &httpSession{
+		uploadQueue:      NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())),
+		isFullyConnected: done.New(),
+	}
+
+	h.sessions.Store(sessionId, s)
+	go h.maybeReapSession(s.isFullyConnected, sessionId)
+	return s
+}
+
 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	if len(h.host) > 0 && request.Host != h.host {
 		newError("failed to validate host, request:", request.Host, ", config:", h.host).WriteToLog()
 		writer.WriteHeader(http.StatusNotFound)
 		return
 	}
-	if request.URL.Path != h.path {
+
+	if !strings.HasPrefix(request.URL.Path, h.path) {
 		newError("failed to validate path, request:", request.URL.Path, ", config:", h.path).WriteToLog()
 		writer.WriteHeader(http.StatusNotFound)
 		return
 	}
 
-	queryString := request.URL.Query()
-	sessionId := queryString.Get("session")
+	sessionId := ""
+	subpath := strings.Split(request.URL.Path[len(h.path):], "/")
+	if len(subpath) > 0 {
+		sessionId = subpath[0]
+	}
+
 	if sessionId == "" {
 		newError("no sessionid on request:", request.URL.Path).WriteToLog()
 		writer.WriteHeader(http.StatusBadRequest)
@@ -60,15 +106,14 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		}
 	}
 
+	currentSession := h.upsertSession(sessionId)
+
 	if request.Method == "POST" {
-		uploadQueue, ok := h.sessions.Load(sessionId)
-		if !ok {
-			newError("sessionid does not exist").WriteToLog()
-			writer.WriteHeader(http.StatusBadRequest)
-			return
+		seq := ""
+		if len(subpath) > 1 {
+			seq = subpath[1]
 		}
 
-		seq := queryString.Get("seq")
 		if seq == "" {
 			newError("no seq on request:", request.URL.Path).WriteToLog()
 			writer.WriteHeader(http.StatusBadRequest)
@@ -89,7 +134,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 			return
 		}
 
-		err = uploadQueue.(*UploadQueue).Push(Packet{
+		err = currentSession.uploadQueue.Push(Packet{
 			Payload: payload,
 			Seq:     seqInt,
 		})
@@ -107,10 +152,9 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 			panic("expected http.ResponseWriter to be an http.Flusher")
 		}
 
-		uploadQueue := NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads()))
-
-		h.sessions.Store(sessionId, uploadQueue)
-		// the connection is finished, clean up map
+		// after GET is done, the connection is finished. disable automatic
+		// session reaping, and handle it in defer
+		currentSession.isFullyConnected.Close()
 		defer h.sessions.Delete(sessionId)
 
 		// magic header instructs nginx + apache to not buffer response body
@@ -130,7 +174,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 				downloadDone:    downloadDone,
 				responseFlusher: responseFlusher,
 			},
-			reader:     uploadQueue,
+			reader:     currentSession.uploadQueue,
 			remoteAddr: remoteAddr,
 		}
 

+ 57 - 0
transport/internet/splithttp/lazy_reader.go

@@ -0,0 +1,57 @@
+package splithttp
+
+import (
+	"io"
+	"sync"
+)
+
+type LazyReader struct {
+	readerSync   sync.Mutex
+	CreateReader func() (io.ReadCloser, error)
+	reader       io.ReadCloser
+	readerError  error
+}
+
+func (r *LazyReader) getReader() (io.ReadCloser, error) {
+	r.readerSync.Lock()
+	defer r.readerSync.Unlock()
+	if r.reader != nil {
+		return r.reader, nil
+	}
+
+	if r.readerError != nil {
+		return nil, r.readerError
+	}
+
+	reader, err := r.CreateReader()
+	if err != nil {
+		r.readerError = err
+		return nil, err
+	}
+
+	r.reader = reader
+	return reader, nil
+}
+
+func (r *LazyReader) Read(b []byte) (int, error) {
+	reader, err := r.getReader()
+	if err != nil {
+		return 0, err
+	}
+	n, err := reader.Read(b)
+	return n, err
+}
+
+func (r *LazyReader) Close() error {
+	r.readerSync.Lock()
+	defer r.readerSync.Unlock()
+
+	var err error
+	if r.reader != nil {
+		err = r.reader.Close()
+		r.reader = nil
+		r.readerError = newError("closed reader")
+	}
+
+	return err
+}