Browse Source

Protocol state machine on receiving side

Jakob Borg 11 years ago
parent
commit
6ade27641d
1 changed files with 25 additions and 0 deletions
  1. 25 0
      protocol/protocol.go

+ 25 - 0
protocol/protocol.go

@@ -28,6 +28,12 @@ const (
 	messageTypeIndexUpdate   = 6
 )
 
+const (
+	stateInitial = iota
+	stateCCRcvd
+	stateIdxRcvd
+)
+
 const (
 	FlagDeleted    uint32 = 1 << 12
 	FlagInvalid           = 1 << 13
@@ -70,6 +76,7 @@ type Connection interface {
 type rawConnection struct {
 	id       string
 	receiver Model
+	state    int
 
 	reader io.ReadCloser
 	cr     *countingReader
@@ -116,6 +123,7 @@ func NewConnection(nodeID string, reader io.Reader, writer io.Writer, receiver M
 	c := rawConnection{
 		id:        nodeID,
 		receiver:  nativeModel{receiver},
+		state:     stateInitial,
 		reader:    flrd,
 		cr:        cr,
 		xr:        xdr.NewReader(flrd),
@@ -257,21 +265,34 @@ func (c *rawConnection) readerLoop() (err error) {
 
 		switch hdr.msgType {
 		case messageTypeIndex:
+			if c.state < stateCCRcvd {
+				return fmt.Errorf("protocol error: index message in state %d", c.state)
+			}
 			if err := c.handleIndex(); err != nil {
 				return err
 			}
+			c.state = stateIdxRcvd
 
 		case messageTypeIndexUpdate:
+			if c.state < stateIdxRcvd {
+				return fmt.Errorf("protocol error: index update message in state %d", c.state)
+			}
 			if err := c.handleIndexUpdate(); err != nil {
 				return err
 			}
 
 		case messageTypeRequest:
+			if c.state < stateIdxRcvd {
+				return fmt.Errorf("protocol error: request message in state %d", c.state)
+			}
 			if err := c.handleRequest(hdr); err != nil {
 				return err
 			}
 
 		case messageTypeResponse:
+			if c.state < stateIdxRcvd {
+				return fmt.Errorf("protocol error: response message in state %d", c.state)
+			}
 			if err := c.handleResponse(hdr); err != nil {
 				return err
 			}
@@ -283,9 +304,13 @@ func (c *rawConnection) readerLoop() (err error) {
 			c.handlePong(hdr)
 
 		case messageTypeClusterConfig:
+			if c.state != stateInitial {
+				return fmt.Errorf("protocol error: cluster config message in state %d", c.state)
+			}
 			if err := c.handleClusterConfig(); err != nil {
 				return err
 			}
+			c.state = stateCCRcvd
 
 		default:
 			return fmt.Errorf("protocol error: %s: unknown message type %#x", c.id, hdr.msgType)