Procházet zdrojové kódy

lib/protocol: Alwasy return buffers to the pool (#7409)

Simon Frei před 4 roky
rodič
revize
0ffd80f380
1 změnil soubory, kde provedl 19 přidání a 5 odebrání
  1. 19 5
      lib/protocol/protocol.go

+ 19 - 5
lib/protocol/protocol.go

@@ -536,6 +536,7 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header, fourByteBuf []byte) (
 
 	buf := BufferPool.Get(int(msgLen))
 	if _, err := io.ReadFull(c.cr, buf); err != nil {
+		BufferPool.Put(buf)
 		return nil, errors.Wrap(err, "reading message")
 	}
 
@@ -561,9 +562,11 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header, fourByteBuf []byte) (
 
 	msg, err := c.newMessage(hdr.Type)
 	if err != nil {
+		BufferPool.Put(buf)
 		return nil, err
 	}
 	if err := msg.Unmarshal(buf); err != nil {
+		BufferPool.Put(buf)
 		return nil, errors.Wrap(err, "unmarshalling message")
 	}
 	BufferPool.Put(buf)
@@ -586,15 +589,17 @@ func (c *rawConnection) readHeader(fourByteBuf []byte) (Header, error) {
 
 	buf := BufferPool.Get(int(hdrLen))
 	if _, err := io.ReadFull(c.cr, buf); err != nil {
+		BufferPool.Put(buf)
 		return Header{}, errors.Wrap(err, "reading header")
 	}
 
 	var hdr Header
-	if err := hdr.Unmarshal(buf); err != nil {
+	err := hdr.Unmarshal(buf)
+	BufferPool.Put(buf)
+	if err != nil {
 		return Header{}, errors.Wrap(err, "unmarshalling header")
 	}
 
-	BufferPool.Put(buf)
 	return hdr, nil
 }
 
@@ -767,11 +772,13 @@ func (c *rawConnection) writeCompressedMessage(msg message) error {
 	size := msg.ProtoSize()
 	buf := BufferPool.Get(size)
 	if _, err := msg.MarshalTo(buf); err != nil {
+		BufferPool.Put(buf)
 		return errors.Wrap(err, "marshalling message")
 	}
 
 	compressed, err := c.lz4Compress(buf)
 	if err != nil {
+		BufferPool.Put(buf)
 		return errors.Wrap(err, "compressing message")
 	}
 
@@ -784,17 +791,20 @@ func (c *rawConnection) writeCompressedMessage(msg message) error {
 		panic("impossibly large header")
 	}
 
-	totSize := 2 + hdrSize + 4 + len(compressed)
+	compressedSize := len(compressed)
+	totSize := 2 + hdrSize + 4 + compressedSize
 	buf = BufferPool.Upgrade(buf, totSize)
 
 	// Header length
 	binary.BigEndian.PutUint16(buf, uint16(hdrSize))
 	// Header
 	if _, err := hdr.MarshalTo(buf[2:]); err != nil {
+		BufferPool.Put(buf)
+		BufferPool.Put(compressed)
 		return errors.Wrap(err, "marshalling header")
 	}
 	// Message length
-	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(len(compressed)))
+	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(compressedSize))
 	// Message
 	copy(buf[2+hdrSize+4:], compressed)
 	BufferPool.Put(compressed)
@@ -802,7 +812,7 @@ func (c *rawConnection) writeCompressedMessage(msg message) error {
 	n, err := c.cw.Write(buf)
 	BufferPool.Put(buf)
 
-	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, len(compressed), size, err)
+	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, compressedSize, size, err)
 	if err != nil {
 		return errors.Wrap(err, "writing message")
 	}
@@ -827,12 +837,14 @@ func (c *rawConnection) writeUncompressedMessage(msg message) error {
 	binary.BigEndian.PutUint16(buf, uint16(hdrSize))
 	// Header
 	if _, err := hdr.MarshalTo(buf[2:]); err != nil {
+		BufferPool.Put(buf)
 		return errors.Wrap(err, "marshalling header")
 	}
 	// Message length
 	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(size))
 	// Message
 	if _, err := msg.MarshalTo(buf[2+hdrSize+4:]); err != nil {
+		BufferPool.Put(buf)
 		return errors.Wrap(err, "marshalling message")
 	}
 
@@ -1033,6 +1045,7 @@ func (c *rawConnection) lz4Compress(src []byte) ([]byte, error) {
 	buf := BufferPool.Get(lz4.CompressBound(len(src)))
 	compressed, err := lz4.Encode(buf, src)
 	if err != nil {
+		BufferPool.Put(buf)
 		return nil, err
 	}
 	if &compressed[0] != &buf[0] {
@@ -1050,6 +1063,7 @@ func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) {
 	buf := BufferPool.Get(int(size))
 	decoded, err := lz4.Decode(buf, src)
 	if err != nil {
+		BufferPool.Put(buf)
 		return nil, err
 	}
 	if &decoded[0] != &buf[0] {