Browse Source

Add multi-repository support to protocol (ref #35)

Jakob Borg 11 years ago
parent
commit
21a7f3960a
10 changed files with 104 additions and 59 deletions
  1. 3 3
      main.go
  2. 9 7
      model.go
  3. 4 4
      model_test.go
  4. 6 1
      protocol/PROTOCOL.md
  5. 3 1
      protocol/common_test.go
  6. 1 1
      protocol/marshal.go
  7. 35 9
      protocol/messages.go
  8. 10 6
      protocol/messages_test.go
  9. 24 21
      protocol/protocol.go
  10. 9 6
      protocol/protocol_test.go

+ 3 - 3
main.go

@@ -496,7 +496,7 @@ func saveIndex(m *Model) {
 
 	gzw := gzip.NewWriter(idxf)
 
-	protocol.WriteIndex(gzw, m.ProtocolIndex())
+	protocol.WriteIndex(gzw, "local", m.ProtocolIndex())
 	gzw.Close()
 	idxf.Close()
 	os.Rename(fullName+".tmp", fullName)
@@ -516,8 +516,8 @@ func loadIndex(m *Model) {
 	}
 	defer gzr.Close()
 
-	idx, err := protocol.ReadIndex(gzr)
-	if err != nil {
+	repo, idx, err := protocol.ReadIndex(gzr)
+	if repo != "local" || err != nil {
 		return
 	}
 	m.SeedLocal(idx)

+ 9 - 7
model.go

@@ -55,8 +55,8 @@ type Model struct {
 
 type Connection interface {
 	ID() string
-	Index([]protocol.FileInfo)
-	Request(name string, offset int64, size uint32, hash []byte) ([]byte, error)
+	Index(string, []protocol.FileInfo)
+	Request(repo, name string, offset int64, size uint32, hash []byte) ([]byte, error)
 	Statistics() protocol.Statistics
 	Option(key string) string
 }
@@ -360,6 +360,8 @@ func (m *Model) Close(node string, err error) {
 	}
 	if err == protocol.ErrClusterHash {
 		warnf("Connection to %s closed due to mismatched cluster hash. Ensure that the configured cluster members are identical on both nodes.", node)
+	} else if err != io.EOF {
+		warnf("Connection to %s closed: %v", node, err)
 	}
 
 	m.fq.RemoveAvailable(node)
@@ -385,7 +387,7 @@ func (m *Model) Close(node string, err error) {
 
 // Request returns the specified data segment by reading it from local disk.
 // Implements the protocol.Model interface.
-func (m *Model) Request(nodeID, name string, offset int64, size uint32, hash []byte) ([]byte, error) {
+func (m *Model) Request(nodeID, repo, name string, offset int64, size uint32, hash []byte) ([]byte, error) {
 	// Verify that the requested file exists in the local and global model.
 	m.lmut.RLock()
 	lf, localOk := m.local[name]
@@ -507,7 +509,7 @@ func (m *Model) AddConnection(rawConn io.Closer, protoConn Connection) {
 
 	go func() {
 		idx := m.ProtocolIndex()
-		protoConn.Index(idx)
+		protoConn.Index("default", idx)
 	}()
 
 	m.initmut.Lock()
@@ -539,7 +541,7 @@ func (m *Model) AddConnection(rawConn io.Closer, protoConn Connection) {
 					if m.trace["pull"] {
 						debugln("PULL: Request", nodeID, i, qb.name, qb.block.Offset)
 					}
-					data, _ := protoConn.Request(qb.name, qb.block.Offset, qb.block.Size, qb.block.Hash)
+					data, _ := protoConn.Request("default", qb.name, qb.block.Offset, qb.block.Size, qb.block.Hash)
 					m.fq.Done(qb.name, qb.block.Offset, data)
 				} else {
 					time.Sleep(1 * time.Second)
@@ -585,7 +587,7 @@ func (m *Model) requestGlobal(nodeID, name string, offset int64, size uint32, ha
 		debugf("NET REQ(out): %s: %q o=%d s=%d h=%x", nodeID, name, offset, size, hash)
 	}
 
-	return nc.Request(name, offset, size, hash)
+	return nc.Request("default", name, offset, size, hash)
 }
 
 func (m *Model) broadcastIndexLoop() {
@@ -613,7 +615,7 @@ func (m *Model) broadcastIndexLoop() {
 					debugf("NET IDX(out/loop): %s: %d files", node.ID(), len(idx))
 				}
 				go func() {
-					node.Index(idx)
+					node.Index("default", idx)
 					indexWg.Done()
 				}()
 			}

+ 4 - 4
model_test.go

@@ -345,7 +345,7 @@ func TestRequest(t *testing.T) {
 	fs, _ := m.Walk(false)
 	m.ReplaceLocal(fs)
 
-	bs, err := m.Request("some node", "foo", 0, 6, nil)
+	bs, err := m.Request("some node", "default", "foo", 0, 6, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -353,7 +353,7 @@ func TestRequest(t *testing.T) {
 		t.Errorf("Incorrect data from request: %q", string(bs))
 	}
 
-	bs, err = m.Request("some node", "../walk.go", 0, 6, nil)
+	bs, err = m.Request("some node", "default", "../walk.go", 0, 6, nil)
 	if err == nil {
 		t.Error("Unexpected nil error on insecure file read")
 	}
@@ -487,9 +487,9 @@ func (f FakeConnection) Option(string) string {
 	return ""
 }
 
-func (FakeConnection) Index([]protocol.FileInfo) {}
+func (FakeConnection) Index(string, []protocol.FileInfo) {}
 
-func (f FakeConnection) Request(name string, offset int64, size uint32, hash []byte) ([]byte, error) {
+func (f FakeConnection) Request(repo, name string, offset int64, size uint32, hash []byte) ([]byte, error) {
 	return f.requestData, nil
 }
 

+ 6 - 1
protocol/PROTOCOL.md

@@ -84,6 +84,7 @@ an empty Index message must be sent. There is no response to the Index
 message.
 
     struct IndexMessage {
+        string Repository<>;
         FileInfo Files<>;
     }
 
@@ -100,6 +101,10 @@ message.
         opaque Hash<>
     }
 
+The Repository field identifies the repository that the index message
+pertains to. For single repository implementations an empty repository
+ID is acceptable.
+
 The file name is the part relative to the repository root. The
 modification time is expressed as the number of seconds since the Unix
 Epoch. The version field is a counter that increments each time the file
@@ -143,6 +148,7 @@ before transmitting data. Each Request message must be met with a Response
 message.
 
     struct RequestMessage {
+        string Repository<>;
         string Name<>;
         unsigned hyper Offset;
         unsigned int Length;
@@ -248,4 +254,3 @@ their repository contents and transmits an updated Index message (10).
 Both peers enter idle state after 10. At some later time 11, peer A
 determines that it has not seen data from B for some time and sends a
 Ping request. A response is sent at 12.
-

+ 3 - 1
protocol/common_test.go

@@ -4,6 +4,7 @@ import "io"
 
 type TestModel struct {
 	data   []byte
+	repo   string
 	name   string
 	offset int64
 	size   uint32
@@ -17,7 +18,8 @@ func (t *TestModel) Index(nodeID string, files []FileInfo) {
 func (t *TestModel) IndexUpdate(nodeID string, files []FileInfo) {
 }
 
-func (t *TestModel) Request(nodeID, name string, offset int64, size uint32, hash []byte) ([]byte, error) {
+func (t *TestModel) Request(nodeID, repo, name string, offset int64, size uint32, hash []byte) ([]byte, error) {
+	t.repo = repo
 	t.name = name
 	t.offset = offset
 	t.size = size

+ 1 - 1
protocol/marshal.go

@@ -30,7 +30,7 @@ type marshalWriter struct {
 // memory when reading a corrupt message.
 const maxBytesFieldLength = 10 * 1 << 20
 
-var ErrFieldLengthExceeded = errors.New("Raw bytes field size exceeds limit")
+var ErrFieldLengthExceeded = errors.New("Protocol error: raw bytes field size exceeds limit")
 
 func (w *marshalWriter) writeString(s string) {
 	w.writeBytes([]byte(s))

+ 35 - 9
protocol/messages.go

@@ -1,8 +1,22 @@
 package protocol
 
-import "io"
+import (
+	"errors"
+	"io"
+)
+
+const (
+	maxNumFiles  = 100000 // More than 100000 files is a protocol error
+	maxNumBlocks = 100000 // 100000 * 128KB = 12.5 GB max acceptable file size
+)
+
+var (
+	ErrMaxFilesExceeded  = errors.New("Protocol error: number of files per index exceeds limit")
+	ErrMaxBlocksExceeded = errors.New("Protocol error: number of blocks per file exceeds limit")
+)
 
 type request struct {
+	repo   string
 	name   string
 	offset int64
 	size   uint32
@@ -33,7 +47,8 @@ func (w *marshalWriter) writeHeader(h header) {
 	w.writeUint32(encodeHeader(h))
 }
 
-func (w *marshalWriter) writeIndex(idx []FileInfo) {
+func (w *marshalWriter) writeIndex(repo string, idx []FileInfo) {
+	w.writeString(repo)
 	w.writeUint32(uint32(len(idx)))
 	for _, f := range idx {
 		w.writeString(f.Name)
@@ -48,13 +63,14 @@ func (w *marshalWriter) writeIndex(idx []FileInfo) {
 	}
 }
 
-func WriteIndex(w io.Writer, idx []FileInfo) (int, error) {
+func WriteIndex(w io.Writer, repo string, idx []FileInfo) (int, error) {
 	mw := marshalWriter{w: w}
-	mw.writeIndex(idx)
+	mw.writeIndex(repo, idx)
 	return int(mw.getTot()), mw.err
 }
 
 func (w *marshalWriter) writeRequest(r request) {
+	w.writeString(r.repo)
 	w.writeString(r.name)
 	w.writeUint64(uint64(r.offset))
 	w.writeUint32(r.size)
@@ -77,9 +93,14 @@ func (r *marshalReader) readHeader() header {
 	return decodeHeader(r.readUint32())
 }
 
-func (r *marshalReader) readIndex() []FileInfo {
+func (r *marshalReader) readIndex() (string, []FileInfo) {
 	var files []FileInfo
+	repo := r.readString()
 	nfiles := r.readUint32()
+	if nfiles > maxNumFiles {
+		r.err = ErrMaxFilesExceeded
+		return "", nil
+	}
 	if nfiles > 0 {
 		files = make([]FileInfo, nfiles)
 		for i := range files {
@@ -88,6 +109,10 @@ func (r *marshalReader) readIndex() []FileInfo {
 			files[i].Modified = int64(r.readUint64())
 			files[i].Version = r.readUint32()
 			nblocks := r.readUint32()
+			if nblocks > maxNumBlocks {
+				r.err = ErrMaxBlocksExceeded
+				return "", nil
+			}
 			blocks := make([]BlockInfo, nblocks)
 			for j := range blocks {
 				blocks[j].Size = r.readUint32()
@@ -96,17 +121,18 @@ func (r *marshalReader) readIndex() []FileInfo {
 			files[i].Blocks = blocks
 		}
 	}
-	return files
+	return repo, files
 }
 
-func ReadIndex(r io.Reader) ([]FileInfo, error) {
+func ReadIndex(r io.Reader) (string, []FileInfo, error) {
 	mr := marshalReader{r: r}
-	idx := mr.readIndex()
-	return idx, mr.err
+	repo, idx := mr.readIndex()
+	return repo, idx, mr.err
 }
 
 func (r *marshalReader) readRequest() request {
 	var req request
+	req.repo = r.readString()
 	req.name = r.readString()
 	req.offset = int64(r.readUint64())
 	req.size = r.readUint32()

+ 10 - 6
protocol/messages_test.go

@@ -35,10 +35,14 @@ func TestIndex(t *testing.T) {
 
 	var buf = new(bytes.Buffer)
 	var wr = marshalWriter{w: buf}
-	wr.writeIndex(idx)
+	wr.writeIndex("default", idx)
 
 	var rd = marshalReader{r: buf}
-	var idx2 = rd.readIndex()
+	var repo, idx2 = rd.readIndex()
+
+	if repo != "default" {
+		t.Error("Incorrect repo", repo)
+	}
 
 	if !reflect.DeepEqual(idx, idx2) {
 		t.Errorf("Index marshal error:\n%#v\n%#v\n", idx, idx2)
@@ -46,9 +50,9 @@ func TestIndex(t *testing.T) {
 }
 
 func TestRequest(t *testing.T) {
-	f := func(name string, offset int64, size uint32, hash []byte) bool {
+	f := func(repo, name string, offset int64, size uint32, hash []byte) bool {
 		var buf = new(bytes.Buffer)
-		var req = request{name, offset, size, hash}
+		var req = request{repo, name, offset, size, hash}
 		var wr = marshalWriter{w: buf}
 		wr.writeRequest(req)
 		var rd = marshalReader{r: buf}
@@ -105,12 +109,12 @@ func BenchmarkWriteIndex(b *testing.B) {
 	var wr = marshalWriter{w: ioutil.Discard}
 
 	for i := 0; i < b.N; i++ {
-		wr.writeIndex(idx)
+		wr.writeIndex("default", idx)
 	}
 }
 
 func BenchmarkWriteRequest(b *testing.B) {
-	var req = request{"blah blah", 1231323, 13123123, []byte("hash hash hash")}
+	var req = request{"default", "blah blah", 1231323, 13123123, []byte("hash hash hash")}
 	var wr = marshalWriter{w: ioutil.Discard}
 
 	for i := 0; i < b.N; i++ {

+ 24 - 21
protocol/protocol.go

@@ -50,7 +50,7 @@ type Model interface {
 	// An index update was received from the peer node
 	IndexUpdate(nodeID string, files []FileInfo)
 	// A request was made by the peer node
-	Request(nodeID, name string, offset int64, size uint32, hash []byte) ([]byte, error)
+	Request(nodeID, repo string, name string, offset int64, size uint32, hash []byte) ([]byte, error)
 	// The peer node closed the connection
 	Close(nodeID string, err error)
 }
@@ -67,7 +67,7 @@ type Connection struct {
 	closed      bool
 	awaiting    map[int]chan asyncResult
 	nextId      int
-	indexSent   map[string][2]int64
+	indexSent   map[string]map[string][2]int64
 	peerOptions map[string]string
 	myOptions   map[string]string
 	optionsLock sync.Mutex
@@ -98,13 +98,14 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
 	}
 
 	c := Connection{
-		id:       nodeID,
-		receiver: receiver,
-		reader:   flrd,
-		mreader:  &marshalReader{r: flrd},
-		writer:   flwr,
-		mwriter:  &marshalWriter{w: flwr},
-		awaiting: make(map[int]chan asyncResult),
+		id:        nodeID,
+		receiver:  receiver,
+		reader:    flrd,
+		mreader:   &marshalReader{r: flrd},
+		writer:    flwr,
+		mwriter:   &marshalWriter{w: flwr},
+		awaiting:  make(map[int]chan asyncResult),
+		indexSent: make(map[string]map[string][2]int64),
 	}
 
 	go c.readerLoop()
@@ -133,32 +134,32 @@ func (c *Connection) ID() string {
 }
 
 // Index writes the list of file information to the connected peer node
-func (c *Connection) Index(idx []FileInfo) {
+func (c *Connection) Index(repo string, idx []FileInfo) {
 	c.Lock()
 	var msgType int
-	if c.indexSent == nil {
+	if c.indexSent[repo] == nil {
 		// This is the first time we send an index.
 		msgType = messageTypeIndex
 
-		c.indexSent = make(map[string][2]int64)
+		c.indexSent[repo] = make(map[string][2]int64)
 		for _, f := range idx {
-			c.indexSent[f.Name] = [2]int64{f.Modified, int64(f.Version)}
+			c.indexSent[repo][f.Name] = [2]int64{f.Modified, int64(f.Version)}
 		}
 	} else {
 		// We have sent one full index. Only send updates now.
 		msgType = messageTypeIndexUpdate
 		var diff []FileInfo
 		for _, f := range idx {
-			if vs, ok := c.indexSent[f.Name]; !ok || f.Modified != vs[0] || int64(f.Version) != vs[1] {
+			if vs, ok := c.indexSent[repo][f.Name]; !ok || f.Modified != vs[0] || int64(f.Version) != vs[1] {
 				diff = append(diff, f)
-				c.indexSent[f.Name] = [2]int64{f.Modified, int64(f.Version)}
+				c.indexSent[repo][f.Name] = [2]int64{f.Modified, int64(f.Version)}
 			}
 		}
 		idx = diff
 	}
 
 	c.mwriter.writeHeader(header{0, c.nextId, msgType})
-	c.mwriter.writeIndex(idx)
+	c.mwriter.writeIndex(repo, idx)
 	err := c.flush()
 	c.nextId = (c.nextId + 1) & 0xfff
 	c.hasSentIndex = true
@@ -174,7 +175,7 @@ func (c *Connection) Index(idx []FileInfo) {
 }
 
 // Request returns the bytes for the specified block after fetching them from the connected peer.
-func (c *Connection) Request(name string, offset int64, size uint32, hash []byte) ([]byte, error) {
+func (c *Connection) Request(repo string, name string, offset int64, size uint32, hash []byte) ([]byte, error) {
 	c.Lock()
 	if c.closed {
 		c.Unlock()
@@ -183,7 +184,7 @@ func (c *Connection) Request(name string, offset int64, size uint32, hash []byte
 	rc := make(chan asyncResult)
 	c.awaiting[c.nextId] = rc
 	c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
-	c.mwriter.writeRequest(request{name, offset, size, hash})
+	c.mwriter.writeRequest(request{repo, name, offset, size, hash})
 	if c.mwriter.err != nil {
 		c.Unlock()
 		c.close(c.mwriter.err)
@@ -279,7 +280,8 @@ loop:
 
 		switch hdr.msgType {
 		case messageTypeIndex:
-			files := c.mreader.readIndex()
+			repo, files := c.mreader.readIndex()
+			_ = repo
 			if c.mreader.err != nil {
 				c.close(c.mreader.err)
 				break loop
@@ -291,7 +293,8 @@ loop:
 			c.Unlock()
 
 		case messageTypeIndexUpdate:
-			files := c.mreader.readIndex()
+			repo, files := c.mreader.readIndex()
+			_ = repo
 			if c.mreader.err != nil {
 				c.close(c.mreader.err)
 				break loop
@@ -370,7 +373,7 @@ loop:
 }
 
 func (c *Connection) processRequest(msgID int, req request) {
-	data, _ := c.receiver.Request(c.id, req.name, req.offset, req.size, req.hash)
+	data, _ := c.receiver.Request(c.id, req.repo, req.name, req.offset, req.size, req.hash)
 
 	c.Lock()
 	c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse}))

+ 9 - 6
protocol/protocol_test.go

@@ -84,8 +84,8 @@ func TestRequestResponseErr(t *testing.T) {
 	e := errors.New("Something broke")
 
 	var pass bool
-	for i := 0; i < 36; i++ {
-		for j := 0; j < 26; j++ {
+	for i := 0; i < 48; i++ {
+		for j := 0; j < 38; j++ {
 			m0 := &TestModel{data: []byte("response data")}
 			m1 := &TestModel{}
 
@@ -97,7 +97,7 @@ func TestRequestResponseErr(t *testing.T) {
 			NewConnection("c0", ar, ebw, m0, nil)
 			c1 := NewConnection("c1", br, eaw, m1, nil)
 
-			d, err := c1.Request("tn", 1234, 3456, []byte("hashbytes"))
+			d, err := c1.Request("default", "tn", 1234, 3456, []byte("hashbytes"))
 			if err == e || err == ErrClosed {
 				t.Logf("Error at %d+%d bytes", i, j)
 				if !m1.closed {
@@ -115,6 +115,9 @@ func TestRequestResponseErr(t *testing.T) {
 			if string(d) != "response data" {
 				t.Errorf("Incorrect response data %q", string(d))
 			}
+			if m0.repo != "default" {
+				t.Error("Incorrect repo %q", m0.repo)
+			}
 			if m0.name != "tn" {
 				t.Error("Incorrect name %q", m0.name)
 			}
@@ -204,10 +207,10 @@ func TestClose(t *testing.T) {
 		t.Error("Ping should not return true")
 	}
 
-	c0.Index(nil)
-	c0.Index(nil)
+	c0.Index("default", nil)
+	c0.Index("default", nil)
 
-	_, err := c0.Request("foo", 0, 0, nil)
+	_, err := c0.Request("default", "foo", 0, 0, nil)
 	if err == nil {
 		t.Error("Request should return an error")
 	}