Browse Source

Merge pull request #1236 from syncthing/flag-safe

Reject Index and Request messages with unexpected flags
Audrius Butkevicius 11 years ago
parent
commit
c95812353f
1 changed files with 42 additions and 23 deletions
  1. 42 23
      internal/protocol/protocol.go

+ 42 - 23
internal/protocol/protocol.go

@@ -71,6 +71,10 @@ var (
 	ErrClosed      = errors.New("connection closed")
 )
 
+// Specific variants of empty messages...
+type pingMessage struct{ EmptyMessage }
+type pongMessage struct{ EmptyMessage }
+
 type Model interface {
 	// An index was received from the peer device
 	Index(deviceID DeviceID, folder string, files []FileInfo)
@@ -289,48 +293,60 @@ func (c *rawConnection) readerLoop() (err error) {
 			return err
 		}
 
-		switch hdr.msgType {
-		case messageTypeIndex:
-			if c.state < stateCCRcvd {
-				return fmt.Errorf("protocol error: index message in state %d", c.state)
+		switch msg := msg.(type) {
+		case IndexMessage:
+			if msg.Flags != 0 {
+				// We don't currently support or expect any flags.
+				return fmt.Errorf("protocol error: unknown flags 0x%x in Index(Update) message", msg.Flags)
 			}
-			c.handleIndex(msg.(IndexMessage))
-			c.state = stateIdxRcvd
 
-		case messageTypeIndexUpdate:
-			if c.state < stateIdxRcvd {
-				return fmt.Errorf("protocol error: index update message in state %d", c.state)
+			switch hdr.msgType {
+			case messageTypeIndex:
+				if c.state < stateCCRcvd {
+					return fmt.Errorf("protocol error: index message in state %d", c.state)
+				}
+				c.handleIndex(msg)
+				c.state = stateIdxRcvd
+
+			case messageTypeIndexUpdate:
+				if c.state < stateIdxRcvd {
+					return fmt.Errorf("protocol error: index update message in state %d", c.state)
+				}
+				c.handleIndexUpdate(msg)
 			}
-			c.handleIndexUpdate(msg.(IndexMessage))
 
-		case messageTypeRequest:
+		case RequestMessage:
+			if msg.Flags != 0 {
+				// We don't currently support or expect any flags.
+				return fmt.Errorf("protocol error: unknown flags 0x%x in Request message", msg.Flags)
+			}
 			if c.state < stateIdxRcvd {
 				return fmt.Errorf("protocol error: request message in state %d", c.state)
 			}
 			// Requests are handled asynchronously
-			go c.handleRequest(hdr.msgID, msg.(RequestMessage))
+			go c.handleRequest(hdr.msgID, msg)
 
-		case messageTypeResponse:
+		case ResponseMessage:
 			if c.state < stateIdxRcvd {
 				return fmt.Errorf("protocol error: response message in state %d", c.state)
 			}
-			c.handleResponse(hdr.msgID, msg.(ResponseMessage))
+			c.handleResponse(hdr.msgID, msg)
 
-		case messageTypePing:
-			c.send(hdr.msgID, messageTypePong, EmptyMessage{})
+		case pingMessage:
+			c.send(hdr.msgID, messageTypePong, pongMessage{})
 
-		case messageTypePong:
+		case pongMessage:
 			c.handlePong(hdr.msgID)
 
-		case messageTypeClusterConfig:
+		case ClusterConfigMessage:
 			if c.state != stateInitial {
 				return fmt.Errorf("protocol error: cluster config message in state %d", c.state)
 			}
-			go c.receiver.ClusterConfig(c.id, msg.(ClusterConfigMessage))
+			go c.receiver.ClusterConfig(c.id, msg)
 			c.state = stateCCRcvd
 
-		case messageTypeClose:
-			return errors.New(msg.(CloseMessage).Reason)
+		case CloseMessage:
+			return errors.New(msg.Reason)
 
 		default:
 			return fmt.Errorf("protocol error: %s: unknown message type %#x", c.id, hdr.msgType)
@@ -428,8 +444,11 @@ func (c *rawConnection) readMessage() (hdr header, msg encodable, err error) {
 		}
 		msg = resp
 
-	case messageTypePing, messageTypePong:
-		msg = EmptyMessage{}
+	case messageTypePing:
+		msg = pingMessage{}
+
+	case messageTypePong:
+		msg = pongMessage{}
 
 	case messageTypeClusterConfig:
 		var cc ClusterConfigMessage