浏览代码

Fix SplitHTTP Unix domain socket (#3577)

Co-authored-by: mmmray <[email protected]>
hellokindle 1 年之前
父节点
当前提交
edae38c620

+ 1 - 1
transport/internet/splithttp/client.go

@@ -117,10 +117,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
 
 func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
 	req, err := http.NewRequest("POST", url, payload)
-	req.ContentLength = contentLength
 	if err != nil {
 		return err
 	}
+	req.ContentLength = contentLength
 	req.Header = c.transportConfig.GetRequestHeader()
 
 	if c.isH2 || c.isH3 {

+ 7 - 8
transport/internet/splithttp/hub.go

@@ -314,14 +314,6 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 			return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
 		}
 		errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
-
-		// h2cHandler can handle both plaintext HTTP/1.1 and h2c
-		h2cHandler := h2c.NewHandler(handler, &http2.Server{})
-		l.server = http.Server{
-			Handler:           h2cHandler,
-			ReadHeaderTimeout: time.Second * 4,
-			MaxHeaderBytes:    8192,
-		}
 	}
 
 	// tcp/unix (h1/h2)
@@ -332,7 +324,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 			}
 		}
 
+		// h2cHandler can handle both plaintext HTTP/1.1 and h2c
+		h2cHandler := h2c.NewHandler(handler, &http2.Server{})
 		l.listener = listener
+		l.server = http.Server{
+			Handler:           h2cHandler,
+			ReadHeaderTimeout: time.Second * 4,
+			MaxHeaderBytes:    8192,
+		}
 
 		go func() {
 			if err := l.server.Serve(l.listener); err != nil {

+ 62 - 0
transport/internet/splithttp/splithttp_test.go

@@ -298,3 +298,65 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
 		t.Error("end: ", end, " start: ", start)
 	}
 }
+
+func Test_listenSHAndDial_Unix(t *testing.T) {
+	tempDir := t.TempDir()
+	tempSocket := tempDir + "/server.sock"
+
+	listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
+		ProtocolName: "splithttp",
+		ProtocolSettings: &Config{
+			Path: "/sh",
+		},
+	}, func(conn stat.Connection) {
+		go func(c stat.Connection) {
+			defer c.Close()
+
+			var b [1024]byte
+			c.SetReadDeadline(time.Now().Add(2 * time.Second))
+			_, err := c.Read(b[:])
+			if err != nil {
+				return
+			}
+
+			common.Must2(c.Write([]byte("Response")))
+		}(conn)
+	})
+	common.Must(err)
+	ctx := context.Background()
+	streamSettings := &internet.MemoryStreamConfig{
+		ProtocolName: "splithttp",
+		ProtocolSettings: &Config{
+			Host: "example.com",
+			Path: "sh",
+		},
+	}
+	conn, err := Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings)
+
+	common.Must(err)
+	_, err = conn.Write([]byte("Test connection 1"))
+	common.Must(err)
+
+	var b [1024]byte
+	fmt.Println("test2")
+	n, _ := conn.Read(b[:])
+	fmt.Println("string is", n)
+	if string(b[:n]) != "Response" {
+		t.Error("response: ", string(b[:n]))
+	}
+
+	common.Must(conn.Close())
+	conn, err = Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings)
+
+	common.Must(err)
+	_, err = conn.Write([]byte("Test connection 2"))
+	common.Must(err)
+	n, _ = conn.Read(b[:])
+	common.Must(err)
+	if string(b[:n]) != "Response" {
+		t.Error("response: ", string(b[:n]))
+	}
+	common.Must(conn.Close())
+
+	common.Must(listen.Close())
+}