Browse Source

Handle ElementSizeExceeded on nested structs

Jakob Borg 11 years ago
parent
commit
32da1c8d58
6 changed files with 173 additions and 11 deletions
  1. 8 2
      cmd/genxdr/main.go
  2. 12 3
      discover/packets_xdr.go
  3. 4 1
      files/leveldb_xdr.go
  4. 1 0
      protocol/.gitignore
  5. 20 5
      protocol/message_xdr.go
  6. 128 0
      protocol/protocol_test.go

+ 8 - 2
cmd/genxdr/main.go

@@ -84,7 +84,10 @@ func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) {
 				{{end}}
 				xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}})
 			{{else}}
-				o.{{$fieldInfo.Name}}.encodeXDR(xw)
+				_, err := o.{{$fieldInfo.Name}}.encodeXDR(xw)
+				if err != nil {
+					return xw.Tot(), err
+				}
 			{{end}}
 		{{else}}
 			{{if ge $fieldInfo.Max 1}}
@@ -99,7 +102,10 @@ func (o {{.TypeName}}) encodeXDR(xw *xdr.Writer) (int, error) {
 			{{else if $fieldInfo.IsBasic}}
 				xw.Write{{$fieldInfo.Encoder}}(o.{{$fieldInfo.Name}}[i])
 			{{else}}
-				o.{{$fieldInfo.Name}}[i].encodeXDR(xw)
+				_, err := o.{{$fieldInfo.Name}}[i].encodeXDR(xw)
+				if err != nil {
+					return xw.Tot(), err
+				}
 			{{end}}
 			}
 		{{end}}

+ 12 - 3
discover/packets_xdr.go

