Browse Source

Close on unknown message type

Jakob Borg 12 years ago
parent
commit
5c1db4f0f4
3 changed files with 33 additions and 5 deletions
  1. 5 5
      model_test.go
  2. 6 0
      protocol/protocol.go
  3. 22 0
      protocol/protocol_test.go

+ 5 - 5
model_test.go

@@ -97,7 +97,7 @@ func TestRemoteUpdateExisting(t *testing.T) {
 		Modified: time.Now().Unix(),
 		Blocks:   []protocol.BlockInfo{{100, []byte("some hash bytes")}},
 	}
-	m.Index(string("42"), []protocol.FileInfo{newFile})
+	m.Index("42", []protocol.FileInfo{newFile})
 
 	if l := len(m.need); l != 1 {
 		t.Errorf("Model missing Need for one file (%d != 1)", l)
@@ -114,7 +114,7 @@ func TestRemoteAddNew(t *testing.T) {
 		Modified: time.Now().Unix(),
 		Blocks:   []protocol.BlockInfo{{100, []byte("some hash bytes")}},
 	}
-	m.Index(string("42"), []protocol.FileInfo{newFile})
+	m.Index("42", []protocol.FileInfo{newFile})
 
 	if l1, l2 := len(m.need), 1; l1 != l2 {
 		t.Errorf("Model len(m.need) incorrect (%d != %d)", l1, l2)
@@ -132,7 +132,7 @@ func TestRemoteUpdateOld(t *testing.T) {
 		Modified: oldTimeStamp,
 		Blocks:   []protocol.BlockInfo{{100, []byte("some hash bytes")}},
 	}
-	m.Index(string("42"), []protocol.FileInfo{newFile})
+	m.Index("42", []protocol.FileInfo{newFile})
 
 	if l1, l2 := len(m.need), 0; l1 != l2 {
 		t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2)
@@ -249,7 +249,7 @@ func TestForgetNode(t *testing.T) {
 		Modified: time.Now().Unix(),
 		Blocks:   []protocol.BlockInfo{{100, []byte("some hash bytes")}},
 	}
-	m.Index(string("42"), []protocol.FileInfo{newFile})
+	m.Index("42", []protocol.FileInfo{newFile})
 
 	if l1, l2 := len(m.local), len(fs); l1 != l2 {
 		t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2)
@@ -261,7 +261,7 @@ func TestForgetNode(t *testing.T) {
 		t.Errorf("Model len(need) incorrect (%d != %d)", l1, l2)
 	}
 
-	m.Close(string("42"))
+	m.Close("42")
 
 	if l1, l2 := len(m.local), len(fs); l1 != l2 {
 		t.Errorf("Model len(local) incorrect (%d != %d)", l1, l2)

+ 6 - 0
protocol/protocol.go

@@ -4,6 +4,7 @@ import (
 	"compress/flate"
 	"errors"
 	"io"
+	"log"
 	"sync"
 	"time"
 
@@ -193,6 +194,7 @@ func (c *Connection) readerLoop() {
 			break
 		}
 		if hdr.version != 0 {
+			log.Printf("Protocol error: %s: unknown message version %#x", c.ID, hdr.version)
 			c.close()
 			break
 		}
@@ -258,6 +260,10 @@ func (c *Connection) readerLoop() {
 				delete(c.awaiting, hdr.msgID)
 				c.wLock.Unlock()
 			}
+
+		default:
+			log.Printf("Protocol error: %s: unknown message type %#x", c.ID, hdr.msgType)
+			c.close()
 		}
 	}
 }

+ 22 - 0
protocol/protocol_test.go

@@ -157,3 +157,25 @@ func TestVersionErr(t *testing.T) {
 		t.Error("Connection should close due to unknown version")
 	}
 }
+
+func TestTypeErr(t *testing.T) {
+	m0 := &TestModel{}
+	m1 := &TestModel{}
+
+	ar, aw := io.Pipe()
+	br, bw := io.Pipe()
+
+	c0 := NewConnection("c0", ar, bw, m0)
+	NewConnection("c1", br, aw, m1)
+
+	c0.mwriter.writeHeader(header{
+		version: 0,
+		msgID:   0,
+		msgType: 42,
+	})
+	c0.flush()
+
+	if !m1.closed {
+		t.Error("Connection should close due to unknown message type")
+	}
+}