Browse Source

Streamline error handling and locking, with fix for close() race

Jakob Borg 11 years ago
parent
commit
482795bab0
1 changed files with 119 additions and 121 deletions
  1. 119 121
      protocol/protocol.go

+ 119 - 121
protocol/protocol.go

@@ -64,24 +64,24 @@ type Connection interface {
 }
 
 type rawConnection struct {
-	sync.RWMutex
-
-	id        string
-	receiver  Model
-	reader    io.ReadCloser
-	cr        *countingReader
-	xr        *xdr.Reader
-	writer    io.WriteCloser
-	cw        *countingWriter
-	wb        *bufio.Writer
-	xw        *xdr.Writer
-	closed    chan struct{}
+	id       string
+	receiver Model
+
+	reader io.ReadCloser
+	cr     *countingReader
+	xr     *xdr.Reader
+	writer io.WriteCloser
+
+	cw     *countingWriter
+	wb     *bufio.Writer
+	xw     *xdr.Writer
+	wmut   sync.Mutex
+	closed bool
+
 	awaiting  map[int]chan asyncResult
 	nextID    int
 	indexSent map[string]map[string][2]int64
-
-	hasSentIndex  bool
-	hasRecvdIndex bool
+	imut      sync.Mutex
 }
 
 type asyncResult struct {
@@ -115,7 +115,6 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
 		cw:        cw,
 		wb:        wb,
 		xw:        xdr.NewWriter(wb),
-		closed:    make(chan struct{}),
 		awaiting:  make(map[int]chan asyncResult),
 		indexSent: make(map[string]map[string][2]int64),
 	}
