Browse Source

Mend protocol tests, for sure

Jakob Borg 9 years ago
parent
commit
357089a438
1 changed files with 30 additions and 7 deletions
  1. 30 7
      lib/protocol/protocol_test.go

+ 30 - 7
lib/protocol/protocol_test.go

@@ -4,6 +4,7 @@ package protocol
 
 import (
 	"bytes"
+	"encoding/binary"
 	"encoding/hex"
 	"encoding/json"
 	"errors"
@@ -16,6 +17,7 @@ import (
 	"strings"
 	"testing"
 	"testing/quick"
+	"time"
 
 	"github.com/calmh/xdr"
 )
@@ -107,12 +109,11 @@ func TestVersionErr(t *testing.T) {
 	c1.ClusterConfig(ClusterConfigMessage{})
 
 	w := xdr.NewWriter(c0.cw)
-	w.WriteUint32(encodeHeader(header{
+	timeoutWriteHeader(w, header{
 		version: 2, // higher than supported
 		msgID:   0,
 		msgType: messageTypeIndex,
-	}))
-	w.WriteUint32(0) // Avoids reader closing due to EOF
+	})
 
 	if err := m1.closedError(); err == nil || !strings.Contains(err.Error(), "unknown protocol version") {
 		t.Error("Connection should close due to unknown version, not", err)
@@ -134,12 +135,11 @@ func TestTypeErr(t *testing.T) {
 	c1.ClusterConfig(ClusterConfigMessage{})
 
 	w := xdr.NewWriter(c0.cw)
-	w.WriteUint32(encodeHeader(header{
+	timeoutWriteHeader(w, header{
 		version: 0,
 		msgID:   0,
-		msgType: 42,
-	}))
-	w.WriteUint32(0) // Avoids reader closing due to EOF
+		msgType: 42, // unknown type
+	})
 
 	if err := m1.closedError(); err == nil || !strings.Contains(err.Error(), "unknown message type") {
 		t.Error("Connection should close due to unknown message type, not", err)
@@ -298,3 +298,26 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool {
 	}
 	return ok
 }
+
+func timeoutWriteHeader(w *xdr.Writer, hdr header) {
+	// This tries to write a message header to w, but times out after a while.
+	// This is useful because in testing, with a PipeWriter, it will block
+	// forever if the other side isn't reading any more. On the other hand we
+	// can't just "go" it into the background, because if the other side is
+	// still there we should wait for the write to complete. Yay.
+
+	var buf [8]byte // header and message length
+	binary.BigEndian.PutUint32(buf[:], encodeHeader(hdr))
+	binary.BigEndian.PutUint32(buf[4:], 0) // zero message length, explicitly
+
+	done := make(chan struct{})
+	go func() {
+		w.WriteRaw(buf[:])
+		l.Infoln("write completed")
+		close(done)
+	}()
+	select {
+	case <-done:
+	case <-time.After(250 * time.Millisecond):
+	}
+}