Browse Source

Fix naive padding

世界 3 years ago
parent
commit
f87baf08d3
2 changed files with 10 additions and 6 deletions
  1. 9 6
      inbound/naive.go
  2. 1 0
      test/box_test.go

+ 9 - 6
inbound/naive.go

@@ -162,6 +162,7 @@ func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
 		conn, _, err := hijacker.Hijack()
 		if err != nil {
+			n.badRequest(ctx, request, E.New("hijack failed"))
 			return
 		}
 		n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination)
@@ -245,7 +246,7 @@ func (c *naiveH1Conn) read(p []byte) (n int, err error) {
 		if err != nil {
 			return
 		}
-		c.readRemaining = 0
+		c.paddingRemaining = 0
 	}
 	if c.readPadding < kFirstPaddings {
 		paddingHdr := p[:3]
@@ -352,14 +353,15 @@ func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
 	return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
 }
 
-func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
+// FIXME
+/*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
 	if c.readPadding < kFirstPaddings {
 		n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
 	} else {
 		n, err = bufio.Copy(w, c.Conn)
 	}
 	return n, wrapHttpError(err)
-}
+}*/
 
 func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
 	if c.writePadding < kFirstPaddings {
@@ -415,7 +417,7 @@ func (c *naiveH2Conn) read(p []byte) (n int, err error) {
 		if err != nil {
 			return
 		}
-		c.readRemaining = 0
+		c.paddingRemaining = 0
 	}
 	if c.readPadding < kFirstPaddings {
 		paddingHdr := p[:3]
@@ -529,14 +531,15 @@ func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
 	return wrapHttpError(err)
 }
 
-func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
+// FIXME
+/*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
 	if c.readPadding < kFirstPaddings {
 		n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
 	} else {
 		n, err = bufio.Copy(w, c.reader)
 	}
 	return n, wrapHttpError(err)
-}
+}*/
 
 func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
 	if c.writePadding < kFirstPaddings {

+ 1 - 0
test/box_test.go

@@ -57,6 +57,7 @@ func testTCP(t *testing.T, clientPort uint16, testPort uint16) {
 		return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
 	}
 	require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
+	require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP))
 }
 
 func testSuitQUIC(t *testing.T, clientPort uint16, testPort uint16) {