@@ -126,13 +126,19 @@ func (o Announce) AppendXDR(bs []byte) []byte {
 
 func (o Announce) encodeXDR(xw *xdr.Writer) (int, error) {
 	xw.WriteUint32(o.Magic)
-	o.This.encodeXDR(xw)
+	_, err := o.This.encodeXDR(xw)
+	if err != nil {
+		return xw.Tot(), err
+	}
 	if len(o.Extra) > 16 {
 		return xw.Tot(), xdr.ErrElementSizeExceeded
 	}
 	xw.WriteUint32(uint32(len(o.Extra)))
 	for i := range o.Extra {
-		o.Extra[i].encodeXDR(xw)
+		_, err := o.Extra[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	return xw.Tot(), xw.Error()
 }
@@ -216,7 +222,10 @@ func (o Node) encodeXDR(xw *xdr.Writer) (int, error) {
 	}
 	xw.WriteUint32(uint32(len(o.Addresses)))
 	for i := range o.Addresses {
-		o.Addresses[i].encodeXDR(xw)
+		_, err := o.Addresses[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	return xw.Tot(), xw.Error()
 }

+ 4 - 1
files/leveldb_xdr.go

@@ -120,7 +120,10 @@ func (o versionList) AppendXDR(bs []byte) []byte {
 func (o versionList) encodeXDR(xw *xdr.Writer) (int, error) {
 	xw.WriteUint32(uint32(len(o.versions)))
 	for i := range o.versions {
-		o.versions[i].encodeXDR(xw)
+		_, err := o.versions[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	return xw.Tot(), xw.Error()
 }

+ 1 - 0
protocol/.gitignore

@@ -0,0 +1 @@
+*.txt

+ 20 - 5
protocol/message_xdr.go

@@ -66,7 +66,10 @@ func (o IndexMessage) encodeXDR(xw *xdr.Writer) (int, error) {
 	xw.WriteString(o.Repository)
 	xw.WriteUint32(uint32(len(o.Files)))
 	for i := range o.Files {
-		o.Files[i].encodeXDR(xw)
+		_, err := o.Files[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	return xw.Tot(), xw.Error()
 }
@@ -165,7 +168,10 @@ func (o FileInfo) encodeXDR(xw *xdr.Writer) (int, error) {
 	xw.WriteUint64(o.LocalVersion)
 	xw.WriteUint32(uint32(len(o.Blocks)))
 	for i := range o.Blocks {
-		o.Blocks[i].encodeXDR(xw)
+		_, err := o.Blocks[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	return xw.Tot(), xw.Error()
 }
@@ -476,14 +482,20 @@ func (o ClusterConfigMessage) encodeXDR(xw *xdr.Writer) (int, error) {
 	}
 	xw.WriteUint32(uint32(len(o.Repositories)))
 	for i := range o.Repositories {
-		o.Repositories[i].encodeXDR(xw)
+		_, err := o.Repositories[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	if len(o.Options) > 64 {
 		return xw.Tot(), xdr.ErrElementSizeExceeded
 	}
 	xw.WriteUint32(uint32(len(o.Options)))
 	for i := range o.Options {
-		o.Options[i].encodeXDR(xw)
+		_, err := o.Options[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	return xw.Tot(), xw.Error()
 }
@@ -575,7 +587,10 @@ func (o Repository) encodeXDR(xw *xdr.Writer) (int, error) {
 	}
 	xw.WriteUint32(uint32(len(o.Nodes)))
 	for i := range o.Nodes {
-		o.Nodes[i].encodeXDR(xw)
+		_, err := o.Nodes[i].encodeXDR(xw)
+		if err != nil {
+			return xw.Tot(), err
+		}
 	}
 	return xw.Tot(), xw.Error()
 }

+ 128 - 0
protocol/protocol_test.go

@@ -5,12 +5,19 @@
 package protocol
 
 import (
+	"bytes"
+	"encoding/hex"
 	"errors"
+	"fmt"
 	"io"
+	"io/ioutil"
+	"os"
+	"reflect"
 	"testing"
 	"testing/quick"
 
 	"github.com/calmh/syncthing/xdr"
+	pretty "github.com/tonnerre/golang-pretty"
 )
 
 var (
@@ -230,3 +237,124 @@ func TestClose(t *testing.T) {
 		t.Error("Request should return an error")
 	}
 }
+
+func TestElementSizeExceededNested(t *testing.T) {
+	m := ClusterConfigMessage{
+		Repositories: []Repository{
+			{ID: "longstringlongstringlongstringinglongstringlongstringlonlongstringlongstringlon"},
+		},
+	}
+	_, err := m.EncodeXDR(ioutil.Discard)
+	if err == nil {
+		t.Errorf("ID length %d > max 64, but no error", len(m.Repositories[0].ID))
+	}
+}
+
+func TestMarshalIndexMessage(t *testing.T) {
+	f := func(m1 IndexMessage) bool {
+		for _, f := range m1.Files {
+			for i := range f.Blocks {
+				f.Blocks[i].Offset = 0
+				if len(f.Blocks[i].Hash) == 0 {
+					f.Blocks[i].Hash = nil
+				}
+			}
+		}
+
+		return testMarshal(t, "index", &m1, &IndexMessage{})
+	}
+
+	if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestMarshalRequestMessage(t *testing.T) {
+	f := func(m1 RequestMessage) bool {
+		return testMarshal(t, "request", &m1, &RequestMessage{})
+	}
+
+	if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestMarshalResponseMessage(t *testing.T) {
+	f := func(m1 ResponseMessage) bool {
+		if len(m1.Data) == 0 {
+			m1.Data = nil
+		}
+		return testMarshal(t, "response", &m1, &ResponseMessage{})
+	}
+
+	if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestMarshalClusterConfigMessage(t *testing.T) {
+	f := func(m1 ClusterConfigMessage) bool {
+		return testMarshal(t, "clusterconfig", &m1, &ClusterConfigMessage{})
+	}
+
+	if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestMarshalCloseMessage(t *testing.T) {
+	f := func(m1 CloseMessage) bool {
+		return testMarshal(t, "close", &m1, &CloseMessage{})
+	}
+
+	if err := quick.Check(f, &quick.Config{MaxCountScale: 10}); err != nil {
+		t.Error(err)
+	}
+}
+
+type message interface {
+	EncodeXDR(io.Writer) (int, error)
+	DecodeXDR(io.Reader) error
+}
+
+func testMarshal(t *testing.T, prefix string, m1, m2 message) bool {
+	var buf bytes.Buffer
+
+	failed := func(bc []byte) {
+		f, _ := os.Create(prefix + "-1.txt")
+		pretty.Fprintf(f, "%# v", m1)
+		f.Close()
+		f, _ = os.Create(prefix + "-2.txt")
+		pretty.Fprintf(f, "%# v", m2)
+		f.Close()
+		if len(bc) > 0 {
+			f, _ := os.Create(prefix + "-data.txt")
+			fmt.Fprint(f, hex.Dump(bc))
+			f.Close()
+		}
+	}
+
+	_, err := m1.EncodeXDR(&buf)
+	if err == xdr.ErrElementSizeExceeded {
+		return true
+	}
+	if err != nil {
+		failed(nil)
+		t.Fatal(err)
+	}
+
+	bc := make([]byte, len(buf.Bytes()))
+	copy(bc, buf.Bytes())
+
+	err = m2.DecodeXDR(&buf)
+	if err != nil {
+		failed(bc)
+		t.Fatal(err)
+	}
+
+	ok := reflect.DeepEqual(m1, m2)
+	if !ok {
+		failed(bc)
+	}
+	return ok
+}