|
|
@@ -240,14 +240,21 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
|
|
|
// Start creates the goroutines for sending and receiving of messages. It must
|
|
|
// be called exactly once after creating a connection.
|
|
|
func (c *rawConnection) Start() {
|
|
|
- c.wg.Add(4)
|
|
|
+ c.startGoroutine(c.readerLoop)
|
|
|
+ c.startGoroutine(c.writerLoop)
|
|
|
+ c.startGoroutine(c.pingSender)
|
|
|
+ c.startGoroutine(c.pingReceiver)
|
|
|
+}
|
|
|
+
|
|
|
+func (c *rawConnection) startGoroutine(loop func() error) {
|
|
|
+ c.wg.Add(1)
|
|
|
go func() {
|
|
|
- err := c.readerLoop()
|
|
|
- c.internalClose(err)
|
|
|
+ err := loop()
|
|
|
+ c.wg.Done()
|
|
|
+ if err != nil && err != ErrClosed {
|
|
|
+ c.internalClose(err)
|
|
|
+ }
|
|
|
}()
|
|
|
- go c.writerLoop()
|
|
|
- go c.pingSender()
|
|
|
- go c.pingReceiver()
|
|
|
}
|
|
|
|
|
|
func (c *rawConnection) ID() DeviceID {
|
|
|
@@ -363,25 +370,44 @@ func (c *rawConnection) ping() bool {
|
|
|
return c.send(&Ping{}, nil)
|
|
|
}
|
|
|
|
|
|
-func (c *rawConnection) readerLoop() (err error) {
|
|
|
- defer c.wg.Done()
|
|
|
+type messageWithError struct {
|
|
|
+ msg message
|
|
|
+ err error
|
|
|
+}
|
|
|
+
|
|
|
+func (c *rawConnection) readerLoop() error {
|
|
|
fourByteBuf := make([]byte, 4)
|
|
|
+ inbox := make(chan messageWithError)
|
|
|
+
|
|
|
+ // Reading from the wire may block until the underlying connection is closed.
|
|
|
+ go func() {
|
|
|
+ for {
|
|
|
+ msg, err := c.readMessage(fourByteBuf)
|
|
|
+ select {
|
|
|
+ case inbox <- messageWithError{msg: msg, err: err}:
|
|
|
+ case <-c.closed:
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
state := stateInitial
|
|
|
+ var msgWithErr messageWithError
|
|
|
for {
|
|
|
- if c.Closed() {
|
|
|
+ select {
|
|
|
+ case msgWithErr = <-inbox:
|
|
|
+ case <-c.closed:
|
|
|
return ErrClosed
|
|
|
}
|
|
|
-
|
|
|
- msg, err := c.readMessage(fourByteBuf)
|
|
|
- if err == errUnknownMessage {
|
|
|
- // Unknown message types are skipped, for future extensibility.
|
|
|
- continue
|
|
|
- }
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ if msgWithErr.err != nil {
|
|
|
+ if msgWithErr.err == errUnknownMessage {
|
|
|
+ // Unknown message types are skipped, for future extensibility.
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ return msgWithErr.err
|
|
|
}
|
|
|
|
|
|
- switch msg := msg.(type) {
|
|
|
+ switch msg := msgWithErr.msg.(type) {
|
|
|
case *ClusterConfig:
|
|
|
l.Debugln("read ClusterConfig message")
|
|
|
if state != stateInitial {
|
|
|
@@ -660,8 +686,7 @@ func (c *rawConnection) send(msg message, done chan struct{}) (sent bool) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *rawConnection) writerLoop() {
|
|
|
- defer c.wg.Done()
|
|
|
+func (c *rawConnection) writerLoop() error {
|
|
|
for {
|
|
|
select {
|
|
|
case hm := <-c.outbox:
|
|
|
@@ -670,12 +695,11 @@ func (c *rawConnection) writerLoop() {
|
|
|
close(hm.done)
|
|
|
}
|
|
|
if err != nil {
|
|
|
- c.internalClose(err)
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
case <-c.closed:
|
|
|
- return
|
|
|
+ return ErrClosed
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -882,9 +906,7 @@ func (c *rawConnection) internalClose(err error) {
|
|
|
// PingSendInterval/2, we do nothing. Otherwise we send a ping message. This
|
|
|
// results in an effecting ping interval of somewhere between
|
|
|
// PingSendInterval/2 and PingSendInterval.
|
|
|
-func (c *rawConnection) pingSender() {
|
|
|
- defer c.wg.Done()
|
|
|
-
|
|
|
+func (c *rawConnection) pingSender() error {
|
|
|
ticker := time.NewTicker(PingSendInterval / 2)
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
@@ -901,7 +923,7 @@ func (c *rawConnection) pingSender() {
|
|
|
c.ping()
|
|
|
|
|
|
case <-c.closed:
|
|
|
- return
|
|
|
+ return ErrClosed
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -909,9 +931,7 @@ func (c *rawConnection) pingSender() {
|
|
|
// The pingReceiver checks that we've received a message (any message will do,
|
|
|
// but we expect pings in the absence of other messages) within the last
|
|
|
// ReceiveTimeout. If not, we close the connection with an ErrTimeout.
|
|
|
-func (c *rawConnection) pingReceiver() {
|
|
|
- defer c.wg.Done()
|
|
|
-
|
|
|
+func (c *rawConnection) pingReceiver() error {
|
|
|
ticker := time.NewTicker(ReceiveTimeout / 2)
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
@@ -921,13 +941,13 @@ func (c *rawConnection) pingReceiver() {
|
|
|
d := time.Since(c.cr.Last())
|
|
|
if d > ReceiveTimeout {
|
|
|
l.Debugln(c.id, "ping timeout", d)
|
|
|
- c.internalClose(ErrTimeout)
|
|
|
+ return ErrTimeout
|
|
|
}
|
|
|
|
|
|
l.Debugln(c.id, "last read within", d)
|
|
|
|
|
|
case <-c.closed:
|
|
|
- return
|
|
|
+ return ErrClosed
|
|
|
}
|
|
|
}
|
|
|
}
|