1
0
Эх сурвалжийг харах

XHTTP server: Fix stream-up "single POST problem", Use united httpServerConn instead of recover()

https://github.com/XTLS/Xray-core/issues/4373#issuecomment-2671795675

https://github.com/XTLS/Xray-core/issues/4406#issuecomment-2668041926
RPRX 8 сар өмнө
parent
commit
b786a50aee

+ 44 - 70
transport/internet/splithttp/hub.go

@@ -47,21 +47,6 @@ type httpSession struct {
 	isFullyConnected *done.Instance
 }
 
-func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
-	shouldReap := done.New()
-	go func() {
-		time.Sleep(30 * time.Second)
-		shouldReap.Close()
-	}()
-
-	select {
-	case <-isFullyConnected.Wait():
-		return
-	case <-shouldReap.Wait():
-		h.sessions.Delete(sessionId)
-	}
-}
-
 func (h *requestHandler) upsertSession(sessionId string) *httpSession {
 	// fast path
 	currentSessionAny, ok := h.sessions.Load(sessionId)
@@ -84,7 +69,21 @@ func (h *requestHandler) upsertSession(sessionId string) *httpSession {
 	}
 
 	h.sessions.Store(sessionId, s)
-	go h.maybeReapSession(s.isFullyConnected, sessionId)
+
+	shouldReap := done.New()
+	go func() {
+		time.Sleep(30 * time.Second)
+		shouldReap.Close()
+	}()
+	go func() {
+		select {
+		case <-shouldReap.Wait():
+			h.sessions.Delete(sessionId)
+			s.uploadQueue.Close()
+		case <-s.isFullyConnected.Wait():
+		}
+	}()
+
 	return s
 }
 
@@ -183,12 +182,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 				writer.WriteHeader(http.StatusBadRequest)
 				return
 			}
-			uploadDone := done.New()
+			httpSC := &httpServerConn{
+				Instance:       done.New(),
+				Reader:         request.Body,
+				ResponseWriter: writer,
+			}
 			err = currentSession.uploadQueue.Push(Packet{
-				Reader: &httpRequestBodyReader{
-					requestReader: request.Body,
-					uploadDone:    uploadDone,
-				},
+				Reader: httpSC,
 			})
 			if err != nil {
 				errors.LogInfoInner(context.Background(), err, "failed to upload (PushReader)")
@@ -200,25 +200,21 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 				scStreamUpServerSecs := h.config.GetNormalizedScStreamUpServerSecs()
 				if referrer != "" && scStreamUpServerSecs.To > 0 {
 					go func() {
-						defer func() {
-							recover()
-						}()
 						for {
-							_, err := writer.Write(bytes.Repeat([]byte{'X'}, int(h.config.GetNormalizedXPaddingBytes().rand())))
+							_, err := httpSC.Write(bytes.Repeat([]byte{'X'}, int(h.config.GetNormalizedXPaddingBytes().rand())))
 							if err != nil {
 								break
 							}
-							writer.(http.Flusher).Flush()
 							time.Sleep(time.Duration(scStreamUpServerSecs.rand()) * time.Second)
 						}
 					}()
 				}
 				select {
 				case <-request.Context().Done():
-				case <-uploadDone.Wait():
+				case <-httpSC.Wait():
 				}
 			}
-			uploadDone.Close()
+			httpSC.Close()
 			return
 		}
 
