Browse Source

Factor out XDR en/decoding

Jakob Borg 11 years ago
parent
commit
f89fa6caed
8 changed files with 342 additions and 253 deletions
  1. 0 142
      protocol/marshal.go
  2. 85 55
      protocol/messages.go
  3. 10 10
      protocol/messages_test.go
  4. 30 29
      protocol/protocol.go
  5. 0 17
      protocol/protocol_test.go
  6. 65 0
      xdr/reader.go
  7. 95 0
      xdr/writer.go
  8. 57 0
      xdr/xdr_test.go

+ 0 - 142
protocol/marshal.go

@@ -1,142 +0,0 @@
-package protocol
-
-import (
-	"errors"
-	"io"
-	"sync/atomic"
-
-	"github.com/calmh/syncthing/buffers"
-)
-
-func pad(l int) int {
-	d := l % 4
-	if d == 0 {
-		return 0
-	}
-	return 4 - d
-}
-
-var padBytes = []byte{0, 0, 0}
-
-type marshalWriter struct {
-	w   io.Writer
-	tot uint64
-	err error
-	b   [8]byte
-}
-
-// We will never encode nor expect to decode blobs larger than 10 MB. Check
-// inserted to protect against attempting to allocate arbitrary amounts of
-// memory when reading a corrupt message.
-const maxBytesFieldLength = 10 * 1 << 20
-
-var ErrFieldLengthExceeded = errors.New("Protocol error: raw bytes field size exceeds limit")
-
-func (w *marshalWriter) writeString(s string) {
-	w.writeBytes([]byte(s))
-}
-
-func (w *marshalWriter) writeBytes(bs []byte) {
-	if w.err != nil {
-		return
-	}
-	if len(bs) > maxBytesFieldLength {
-		w.err = ErrFieldLengthExceeded
-		return
-	}
-	w.writeUint32(uint32(len(bs)))
-	if w.err != nil {
-		return
-	}
-	_, w.err = w.w.Write(bs)
-	if p := pad(len(bs)); w.err == nil && p > 0 {
-		_, w.err = w.w.Write(padBytes[:p])
-	}
-	atomic.AddUint64(&w.tot, uint64(len(bs)+pad(len(bs))))
-}
-
-func (w *marshalWriter) writeUint32(v uint32) {
-	if w.err != nil {
-		return
-	}
-	w.b[0] = byte(v >> 24)
-	w.b[1] = byte(v >> 16)
-	w.b[2] = byte(v >> 8)
-	w.b[3] = byte(v)
-	_, w.err = w.w.Write(w.b[:4])
-	atomic.AddUint64(&w.tot, 4)
-}
-
-func (w *marshalWriter) writeUint64(v uint64) {
-	if w.err != nil {
-		return
-	}
-	w.b[0] = byte(v >> 56)
-	w.b[1] = byte(v >> 48)
-	w.b[2] = byte(v >> 40)
-	w.b[3] = byte(v >> 32)
-	w.b[4] = byte(v >> 24)
-	w.b[5] = byte(v >> 16)
-	w.b[6] = byte(v >> 8)
-	w.b[7] = byte(v)
-	_, w.err = w.w.Write(w.b[:8])
-	atomic.AddUint64(&w.tot, 8)
-}
-
-func (w *marshalWriter) getTot() uint64 {
-	return atomic.LoadUint64(&w.tot)
-}
-
-type marshalReader struct {
-	r   io.Reader
-	tot uint64
-	err error
-	b   [8]byte
-}
-
-func (r *marshalReader) readString() string {
-	bs := r.readBytes()
-	defer buffers.Put(bs)
-	return string(bs)
-}
-
-func (r *marshalReader) readBytes() []byte {
-	if r.err != nil {
-		return nil
-	}
-	l := int(r.readUint32())
-	if r.err != nil {
-		return nil
-	}
-	if l > maxBytesFieldLength {
-		r.err = ErrFieldLengthExceeded
-		return nil
-	}
-	b := buffers.Get(l + pad(l))
-	_, r.err = io.ReadFull(r.r, b)
-	atomic.AddUint64(&r.tot, uint64(l+pad(l)))
-	return b[:l]
-}
-
-func (r *marshalReader) readUint32() uint32 {
-	if r.err != nil {
-		return 0
-	}
-	_, r.err = io.ReadFull(r.r, r.b[:4])
-	atomic.AddUint64(&r.tot, 8)
-	return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
-}
-
-func (r *marshalReader) readUint64() uint64 {
-	if r.err != nil {
-		return 0
-	}
-	_, r.err = io.ReadFull(r.r, r.b[:8])
-	atomic.AddUint64(&r.tot, 8)
-	return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
-		uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
-}
-
-func (r *marshalReader) getTot() uint64 {
-	return atomic.LoadUint64(&r.tot)
-}

