Răsfoiți Sursa

lib/protocol: Send Close message on read error (#7141)

Simon Frei 5 ani în urmă
părinte
comite
bbb22c8c80
3 a modificat fișierele cu 94 adăugiri și 23 ștergeri
  1. 27 10
      lib/protocol/protocol.go
  2. 36 5
      lib/protocol/protocol_test.go
  3. 31 8
      lib/testutils/testutils.go

+ 27 - 10
lib/protocol/protocol.go

@@ -170,6 +170,8 @@ type rawConnection struct {
 	closeOnce             sync.Once
 	sendCloseOnce         sync.Once
 	compression           Compression
+
+	loopWG sync.WaitGroup // Need to ensure no leftover routines in testing
 }
 
 type asyncResult struct {
@@ -244,20 +246,35 @@ func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, rec
 		dispatcherLoopStopped: make(chan struct{}),
 		closed:                make(chan struct{}),
 		compression:           compress,
+		loopWG:                sync.WaitGroup{},
 	}
 }
 
 // Start creates the goroutines for sending and receiving of messages. It must
 // be called exactly once after creating a connection.
 func (c *rawConnection) Start() {
-	go c.readerLoop()
+	c.loopWG.Add(5)
+	go func() {
+		c.readerLoop()
+		c.loopWG.Done()
+	}()
 	go func() {
 		err := c.dispatcherLoop()
-		c.internalClose(err)
+		c.Close(err)
+		c.loopWG.Done()
+	}()
+	go func() {
+		c.writerLoop()
+		c.loopWG.Done()
+	}()
+	go func() {
+		c.pingSender()
+		c.loopWG.Done()
+	}()
+	go func() {
+		c.pingReceiver()
+		c.loopWG.Done()
 	}()
-	go c.writerLoop()
-	go c.pingSender()
-	go c.pingReceiver()
 	c.startTime = time.Now()
 }
 
@@ -410,7 +427,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
 				state = stateReady
 			}
 			if err := c.receiver.ClusterConfig(c.id, *msg); err != nil {
-				return errors.Wrap(err, "receiver error")
+				return fmt.Errorf("receiving cluster config: %w", err)
 			}
 
 		case *Index:
@@ -422,7 +439,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
 				return errors.Wrap(err, "protocol error: index")
 			}
 			if err := c.handleIndex(*msg); err != nil {
-				return errors.Wrap(err, "receiver error")
+				return fmt.Errorf("receiving index: %w", err)
 			}
 			state = stateReady
 
@@ -435,7 +452,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
 				return errors.Wrap(err, "protocol error: index update")
 			}
 			if err := c.handleIndexUpdate(*msg); err != nil {
-				return errors.Wrap(err, "receiver error")
+				return fmt.Errorf("receiving index update: %w", err)
 			}
 			state = stateReady
 
@@ -462,7 +479,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
 				return fmt.Errorf("protocol error: response message in state %d", state)
 			}
 			if err := c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates); err != nil {
-				return errors.Wrap(err, "receiver error")
+				return fmt.Errorf("receiving download progress: %w", err)
 			}
 
 		case *Ping:
@@ -474,7 +491,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
 
 		case *Close:
 			l.Debugln("read Close message")
-			return errors.New(msg.Reason)
+			return fmt.Errorf("closed by remote: %v", msg.Reason)
 
 		default:
 			l.Debugf("read unknown message: %+T", msg)

+ 36 - 5
lib/protocol/protocol_test.go