@@ -262,11 +258,6 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 
 		writer.WriteHeader(http.StatusOK)
 	} else if request.Method == "GET" || sessionId == "" { // stream-down, stream-one
-		responseFlusher, ok := writer.(http.Flusher)
-		if !ok {
-			panic("expected http.ResponseWriter to be an http.Flusher")
-		}
-
 		if sessionId != "" {
 			// after GET is done, the connection is finished. disable automatic
 			// session reaping, and handle it in defer
@@ -287,20 +278,18 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		}
 
 		writer.WriteHeader(http.StatusOK)
+		writer.(http.Flusher).Flush()
 
-		responseFlusher.Flush()
-
-		downloadDone := done.New()
-
+		httpSC := &httpServerConn{
+			Instance:       done.New(),
+			Reader:         request.Body,
+			ResponseWriter: writer,
+		}
 		conn := splitConn{
-			writer: &httpResponseBodyWriter{
-				responseWriter:  writer,
-				downloadDone:    downloadDone,
-				responseFlusher: responseFlusher,
-			},
-			reader:     request.Body,
-			localAddr:  h.localAddr,
+			writer:     httpSC,
+			reader:     httpSC,
 			remoteAddr: remoteAddr,
+			localAddr:  h.localAddr,
 		}
 		if sessionId != "" { // if not stream-one
 			conn.reader = currentSession.uploadQueue
@@ -311,7 +300,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
 		select {
 		case <-request.Context().Done():
-		case <-downloadDone.Wait():
+		case <-httpSC.Wait():
 		}
 
 		conn.Close()
@@ -321,45 +310,30 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 	}
 }
 
-type httpRequestBodyReader struct {
-	requestReader io.ReadCloser
-	uploadDone    *done.Instance
-}
-
-func (c *httpRequestBodyReader) Read(b []byte) (int, error) {
-	return c.requestReader.Read(b)
-}
-
-func (c *httpRequestBodyReader) Close() error {
-	defer c.uploadDone.Close()
-	return c.requestReader.Close()
-}
-
-type httpResponseBodyWriter struct {
+type httpServerConn struct {
 	sync.Mutex
-	responseWriter  http.ResponseWriter
-	responseFlusher http.Flusher
-	downloadDone    *done.Instance
+	*done.Instance
+	io.Reader // no need to Close request.Body
+	http.ResponseWriter
 }
 
-func (c *httpResponseBodyWriter) Write(b []byte) (int, error) {
+func (c *httpServerConn) Write(b []byte) (int, error) {
 	c.Lock()
 	defer c.Unlock()
-	if c.downloadDone.Done() {
+	if c.Done() {
 		return 0, io.ErrClosedPipe
 	}
-	n, err := c.responseWriter.Write(b)
+	n, err := c.ResponseWriter.Write(b)
 	if err == nil {
-		c.responseFlusher.Flush()
+		c.ResponseWriter.(http.Flusher).Flush()
 	}
 	return n, err
 }
 
-func (c *httpResponseBodyWriter) Close() error {
+func (c *httpServerConn) Close() error {
 	c.Lock()
 	defer c.Unlock()
-	c.downloadDone.Close()
-	return nil
+	return c.Instance.Close()
 }
 
 type Listener struct {

+ 19 - 11
transport/internet/splithttp/upload_queue.go

@@ -20,6 +20,7 @@ type Packet struct {
 
 type uploadQueue struct {
 	reader          io.ReadCloser
+	nomore          bool
 	pushedPackets   chan Packet
 	writeCloseMutex sync.Mutex
 	heap            uploadHeap
@@ -42,19 +43,15 @@ func (h *uploadQueue) Push(p Packet) error {
 	h.writeCloseMutex.Lock()
 	defer h.writeCloseMutex.Unlock()
 
-	runtime.Gosched()
-	if h.reader != nil && p.Reader != nil {
-		p.Reader.Close()
-		return errors.New("h.reader already exists")
-	}
-
 	if h.closed {
-		if p.Reader != nil {
-			p.Reader.Close()
-		}
 		return errors.New("packet queue closed")
 	}
-
+	if h.nomore {
+		return errors.New("h.reader already exists")
+	}
+	if p.Reader != nil {
+		h.nomore = true
+	}
 	h.pushedPackets <- p
 	return nil
 }
@@ -65,9 +62,20 @@ func (h *uploadQueue) Close() error {
 
 	if !h.closed {
 		h.closed = true
+		runtime.Gosched() // hope Read() gets the packet
+	f:
+		for {
+			select {
+			case p := <-h.pushedPackets:
+				if p.Reader != nil {
+					h.reader = p.Reader
+				}
+			default:
+				break f
+			}
+		}
 		close(h.pushedPackets)
 	}
-	runtime.Gosched()
 	if h.reader != nil {
 		return h.reader.Close()
 	}