+ 85 - 55
protocol/messages.go

@@ -3,6 +3,9 @@ package protocol
 import (
 	"errors"
 	"io"
+
+	"github.com/calmh/syncthing/buffers"
+	"github.com/calmh/syncthing/xdr"
 )
 
 const (
@@ -43,60 +46,93 @@ func decodeHeader(u uint32) header {
 	}
 }
 
+func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) {
+	mw := newMarshalWriter(w)
+	mw.writeIndex(repo, idx)
+	return int(mw.Tot()), mw.Err()
+}
+
+type marshalWriter struct {
+	*xdr.Writer
+}
+
+func newMarshalWriter(w io.Writer) marshalWriter {
+	return marshalWriter{xdr.NewWriter(w)}
+}
+
 func (w *marshalWriter) writeHeader(h header) {
-	w.writeUint32(encodeHeader(h))
+	w.WriteUint32(encodeHeader(h))
 }
 
 func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) {
-	w.writeString(repo)
-	w.writeUint32(uint32(len(idx)))
+	w.WriteString(repo)
+	w.WriteUint32(uint32(len(idx)))
 	for _, f := range idx {
-		w.writeString(f.Name)
-		w.writeUint32(f.Flags)
-		w.writeUint64(uint64(f.Modified))
-		w.writeUint32(f.Version)
-		w.writeUint32(uint32(len(f.Blocks)))
+		w.WriteString(f.Name)
+		w.WriteUint32(f.Flags)
+		w.WriteUint64(uint64(f.Modified))
+		w.WriteUint32(f.Version)
+		w.WriteUint32(uint32(len(f.Blocks)))
 		for _, b := range f.Blocks {
-			w.writeUint32(b.Size)
-			w.writeBytes(b.Hash)
+			w.WriteUint32(b.Size)
+			w.WriteBytes(b.Hash)
 		}
 	}
 }
 
-func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) {
-	mw := marshalWriter{w: w}
-	mw.writeIndex(repo, idx)
-	return int(mw.getTot()), mw.err
-}
-
 func (w *marshalWriter) writeRequest(r request) {
-	w.writeString(r.repo)
-	w.writeString(r.name)
-	w.writeUint64(uint64(r.offset))
-	w.writeUint32(r.size)
-	w.writeBytes(r.hash)
+	w.WriteString(r.repo)
+	w.WriteString(r.name)
+	w.WriteUint64(uint64(r.offset))
+	w.WriteUint32(r.size)
+	w.WriteBytes(r.hash)
 }
 
 func (w *marshalWriter) writeResponse(data []byte) {
-	w.writeBytes(data)
+	w.WriteBytes(data)
 }
 
 func (w *marshalWriter) writeOptions(opts map[string]string) {
-	w.writeUint32(uint32(len(opts)))
+	w.WriteUint32(uint32(len(opts)))
 	for k, v := range opts {
-		w.writeString(k)
-		w.writeString(v)
+		w.WriteString(k)
+		w.WriteString(v)
 	}
 }
 