@@ -33,8 +33,10 @@ func TestPing(t *testing.T) {
 
 	c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
+	defer closeAndWait(c0, ar, bw)
 	c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c1.Start()
+	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
 	c1.ClusterConfig(ClusterConfig{})
 
@@ -57,8 +59,10 @@ func TestClose(t *testing.T) {
 
 	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
+	defer closeAndWait(c0, ar, bw)
 	c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways)
 	c1.Start()
+	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
 	c1.ClusterConfig(ClusterConfig{})
 
@@ -97,8 +101,10 @@ func TestCloseOnBlockingSend(t *testing.T) {
 
 	m := newTestModel()
 
-	c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	rw := testutils.NewBlockingRW()
+	c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
+	defer closeAndWait(c, rw)
 
 	wg := sync.WaitGroup{}
 
@@ -149,8 +155,10 @@ func TestCloseRace(t *testing.T) {
 
 	c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
+	defer closeAndWait(c0, ar, bw)
 	c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever)
 	c1.Start()
+	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
 	c1.ClusterConfig(ClusterConfig{})
 
@@ -184,8 +192,10 @@ func TestCloseRace(t *testing.T) {
 func TestClusterConfigFirst(t *testing.T) {
 	m := newTestModel()
 
-	c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	rw := testutils.NewBlockingRW()
+	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
+	defer closeAndWait(c, rw)
 
 	select {
 	case c.outbox <- asyncMessage{&Ping{}, nil}:
@@ -234,8 +244,10 @@ func TestCloseTimeout(t *testing.T) {
 
 	m := newTestModel()
 
-	c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	rw := testutils.NewBlockingRW()
+	c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
+	defer closeAndWait(c, rw)
 
 	done := make(chan struct{})
 	go func() {
@@ -852,8 +864,10 @@ func TestSha256OfEmptyBlock(t *testing.T) {
 func TestClusterConfigAfterClose(t *testing.T) {
 	m := newTestModel()
 
-	c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	rw := testutils.NewBlockingRW()
+	c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
+	defer closeAndWait(c, rw)
 
 	c.internalClose(errManual)
 
@@ -874,11 +888,13 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
 	// Verify that we don't deadlock when calling Close() from within one of
 	// the model callbacks (ClusterConfig).
 	m := newTestModel()
-	c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	rw := testutils.NewBlockingRW()
+	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	m.ccFn = func(devID DeviceID, cc ClusterConfig) {
 		c.Close(errManual)
 	}
 	c.Start()
+	defer closeAndWait(c, rw)
 
 	c.inbox <- &ClusterConfig{}
 
@@ -945,3 +961,18 @@ func TestIndexIDString(t *testing.T) {
 		t.Error(i.String())
 	}
 }
+
+func closeAndWait(c Connection, closers ...io.Closer) {
+	for _, closer := range closers {
+		closer.Close()
+	}
+	var raw *rawConnection
+	switch i := c.(type) {
+	case wireFormatConnection:
+		raw = i.Connection.(*rawConnection)
+	case *rawConnection:
+		raw = i
+	}
+	raw.internalClose(ErrClosed)
+	raw.loopWG.Wait()
+}

+ 31 - 8
lib/testutils/testutils.go

@@ -6,17 +6,40 @@
 
 package testutils
 
-// BlockingRW implements io.Reader and Writer but never returns when called
-type BlockingRW struct{ nilChan chan struct{} }
+import (
+	"errors"
+	"sync"
+)
 
-func (rw *BlockingRW) Read(p []byte) (n int, err error) {
-	<-rw.nilChan
-	return
+var ErrClosed = errors.New("closed")
+
+// BlockingRW implements io.Reader, Writer and Closer, but only returns when closed
+type BlockingRW struct {
+	c         chan struct{}
+	closeOnce sync.Once
+}
+
+func NewBlockingRW() *BlockingRW {
+	return &BlockingRW{
+		c:         make(chan struct{}),
+		closeOnce: sync.Once{},
+	}
+}
+func (rw *BlockingRW) Read(p []byte) (int, error) {
+	<-rw.c
+	return 0, ErrClosed
+}
+
+func (rw *BlockingRW) Write(p []byte) (int, error) {
+	<-rw.c
+	return 0, ErrClosed
 }
 
-func (rw *BlockingRW) Write(p []byte) (n int, err error) {
-	<-rw.nilChan
-	return
+func (rw *BlockingRW) Close() error {
+	rw.closeOnce.Do(func() {
+		close(rw.c)
+	})
+	return nil
 }
 
 // NoopRW implements io.Reader and Writer but never returns when called