Răsfoiți Sursa

lib/protocol: Fix yet another deadlock (fixes #5678) (#5679)

* lib/protocol: Fix yet another deadlock (fixes #5678)

* more consistency

* read deadlock

* naming

* more naming
Simon Frei 6 ani în urmă
părinte
comite
ec7c88ca55
1 a modificat fișierele cu 52 adăugiri și 32 ștergeri
  1. 52 32
      lib/protocol/protocol.go

+ 52 - 32
lib/protocol/protocol.go

@@ -240,14 +240,21 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
 // Start creates the goroutines for sending and receiving of messages. It must
 // be called exactly once after creating a connection.
 func (c *rawConnection) Start() {
-	c.wg.Add(4)
+	c.startGoroutine(c.readerLoop)
+	c.startGoroutine(c.writerLoop)
+	c.startGoroutine(c.pingSender)
+	c.startGoroutine(c.pingReceiver)
+}
+
+func (c *rawConnection) startGoroutine(loop func() error) {
+	c.wg.Add(1)
 	go func() {
-		err := c.readerLoop()
-		c.internalClose(err)
+		err := loop()
+		c.wg.Done()
+		if err != nil && err != ErrClosed {
+			c.internalClose(err)
+		}
 	}()
-	go c.writerLoop()
-	go c.pingSender()
-	go c.pingReceiver()
 }
 
 func (c *rawConnection) ID() DeviceID {
@@ -363,25 +370,44 @@ func (c *rawConnection) ping() bool {
 	return c.send(&Ping{}, nil)
 }
 
-func (c *rawConnection) readerLoop() (err error) {
-	defer c.wg.Done()
+type messageWithError struct {
+	msg message
+	err error
+}
+
+func (c *rawConnection) readerLoop() error {
 	fourByteBuf := make([]byte, 4)
+	inbox := make(chan messageWithError)
+
+	// Reading from the wire may block until the underlying connection is closed.
+	go func() {
+		for {
+			msg, err := c.readMessage(fourByteBuf)
+			select {
+			case inbox <- messageWithError{msg: msg, err: err}:
+			case <-c.closed:
+				return
+			}
+		}
+	}()
+
 	state := stateInitial
+	var msgWithErr messageWithError
 	for {
-		if c.Closed() {
+		select {
+		case msgWithErr = <-inbox:
+		case <-c.closed:
 			return ErrClosed
 		}
-
-		msg, err := c.readMessage(fourByteBuf)
-		if err == errUnknownMessage {
-			// Unknown message types are skipped, for future extensibility.
-			continue
-		}
-		if err != nil {
-			return err
+		if msgWithErr.err != nil {
+			if msgWithErr.err == errUnknownMessage {
+				// Unknown message types are skipped, for future extensibility.
+				continue
+			}
+			return msgWithErr.err
 		}
 
-		switch msg := msg.(type) {
+		switch msg := msgWithErr.msg.(type) {
 		case *ClusterConfig:
 			l.Debugln("read ClusterConfig message")
 			if state != stateInitial {
@@ -660,8 +686,7 @@ func (c *rawConnection) send(msg message, done chan struct{}) (sent bool) {
 	}
 }
 
-func (c *rawConnection) writerLoop() {
-	defer c.wg.Done()
+func (c *rawConnection) writerLoop() error {
 	for {
 		select {
 		case hm := <-c.outbox:
@@ -670,12 +695,11 @@ func (c *rawConnection) writerLoop() {
 				close(hm.done)
 			}
 			if err != nil {
-				c.internalClose(err)
-				return
+				return err
 			}
 
 		case <-c.closed:
-			return
+			return ErrClosed
 		}
 	}
 }
@@ -882,9 +906,7 @@ func (c *rawConnection) internalClose(err error) {
 // PingSendInterval/2, we do nothing. Otherwise we send a ping message. This
 // results in an effecting ping interval of somewhere between
 // PingSendInterval/2 and PingSendInterval.
-func (c *rawConnection) pingSender() {
-	defer c.wg.Done()
-
+func (c *rawConnection) pingSender() error {
 	ticker := time.NewTicker(PingSendInterval / 2)
 	defer ticker.Stop()
 
@@ -901,7 +923,7 @@ func (c *rawConnection) pingSender() {
 			c.ping()
 
 		case <-c.closed:
-			return
+			return ErrClosed
 		}
 	}
 }
@@ -909,9 +931,7 @@ func (c *rawConnection) pingSender() {
 // The pingReceiver checks that we've received a message (any message will do,
 // but we expect pings in the absence of other messages) within the last
 // ReceiveTimeout. If not, we close the connection with an ErrTimeout.
-func (c *rawConnection) pingReceiver() {
-	defer c.wg.Done()
-
+func (c *rawConnection) pingReceiver() error {
 	ticker := time.NewTicker(ReceiveTimeout / 2)
 	defer ticker.Stop()
 
@@ -921,13 +941,13 @@ func (c *rawConnection) pingReceiver() {
 			d := time.Since(c.cr.Last())
 			if d > ReceiveTimeout {
 				l.Debugln(c.id, "ping timeout", d)
-				c.internalClose(ErrTimeout)
+				return ErrTimeout
 			}
 
 			l.Debugln(c.id, "last read within", d)
 
 		case <-c.closed:
-			return
+			return ErrClosed
 		}
 	}
 }