Browse Source

Error handling, testing

Jakob Borg 12 years ago
parent
commit
f5987fba32
6 changed files with 207 additions and 35 deletions
  1. 6 6
      model_test.go
  2. 49 0
      protocol/common_test.go
  3. 2 2
      protocol/marshal.go
  4. 49 26
      protocol/protocol.go
  5. 100 0
      protocol/protocol_test.go
  6. 1 1
      walk_test.go

+ 6 - 6
model_test.go

@@ -47,7 +47,7 @@ var testDataExpected = map[string]File{
 
 func TestUpdateLocal(t *testing.T) {
 	m := NewModel("foo")
-	fs := Walk("testdata", m)
+	fs := Walk("testdata", m, false)
 	m.ReplaceLocal(fs)
 
 	if len(m.need) > 0 {
@@ -89,7 +89,7 @@ func TestUpdateLocal(t *testing.T) {
 
 func TestRemoteUpdateExisting(t *testing.T) {
 	m := NewModel("foo")
-	fs := Walk("testdata", m)
+	fs := Walk("testdata", m, false)
 	m.ReplaceLocal(fs)
 
 	newFile := protocol.FileInfo{
@@ -106,7 +106,7 @@ func TestRemoteUpdateExisting(t *testing.T) {
 
 func TestRemoteAddNew(t *testing.T) {
 	m := NewModel("foo")
-	fs := Walk("testdata", m)
+	fs := Walk("testdata", m, false)
 	m.ReplaceLocal(fs)
 
 	newFile := protocol.FileInfo{
@@ -123,7 +123,7 @@ func TestRemoteAddNew(t *testing.T) {
 
 func TestRemoteUpdateOld(t *testing.T) {
 	m := NewModel("foo")
-	fs := Walk("testdata", m)
+	fs := Walk("testdata", m, false)
 	m.ReplaceLocal(fs)
 
 	oldTimeStamp := int64(1234)
@@ -141,7 +141,7 @@ func TestRemoteUpdateOld(t *testing.T) {
 
 func TestDelete(t *testing.T) {
 	m := NewModel("foo")
-	fs := Walk("testdata", m)
+	fs := Walk("testdata", m, false)
 	m.ReplaceLocal(fs)
 
 	if l1, l2 := len(m.local), len(fs); l1 != l2 {
@@ -231,7 +231,7 @@ func TestDelete(t *testing.T) {
 
 func TestForgetNode(t *testing.T) {
 	m := NewModel("foo")
-	fs := Walk("testdata", m)
+	fs := Walk("testdata", m, false)
 	m.ReplaceLocal(fs)
 
 	if l1, l2 := len(m.local), len(fs); l1 != l2 {

+ 49 - 0
protocol/common_test.go

@@ -0,0 +1,49 @@
+package protocol
+
+import "io"
+
+type TestModel struct {
+	data   []byte
+	name   string
+	offset uint64
+	size   uint32
+	hash   []byte
+	closed bool
+}
+
+func (t *TestModel) Index(nodeID string, files []FileInfo) {
+}
+
+func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
+	t.name = name
+	t.offset = offset
+	t.size = size
+	t.hash = hash
+	return t.data, nil
+}
+
+func (t *TestModel) Close(nodeID string) {
+	t.closed = true
+}
+
+type ErrPipe struct {
+	io.PipeWriter
+	written int
+	max     int
+	err     error
+	closed  bool
+}
+
+func (e *ErrPipe) Write(data []byte) (int, error) {
+	if e.closed {
+		return 0, e.err
+	}
+	if e.written+len(data) > e.max {
+		n, _ := e.PipeWriter.Write(data[:e.max-e.written])
+		e.PipeWriter.CloseWithError(e.err)
+		e.closed = true
+		return n, e.err
+	} else {
+		return e.PipeWriter.Write(data)
+	}
+}

+ 2 - 2
protocol/marshal.go

@@ -35,8 +35,8 @@ func (w *marshalWriter) writeBytes(bs []byte) {
 		return
 	}
 	_, w.err = w.w.Write(bs)
-	if p := pad(len(bs)); p > 0 {
-		w.w.Write(padBytes[:p])
+	if p := pad(len(bs)); w.err == nil && p > 0 {
+		_, w.err = w.w.Write(padBytes[:p])
 	}
 	w.tot += len(bs) + pad(len(bs))
 }

+ 49 - 26
protocol/protocol.go

@@ -49,7 +49,6 @@ type Connection struct {
 	mwriter     *marshalWriter
 	wLock       sync.RWMutex
 	closed      bool
-	closedLock  sync.RWMutex
 	awaiting    map[int]chan asyncResult
 	nextId      int
 	lastReceive time.Time
@@ -74,13 +73,14 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
 	}
 
 	c := Connection{
-		receiver: receiver,
-		reader:   flrd,
-		mreader:  &marshalReader{flrd, 0, nil},
-		writer:   flwr,
-		mwriter:  &marshalWriter{flwr, 0, nil},
-		awaiting: make(map[int]chan asyncResult),
-		ID:       nodeID,
+		receiver:    receiver,
+		reader:      flrd,
+		mreader:     &marshalReader{flrd, 0, nil},
+		writer:      flwr,
+		mwriter:     &marshalWriter{flwr, 0, nil},
+		awaiting:    make(map[int]chan asyncResult),
+		lastReceive: time.Now(),
+		ID:          nodeID,
 	}
 
 	go c.readerLoop()
@@ -92,12 +92,15 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
 // Index writes the list of file information to the connected peer node
 func (c *Connection) Index(idx []FileInfo) {
 	c.wLock.Lock()
-	defer c.wLock.Unlock()
-
 	c.mwriter.writeHeader(header{0, c.nextId, messageTypeIndex})
-	c.nextId = (c.nextId + 1) & 0xfff
 	c.mwriter.writeIndex(idx)
-	c.flush()
+	err := c.flush()
+	c.nextId = (c.nextId + 1) & 0xfff
+	c.wLock.Unlock()
+	if err != nil || c.mwriter.err != nil {
+		c.close()
+		return
+	}
 }
 
 // Request returns the bytes for the specified block after fetching them from the connected peer.
@@ -107,7 +110,17 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt
 	c.awaiting[c.nextId] = rc
 	c.mwriter.writeHeader(header{0, c.nextId, messageTypeRequest})
 	c.mwriter.writeRequest(request{name, offset, size, hash})
-	c.flush()
+	if c.mwriter.err != nil {
+		c.wLock.Unlock()
+		c.close()
+		return nil, c.mwriter.err
+	}
+	err := c.flush()
+	if err != nil {
+		c.wLock.Unlock()
+		c.close()
+		return nil, err
+	}
 	c.nextId = (c.nextId + 1) & 0xfff
 	c.wLock.Unlock()
 
@@ -123,7 +136,12 @@ func (c *Connection) Ping() bool {
 	rc := make(chan asyncResult)
 	c.awaiting[c.nextId] = rc
 	c.mwriter.writeHeader(header{0, c.nextId, messageTypePing})
-	c.flush()
+	err := c.flush()
+	if err != nil || c.mwriter.err != nil {
+		c.wLock.Unlock()
+		c.close()
+		return false
+	}
 	c.nextId = (c.nextId + 1) & 0xfff
 	c.wLock.Unlock()
 
@@ -138,18 +156,20 @@ type flusher interface {
 	Flush() error
 }
 
-func (c *Connection) flush() {
+func (c *Connection) flush() error {
 	if f, ok := c.writer.(flusher); ok {
-		f.Flush()
+		return f.Flush()
 	}
+	return nil
 }
 
 func (c *Connection) close() {
-	c.closedLock.Lock()
-	c.closed = true
-	c.closedLock.Unlock()
-
 	c.wLock.Lock()
+	if c.closed {
+		c.wLock.Unlock()
+		return
+	}
+	c.closed = true
 	for _, ch := range c.awaiting {
 		close(ch)
 	}
@@ -160,8 +180,8 @@ func (c *Connection) close() {
 }
 
 func (c *Connection) isClosed() bool {
-	c.closedLock.RLock()
-	defer c.closedLock.RUnlock()
+	c.wLock.RLock()
+	defer c.wLock.RUnlock()
 	return c.closed
 }
 
@@ -215,9 +235,9 @@ func (c *Connection) readerLoop() {
 		case messageTypePing:
 			c.wLock.Lock()
 			c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
-			c.flush()
+			err := c.flush()
 			c.wLock.Unlock()
-			if c.mwriter.err != nil {
+			if err != nil || c.mwriter.err != nil {
 				c.close()
 			}
 
@@ -248,9 +268,12 @@ func (c *Connection) processRequest(msgID int) {
 			c.wLock.Lock()
 			c.mwriter.writeUint32(encodeHeader(header{0, msgID, messageTypeResponse}))
 			c.mwriter.writeResponse(data)
-			buffers.Put(data)
-			c.flush()
+			err := c.flush()
 			c.wLock.Unlock()
+			buffers.Put(data)
+			if c.mwriter.err != nil || err != nil {
+				c.close()
+			}
 		}()
 	}
 }

+ 100 - 0
protocol/protocol_test.go

@@ -1,8 +1,11 @@
 package protocol
 
 import (
+	"errors"
+	"io"
 	"testing"
 	"testing/quick"
+	"time"
 )
 
 func TestHeaderFunctions(t *testing.T) {
@@ -35,3 +38,100 @@ func TestPad(t *testing.T) {
 		}
 	}
 }
+
+func TestPing(t *testing.T) {
+	ar, aw := io.Pipe()
+	br, bw := io.Pipe()
+
+	c0 := NewConnection("c0", ar, bw, nil)
+	c1 := NewConnection("c1", br, aw, nil)
+
+	if !c0.Ping() {
+		t.Error("c0 ping failed")
+	}
+	if !c1.Ping() {
+		t.Error("c1 ping failed")
+	}
+}
+
+func TestPingErr(t *testing.T) {
+	e := errors.New("Something broke")
+
+	for i := 0; i < 12; i++ {
+		for j := 0; j < 12; j++ {
+			m0 := &TestModel{}
+			m1 := &TestModel{}
+
+			ar, aw := io.Pipe()
+			br, bw := io.Pipe()
+			eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
+			ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
+
+			c0 := NewConnection("c0", ar, ebw, m0)
+			NewConnection("c1", br, eaw, m1)
+
+			res := c0.Ping()
+			if (i < 4 || j < 4) && res {
+				t.Errorf("Unexpected ping success; i=%d, j=%d", i, j)
+			} else if (i >= 8 && j >= 8) && !res {
+				t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j)
+			}
+		}
+	}
+}
+
+func TestRequestResponseErr(t *testing.T) {
+	e := errors.New("Something broke")
+
+	var pass bool
+	for i := 0; i < 36; i++ {
+		for j := 0; j < 26; j++ {
+			m0 := &TestModel{data: []byte("response data")}
+			m1 := &TestModel{}
+
+			ar, aw := io.Pipe()
+			br, bw := io.Pipe()
+			eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
+			ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
+
+			NewConnection("c0", ar, ebw, m0)
+			c1 := NewConnection("c1", br, eaw, m1)
+
+			d, err := c1.Request("tn", 1234, 3456, []byte("hashbytes"))
+			if err == e || err == ErrClosed {
+				t.Logf("Error at %d+%d bytes", i, j)
+				if !m1.closed {
+					t.Error("c1 not closed")
+				}
+				time.Sleep(1 * time.Millisecond)
+				if !m0.closed {
+					t.Error("c0 not closed")
+				}
+				continue
+			}
+			if err != nil {
+				t.Error(err)
+			}
+			if string(d) != "response data" {
+				t.Errorf("Incorrect response data %q", string(d))
+			}
+			if m0.name != "tn" {
+				t.Error("Incorrect name %q", m0.name)
+			}
+			if m0.offset != 1234 {
+				t.Error("Incorrect offset %d", m0.offset)
+			}
+			if m0.size != 3456 {
+				t.Error("Incorrect size %d", m0.size)
+			}
+			if string(m0.hash) != "hashbytes" {
+				t.Error("Incorrect hash %q", m0.hash)
+			}
+			t.Logf("Pass at %d+%d bytes", i, j)
+			pass = true
+		}
+	}
+	if !pass {
+		t.Error("Never passed")
+	}
+}

+ 1 - 1
walk_test.go

@@ -18,7 +18,7 @@ var testdata = []struct {
 
 func TestWalk(t *testing.T) {
 	m := new(Model)
-	files := Walk("testdata", m)
+	files := Walk("testdata", m, false)
 
 	if l1, l2 := len(files), len(testdata); l1 != l2 {
 		t.Fatalf("Incorrect number of walked files %d != %d", l1, l2)