Browse Source

SplitHTTP: Fix connection leaks and crashes (#3710)

mmmray 1 year ago
parent
commit
83eef6bc1f

+ 32 - 9
transport/internet/splithttp/client.go

@@ -49,6 +49,8 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
 	var downResponse io.ReadCloser
 	gotDownResponse := done.New()
 
+	ctx, ctxCancel := context.WithCancel(ctx)
+
 	go func() {
 		trace := &httptrace.ClientTrace{
 			GotConn: func(connInfo httptrace.GotConnInfo) {
@@ -61,8 +63,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
 		// in case we hit an error, we want to unblock this part
 		defer gotConn.Close()
 
+		ctx = httptrace.WithClientTrace(ctx, trace)
+
 		req, err := http.NewRequestWithContext(
-			httptrace.WithClientTrace(ctx, trace),
+			ctx,
 			"GET",
 			baseURL,
 			nil,
@@ -94,16 +98,17 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
 		gotDownResponse.Close()
 	}()
 
-	if c.isH3 {
-		gotConn.Close()
+	if !c.isH3 {
+		// in quic-go, sometimes gotConn is never closed for the lifetime of
+		// the entire connection, and the download locks up
+		// https://github.com/quic-go/quic-go/issues/3342
+		// for other HTTP versions, we want to block Dial until we know the
+		// remote address of the server, for logging purposes
+		<-gotConn.Wait()
 	}
 
-	// we want to block Dial until we know the remote address of the server,
-	// for logging purposes
-	<-gotConn.Wait()
-
 	lazyDownload := &LazyReader{
-		CreateReader: func() (io.ReadCloser, error) {
+		CreateReader: func() (io.Reader, error) {
 			<-gotDownResponse.Wait()
 			if downResponse == nil {
 				return nil, errors.New("downResponse failed")
@@ -112,7 +117,15 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
 		},
 	}
 
-	return lazyDownload, remoteAddr, localAddr, nil
+	// workaround for https://github.com/quic-go/quic-go/issues/2143 --
+	// always cancel request context so that Close cancels any Read.
+	// Should then match the behavior of http2 and http1.
+	reader := downloadBody{
+		lazyDownload,
+		ctxCancel,
+	}
+
+	return reader, remoteAddr, localAddr, nil
 }
 
 func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
@@ -172,3 +185,13 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string,
 
 	return nil
 }
+
+type downloadBody struct {
+	io.Reader
+	cancel context.CancelFunc
+}
+
+func (c downloadBody) Close() error {
+	c.cancel()
+	return nil
+}

+ 2 - 32
transport/internet/splithttp/dialer.go

@@ -1,10 +1,8 @@
 package splithttp
 
 import (
-	"bytes"
 	"context"
 	gotls "crypto/tls"
-	"io"
 	"net/http"
 	"net/url"
 	"strconv"
@@ -292,35 +290,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		return nil, err
 	}
 
-	lazyDownload := &LazyReader{
-		CreateReader: func() (io.ReadCloser, error) {
-			// skip "ok" response
-			trashHeader := []byte{0, 0}
-			_, err := io.ReadFull(lazyRawDownload, trashHeader)
-			if err != nil {
-				return nil, errors.New("failed to read initial response").Base(err)
-			}
-
-			if bytes.Equal(trashHeader, []byte("ok")) {
-				return lazyRawDownload, nil
-			}
-
-			// we read some garbage byte that may not have been "ok" at
-			// all. return a reader that replays what we have read so far
-			reader := io.MultiReader(
-				bytes.NewReader(trashHeader),
-				lazyRawDownload,
-			)
-			readCloser := struct {
-				io.Reader
-				io.Closer
-			}{
-				Reader: reader,
-				Closer: lazyRawDownload,
-			}
-			return readCloser, nil
-		},
-	}
+	reader := &stripOkReader{ReadCloser: lazyRawDownload}
 
 	writer := uploadWriter{
 		uploadPipeWriter,
@@ -329,7 +299,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 
 	conn := splitConn{
 		writer:     writer,
-		reader:     lazyDownload,
+		reader:     reader,
 		remoteAddr: remoteAddr,
 		localAddr:  localAddr,
 	}

+ 5 - 1
transport/internet/splithttp/hub.go

@@ -222,8 +222,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		h.ln.addConn(stat.Connection(&conn))
 
 		// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
-		<-downloadDone.Wait()
+		select {
+		case <-request.Context().Done():
+		case <-downloadDone.Wait():
+		}
 
+		conn.Close()
 	} else {
 		writer.WriteHeader(http.StatusMethodNotAllowed)
 	}

+ 7 - 19
transport/internet/splithttp/lazy_reader.go

@@ -3,18 +3,20 @@ package splithttp
 import (
 	"io"
 	"sync"
-
-	"github.com/xtls/xray-core/common/errors"
 )
 
+// Close is intentionally not supported by LazyReader because it's not clear
+// how CreateReader should be aborted in case of Close. It's best to wrap
+// LazyReader in another struct that handles Close correctly, or better, stop
+// using LazyReader entirely.
 type LazyReader struct {
 	readerSync   sync.Mutex
-	CreateReader func() (io.ReadCloser, error)
-	reader       io.ReadCloser
+	CreateReader func() (io.Reader, error)
+	reader       io.Reader
 	readerError  error
 }
 
-func (r *LazyReader) getReader() (io.ReadCloser, error) {
+func (r *LazyReader) getReader() (io.Reader, error) {
 	r.readerSync.Lock()
 	defer r.readerSync.Unlock()
 	if r.reader != nil {
@@ -43,17 +45,3 @@ func (r *LazyReader) Read(b []byte) (int, error) {
 	n, err := reader.Read(b)
 	return n, err
 }
-
-func (r *LazyReader) Close() error {
-	r.readerSync.Lock()
-	defer r.readerSync.Unlock()
-
-	var err error
-	if r.reader != nil {
-		err = r.reader.Close()
-		r.reader = nil
-		r.readerError = errors.New("closed reader")
-	}
-
-	return err
-}

+ 11 - 2
transport/internet/splithttp/splithttp_test.go

@@ -248,6 +248,8 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
 			NextProtocol:  []string{"h3"},
 		},
 	}
+
+	serverClosed := false
 	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
 		go func() {
 			defer conn.Close()
@@ -258,10 +260,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
 			for {
 				b.Clear()
 				if _, err := b.ReadFrom(conn); err != nil {
-					return
+					break
 				}
 				common.Must2(conn.Write(b.Bytes()))
 			}
+
+			serverClosed = true
 		}()
 	})
 	common.Must(err)
@@ -271,7 +275,6 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
 
 	conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
 	common.Must(err)
-	defer conn.Close()
 
 	const N = 1024
 	b1 := make([]byte, N)
@@ -294,6 +297,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
 		t.Error(r)
 	}
 
+	conn.Close()
+	time.Sleep(100 * time.Millisecond)
+	if !serverClosed {
+		t.Error("server did not get closed")
+	}
+
 	end := time.Now()
 	if !end.Before(start.Add(time.Second * 5)) {
 		t.Error("end: ", end, " start: ", start)

+ 48 - 0
transport/internet/splithttp/strip_ok_reader.go

@@ -0,0 +1,48 @@
+package splithttp
+
+import (
+	"bytes"
+	"io"
+
+	"github.com/xtls/xray-core/common/errors"
+)
+
+// in older versions of splithttp, the server would respond with `ok` to flush
+// out HTTP response headers early. Response headers and a 200 OK were required
+// to initiate the connection. Later versions of splithttp dropped this
+// requirement, and in xray 1.8.24 the server stopped sending "ok" if it sees
+// x_padding. For compatibility, we need to remove "ok" from the underlying
+// reader if it exists, and otherwise forward the stream as-is.
+type stripOkReader struct {
+	io.ReadCloser
+	firstDone  bool
+	prefixRead []byte
+}
+
+func (r *stripOkReader) Read(b []byte) (int, error) {
+	if !r.firstDone {
+		r.firstDone = true
+
+		// skip "ok" response
+		prefixRead := []byte{0, 0}
+		_, err := io.ReadFull(r.ReadCloser, prefixRead)
+		if err != nil {
+			return 0, errors.New("failed to read initial response").Base(err)
+		}
+
+		if !bytes.Equal(prefixRead, []byte("ok")) {
+			// we read some garbage byte that may not have been "ok" at
+			// all. return a reader that replays what we have read so far
+			r.prefixRead = prefixRead
+		}
+	}
+
+	if len(r.prefixRead) > 0 {
+		n := copy(b, r.prefixRead)
+		r.prefixRead = r.prefixRead[n:]
+		return n, nil
+	}
+
+	n, err := r.ReadCloser.Read(b)
+	return n, err
+}

+ 4 - 2
transport/internet/splithttp/upload_queue.go

@@ -51,8 +51,10 @@ func (h *uploadQueue) Close() error {
 	h.writeCloseMutex.Lock()
 	defer h.writeCloseMutex.Unlock()
 
-	h.closed = true
-	close(h.pushedPackets)
+	if !h.closed {
+		h.closed = true
+		close(h.pushedPackets)
+	}
 	return nil
 }