1
0
Эх сурвалжийг харах

lib/protocol: Write uncompressible messages uncompressed (#7790)

greatroar 4 жил өмнө
parent
commit
bd363fe0b7

+ 46 - 58
lib/protocol/protocol.go

@@ -523,7 +523,7 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header, fourByteBuf []byte) (
 		// Nothing
 
 	case MessageCompressionLZ4:
-		decomp, err := c.lz4Decompress(buf)
+		decomp, err := lz4Decompress(buf)
 		BufferPool.Put(buf)
 		if err != nil {
 			return nil, errors.Wrap(err, "decompressing message")
@@ -740,100 +740,91 @@ func (c *rawConnection) writerLoop() {
 func (c *rawConnection) writeMessage(msg message) error {
 	msgContext, _ := messageContext(msg)
 	l.Debugf("Writing %v", msgContext)
-	if c.shouldCompressMessage(msg) {
-		return c.writeCompressedMessage(msg)
-	}
-	return c.writeUncompressedMessage(msg)
-}
 
-func (c *rawConnection) writeCompressedMessage(msg message) error {
 	size := msg.ProtoSize()
-	buf := BufferPool.Get(size)
-	if _, err := msg.MarshalTo(buf); err != nil {
-		BufferPool.Put(buf)
-		return errors.Wrap(err, "marshalling message")
-	}
-
-	compressed, err := c.lz4Compress(buf)
-	if err != nil {
-		BufferPool.Put(buf)
-		return errors.Wrap(err, "compressing message")
-	}
-
 	hdr := Header{
-		Type:        c.typeOf(msg),
-		Compression: MessageCompressionLZ4,
+		Type: c.typeOf(msg),
 	}
 	hdrSize := hdr.ProtoSize()
 	if hdrSize > 1<<16-1 {
 		panic("impossibly large header")
 	}
 
-	compressedSize := len(compressed)
-	totSize := 2 + hdrSize + 4 + compressedSize
-	buf = BufferPool.Upgrade(buf, totSize)
+	overhead := 2 + hdrSize + 4
+	totSize := overhead + size
+	buf := BufferPool.Get(totSize)
+	defer BufferPool.Put(buf)
+
+	// Message
+	if _, err := msg.MarshalTo(buf[2+hdrSize+4:]); err != nil {
+		return errors.Wrap(err, "marshalling message")
+	}
+
+	if c.shouldCompressMessage(msg) {
+		ok, err := c.writeCompressedMessage(msg, buf[overhead:], overhead)
+		if ok {
+			return err
+		}
+	}
 
 	// Header length
 	binary.BigEndian.PutUint16(buf, uint16(hdrSize))
 	// Header
 	if _, err := hdr.MarshalTo(buf[2:]); err != nil {
-		BufferPool.Put(buf)
-		BufferPool.Put(compressed)
 		return errors.Wrap(err, "marshalling header")
 	}
 	// Message length
-	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(compressedSize))
-	// Message
-	copy(buf[2+hdrSize+4:], compressed)
-	BufferPool.Put(compressed)
+	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(size))
 
 	n, err := c.cw.Write(buf)
-	BufferPool.Put(buf)
 
-	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, compressedSize, size, err)
+	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err)
 	if err != nil {
 		return errors.Wrap(err, "writing message")
 	}
 	return nil
 }
 
-func (c *rawConnection) writeUncompressedMessage(msg message) error {
-	size := msg.ProtoSize()
-
+// Write msg out compressed, given its uncompressed marshaled payload and overhead.
+//
+// The first return value indicates whether compression succeeded.
+// If not, the caller should retry without compression.
+func (c *rawConnection) writeCompressedMessage(msg message, marshaled []byte, overhead int) (ok bool, err error) {
 	hdr := Header{
-		Type: c.typeOf(msg),
+		Type:        c.typeOf(msg),
+		Compression: MessageCompressionLZ4,
 	}
 	hdrSize := hdr.ProtoSize()
 	if hdrSize > 1<<16-1 {
 		panic("impossibly large header")
 	}
 
-	totSize := 2 + hdrSize + 4 + size
-	buf := BufferPool.Get(totSize)
+	cOverhead := 2 + hdrSize + 4
+	maxCompressed := cOverhead + lz4.CompressBound(len(marshaled))
+	buf := BufferPool.Get(maxCompressed)
+	defer BufferPool.Put(buf)
+
+	compressedSize, err := lz4Compress(marshaled, buf[cOverhead:])
+	totSize := compressedSize + cOverhead
+	if err != nil || totSize >= len(marshaled)+overhead {
+		return false, nil
+	}
 
 	// Header length
 	binary.BigEndian.PutUint16(buf, uint16(hdrSize))
 	// Header
 	if _, err := hdr.MarshalTo(buf[2:]); err != nil {
-		BufferPool.Put(buf)
-		return errors.Wrap(err, "marshalling header")
+		return true, errors.Wrap(err, "marshalling header")
 	}
 	// Message length
-	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(size))
-	// Message
-	if _, err := msg.MarshalTo(buf[2+hdrSize+4:]); err != nil {
-		BufferPool.Put(buf)
-		return errors.Wrap(err, "marshalling message")
-	}
+	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(compressedSize))
 
 	n, err := c.cw.Write(buf[:totSize])
