Browse Source

Handle calls on closed connection

Jakob Borg 12 years ago
parent
commit
454e672d42
3 changed files with 42 additions and 0 deletions
  1. 3 0
      protocol/common_test.go
  2. 6 0
      protocol/protocol.go
  3. 33 0
      protocol/protocol_test.go

+ 3 - 0
protocol/common_test.go

@@ -14,6 +14,9 @@ type TestModel struct {
 func (t *TestModel) Index(nodeID string, files []FileInfo) {
 }
 
+func (t *TestModel) IndexUpdate(nodeID string, files []FileInfo) {
+}
+
 func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
 	t.name = name
 	t.offset = offset

+ 6 - 0
protocol/protocol.go

@@ -134,6 +134,9 @@ func (c *Connection) Index(idx []FileInfo) {
 
 // Request returns the bytes for the specified block after fetching them from the connected peer.
 func (c *Connection) Request(name string, offset uint64, size uint32, hash []byte) ([]byte, error) {
+	if c.isClosed() {
+		return nil, ErrClosed
+	}
 	c.Lock()
 	rc := make(chan asyncResult)
 	c.awaiting[c.nextId] = rc
@@ -161,6 +164,9 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt
 }
 
 func (c *Connection) Ping() (time.Duration, bool) {
+	if c.isClosed() {
+		return 0, false
+	}
 	c.Lock()
 	rc := make(chan asyncResult)
 	c.awaiting[c.nextId] = rc

+ 33 - 0
protocol/protocol_test.go

@@ -179,3 +179,36 @@ func TestTypeErr(t *testing.T) {
 		t.Error("Connection should close due to unknown message type")
 	}
 }
+
+func TestClose(t *testing.T) {
+	m0 := &TestModel{}
+	m1 := &TestModel{}
+
+	ar, aw := io.Pipe()
+	br, bw := io.Pipe()
+
+	c0 := NewConnection("c0", ar, bw, m0)
+	NewConnection("c1", br, aw, m1)
+
+	c0.close()
+
+	ok := c0.isClosed()
+	if !ok {
+		t.Fatal("Connection should be closed")
+	}
+
+	// None of these should panic, some should return an error
+
+	_, ok = c0.Ping()
+	if ok {
+		t.Error("Ping should not return true")
+	}
+
+	c0.Index(nil)
+	c0.Index(nil)
+
+	_, err := c0.Request("foo", 0, 0, nil)
+	if err == nil {
+		t.Error("Request should return an error")
+	}
+}