Browse Source

lib/protocol: Prioritize close msg and add close timeout (#5746)

Simon Frei 6 years ago
parent
commit
e39d3f95dd
3 changed files with 64 additions and 6 deletions
  1. 6 0
      lib/model/model_test.go
  2. 25 6
      lib/protocol/protocol.go
  3. 33 0
      lib/protocol/protocol_test.go

+ 6 - 0
lib/model/model_test.go

@@ -3266,6 +3266,12 @@ func TestSanitizePath(t *testing.T) {
 // on a protocol connection that has a blocking reader (blocking writer can't
 // be done as the test requires clusterconfigs to go through).
 func TestConnCloseOnRestart(t *testing.T) {
+	oldCloseTimeout := protocol.CloseTimeout
+	protocol.CloseTimeout = 100 * time.Millisecond
+	defer func() {
+		protocol.CloseTimeout = oldCloseTimeout
+	}()
+
 	w, fcfg := tmpDefaultWrapper()
 	m := setupModel(w)
 	defer cleanupModelAndRemoveDir(m, fcfg.Filesystem().URI())

+ 25 - 6
lib/protocol/protocol.go

@@ -184,6 +184,7 @@ type rawConnection struct {
 
 	inbox                 chan message
 	outbox                chan asyncMessage
+	closeBox              chan asyncMessage
 	clusterConfigBox      chan *ClusterConfig
 	dispatcherLoopStopped chan struct{}
 	closed                chan struct{}
@@ -218,6 +219,11 @@ const (
 	ReceiveTimeout = 300 * time.Second
 )
 
+// CloseTimeout is the longest we'll wait when trying to send the close
+// message before just closing the connection.
+// Should not be modified in production code, just for testing.
+var CloseTimeout = 10 * time.Second
+
 func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
 	cr := &countingReader{Reader: reader}
 	cw := &countingWriter{Writer: writer}
@@ -231,6 +237,7 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
 		awaiting:              make(map[int32]chan asyncResult),
 		inbox:                 make(chan message),
 		outbox:                make(chan asyncMessage),
+		closeBox:              make(chan asyncMessage),
 		clusterConfigBox:      make(chan *ClusterConfig),
 		dispatcherLoopStopped: make(chan struct{}),
 		closed:                make(chan struct{}),
@@ -671,6 +678,10 @@ func (c *rawConnection) writerLoop() {
 			c.internalClose(err)
 			return
 		}
+	case hm := <-c.closeBox:
+		_ = c.writeMessage(hm.msg)
+		close(hm.done)
+		return
 	case <-c.closed:
 		return
 	}
@@ -686,6 +697,11 @@ func (c *rawConnection) writerLoop() {
 				return
 			}
 
+		case hm := <-c.closeBox:
+			_ = c.writeMessage(hm.msg)
+			close(hm.done)
+			return
+
 		case <-c.closed:
 			return
 		}
@@ -853,17 +869,20 @@ func (c *rawConnection) shouldCompressMessage(msg message) bool {
 func (c *rawConnection) Close(err error) {
 	c.sendCloseOnce.Do(func() {
 		done := make(chan struct{})
-		c.send(&Close{err.Error()}, done)
+		timeout := time.NewTimer(CloseTimeout)
 		select {
-		case <-done:
+		case c.closeBox <- asyncMessage{&Close{err.Error()}, done}:
+			select {
+			case <-done:
+			case <-timeout.C:
+			case <-c.closed:
+			}
+		case <-timeout.C:
 		case <-c.closed:
 		}
 	})
 
-	// No more sends are necessary, therefore further steps to close the
-	// connection outside of this package can proceed immediately.
-	// And this prevents a potential deadlock due to calling c.receiver.Closed
-	go c.internalClose(err)
+	c.internalClose(err)
 }
 
 // internalClose is called if there is an unexpected error during normal operation.

+ 33 - 0
lib/protocol/protocol_test.go

@@ -86,6 +86,12 @@ func TestClose(t *testing.T) {
 // Close is called while the underlying connection is broken (send blocks).
 // https://github.com/syncthing/syncthing/pull/5442
 func TestCloseOnBlockingSend(t *testing.T) {
+	oldCloseTimeout := CloseTimeout
+	CloseTimeout = 100 * time.Millisecond
+	defer func() {
+		CloseTimeout = oldCloseTimeout
+	}()
+
 	m := newTestModel()
 
 	c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection)
@@ -214,6 +220,33 @@ func TestClusterConfigFirst(t *testing.T) {
 	}
 }
 
+// TestCloseTimeout checks that calling Close times out and proceeds, if sending
+// the close message does not succeed.
+func TestCloseTimeout(t *testing.T) {
+	oldCloseTimeout := CloseTimeout
+	CloseTimeout = 100 * time.Millisecond
+	defer func() {
+		CloseTimeout = oldCloseTimeout
+	}()
+
+	m := newTestModel()
+
+	c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c.Start()
+
+	done := make(chan struct{})
+	go func() {
+		c.Close(errManual)
+		close(done)
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(5 * CloseTimeout):
+		t.Fatal("timed out before Close returned")
+	}
+}
+
 func TestMarshalIndexMessage(t *testing.T) {
 	if testing.Short() {
 		quickCfg.MaxCount = 10