Browse Source

Add more fine grained compression control

Jakob Borg 10 years ago
parent
commit
108b4e2e10
4 changed files with 135 additions and 27 deletions
  1. 53 0
      compression.go
  2. 51 0
      compression_test.go
  3. 21 17
      protocol.go
  4. 10 10
      protocol_test.go

+ 53 - 0
compression.go

@@ -0,0 +1,53 @@
+// Copyright (C) 2015 The Protocol Authors.
+
+package protocol
+
+import "fmt"
+
+type Compression int
+
+const (
+	CompressMetadata Compression = iota // zero value is the default, default should be "metadata"
+	CompressNever
+	CompressAlways
+
+	compressionThreshold = 128 // don't bother compressing messages smaller than this many bytes
+)
+
+var compressionMarshal = map[Compression]string{
+	CompressNever:    "never",
+	CompressMetadata: "metadata",
+	CompressAlways:   "always",
+}
+
+var compressionUnmarshal = map[string]Compression{
+	// Legacy
+	"false": CompressNever,
+	"true":  CompressMetadata,
+
+	// Current
+	"never":    CompressNever,
+	"metadata": CompressMetadata,
+	"always":   CompressAlways,
+}
+
+func (c Compression) String() string {
+	s, ok := compressionMarshal[c]
+	if !ok {
+		return fmt.Sprintf("unknown:%d", c)
+	}
+	return s
+}
+
+func (c Compression) GoString() string {
+	return fmt.Sprintf("%q", c.String())
+}
+
+func (c Compression) MarshalText() ([]byte, error) {
+	return []byte(compressionMarshal[c]), nil
+}
+
+func (c *Compression) UnmarshalText(bs []byte) error {
+	*c = compressionUnmarshal[string(bs)]
+	return nil
+}

+ 51 - 0
compression_test.go

@@ -0,0 +1,51 @@
+// Copyright (C) 2015 The Protocol Authors.
+
+package protocol
+
+import "testing"
+
+func TestCompressionMarshal(t *testing.T) {
+	uTestcases := []struct {
+		s string
+		c Compression
+	}{
+		{"true", CompressMetadata},
+		{"false", CompressNever},
+		{"never", CompressNever},
+		{"metadata", CompressMetadata},
+		{"filedata", CompressFiledata},
+		{"always", CompressAlways},
+		{"whatever", CompressNever},
+	}
+
+	mTestcases := []struct {
+		s string
+		c Compression
+	}{
+		{"never", CompressNever},
+		{"metadata", CompressMetadata},
+		{"filedata", CompressFiledata},
+		{"always", CompressAlways},
+	}
+
+	var c Compression
+	for _, tc := range uTestcases {
+		err := c.UnmarshalText([]byte(tc.s))
+		if err != nil {
+			t.Error(err)
+		}
+		if c != tc.c {
+			t.Errorf("%s unmarshalled to %d, not %d", tc.s, c, tc.c)
+		}
+	}
+
+	for _, tc := range mTestcases {
+		bs, err := tc.c.MarshalText()
+		if err != nil {
+			t.Error(err)
+		}
+		if s := string(bs); s != tc.s {
+			t.Errorf("%d marshalled to %q, not %q", tc.c, s, tc.s)
+		}
+	}
+}

+ 21 - 17
protocol.go

@@ -106,7 +106,7 @@ type rawConnection struct {
 	closed chan struct{}
 	once   sync.Once
 
-	compressionThreshold int // compress messages larger than this many bytes
+	compression Compression
 
 	rdbuf0 []byte // used & reused by readMessage
 	rdbuf1 []byte // used & reused by readMessage
@@ -135,25 +135,21 @@ const (
 	pingIdleTime = 60 * time.Second
 )
 
-func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress bool) Connection {
+func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
 	cr := &countingReader{Reader: reader}
 	cw := &countingWriter{Writer: writer}
 
-	compThres := 1<<31 - 1 // compression disabled
-	if compress {
-		compThres = 128 // compress messages that are 128 bytes long or larger
-	}
 	c := rawConnection{
-		id:                   deviceID,
-		name:                 name,
-		receiver:             nativeModel{receiver},
-		state:                stateInitial,
-		cr:                   cr,
-		cw:                   cw,
-		outbox:               make(chan hdrMsg),
-		nextID:               make(chan int),
-		closed:               make(chan struct{}),
-		compressionThreshold: compThres,
+		id:          deviceID,
+		name:        name,
+		receiver:    nativeModel{receiver},
+		state:       stateInitial,
+		cr:          cr,
+		cw:          cw,
+		outbox:      make(chan hdrMsg),
+		nextID:      make(chan int),
+		closed:      make(chan struct{}),
+		compression: compress,
 	}
 
 	go c.readerLoop()
@@ -571,7 +567,15 @@ func (c *rawConnection) writerLoop() {
 					return
 				}
 
-				if len(uncBuf) >= c.compressionThreshold {
+				compress := false
+				switch c.compression {
+				case CompressAlways:
+					compress = true
+				case CompressMetadata:
+					compress = hm.hdr.msgType != messageTypeResponse
+				}
+
+				if compress && len(uncBuf) >= compressionThreshold {
 					// Use compression for large messages
 					hm.hdr.compression = true
 

+ 10 - 10
protocol_test.go

@@ -67,8 +67,8 @@ func TestPing(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, nil, "name", true).(wireFormatConnection).next.(*rawConnection)
-	c1 := NewConnection(c1ID, br, aw, nil, "name", true).(wireFormatConnection).next.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	c1 := NewConnection(c1ID, br, aw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
 
 	if ok := c0.ping(); !ok {
 		t.Error("c0 ping failed")
@@ -91,8 +91,8 @@ func TestPingErr(t *testing.T) {
 			eaw := &ErrPipe{PipeWriter: *aw, max: i, err: e}
 			ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
 
-			c0 := NewConnection(c0ID, ar, ebw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
-			NewConnection(c1ID, br, eaw, m1, "name", true)
+			c0 := NewConnection(c0ID, ar, ebw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+			NewConnection(c1ID, br, eaw, m1, "name", CompressAlways)
 
 			res := c0.ping()
 			if (i < 8 || j < 8) && res {
@@ -167,8 +167,8 @@ func TestVersionErr(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
-	NewConnection(c1ID, br, aw, m1, "name", true)
+	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
 
 	w := xdr.NewWriter(c0.cw)
 	w.WriteUint32(encodeHeader(header{
@@ -190,8 +190,8 @@ func TestTypeErr(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
-	NewConnection(c1ID, br, aw, m1, "name", true)
+	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
 
 	w := xdr.NewWriter(c0.cw)
 	w.WriteUint32(encodeHeader(header{
@@ -213,8 +213,8 @@ func TestClose(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "name", true).(wireFormatConnection).next.(*rawConnection)
-	NewConnection(c1ID, br, aw, m1, "name", true)
+	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
 
 	c0.close(nil)