|
|
@@ -12,13 +12,6 @@ def merge_user_messages(user_messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
|
"""
|
|
|
合并多个 userInputMessage 的内容
|
|
|
|
|
|
- IMPORTANT: This function properly merges toolResults from all messages to prevent
|
|
|
- losing tool execution history, which would cause infinite loops where the model
|
|
|
- keeps responding to the same user message.
|
|
|
-
|
|
|
- Key fix: Deduplicate toolResults by toolUseId to prevent duplicate tool_result
|
|
|
- entries that cause the model to repeatedly respond to the same user message.
|
|
|
-
|
|
|
Args:
|
|
|
user_messages: userInputMessage 列表
|
|
|
|
|
|
@@ -33,30 +26,19 @@ def merge_user_messages(user_messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
|
base_context = None
|
|
|
base_origin = None
|
|
|
base_model = None
|
|
|
- # Use dict to deduplicate toolResults by toolUseId
|
|
|
- tool_results_by_id: Dict[str, Dict[str, Any]] = {}
|
|
|
all_images = [] # Collect images from messages
|
|
|
|
|
|
for msg in user_messages:
|
|
|
content = msg.get("content", "")
|
|
|
msg_ctx = msg.get("userInputMessageContext", {})
|
|
|
|
|
|
- # Initialize base context from first message
|
|
|
+ # 保留第一个消息的上下文信息
|
|
|
if base_context is None:
|
|
|
base_context = msg_ctx.copy() if msg_ctx else {}
|
|
|
- # Remove toolResults from base to merge them separately
|
|
|
- if "toolResults" in base_context:
|
|
|
- for tr in base_context.pop("toolResults"):
|
|
|
- _merge_tool_result(tool_results_by_id, tr)
|
|
|
- else:
|
|
|
- # Collect toolResults from subsequent messages
|
|
|
- if "toolResults" in msg_ctx:
|
|
|
- for tr in msg_ctx["toolResults"]:
|
|
|
- _merge_tool_result(tool_results_by_id, tr)
|
|
|
|
|
|
# 保留第一个消息的 origin
|
|
|
if base_origin is None:
|
|
|
- base_origin = msg.get("origin", "CLI")
|
|
|
+ base_origin = msg.get("origin", "KIRO_CLI")
|
|
|
|
|
|
# 保留第一个消息的 modelId
|
|
|
if base_model is None and "modelId" in msg:
|
|
|
@@ -78,13 +60,9 @@ def merge_user_messages(user_messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
|
merged_msg = {
|
|
|
"content": merged_content,
|
|
|
"userInputMessageContext": base_context or {},
|
|
|
- "origin": base_origin or "CLI"
|
|
|
+ "origin": base_origin or "KIRO_CLI"
|
|
|
}
|
|
|
|
|
|
- # Add deduplicated toolResults if any
|
|
|
- if tool_results_by_id:
|
|
|
- merged_msg["userInputMessageContext"]["toolResults"] = list(tool_results_by_id.values())
|
|
|
-
|
|
|
# 如果原始消息有 modelId,也保留
|
|
|
if base_model:
|
|
|
merged_msg["modelId"] = base_model
|
|
|
@@ -100,46 +78,6 @@ def merge_user_messages(user_messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
|
return merged_msg
|
|
|
|
|
|
|
|
|
-def _merge_tool_result(tool_results_by_id: Dict[str, Dict[str, Any]], tool_result: Dict[str, Any]) -> None:
|
|
|
- """
|
|
|
- Merge a tool_result into the deduplicated dict.
|
|
|
- If toolUseId already exists, merge the content arrays.
|
|
|
-
|
|
|
- Args:
|
|
|
- tool_results_by_id: Dict mapping toolUseId to tool_result
|
|
|
- tool_result: The tool_result to merge
|
|
|
- """
|
|
|
- tool_use_id = tool_result.get("toolUseId")
|
|
|
- if not tool_use_id:
|
|
|
- return
|
|
|
-
|
|
|
- if tool_use_id in tool_results_by_id:
|
|
|
- # Merge content arrays
|
|
|
- existing = tool_results_by_id[tool_use_id]
|
|
|
- existing_content = existing.get("content", [])
|
|
|
- new_content = tool_result.get("content", [])
|
|
|
-
|
|
|
- # Deduplicate content by text value
|
|
|
- existing_texts = {item.get("text", "") for item in existing_content if isinstance(item, dict)}
|
|
|
- for item in new_content:
|
|
|
- if isinstance(item, dict):
|
|
|
- text = item.get("text", "")
|
|
|
- if text and text not in existing_texts:
|
|
|
- existing_content.append(item)
|
|
|
- existing_texts.add(text)
|
|
|
-
|
|
|
- existing["content"] = existing_content
|
|
|
-
|
|
|
- # If any result has error status, keep error
|
|
|
- if tool_result.get("status") == "error":
|
|
|
- existing["status"] = "error"
|
|
|
-
|
|
|
- logger.debug(f"Merged duplicate toolUseId {tool_use_id}")
|
|
|
- else:
|
|
|
- # New toolUseId, add to dict
|
|
|
- tool_results_by_id[tool_use_id] = tool_result.copy()
|
|
|
-
|
|
|
-
|
|
|
def process_claude_history_for_amazonq(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
处理 Claude Code 历史记录,使其符合 Amazon Q 要求
|