Bladeren bron

Use v2 of XDR package (actual changes)

Jakob Borg 9 jaren geleden
bovenliggende
commit
e1ac740ac4
5 gewijzigde bestanden met toevoegingen van 108 en 68 verwijderingen
  1. 10 2
      lib/db/leveldb_dbinstance.go
  2. 8 7
      lib/protocol/header.go
  3. 38 21
      lib/protocol/protocol.go
  4. 32 29
      lib/protocol/protocol_test.go
  5. 20 9
      lib/relay/protocol/protocol.go

+ 10 - 2
lib/db/leveldb_dbinstance.go

@@ -267,7 +267,11 @@ func (db *Instance) withHave(folder, device []byte, truncate bool, fn Iterator)
 	defer dbi.Release()
 
 	for dbi.Next() {
-		f, err := unmarshalTrunc(dbi.Value(), truncate)
+		// The iterator function may keep a reference to the unmarshalled
+		// struct, which in turn references the buffer it was unmarshalled
+		// from. dbi.Value() just returns an internal slice that it reuses, so
+		// we need to copy it.
+		f, err := unmarshalTrunc(append([]byte{}, dbi.Value()...), truncate)
 		if err != nil {
 			panic(err)
 		}
@@ -287,7 +291,11 @@ func (db *Instance) withAllFolderTruncated(folder []byte, fn func(device []byte,
 	for dbi.Next() {
 		device := db.deviceKeyDevice(dbi.Key())
 		var f FileInfoTruncated
-		err := f.UnmarshalXDR(dbi.Value())
+		// The iterator function may keep a reference to the unmarshalled
+		// struct, which in turn references the buffer it was unmarshalled
+		// from. dbi.Value() just returns an internal slice that it reuses, so
+		// we need to copy it.
+		err := f.UnmarshalXDR(append([]byte{}, dbi.Value()...))
 		if err != nil {
 			panic(err)
 		}

+ 8 - 7
lib/protocol/header.go

@@ -11,15 +11,16 @@ type header struct {
 	compression bool
 }
 
-func (h header) encodeXDR(xw *xdr.Writer) (int, error) {
-	u := encodeHeader(h)
-	return xw.WriteUint32(u)
+func (h header) MarshalXDRInto(m *xdr.Marshaller) error {
+	v := encodeHeader(h)
+	m.MarshalUint32(v)
+	return m.Error
 }
 
-func (h *header) decodeXDR(xr *xdr.Reader) error {
-	u := xr.ReadUint32()
-	*h = decodeHeader(u)
-	return xr.Error()
+func (h *header) UnmarshalXDRFrom(u *xdr.Unmarshaller) error {
+	v := u.UnmarshalUint32()
+	*h = decodeHeader(v)
+	return u.Error
 }
 
 func encodeHeader(h header) uint32 {

+ 38 - 21
lib/protocol/protocol.go

@@ -12,6 +12,7 @@ import (
 	"time"
 
 	lz4 "github.com/bkaradzic/go-lz4"
+	"github.com/calmh/xdr"
 )
 
 const (
@@ -130,8 +131,7 @@ type rawConnection struct {
 	pool        sync.Pool
 	compression Compression
 
-	rdbuf0 []byte // used & reused by readMessage
-	rdbuf1 []byte // used & reused by readMessage
+	readerBuf []byte // used & reused by readMessage
 }
 
 type asyncResult struct {
@@ -146,7 +146,8 @@ type hdrMsg struct {
 }
 
 type encodable interface {
-	AppendXDR([]byte) ([]byte, error)
+	MarshalXDRInto(m *xdr.Marshaller) error
+	XDRSize() int
 }
 
 type isEofer interface {
@@ -374,18 +375,14 @@ func (c *rawConnection) readerLoop() (err error) {
 }
 
 func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) {
-	if cap(c.rdbuf0) < 8 {
-		c.rdbuf0 = make([]byte, 8)
-	} else {
-		c.rdbuf0 = c.rdbuf0[:8]
-	}
-	_, err = io.ReadFull(c.cr, c.rdbuf0)
+	hdrBuf := make([]byte, 8)
+	_, err = io.ReadFull(c.cr, hdrBuf)
 	if err != nil {
 		return
 	}
 
-	hdr = decodeHeader(binary.BigEndian.Uint32(c.rdbuf0[0:4]))
-	msglen := int(binary.BigEndian.Uint32(c.rdbuf0[4:8]))
+	hdr = decodeHeader(binary.BigEndian.Uint32(hdrBuf[:4]))
+	msglen := int(binary.BigEndian.Uint32(hdrBuf[4:]))
 
 	l.Debugf("read header %v (msglen=%d)", hdr, msglen)
 
@@ -399,27 +396,40 @@ func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) {
 		return
 	}
 
-	if cap(c.rdbuf0) < msglen {
-		c.rdbuf0 = make([]byte, msglen)
+	// c.readerBuf contains a buffer we can reuse. But once we've unmarshalled
+	// a message from the buffer we can't reuse it again as the unmarshalled
+	// message refers to the contents of the buffer. The only case we a buffer
+	// ends up in readerBuf for reuse is when the message is compressed, as we
+	// then decompress into a new buffer instead.
+
+	var msgBuf []byte
+	if cap(c.readerBuf) >= msglen {
+		// If we have a buffer ready in rdbuf we just use that.
+		msgBuf = c.readerBuf[:msglen]
 	} else {
-		c.rdbuf0 = c.rdbuf0[:msglen]
+		// Otherwise we allocate a new buffer.
+		msgBuf = make([]byte, msglen)
 	}
-	_, err = io.ReadFull(c.cr, c.rdbuf0)
+
+	_, err = io.ReadFull(c.cr, msgBuf)
 	if err != nil {
 		return
 	}
 
-	l.Debugf("read %d bytes", len(c.rdbuf0))
+	l.Debugf("read %d bytes", len(msgBuf))
 
-	msgBuf := c.rdbuf0
 	if hdr.compression && msglen > 0 {
-		c.rdbuf1 = c.rdbuf1[:cap(c.rdbuf1)]
-		c.rdbuf1, err = lz4.Decode(c.rdbuf1, c.rdbuf0)
+		// We're going to decompress msgBuf into a different newly allocated
+		// buffer, so keep msgBuf around for reuse on the next message.
+		c.readerBuf = msgBuf
+
+		msgBuf, err = lz4.Decode(nil, msgBuf)
 		if err != nil {
 			return
 		}
-		msgBuf = c.rdbuf1
 		l.Debugf("decompressed to %d bytes", len(msgBuf))
+	} else {
+		c.readerBuf = nil
 	}
 
 	if shouldDebug() {
@@ -601,7 +611,14 @@ func (c *rawConnection) writerLoop() {
 		case hm := <-c.outbox:
 			if hm.msg != nil {
 				// Uncompressed message in uncBuf
-				uncBuf, err = hm.msg.AppendXDR(uncBuf[:0])
+				msgLen := hm.msg.XDRSize()
+				if cap(uncBuf) >= msgLen {
+					uncBuf = uncBuf[:msgLen]
+				} else {
+					uncBuf = make([]byte, msgLen)
+				}
+				m := &xdr.Marshaller{Data: uncBuf}
+				err = hm.msg.MarshalXDRInto(m)
 				if hm.done != nil {
 					close(hm.done)
 				}

+ 32 - 29
lib/protocol/protocol_test.go

@@ -3,7 +3,6 @@
 package protocol
 
 import (
-	"bytes"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/json"
@@ -55,14 +54,13 @@ func TestHeaderMarshalUnmarshal(t *testing.T) {
 		ver = int(uint(ver) % 16)
 		id = int(uint(id) % 4096)
 		typ = int(uint(typ) % 256)
-		buf := new(bytes.Buffer)
-		xw := xdr.NewWriter(buf)
+		buf := make([]byte, 4)
+
 		h0 := header{version: ver, msgID: id, msgType: typ}
-		h0.encodeXDR(xw)
+		h0.MarshalXDRInto(&xdr.Marshaller{Data: buf})
 
-		xr := xdr.NewReader(buf)
 		var h1 header
-		h1.decodeXDR(xr)
+		h1.UnmarshalXDRFrom(&xdr.Unmarshaller{Data: buf})
 		return h0 == h1
 	}
 	if err := quick.Check(f, nil); err != nil {
@@ -128,8 +126,7 @@ func TestVersionErr(t *testing.T) {
 	c0.ClusterConfig(ClusterConfigMessage{})
 	c1.ClusterConfig(ClusterConfigMessage{})
 
-	w := xdr.NewWriter(c0.cw)
-	timeoutWriteHeader(w, header{
+	timeoutWriteHeader(c0.cw, header{
 		version: 2, // higher than supported
 		msgID:   0,
 		msgType: messageTypeIndex,
@@ -154,8 +151,7 @@ func TestTypeErr(t *testing.T) {
 	c0.ClusterConfig(ClusterConfigMessage{})
 	c1.ClusterConfig(ClusterConfigMessage{})
 
-	w := xdr.NewWriter(c0.cw)
-	timeoutWriteHeader(w, header{
+	timeoutWriteHeader(c0.cw, header{
 		version: 0,
 		msgID:   0,
 		msgType: 42, // unknown type
@@ -205,7 +201,7 @@ func TestElementSizeExceededNested(t *testing.T) {
 	m := ClusterConfigMessage{
 		ClientName: "longstringlongstringlongstringinglongstringlongstringlonlongstringlongstringlon",
 	}
-	_, err := m.EncodeXDR(ioutil.Discard)
+	_, err := m.MarshalXDR()
 	if err == nil {
 		t.Errorf("ID length %d > max 64, but no error", len(m.Folders[0].ID))
 	}
@@ -213,12 +209,19 @@ func TestElementSizeExceededNested(t *testing.T) {
 
 func TestMarshalIndexMessage(t *testing.T) {
 	f := func(m1 IndexMessage) bool {
+		if len(m1.Options) == 0 {
+			m1.Options = nil
+		}
 		for i, f := range m1.Files {
 			m1.Files[i].CachedSize = 0
-			for j := range f.Blocks {
-				f.Blocks[j].Offset = 0
-				if len(f.Blocks[j].Hash) == 0 {
-					f.Blocks[j].Hash = nil
+			if len(f.Blocks) == 0 {
+				m1.Files[i].Blocks = nil
+			} else {
+				for j := range f.Blocks {
+					f.Blocks[j].Offset = 0
+					if len(f.Blocks[j].Hash) == 0 {
+						f.Blocks[j].Hash = nil
+					}
 				}
 			}
 		}
@@ -233,6 +236,9 @@ func TestMarshalIndexMessage(t *testing.T) {
 
 func TestMarshalRequestMessage(t *testing.T) {
 	f := func(m1 RequestMessage) bool {
+		if len(m1.Options) == 0 {
+			m1.Options = nil
+		}
 		return testMarshal(t, "request", &m1, &RequestMessage{})
 	}
 
@@ -256,6 +262,9 @@ func TestMarshalResponseMessage(t *testing.T) {
 
 func TestMarshalClusterConfigMessage(t *testing.T) {
 	f := func(m1 ClusterConfigMessage) bool {
+		if len(m1.Options) == 0 {
+			m1.Options = nil
+		}
 		return testMarshal(t, "clusterconfig", &m1, &ClusterConfigMessage{})
 	}
 
@@ -275,13 +284,11 @@ func TestMarshalCloseMessage(t *testing.T) {
 }
 
 type message interface {
-	EncodeXDR(io.Writer) (int, error)
-	DecodeXDR(io.Reader) error
+	MarshalXDR() ([]byte, error)
+	UnmarshalXDR([]byte) error
 }
 
 func testMarshal(t *testing.T, prefix string, m1, m2 message) bool {
-	var buf bytes.Buffer
-
 	failed := func(bc []byte) {
 		bs, _ := json.MarshalIndent(m1, "", "  ")
 		ioutil.WriteFile(prefix+"-1.txt", bs, 0644)
@@ -294,7 +301,7 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool {
 		}
 	}
 
-	_, err := m1.EncodeXDR(&buf)
+	buf, err := m1.MarshalXDR()
 	if err != nil && strings.Contains(err.Error(), "exceeds size") {
 		return true
 	}
@@ -303,23 +310,20 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool {
 		t.Fatal(err)
 	}
 
-	bc := make([]byte, len(buf.Bytes()))
-	copy(bc, buf.Bytes())
-
-	err = m2.DecodeXDR(&buf)
+	err = m2.UnmarshalXDR(buf)
 	if err != nil {
-		failed(bc)
+		failed(buf)
 		t.Fatal(err)
 	}
 
 	ok := reflect.DeepEqual(m1, m2)
 	if !ok {
-		failed(bc)
+		failed(buf)
 	}
 	return ok
 }
 
-func timeoutWriteHeader(w *xdr.Writer, hdr header) {
+func timeoutWriteHeader(w io.Writer, hdr header) {
 	// This tries to write a message header to w, but times out after a while.
 	// This is useful because in testing, with a PipeWriter, it will block
 	// forever if the other side isn't reading any more. On the other hand we
@@ -332,8 +336,7 @@ func timeoutWriteHeader(w *xdr.Writer, hdr header) {
 
 	done := make(chan struct{})
 	go func() {
-		w.WriteRaw(buf[:])
-		l.Infoln("write completed")
+		w.Write(buf[:])
 		close(done)
 	}()
 	select {

+ 20 - 9
lib/relay/protocol/protocol.go

@@ -74,7 +74,13 @@ func WriteMessage(w io.Writer, message interface{}) error {
 
 func ReadMessage(r io.Reader) (interface{}, error) {
 	var header header
-	if err := header.DecodeXDR(r); err != nil {
+
+	buf := make([]byte, header.XDRSize())
+	if _, err := io.ReadFull(r, buf); err != nil {
+		return nil, err
+	}
+
+	if err := header.UnmarshalXDR(buf); err != nil {
 		return nil, err
 	}
 
@@ -82,38 +88,43 @@ func ReadMessage(r io.Reader) (interface{}, error) {
 		return nil, fmt.Errorf("magic mismatch")
 	}
 
+	buf = make([]byte, int(header.messageLength))
+	if _, err := io.ReadFull(r, buf); err != nil {
+		return nil, err
+	}
+
 	switch header.messageType {
 	case messageTypePing:
 		var msg Ping
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypePong:
 		var msg Pong
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypeJoinRelayRequest:
 		var msg JoinRelayRequest
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypeJoinSessionRequest:
 		var msg JoinSessionRequest
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypeResponse:
 		var msg Response
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypeConnectRequest:
 		var msg ConnectRequest
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypeSessionInvitation:
 		var msg SessionInvitation
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	case messageTypeRelayFull:
 		var msg RelayFull
-		err := msg.DecodeXDR(r)
+		err := msg.UnmarshalXDR(buf)
 		return msg, err
 	}