-	BufferPool.Put(buf)
-
-	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err)
+	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, compressedSize, len(marshaled), err)
 	if err != nil {
-		return errors.Wrap(err, "writing message")
+		return true, errors.Wrap(err, "writing message")
 	}
-	return nil
+	return true, nil
 }
 
 func (c *rawConnection) typeOf(msg message) MessageType {
@@ -1018,23 +1009,20 @@ func (c *rawConnection) Statistics() Statistics {
 	}
 }
 
-func (c *rawConnection) lz4Compress(src []byte) ([]byte, error) {
-	var err error
-	buf := BufferPool.Get(lz4.CompressBound(len(src)))
+func lz4Compress(src, buf []byte) (int, error) {
 	compressed, err := lz4.Encode(buf, src)
 	if err != nil {
-		BufferPool.Put(buf)
-		return nil, err
+		return -1, err
 	}
 	if &compressed[0] != &buf[0] {
 		panic("bug: lz4.Compress allocated, which it must not (should use buffer pool)")
 	}
 
 	binary.BigEndian.PutUint32(compressed, binary.LittleEndian.Uint32(compressed))
-	return compressed, nil
+	return len(compressed), nil
 }
 
-func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) {
+func lz4Decompress(src []byte) ([]byte, error) {
 	size := binary.BigEndian.Uint32(src)
 	binary.LittleEndian.PutUint32(src, size)
 	var err error

+ 40 - 36
lib/protocol/protocol_test.go

@@ -17,6 +17,7 @@ import (
 	"testing/quick"
 	"time"
 
+	lz4 "github.com/bkaradzic/go-lz4"
 	"github.com/syncthing/syncthing/lib/rand"
 	"github.com/syncthing/syncthing/lib/testutils"
 )
@@ -439,9 +440,42 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool {
 	return true
 }
 
-func TestLZ4Compression(t *testing.T) {
-	c := new(rawConnection)
+func TestWriteCompressed(t *testing.T) {
+	for _, random := range []bool{false, true} {
+		buf := new(bytes.Buffer)
+		c := &rawConnection{
+			cr:          &countingReader{Reader: buf},
+			cw:          &countingWriter{Writer: buf},
+			compression: CompressionAlways,
+		}
+
+		msg := &Response{Data: make([]byte, 10240)}
+		if random {
+			// This should make the message uncompressible.
+			rand.Read(msg.Data)
+		}
+
+		if err := c.writeMessage(msg); err != nil {
+			t.Fatal(err)
+		}
+		got, err := c.readMessage(make([]byte, 4))
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !bytes.Equal(got.(*Response).Data, msg.Data) {
+			t.Error("received the wrong message")
+		}
+
+		hdr := Header{Type: c.typeOf(msg)}
+		size := int64(2 + hdr.ProtoSize() + 4 + msg.ProtoSize())
+		if c.cr.tot > size {
+			t.Errorf("compression enlarged message from %d to %d",
+				size, c.cr.tot)
+		}
+	}
+}
 
+func TestLZ4Compression(t *testing.T) {
 	for i := 0; i < 10; i++ {
 		dataLen := 150 + rand.Intn(150)
 		data := make([]byte, dataLen)
@@ -449,13 +483,15 @@ func TestLZ4Compression(t *testing.T) {
 		if err != nil {
 			t.Fatal(err)
 		}
-		comp, err := c.lz4Compress(data)
+
+		comp := make([]byte, lz4.CompressBound(dataLen))
+		compLen, err := lz4Compress(data, comp)
 		if err != nil {
 			t.Errorf("compressing %d bytes: %v", dataLen, err)
 			continue
 		}
 
-		res, err := c.lz4Decompress(comp)
+		res, err := lz4Decompress(comp[:compLen])
 		if err != nil {
 			t.Errorf("decompressing %d bytes to %d: %v", len(comp), dataLen, err)
 			continue
@@ -470,38 +506,6 @@ func TestLZ4Compression(t *testing.T) {
 	}
 }
 
-func TestStressLZ4CompressGrows(t *testing.T) {
-	c := new(rawConnection)
-	success := 0
-	for i := 0; i < 100; i++ {
-		// Create a slize that is precisely one min block size, fill it with
-		// random data. This shouldn't compress at all, so will in fact
-		// become larger when LZ4 does its thing.
-		data := make([]byte, MinBlockSize)
-		if _, err := rand.Reader.Read(data); err != nil {
-			t.Fatal("randomness failure")
-		}
-
-		comp, err := c.lz4Compress(data)
-		if err != nil {
-			t.Fatal("unexpected compression error: ", err)
-		}
-		if len(comp) < len(data) {
-			// data size should grow. We must have been really unlucky in
-			// the random generation, try again.
-			continue
-		}
-
-		// Putting it into the buffer pool shouldn't panic because the block
-		// should come from there to begin with.
-		BufferPool.Put(comp)
-		success++
-	}
-	if success == 0 {
-		t.Fatal("unable to find data that grows when compressed")
-	}
-}
-
 func TestCheckFilename(t *testing.T) {
 	cases := []struct {
 		name string