@@ -132,11 +131,11 @@ func (c *rawConnection) ID() string {
 
 // Index writes the list of file information to the connected peer node
 func (c *rawConnection) Index(repo string, idx []FileInfo) {
-	c.Lock()
 	if c.isClosed() {
-		c.Unlock()
 		return
 	}
+
+	c.imut.Lock()
 	var msgType int
 	if c.indexSent[repo] == nil {
 		// This is the first time we send an index.
@@ -159,14 +158,15 @@ func (c *rawConnection) Index(repo string, idx []FileInfo) {
 		idx = diff
 	}
 
-	header{0, c.nextID, msgType}.encodeXDR(c.xw)
-	_, err := IndexMessage{repo, idx}.encodeXDR(c.xw)
-	if err == nil {
-		err = c.flush()
-	}
+	id := c.nextID
 	c.nextID = (c.nextID + 1) & 0xfff
-	c.hasSentIndex = true
-	c.Unlock()
+	c.imut.Unlock()
+
+	c.wmut.Lock()
+	header{0, id, msgType}.encodeXDR(c.xw)
+	IndexMessage{repo, idx}.encodeXDR(c.xw)
+	err := c.flush()
+	c.wmut.Unlock()
 
 	if err != nil {
 		c.close(err)
@@ -176,28 +176,30 @@ func (c *rawConnection) Index(repo string, idx []FileInfo) {
 
 // Request returns the bytes for the specified block after fetching them from the connected peer.
 func (c *rawConnection) Request(repo string, name string, offset int64, size int) ([]byte, error) {
-	c.Lock()
 	if c.isClosed() {
-		c.Unlock()
 		return nil, ErrClosed
 	}
+
+	c.imut.Lock()
+	id := c.nextID
+	c.nextID = (c.nextID + 1) & 0xfff
 	rc := make(chan asyncResult)
-	if _, ok := c.awaiting[c.nextID]; ok {
+	if _, ok := c.awaiting[id]; ok {
 		panic("id taken")
 	}
-	c.awaiting[c.nextID] = rc
-	header{0, c.nextID, messageTypeRequest}.encodeXDR(c.xw)
-	_, err := RequestMessage{repo, name, uint64(offset), uint32(size)}.encodeXDR(c.xw)
-	if err == nil {
-		err = c.flush()
-	}
+	c.awaiting[id] = rc
+	c.imut.Unlock()
+
+	c.wmut.Lock()
+	header{0, id, messageTypeRequest}.encodeXDR(c.xw)
+	RequestMessage{repo, name, uint64(offset), uint32(size)}.encodeXDR(c.xw)
+	err := c.flush()
+	c.wmut.Unlock()
+
 	if err != nil {
-		c.Unlock()
 		c.close(err)
 		return nil, err
 	}
-	c.nextID = (c.nextID + 1) & 0xfff
-	c.Unlock()
 
 	res, ok := <-rc
 	if !ok {
@@ -208,46 +210,47 @@ func (c *rawConnection) Request(repo string, name string, offset int64, size int
 
 // ClusterConfig send the cluster configuration message to the peer and returns any error
 func (c *rawConnection) ClusterConfig(config ClusterConfigMessage) {
-	c.Lock()
-	defer c.Unlock()
-
 	if c.isClosed() {
 		return
 	}
 
-	header{0, c.nextID, messageTypeClusterConfig}.encodeXDR(c.xw)
+	c.imut.Lock()
+	id := c.nextID
 	c.nextID = (c.nextID + 1) & 0xfff
+	c.imut.Unlock()
+
+	c.wmut.Lock()
+	header{0, id, messageTypeClusterConfig}.encodeXDR(c.xw)
+	config.encodeXDR(c.xw)
+	err := c.flush()
+	c.wmut.Unlock()
 
-	_, err := config.encodeXDR(c.xw)
-	if err == nil {
-		err = c.flush()
-	}
 	if err != nil {
 		c.close(err)
 	}
 }
 
 func (c *rawConnection) ping() bool {
-	c.Lock()
 	if c.isClosed() {
-		c.Unlock()
 		return false
 	}
+
+	c.imut.Lock()
+	id := c.nextID
+	c.nextID = (c.nextID + 1) & 0xfff
 	rc := make(chan asyncResult, 1)
-	c.awaiting[c.nextID] = rc
-	header{0, c.nextID, messageTypePing}.encodeXDR(c.xw)
+	c.awaiting[id] = rc
+	c.imut.Unlock()
+
+	c.wmut.Lock()
+	header{0, id, messageTypePing}.encodeXDR(c.xw)
 	err := c.flush()
+	c.wmut.Unlock()
+
 	if err != nil {
-		c.Unlock()
 		c.close(err)
 		return false
-	} else if c.xw.Error() != nil {
-		c.Unlock()
-		c.close(c.xw.Error())
-		return false
 	}
-	c.nextID = (c.nextID + 1) & 0xfff
-	c.Unlock()
 
 	res, ok := <-rc
 	return ok && res.err == nil
@@ -258,40 +261,47 @@ type flusher interface {
 }
 
 func (c *rawConnection) flush() error {
-	c.wb.Flush()
+	if err := c.xw.Error(); err != nil {
+		return err
+	}
+
+	if err := c.wb.Flush(); err != nil {
+		return err
+	}
+
 	if f, ok := c.writer.(flusher); ok {
 		return f.Flush()
 	}
+
 	return nil
 }
 
 func (c *rawConnection) close(err error) {
-	c.Lock()
-	select {
-	case <-c.closed:
-		c.Unlock()
+	c.imut.Lock()
+	c.wmut.Lock()
+	defer c.imut.Unlock()
+	defer c.wmut.Unlock()
+
+	if c.closed {
 		return
-	default:
 	}
-	close(c.closed)
+
+	c.closed = true
+
 	for _, ch := range c.awaiting {
 		close(ch)
 	}
 	c.awaiting = nil
 	c.writer.Close()
 	c.reader.Close()
-	c.Unlock()
 
 	c.receiver.Close(c.id, err)
 }
 
 func (c *rawConnection) isClosed() bool {
-	select {
-	case <-c.closed:
-		return true
-	default:
-		return false
-	}
+	c.wmut.Lock()
+	defer c.wmut.Unlock()
+	return c.closed
 }
 
 func (c *rawConnection) readerLoop() {
@@ -299,8 +309,8 @@ loop:
 	for !c.isClosed() {
 		var hdr header
 		hdr.decodeXDR(c.xr)
-		if c.xr.Error() != nil {
-			c.close(c.xr.Error())
+		if err := c.xr.Error(); err != nil {
+			c.close(err)
 			break loop
 		}
 		if hdr.version != 0 {
@@ -312,8 +322,8 @@ loop:
 		case messageTypeIndex:
 			var im IndexMessage
 			im.decodeXDR(c.xr)
-			if c.xr.Error() != nil {
-				c.close(c.xr.Error())
+			if err := c.xr.Error(); err != nil {
+				c.close(err)
 				break loop
 			} else {
 
@@ -326,15 +336,12 @@ loop:
 
 				go c.receiver.Index(c.id, im.Repository, im.Files)
 			}
-			c.Lock()
-			c.hasRecvdIndex = true
-			c.Unlock()
 
 		case messageTypeIndexUpdate:
 			var im IndexMessage
 			im.decodeXDR(c.xr)
-			if c.xr.Error() != nil {
-				c.close(c.xr.Error())
+			if err := c.xr.Error(); err != nil {
+				c.close(err)
 				break loop
 			} else {
 				go c.receiver.IndexUpdate(c.id, im.Repository, im.Files)
@@ -343,8 +350,8 @@ loop:
 		case messageTypeRequest:
 			var req RequestMessage
 			req.decodeXDR(c.xr)
-			if c.xr.Error() != nil {
-				c.close(c.xr.Error())
+			if err := c.xr.Error(); err != nil {
+				c.close(err)
 				break loop
 			}
 			go c.processRequest(hdr.msgID, req)
@@ -352,16 +359,16 @@ loop:
 		case messageTypeResponse:
 			data := c.xr.ReadBytesMax(256 * 1024) // Sufficiently larger than max expected block size
 
-			if c.xr.Error() != nil {
-				c.close(c.xr.Error())
+			if err := c.xr.Error(); err != nil {
+				c.close(err)
 				break loop
 			}
 
 			go func(hdr header, err error) {
-				c.Lock()
+				c.imut.Lock()
 				rc, ok := c.awaiting[hdr.msgID]
 				delete(c.awaiting, hdr.msgID)
-				c.Unlock()
+				c.imut.Unlock()
 
 				if ok {
 					rc <- asyncResult{data, err}
@@ -370,37 +377,34 @@ loop:
 			}(hdr, c.xr.Error())
 
 		case messageTypePing:
-			c.Lock()
+			c.wmut.Lock()
 			header{0, hdr.msgID, messageTypePong}.encodeXDR(c.xw)
 			err := c.flush()
-			c.Unlock()
+			c.wmut.Unlock()
 			if err != nil {
 				c.close(err)
 				break loop
-			} else if c.xw.Error() != nil {
-				c.close(c.xw.Error())
-				break loop
 			}
 
 		case messageTypePong:
-			c.RLock()
+			c.imut.Lock()
 			rc, ok := c.awaiting[hdr.msgID]
-			c.RUnlock()
 
 			if ok {
-				rc <- asyncResult{}
-				close(rc)
+				go func() {
+					rc <- asyncResult{}
+					close(rc)
+				}()
 
-				c.Lock()
 				delete(c.awaiting, hdr.msgID)
-				c.Unlock()
 			}
+			c.imut.Unlock()
 
 		case messageTypeClusterConfig:
 			var cm ClusterConfigMessage
 			cm.decodeXDR(c.xr)
-			if c.xr.Error() != nil {
-				c.close(c.xr.Error())
+			if err := c.xr.Error(); err != nil {
+				c.close(err)
 				break loop
 			} else {
 				go c.receiver.ClusterConfig(c.id, cm)
@@ -416,15 +420,14 @@ loop:
 func (c *rawConnection) processRequest(msgID int, req RequestMessage) {
 	data, _ := c.receiver.Request(c.id, req.Repository, req.Name, int64(req.Offset), int(req.Size))
 
-	c.Lock()
+	c.wmut.Lock()
 	header{0, msgID, messageTypeResponse}.encodeXDR(c.xw)
-	_, err := c.xw.WriteBytes(data)
-	if err == nil {
-		err = c.flush()
-	}
-	c.Unlock()
+	c.xw.WriteBytes(data)
+	err := c.flush()
+	c.wmut.Unlock()
 
 	buffers.Put(data)
+
 	if err != nil {
 		c.close(err)
 	}
@@ -434,27 +437,22 @@ func (c *rawConnection) pingerLoop() {
 	var rc = make(chan bool, 1)
 	ticker := time.Tick(pingIdleTime / 2)
 	for {
+		if c.isClosed() {
+			return
+		}
 		select {
 		case <-ticker:
-			c.RLock()
-			ready := c.hasRecvdIndex && c.hasSentIndex
-			c.RUnlock()
-
-			if ready {
-				go func() {
-					rc <- c.ping()
-				}()
-				select {
-				case ok := <-rc:
-					if !ok {
-						c.close(fmt.Errorf("ping failure"))
-					}
-				case <-time.After(pingTimeout):
-					c.close(fmt.Errorf("ping timeout"))
+			go func() {
+				rc <- c.ping()
+			}()
+			select {
+			case ok := <-rc:
+				if !ok {
+					c.close(fmt.Errorf("ping failure"))
 				}
+			case <-time.After(pingTimeout):
+				c.close(fmt.Errorf("ping timeout"))
 			}
-		case <-c.closed:
-			return
 		}
 	}
 }