瀏覽代碼

SplitHTTP: Read and validate HTTP/1.1 responses (#3797)

Dmitry Anderson 1 年之前
父節點
當前提交
a931507dd6

+ 3 - 0
.gitignore

@@ -28,3 +28,6 @@ errorgen
 *.dat
 .vscode
 /build_assets
+
+# Output from dlv test
+**/debug.*

+ 22 - 5
transport/internet/splithttp/client.go

@@ -3,6 +3,7 @@ package splithttp
 import (
 	"bytes"
 	"context"
+	"fmt"
 	"io"
 	gonet "net"
 	"net/http"
@@ -152,23 +153,39 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string,
 		// safely retried. if instead req.Write is called multiple
 		// times, the body is already drained after the first
 		// request
-		requestBytes := new(bytes.Buffer)
-		common.Must(req.Write(requestBytes))
+		requestBuff := new(bytes.Buffer)
+		common.Must(req.Write(requestBuff))
 
 		var uploadConn any
+		var h1UploadConn *H1Conn
 
 		for {
 			uploadConn = c.uploadRawPool.Get()
 			newConnection := uploadConn == nil
 			if newConnection {
-				uploadConn, err = c.dialUploadConn(context.WithoutCancel(ctx))
+				newConn, err := c.dialUploadConn(context.WithoutCancel(ctx))
 				if err != nil {
 					return err
 				}
+				h1UploadConn = NewH1Conn(newConn)
+				uploadConn = h1UploadConn
+			} else {
+				h1UploadConn = uploadConn.(*H1Conn)
+
+				// TODO: Replace 0 here with a config value later
+				// Or add some other condition for optimization purposes
+				if h1UploadConn.UnreadedResponsesCount > 0 {
+					resp, err := http.ReadResponse(h1UploadConn.RespBufReader, req)
+					if err != nil {
+						return fmt.Errorf("error while reading response: %s", err.Error())
+					}
+					if resp.StatusCode != 200 {
+						return fmt.Errorf("got non-200 error response code: %d", resp.StatusCode)
+					}
+				}
 			}
 
-			_, err = uploadConn.(net.Conn).Write(requestBytes.Bytes())
-
+			_, err := h1UploadConn.Write(requestBuff.Bytes())
 			// if the write failed, we try another connection from
 			// the pool, until the write on a new connection fails.
 			// failed writes to a pooled connection are normal when

+ 0 - 1
transport/internet/splithttp/dialer.go

@@ -267,7 +267,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 					&buf.MultiBufferContainer{MultiBuffer: chunk},
 					int64(chunk.Len()),
 				)
-
 				if err != nil {
 					errors.LogInfoInner(ctx, err, "failed to send upload")
 					uploadPipeReader.Interrupt()

+ 19 - 0
transport/internet/splithttp/h1_conn.go

@@ -0,0 +1,19 @@
+package splithttp
+
+import (
+	"bufio"
+	"net"
+)
+
+type H1Conn struct {
+	UnreadedResponsesCount int
+	RespBufReader          *bufio.Reader
+	net.Conn
+}
+
+func NewH1Conn(conn net.Conn) *H1Conn {
+	return &H1Conn{
+		RespBufReader: bufio.NewReader(conn),
+		Conn:          conn,
+	}
+}