Просмотр исходного кода

Propagate and log reason for connection close

Jakob Borg 12 лет назад
Родитель
Сommit
0f6b34160c
5 измененных файлов с 43 добавлено и 30 удалено
  1. 6 2
      model.go
  2. 1 1
      model_test.go
  3. 1 1
      protocol/common_test.go
  4. 34 25
      protocol/protocol.go
  5. 1 1
      protocol/protocol_test.go

+ 6 - 2
model.go

@@ -188,11 +188,15 @@ func (m *Model) SeedIndex(fs []protocol.FileInfo) {
 	m.printModelStats()
 }
 
-func (m *Model) Close(node string) {
+func (m *Model) Close(node string, err error) {
 	m.Lock()
 	defer m.Unlock()
 
-	infoln("Disconnected from node", node)
+	if err != nil {
+		warnf("Disconnected from node %s: %v", node, err)
+	} else {
+		infoln("Disconnected from node", node)
+	}
 
 	delete(m.remote, node)
 	delete(m.nodes, node)

+ 1 - 1
model_test.go

@@ -294,7 +294,7 @@ func TestForgetNode(t *testing.T) {
 		t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2)
 	}
 
-	m.Close("42")
+	m.Close("42", nil)
 
 	if l1, l2 := len(m.local), len(fs); l1 != l2 {
 		t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2)

+ 1 - 1
protocol/common_test.go

@@ -25,7 +25,7 @@ func (t *TestModel) Request(nodeID, name string, offset uint64, size uint32, has
 	return t.data, nil
 }
 
-func (t *TestModel) Close(nodeID string) {
+func (t *TestModel) Close(nodeID string, err error) {
 	t.closed = true
 }
 

+ 34 - 25
protocol/protocol.go

@@ -3,8 +3,8 @@ package protocol
 import (
 	"compress/flate"
 	"errors"
+	"fmt"
 	"io"
-	"log"
 	"sync"
 	"time"
 
@@ -40,7 +40,7 @@ type Model interface {
 	// A request was made by the peer node
 	Request(nodeID, name string, offset uint64, size uint32, hash []byte) ([]byte, error)
 	// The peer node closed the connection
-	Close(nodeID string)
+	Close(nodeID string, err error)
 }
 
 type Connection struct {
@@ -130,8 +130,11 @@ func (c *Connection) Index(idx []FileInfo) {
 	err := c.flush()
 	c.nextId = (c.nextId + 1) & 0xfff
 	c.Unlock()
-	if err != nil || c.mwriter.err != nil {
-		c.Close()
+	if err != nil {
+		c.Close(err)
+		return
+	} else if c.mwriter.err != nil {
+		c.Close(c.mwriter.err)
 		return
 	}
 }
@@ -149,13 +152,13 @@ func (c *Connection) Request(name string, offset uint64, size uint32, hash []byt
 	c.mwriter.writeRequest(request{name, offset, size, hash})
 	if c.mwriter.err != nil {
 		c.Unlock()
-		c.Close()
+		c.Close(c.mwriter.err)
 		return nil, c.mwriter.err
 	}
 	err := c.flush()
 	if err != nil {
 		c.Unlock()
-		c.Close()
+		c.Close(err)
 		return nil, err
 	}
 	c.nextId = (c.nextId + 1) & 0xfff
@@ -178,9 +181,13 @@ func (c *Connection) Ping() bool {
 	c.awaiting[c.nextId] = rc
 	c.mwriter.writeHeader(header{0, c.nextId, messageTypePing})
 	err := c.flush()
-	if err != nil || c.mwriter.err != nil {
+	if err != nil {
+		c.Unlock()
+		c.Close(err)
+		return false
+	} else if c.mwriter.err != nil {
 		c.Unlock()
-		c.Close()
+		c.Close(c.mwriter.err)
 		return false
 	}
 	c.nextId = (c.nextId + 1) & 0xfff
@@ -204,7 +211,7 @@ func (c *Connection) flush() error {
 	return nil
 }
 
-func (c *Connection) Close() {
+func (c *Connection) Close(err error) {
 	c.Lock()
 	if c.closed {
 		c.Unlock()
@@ -217,7 +224,7 @@ func (c *Connection) Close() {
 	c.awaiting = nil
 	c.Unlock()
 
-	c.receiver.Close(c.ID)
+	c.receiver.Close(c.ID, err)
 }
 
 func (c *Connection) isClosed() bool {
@@ -230,12 +237,11 @@ func (c *Connection) readerLoop() {
 	for !c.isClosed() {
 		hdr := c.mreader.readHeader()
 		if c.mreader.err != nil {
-			c.Close()
+			c.Close(c.mreader.err)
 			break
 		}
 		if hdr.version != 0 {
-			log.Printf("Protocol error: %s: unknown message version %#x", c.ID, hdr.version)
-			c.Close()
+			c.Close(fmt.Errorf("Protocol error: %s: unknown message version %#x", c.ID, hdr.version))
 			break
 		}
 
@@ -247,7 +253,7 @@ func (c *Connection) readerLoop() {
 		case messageTypeIndex:
 			files := c.mreader.readIndex()
 			if c.mreader.err != nil {
-				c.Close()
+				c.Close(c.mreader.err)
 			} else {
 				c.receiver.Index(c.ID, files)
 			}
@@ -255,7 +261,7 @@ func (c *Connection) readerLoop() {
 		case messageTypeIndexUpdate:
 			files := c.mreader.readIndex()
 			if c.mreader.err != nil {
-				c.Close()
+				c.Close(c.mreader.err)
 			} else {
 				c.receiver.IndexUpdate(c.ID, files)
 			}
@@ -267,7 +273,7 @@ func (c *Connection) readerLoop() {
 			data := c.mreader.readResponse()
 
 			if c.mreader.err != nil {
-				c.Close()
+				c.Close(c.mreader.err)
 			} else {
 				c.Lock()
 				rc, ok := c.awaiting[hdr.msgID]
@@ -285,8 +291,10 @@ func (c *Connection) readerLoop() {
 			c.mwriter.writeUint32(encodeHeader(header{0, hdr.msgID, messageTypePong}))
 			err := c.flush()
 			c.Unlock()
-			if err != nil || c.mwriter.err != nil {
-				c.Close()
+			if err != nil {
+				c.Close(err)
+			} else if c.mwriter.err != nil {
+				c.Close(c.mwriter.err)
 			}
 
 		case messageTypePong:
@@ -304,8 +312,7 @@ func (c *Connection) readerLoop() {
 			}
 
 		default:
-			log.Printf("Protocol error: %s: unknown message type %#x", c.ID, hdr.msgType)
-			c.Close()
+			c.Close(fmt.Errorf("Protocol error: %s: unknown message type %#x", c.ID, hdr.msgType))
 		}
 	}
 }
@@ -313,7 +320,7 @@ func (c *Connection) readerLoop() {
 func (c *Connection) processRequest(msgID int) {
 	req := c.mreader.readRequest()
 	if c.mreader.err != nil {
-		c.Close()
+		c.Close(c.mreader.err)
 	} else {
 		go func() {
 			data, _ := c.receiver.Request(c.ID, req.name, req.offset, req.size, req.hash)
@@ -323,8 +330,10 @@ func (c *Connection) processRequest(msgID int) {
 			err := c.flush()
 			c.Unlock()
 			buffers.Put(data)
-			if c.mwriter.err != nil || err != nil {
-				c.Close()
+			if err != nil {
+				c.Close(err)
+			} else if c.mwriter.err != nil {
+				c.Close(c.mwriter.err)
 			}
 		}()
 	}
@@ -340,10 +349,10 @@ func (c *Connection) pingerLoop() {
 		select {
 		case ok := <-rc:
 			if !ok {
-				c.Close()
+				c.Close(fmt.Errorf("Ping failure"))
 			}
 		case <-time.After(pingTimeout):
-			c.Close()
+			c.Close(fmt.Errorf("Ping timeout"))
 		}
 	}
 }

+ 1 - 1
protocol/protocol_test.go

@@ -190,7 +190,7 @@ func TestClose(t *testing.T) {
 	c0 := NewConnection("c0", ar, bw, m0)
 	NewConnection("c1", br, aw, m1)
 
-	c0.Close()
+	c0.Close(nil)
 
 	ok := c0.isClosed()
 	if !ok {