-func (r *marshalReader) readHeader() header {
-	return decodeHeader(r.readUint32())
+func ReadIndex(r io.Reader) (string, []FileInfo, error) {
+	mr := newMarshalReader(r)
+	repo, idx := mr.readIndex()
+	return repo, idx, mr.Err()
 }
 
-func (r *marshalReader) readIndex() (string, []FileInfo) {
+type marshalReader struct {
+	*xdr.Reader
+	err error
+}
+
+func newMarshalReader(r io.Reader) marshalReader {
+	return marshalReader{
+		Reader: xdr.NewReader(r),
+		err:    nil,
+	}
+}
+
+func (r marshalReader) Err() error {
+	if r.err != nil {
+		return r.err
+	}
+	return r.Reader.Err()
+}
+
+func (r marshalReader) readHeader() header {
+	return decodeHeader(r.ReadUint32())
+}
+
+func (r marshalReader) readIndex() (string, []FileInfo) {
 	var files []FileInfo
-	repo := r.readString()
-	nfiles := r.readUint32()
+	repo := r.ReadString()
+	nfiles := r.ReadUint32()
 	if nfiles > maxNumFiles {
 		r.err = ErrMaxFilesExceeded
 		return "", nil
@@ -104,19 +140,19 @@ func (r *marshalReader) readIndex() (string, []FileInfo) {
 	if nfiles > 0 {
 		files = make([]FileInfo, nfiles)
 		for i := range files {
-			files[i].Name = r.readString()
-			files[i].Flags = r.readUint32()
-			files[i].Modified = int64(r.readUint64())
-			files[i].Version = r.readUint32()
-			nblocks := r.readUint32()
+			files[i].Name = r.ReadString()
+			files[i].Flags = r.ReadUint32()
+			files[i].Modified = int64(r.ReadUint64())
+			files[i].Version = r.ReadUint32()
+			nblocks := r.ReadUint32()
 			if nblocks > maxNumBlocks {
 				r.err = ErrMaxBlocksExceeded
 				return "", nil
 			}
 			blocks := make([]BlockInfo, nblocks)
 			for j := range blocks {
-				blocks[j].Size = r.readUint32()
-				blocks[j].Hash = r.readBytes()
+				blocks[j].Size = r.ReadUint32()
+				blocks[j].Hash = r.ReadBytes(buffers.Get(32))
 			}
 			files[i].Blocks = blocks
 		}
@@ -124,32 +160,26 @@ func (r *marshalReader) readIndex() (string, []FileInfo) {
 	return repo, files
 }
 
-func ReadIndex(r io.Reader) (string, []FileInfo, error) {
-	mr := marshalReader{r: r}
-	repo, idx := mr.readIndex()
-	return repo, idx, mr.err
-}
-
-func (r *marshalReader) readRequest() request {
+func (r marshalReader) readRequest() request {
 	var req request
-	req.repo = r.readString()
-	req.name = r.readString()
-	req.offset = int64(r.readUint64())
-	req.size = r.readUint32()
-	req.hash = r.readBytes()
+	req.repo = r.ReadString()
+	req.name = r.ReadString()
+	req.offset = int64(r.ReadUint64())
+	req.size = r.ReadUint32()
+	req.hash = r.ReadBytes(buffers.Get(32))
 	return req
 }
 
-func (r *marshalReader) readResponse() []byte {
-	return r.readBytes()
+func (r marshalReader) readResponse() []byte {
+	return r.ReadBytes(buffers.Get(128 * 1024))
 }
 
-func (r *marshalReader) readOptions() map[string]string {
-	n := r.readUint32()
+func (r marshalReader) readOptions() map[string]string {
+	n := r.ReadUint32()
 	opts := make(map[string]string, n)
 	for i := 0; i < int(n); i++ {
-		k := r.readString()
-		v := r.readString()
+		k := r.ReadString()
+		v := r.ReadString()
 		opts[k] = v
 	}
 	return opts

+ 10 - 10
protocol/messages_test.go

@@ -34,10 +34,10 @@ func TestIndex(t *testing.T) {
 	}
 
 	var buf = new(bytes.Buffer)
-	var wr = marshalWriter{w: buf}
+	var wr = newMarshalWriter(buf)
 	wr.writeIndex("default", idx)
 
-	var rd = marshalReader{r: buf}
+	var rd = newMarshalReader(buf)
 	var repo, idx2 = rd.readIndex()
 
 	if repo != "default" {
@@ -53,9 +53,9 @@ func TestRequest(t *testing.T) {
 	f := func(repo, name string, offset int64, size uint32, hash []byte) bool {
 		var buf = new(bytes.Buffer)
 		var req = request{repo, name, offset, size, hash}
-		var wr = marshalWriter{w: buf}
+		var wr = newMarshalWriter(buf)
 		wr.writeRequest(req)
-		var rd = marshalReader{r: buf}
+		var rd = newMarshalReader(buf)
 		var req2 = rd.readRequest()
 		return req.name == req2.name &&
 			req.offset == req2.offset &&
@@ -70,9 +70,9 @@ func TestRequest(t *testing.T) {
 func TestResponse(t *testing.T) {
 	f := func(data []byte) bool {
 		var buf = new(bytes.Buffer)
-		var wr = marshalWriter{w: buf}
+		var wr = newMarshalWriter(buf)
 		wr.writeResponse(data)
-		var rd = marshalReader{r: buf}
+		var rd = newMarshalReader(buf)
 		var read = rd.readResponse()
 		return bytes.Compare(read, data) == 0
 	}
@@ -106,7 +106,7 @@ func BenchmarkWriteIndex(b *testing.B) {
 		},
 	}
 
-	var wr = marshalWriter{w: ioutil.Discard}
+	var wr = newMarshalWriter(ioutil.Discard)
 
 	for i := 0; i < b.N; i++ {
 		wr.writeIndex("default", idx)
@@ -115,7 +115,7 @@ func BenchmarkWriteIndex(b *testing.B) {
 
 func BenchmarkWriteRequest(b *testing.B) {
 	var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")}
-	var wr = marshalWriter{w: ioutil.Discard}
+	var wr = newMarshalWriter(ioutil.Discard)
 
 	for i := 0; i < b.N; i++ {
 		wr.writeRequest(req)
@@ -131,10 +131,10 @@ func TestOptions(t *testing.T) {
 	}
 
 	var buf = new(bytes.Buffer)
-	var wr = marshalWriter{w: buf}
+	var wr = newMarshalWriter(buf)
 	wr.writeOptions(opts)
 
-	var rd = marshalReader{r: buf}
+	var rd = newMarshalReader(buf)
 	var ropts = rd.readOptions()
 
 	if !reflect.DeepEqual(opts, ropts) {

+ 30 - 29
protocol/protocol.go

@@ -10,6 +10,7 @@ import (
 	"time"
 
 	"github.com/calmh/syncthing/buffers"
+	"github.com/calmh/syncthing/xdr"
 )
 
 const (
@@ -61,9 +62,9 @@ type Connection struct {
 	id          string
 	receiver    Model
 	reader      io.Reader
-	mreader     *marshalReader
+	mreader     marshalReader
 	writer      io.Writer
-	mwriter     *marshalWriter
+	mwriter     marshalWriter
 	closed      bool
 	awaiting    map[int]chan asyncResult
 	nextId      int
@@ -101,9 +102,9 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
 		id:        nodeID,
 		receiver:  receiver,
 		reader:    flrd,
-		mreader:   &marshalReader{r: flrd},
+		mreader:   marshalReader{Reader: xdr.NewReader(flrd)},
 		writer:    flwr,
-		mwriter:   &marshalWriter{w: flwr},
+		mwriter:   marshalWriter{Writer: xdr.NewWriter(flwr)},
 		awaiting:  make(map[int]chan asyncResult),
 		indexSent: make(map[string]map[string][2]int64),
 	}
@@ -168,8 +169,8 @@ func (c *Connection) Index(repo string, idx []FileInfo) {
 	if err != nil {
 		c.close(err)
 		return
-	} else if c.mwriter.err != nil {
-		c.close(c.mwriter.err)
+	} else if c.mwriter.Err() != nil {
+		c.close(c.mwriter.Err())
 		return
 	}
 }
@@ -185,10 +186,10 @@ func (c *Connection) Request(repo string, name string, offset int64, size uint32
 	c.awaiting[c.nextId] = rc
 	c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
 	c.mwriter.writeRequest(request{repo, name, offset, size, hash})
-	if c.mwriter.err != nil {
+	if c.mwriter.Err() != nil {
 		c.Unlock()
-		c.close(c.mwriter.err)
-		return nil, c.mwriter.err
+		c.close(c.mwriter.Err())
+		return nil, c.mwriter.Err()
 	}
 	err := c.flush()
 	if err != nil {
@@ -220,9 +221,9 @@ func (c *Connection) ping() bool {
 		c.Unlock()
 		c.close(err)
 		return false
-	} else if c.mwriter.err != nil {
+	} else if c.mwriter.Err() != nil {
 		c.Unlock()
-		c.close(c.mwriter.err)
+		c.close(c.mwriter.Err())
 		return false
 	}
 	c.nextId = (c.nextId + 1) & 0xfff
@@ -269,8 +270,8 @@ func (c *Connection) readerLoop() {
 loop:
 	for {
 		hdr := c.mreader.readHeader()
-		if c.mreader.err != nil {
-			c.close(c.mreader.err)
+		if c.mreader.Err() != nil {
+			c.close(c.mreader.Err())
 			break loop
 		}
 		if hdr.version != 0 {
@@ -282,8 +283,8 @@ loop:
 		case messageTypeIndex:
 			repo, files := c.mreader.readIndex()
 			_ = repo
-			if c.mreader.err != nil {
-				c.close(c.mreader.err)
+			if c.mreader.Err() != nil {
+				c.close(c.mreader.Err())
 				break loop
 			} else {
 				c.receiver.Index(c.id, files)
@@ -295,8 +296,8 @@ loop:
 		case messageTypeIndexUpdate:
 			repo, files := c.mreader.readIndex()
 			_ = repo
-			if c.mreader.err != nil {
-				c.close(c.mreader.err)
+			if c.mreader.Err() != nil {
+				c.close(c.mreader.Err())
 				break loop
 			} else {
 				c.receiver.IndexUpdate(c.id, files)
@@ -304,8 +305,8 @@ loop:
 
 		case messageTypeRequest:
 			req := c.mreader.readRequest()
-			if c.mreader.err != nil {
-				c.close(c.mreader.err)
+			if c.mreader.Err() != nil {
+				c.close(c.mreader.Err())
 				break loop
 			}
 			go c.processRequest(hdr.msgID, req)
@@ -313,8 +314,8 @@ loop:
 		case messageTypeResponse:
 			data := c.mreader.readResponse()
 
-			if c.mreader.err != nil {
-				c.close(c.mreader.err)
+			if c.mreader.Err() != nil {
+				c.close(c.mreader.Err())
 				break loop
 			} else {
 				c.Lock()
@@ -323,21 +324,21 @@ loop:
 				c.Unlock()
 
 				if ok {
-					rc <- asyncResult{data, c.mreader.err}
+					rc <- asyncResult{data, c.mreader.Err()}
 					close(rc)
 				}
 			}
 
 		case messageTypePing:
 			c.Lock()
-			c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
+			c.mwriter.WriteUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
 			err := c.flush()
 			c.Unlock()
 			if err != nil {
 				c.close(err)
 				break loop
-			} else if c.mwriter.err != nil {
-				c.close(c.mwriter.err)
+			} else if c.mwriter.Err() != nil {
+				c.close(c.mwriter.Err())
 				break loop
 			}
 
@@ -376,9 +377,9 @@ func (c *Connection) processRequest(msgID int, req request) {
 	data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash)
 
 	c.Lock()
-	c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
+	c.mwriter.WriteUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
 	c.mwriter.writeResponse(data)
-	err := c.mwriter.err
+	err := c.mwriter.Err()
 	if err == nil {
 		err = c.flush()
 	}
@@ -427,8 +428,8 @@ func (c *Connection) Statistics() Statistics {
 
 	stats := Statistics{
 		At:            time.Now(),
-		InBytesTotal:  int(c.mreader.getTot()),
-		OutBytesTotal: int(c.mwriter.getTot()),
+		InBytesTotal:  int(c.mreader.Tot()),
+		OutBytesTotal: int(c.mwriter.Tot()),
 	}
 
 	return stats

+ 0 - 17
protocol/protocol_test.go

@@ -22,23 +22,6 @@ func TestHeaderFunctions(t *testing.T) {
 	}
 }
 
-func TestPad(t *testing.T) {
-	tests := [][]int{
-		{0, 0},
-		{1, 3},
-		{2, 2},
-		{3, 1},
-		{4, 0},
-		{32, 0},
-		{33, 3},
-	}
-	for _, tc := range tests {
-		if p := pad(tc[0]); p != tc[1] {
-			t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1])
-		}
-	}
-}
-
 func TestPing(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()

+ 65 - 0
xdr/reader.go

@@ -0,0 +1,65 @@
+package xdr
+
+import "io"
+
+type Reader struct {
+	r   io.Reader
+	tot uint64
+	err error
+	b   [8]byte
+}
+
+func NewReader(r io.Reader) *Reader {
+	return &Reader{
+		r: r,
+	}
+}
+
+func (r *Reader) ReadString() string {
+	return string(r.ReadBytes(nil))
+}
+
+func (r *Reader) ReadBytes(dst []byte) []byte {
+	if r.err != nil {
+		return nil
+	}
+	l := int(r.ReadUint32())
+	if r.err != nil {
+		return nil
+	}
+	if l+pad(l) > len(dst) {
+		dst = make([]byte, l+pad(l))
+	} else {
+		dst = dst[:l+pad(l)]
+	}
+	_, r.err = io.ReadFull(r.r, dst)
+	r.tot += uint64(l + pad(l))
+	return dst[:l]
+}
+
+func (r *Reader) ReadUint32() uint32 {
+	if r.err != nil {
+		return 0
+	}
+	_, r.err = io.ReadFull(r.r, r.b[:4])
+	r.tot += 8
+	return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
+}
+
+func (r *Reader) ReadUint64() uint64 {
+	if r.err != nil {
+		return 0
+	}
+	_, r.err = io.ReadFull(r.r, r.b[:8])
+	r.tot += 8
+	return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
+		uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
+}
+
+func (r *Reader) Tot() uint64 {
+	return r.tot
+}
+
+func (r *Reader) Err() error {
+	return r.err
+}

+ 95 - 0
xdr/writer.go

@@ -0,0 +1,95 @@
+package xdr
+
+import "io"
+
+func pad(l int) int {
+	d := l % 4
+	if d == 0 {
+		return 0
+	}
+	return 4 - d
+}
+
+var padBytes = []byte{0, 0, 0}
+
+type Writer struct {
+	w   io.Writer
+	tot uint64
+	err error
+	b   [8]byte
+}
+
+func NewWriter(w io.Writer) *Writer {
+	return &Writer{
+		w: w,
+	}
+}
+
+func (w *Writer) WriteString(s string) (int, error) {
+	return w.WriteBytes([]byte(s))
+}
+
+func (w *Writer) WriteBytes(bs []byte) (int, error) {
+	if w.err != nil {
+		return 0, w.err
+	}
+
+	w.WriteUint32(uint32(len(bs)))
+	if w.err != nil {
+		return 0, w.err
+	}
+
+	var l, n int
+	n, w.err = w.w.Write(bs)
+	l += n
+
+	if p := pad(len(bs)); w.err == nil && p > 0 {
+		n, w.err = w.w.Write(padBytes[:p])
+		l += n
+	}
+
+	w.tot += uint64(l)
+	return l, w.err
+}
+
+func (w *Writer) WriteUint32(v uint32) (int, error) {
+	if w.err != nil {
+		return 0, w.err
+	}
+	w.b[0] = byte(v >> 24)
+	w.b[1] = byte(v >> 16)
+	w.b[2] = byte(v >> 8)
+	w.b[3] = byte(v)
+
+	var l int
+	l, w.err = w.w.Write(w.b[:4])
+	w.tot += uint64(l)
+	return l, w.err
+}
+
+func (w *Writer) WriteUint64(v uint64) (int, error) {
+	if w.err != nil {
+		return 0, w.err
+	}
+	w.b[0] = byte(v >> 56)
+	w.b[1] = byte(v >> 48)
+	w.b[2] = byte(v >> 40)
+	w.b[3] = byte(v >> 32)
+	w.b[4] = byte(v >> 24)
+	w.b[5] = byte(v >> 16)
+	w.b[6] = byte(v >> 8)
+	w.b[7] = byte(v)
+
+	var l int
+	l, w.err = w.w.Write(w.b[:8])
+	w.tot += uint64(l)
+	return l, w.err
+}
+
+func (w *Writer) Tot() uint64 {
+	return w.tot
+}
+
+func (w *Writer) Err() error {
+	return w.err
+}

+ 57 - 0
xdr/xdr_test.go

@@ -0,0 +1,57 @@
+package xdr
+
+import (
+	"bytes"
+	"testing"
+	"testing/quick"
+)
+
+func TestPad(t *testing.T) {
+	tests := [][]int{
+		{0, 0},
+		{1, 3},
+		{2, 2},
+		{3, 1},
+		{4, 0},
+		{32, 0},
+		{33, 3},
+	}
+	for _, tc := range tests {
+		if p := pad(tc[0]); p != tc[1] {
+			t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1])
+		}
+	}
+}
+
+func TestBytesNil(t *testing.T) {
+	fn := func(bs []byte) bool {
+		var b = new(bytes.Buffer)
+		var w = NewWriter(b)
+		var r = NewReader(b)
+		w.WriteBytes(bs)
+		w.WriteBytes(bs)
+		r.ReadBytes(nil)
+		res := r.ReadBytes(nil)
+		return bytes.Compare(bs, res) == 0
+	}
+	if err := quick.Check(fn, nil); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestBytesGiven(t *testing.T) {
+	fn := func(bs []byte) bool {
+		var b = new(bytes.Buffer)
+		var w = NewWriter(b)
+		var r = NewReader(b)
+		w.WriteBytes(bs)
+		w.WriteBytes(bs)
+		res := make([]byte, 12)
+		res = r.ReadBytes(res)
+		res = r.ReadBytes(res)
+		return bytes.Compare(bs, res) == 0
+	}
+	if err := quick.Check(fn, nil); err != nil {
+		t.Error(err)
+	}
+}