Преглед изворни кода

lib/protocol: Improve messages when an error occurs receiving (ref #7466) (#7470)

Simon Frei пре 4 година
родитељ
комит
d2d4fcc1df
1 измењених фајлова са 68 додато и 58 уклоњено
  1. 68 58
      lib/protocol/protocol.go

+ 68 - 58
lib/protocol/protocol.go

@@ -437,82 +437,61 @@ func (c *rawConnection) dispatcherLoop() (err error) {
 		case <-c.closed:
 			return ErrClosed
 		}
+
+		msgContext, err := messageContext(msg)
+		if err != nil {
+			return fmt.Errorf("protocol error: %w", err)
+		}
+		l.Debugf("handle %v message", msgContext)
+
 		switch msg := msg.(type) {
 		case *ClusterConfig:
-			l.Debugln("read ClusterConfig message")
 			if state == stateInitial {
 				state = stateReady
 			}
-			if err := c.receiver.ClusterConfig(c.id, *msg); err != nil {
-				return fmt.Errorf("receiving cluster config: %w", err)
+		case *Close:
+			return fmt.Errorf("closed by remote: %v", msg.Reason)
+		default:
+			if state != stateReady {
+				return newProtocolError(fmt.Errorf("invalid state %d", state), msgContext)
 			}
+		}
 
+		switch msg := msg.(type) {
 		case *Index:
-			l.Debugln("read Index message")
-			if state != stateReady {
-				return fmt.Errorf("protocol error: index message in state %d", state)
-			}
-			if err := checkIndexConsistency(msg.Files); err != nil {
-				return errors.Wrap(err, "protocol error: index")
-			}
-			if err := c.handleIndex(*msg); err != nil {
-				return fmt.Errorf("receiving index: %w", err)
-			}
-			state = stateReady
+			err = checkIndexConsistency(msg.Files)
 
 		case *IndexUpdate:
-			l.Debugln("read IndexUpdate message")
-			if state != stateReady {
-				return fmt.Errorf("protocol error: index update message in state %d", state)
-			}
-			if err := checkIndexConsistency(msg.Files); err != nil {
-				return errors.Wrap(err, "protocol error: index update")
-			}
-			if err := c.handleIndexUpdate(*msg); err != nil {
-				return fmt.Errorf("receiving index update: %w", err)
-			}
-			state = stateReady
+			err = checkIndexConsistency(msg.Files)
+
+		case *Request:
+			err = checkFilename(msg.Name)
+		}
+		if err != nil {
+			return newProtocolError(err, msgContext)
+		}
+
+		switch msg := msg.(type) {
+		case *ClusterConfig:
+			err = c.receiver.ClusterConfig(c.id, *msg)
+
+		case *Index:
+			err = c.handleIndex(*msg)
+
+		case *IndexUpdate:
+			err = c.handleIndexUpdate(*msg)
 
 		case *Request:
-			l.Debugln("read Request message")
-			if state != stateReady {
-				return fmt.Errorf("protocol error: request message in state %d", state)
-			}
-			if err := checkFilename(msg.Name); err != nil {
-				return errors.Wrapf(err, "protocol error: request: %q", msg.Name)
-			}
 			go c.handleRequest(*msg)
 
 		case *Response:
-			l.Debugln("read Response message")
-			if state != stateReady {
-				return fmt.Errorf("protocol error: response message in state %d", state)
-			}
 			c.handleResponse(*msg)
 
 		case *DownloadProgress:
-			l.Debugln("read DownloadProgress message")
-			if state != stateReady {
-				return fmt.Errorf("protocol error: response message in state %d", state)
-			}
-			if err := c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates); err != nil {
-				return fmt.Errorf("receiving download progress: %w", err)
-			}
-
-		case *Ping:
-			l.Debugln("read Ping message")
-			if state != stateReady {
-				return fmt.Errorf("protocol error: ping message in state %d", state)
-			}
-			// Nothing
-
-		case *Close:
-			l.Debugln("read Close message")
-			return fmt.Errorf("closed by remote: %v", msg.Reason)
-
-		default:
-			l.Debugf("read unknown message: %+T", msg)
-			return fmt.Errorf("protocol error: %s: unknown or empty message", c.id)
+			err = c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates)
+		}
+		if err != nil {
+			return newHandleError(err, msgContext)
 		}
 	}
 }
@@ -1078,3 +1057,34 @@ func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) {
 	}
 	return decoded, nil
 }
+
+func newProtocolError(err error, msgContext string) error {
+	return fmt.Errorf("protocol error on %v: %w", msgContext, err)
+}
+
+func newHandleError(err error, msgContext string) error {
+	return fmt.Errorf("handling %v: %w", msgContext, err)
+}
+
+func messageContext(msg message) (string, error) {
+	switch msg := msg.(type) {
+	case *ClusterConfig:
+		return "cluster-config", nil
+	case *Index:
+		return fmt.Sprintf("index for %v", msg.Folder), nil
+	case *IndexUpdate:
+		return fmt.Sprintf("index-update for %v", msg.Folder), nil
+	case *Request:
+		return fmt.Sprintf(`request for "%v" in %v`, msg.Name, msg.Folder), nil
+	case *Response:
+		return "response", nil
+	case *DownloadProgress:
+		return fmt.Sprintf("download-progress for %v", msg.Folder), nil
+	case *Ping:
+		return "ping", nil
+	case *Close:
+		return "close", nil
+	default:
+		return "", errors.New("unknown or empty message")
+	}
+}