claude_stream.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import json
  2. import logging
  3. import importlib.util
  4. import uuid
  5. from pathlib import Path
  6. from typing import AsyncGenerator, Optional, Dict, Any, List, Set
  7. import tiktoken
  8. logger = logging.getLogger(__name__)
  9. # ------------------------------------------------------------------------------
  10. # Tokenizer
  11. # ------------------------------------------------------------------------------
  12. try:
  13. # cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
  14. ENCODING = tiktoken.get_encoding("cl100k_base")
  15. except Exception:
  16. ENCODING = None
  17. THINKING_START_TAG = "<thinking>"
  18. THINKING_END_TAG = "</thinking>"
  19. def _pending_tag_suffix(buffer: str, tag: str) -> int:
  20. """Length of the suffix of buffer that matches the prefix of tag (for partial matches)."""
  21. if not buffer or not tag:
  22. return 0
  23. max_len = min(len(buffer), len(tag) - 1)
  24. for length in range(max_len, 0, -1):
  25. if buffer[-length:] == tag[:length]:
  26. return length
  27. return 0
  28. def count_tokens(text: str) -> int:
  29. """Counts tokens with tiktoken."""
  30. if not text or not ENCODING:
  31. return 0
  32. return len(ENCODING.encode(text))
  33. # ------------------------------------------------------------------------------
  34. # Dynamic Loader
  35. # ------------------------------------------------------------------------------
  36. def _load_claude_parser():
  37. """Dynamically load claude_parser module."""
  38. base_dir = Path(__file__).resolve().parent
  39. spec = importlib.util.spec_from_file_location("v2_claude_parser", str(base_dir / "claude_parser.py"))
  40. module = importlib.util.module_from_spec(spec)
  41. spec.loader.exec_module(module)
  42. return module
  43. try:
  44. _parser = _load_claude_parser()
  45. build_message_start = _parser.build_message_start
  46. build_content_block_start = _parser.build_content_block_start
  47. build_content_block_delta = _parser.build_content_block_delta
  48. build_content_block_stop = _parser.build_content_block_stop
  49. build_ping = _parser.build_ping
  50. build_message_stop = _parser.build_message_stop
  51. build_tool_use_start = _parser.build_tool_use_start
  52. build_tool_use_input_delta = _parser.build_tool_use_input_delta
  53. except Exception as e:
  54. logger.error(f"Failed to load claude_parser: {e}")
  55. # Fallback definitions
  56. def build_message_start(*args, **kwargs): return ""
  57. def build_content_block_start(*args, **kwargs): return ""
  58. def build_content_block_delta(*args, **kwargs): return ""
  59. def build_content_block_stop(*args, **kwargs): return ""
  60. def build_ping(*args, **kwargs): return ""
  61. def build_message_stop(*args, **kwargs): return ""
  62. def build_tool_use_start(*args, **kwargs): return ""
  63. def build_tool_use_input_delta(*args, **kwargs): return ""
  64. class ClaudeStreamHandler:
  65. def __init__(self, model: str, input_tokens: int = 0, conversation_id: Optional[str] = None):
  66. self.model = model
  67. self.input_tokens = input_tokens
  68. self.response_buffer: List[str] = []
  69. self.content_block_index: int = -1
  70. self.content_block_started: bool = False
  71. self.content_block_start_sent: bool = False
  72. self.content_block_stop_sent: bool = False
  73. self.message_start_sent: bool = False
  74. self.conversation_id: Optional[str] = conversation_id
  75. # Tool use state
  76. self.current_tool_use: Optional[Dict[str, Any]] = None
  77. self.tool_input_buffer: List[str] = []
  78. self.tool_use_id: Optional[str] = None
  79. self.tool_name: Optional[str] = None
  80. self._processed_tool_use_ids: Set[str] = set()
  81. self.all_tool_inputs: List[str] = []
  82. self.has_tool_use: bool = False # Track if any tool_use was emitted
  83. # Think tag state
  84. self.in_think_block: bool = False
  85. self.think_buffer: str = ""
  86. self.pending_start_tag_chars: int = 0
  87. # Response termination flag
  88. self.response_ended: bool = False
  89. async def handle_event(self, event_type: str, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
  90. """Process a single Amazon Q event and yield Claude SSE events."""
  91. # Early return if response has already ended
  92. if self.response_ended:
  93. return
  94. # 1. Message Start (initial-response)
  95. if event_type == "initial-response":
  96. if not self.message_start_sent:
  97. # Use conversation_id from payload if available, otherwise use the one passed to constructor
  98. conv_id = payload.get('conversationId') or self.conversation_id or str(uuid.uuid4())
  99. self.conversation_id = conv_id
  100. yield build_message_start(conv_id, self.model, self.input_tokens)
  101. self.message_start_sent = True
  102. yield build_ping()
  103. # 2. Content Block Delta (assistantResponseEvent)
  104. elif event_type == "assistantResponseEvent":
  105. content = payload.get("content", "")
  106. # Close any open tool use block
  107. if self.current_tool_use and not self.content_block_stop_sent:
  108. yield build_content_block_stop(self.content_block_index)
  109. self.content_block_stop_sent = True
  110. self.current_tool_use = None
  111. # Process content with think tag detection
  112. if content:
  113. self.think_buffer += content
  114. while self.think_buffer:
  115. if self.pending_start_tag_chars > 0:
  116. if len(self.think_buffer) < self.pending_start_tag_chars:
  117. self.pending_start_tag_chars -= len(self.think_buffer)
  118. self.think_buffer = ""
  119. break
  120. else:
  121. self.think_buffer = self.think_buffer[self.pending_start_tag_chars:]
  122. self.pending_start_tag_chars = 0
  123. if not self.think_buffer:
  124. break
  125. continue
  126. if not self.in_think_block:
  127. think_start = self.think_buffer.find(THINKING_START_TAG)
  128. if think_start == -1:
  129. pending = _pending_tag_suffix(self.think_buffer, THINKING_START_TAG)
  130. if pending == len(self.think_buffer) and pending > 0:
  131. if self.content_block_start_sent:
  132. yield build_content_block_stop(self.content_block_index)
  133. self.content_block_stop_sent = True
  134. self.content_block_start_sent = False
  135. self.content_block_index += 1
  136. yield build_content_block_start(self.content_block_index, "thinking")
  137. self.content_block_start_sent = True
  138. self.content_block_started = True
  139. self.content_block_stop_sent = False
  140. self.in_think_block = True
  141. self.pending_start_tag_chars = len(THINKING_START_TAG) - pending
  142. self.think_buffer = ""
  143. break
  144. emit_len = len(self.think_buffer) - pending
  145. if emit_len <= 0:
  146. break
  147. text_chunk = self.think_buffer[:emit_len]
  148. if text_chunk:
  149. if not self.content_block_start_sent:
  150. self.content_block_index += 1
  151. yield build_content_block_start(self.content_block_index, "text")
  152. self.content_block_start_sent = True
  153. self.content_block_started = True
  154. self.content_block_stop_sent = False
  155. self.response_buffer.append(text_chunk)
  156. yield build_content_block_delta(self.content_block_index, text_chunk)
  157. self.think_buffer = self.think_buffer[emit_len:]
  158. else:
  159. before_text = self.think_buffer[:think_start]
  160. if before_text:
  161. if not self.content_block_start_sent:
  162. self.content_block_index += 1
  163. yield build_content_block_start(self.content_block_index, "text")
  164. self.content_block_start_sent = True
  165. self.content_block_started = True
  166. self.content_block_stop_sent = False
  167. self.response_buffer.append(before_text)
  168. yield build_content_block_delta(self.content_block_index, before_text)
  169. self.think_buffer = self.think_buffer[think_start + len(THINKING_START_TAG):]
  170. if self.content_block_start_sent:
  171. yield build_content_block_stop(self.content_block_index)
  172. self.content_block_stop_sent = True
  173. self.content_block_start_sent = False
  174. self.content_block_index += 1
  175. yield build_content_block_start(self.content_block_index, "thinking")
  176. self.content_block_start_sent = True
  177. self.content_block_started = True
  178. self.content_block_stop_sent = False
  179. self.in_think_block = True
  180. self.pending_start_tag_chars = 0
  181. else:
  182. think_end = self.think_buffer.find(THINKING_END_TAG)
  183. if think_end == -1:
  184. pending = _pending_tag_suffix(self.think_buffer, THINKING_END_TAG)
  185. emit_len = len(self.think_buffer) - pending
  186. if emit_len <= 0:
  187. break
  188. thinking_chunk = self.think_buffer[:emit_len]
  189. if thinking_chunk:
  190. yield build_content_block_delta(
  191. self.content_block_index,
  192. thinking_chunk,
  193. delta_type="thinking_delta",
  194. field_name="thinking"
  195. )
  196. self.think_buffer = self.think_buffer[emit_len:]
  197. else:
  198. thinking_chunk = self.think_buffer[:think_end]
  199. if thinking_chunk:
  200. yield build_content_block_delta(
  201. self.content_block_index,
  202. thinking_chunk,
  203. delta_type="thinking_delta",
  204. field_name="thinking"
  205. )
  206. self.think_buffer = self.think_buffer[think_end + len(THINKING_END_TAG):]
  207. yield build_content_block_stop(self.content_block_index)
  208. self.content_block_stop_sent = True
  209. self.content_block_start_sent = False
  210. self.in_think_block = False
  211. # 3. Tool Use (toolUseEvent)
  212. elif event_type == "toolUseEvent":
  213. tool_use_id = payload.get("toolUseId")
  214. tool_name = payload.get("name")
  215. tool_input = payload.get("input", {})
  216. is_stop = payload.get("stop", False)
  217. # Deduplication: skip if this tool_use_id was already processed and no tool is active
  218. # (allows input deltas to pass through when current_tool_use is set)
  219. if tool_use_id and tool_use_id in self._processed_tool_use_ids and not self.current_tool_use:
  220. logger.warning(f"Detected duplicate tool use event, toolUseId={tool_use_id}, skipping")
  221. return
  222. # Start new tool use
  223. if tool_use_id and tool_name and not self.current_tool_use:
  224. # Close previous text block if open
  225. if self.content_block_start_sent and not self.content_block_stop_sent:
  226. yield build_content_block_stop(self.content_block_index)
  227. self.content_block_stop_sent = True
  228. self._processed_tool_use_ids.add(tool_use_id)
  229. self.content_block_index += 1
  230. yield build_tool_use_start(self.content_block_index, tool_use_id, tool_name)
  231. self.content_block_started = True
  232. self.current_tool_use = {"toolUseId": tool_use_id, "name": tool_name}
  233. self.tool_use_id = tool_use_id
  234. self.tool_name = tool_name
  235. self.tool_input_buffer = []
  236. self.content_block_stop_sent = False
  237. self.content_block_start_sent = True
  238. self.has_tool_use = True # Mark that we have tool_use
  239. # Accumulate input
  240. if self.current_tool_use and tool_input:
  241. fragment = ""
  242. if isinstance(tool_input, str):
  243. fragment = tool_input
  244. else:
  245. fragment = json.dumps(tool_input, ensure_ascii=False)
  246. self.tool_input_buffer.append(fragment)
  247. yield build_tool_use_input_delta(self.content_block_index, fragment)
  248. # Stop tool use
  249. if is_stop and self.current_tool_use:
  250. full_input = "".join(self.tool_input_buffer)
  251. self.all_tool_inputs.append(full_input)
  252. yield build_content_block_stop(self.content_block_index)
  253. # Reset state to allow next content block
  254. self.content_block_stop_sent = False # Reset to False to allow next block
  255. self.content_block_started = False
  256. self.content_block_start_sent = False # Important: reset start flag for next block
  257. self.current_tool_use = None
  258. self.tool_use_id = None
  259. self.tool_name = None
  260. self.tool_input_buffer = []
  261. # 4. Assistant Response End (assistantResponseEnd)
  262. elif event_type == "assistantResponseEnd":
  263. # Close any open block
  264. if self.content_block_started and not self.content_block_stop_sent:
  265. yield build_content_block_stop(self.content_block_index)
  266. self.content_block_stop_sent = True
  267. # Mark as finished to prevent processing further events
  268. self.response_ended = True
  269. # Immediately send message_stop (instead of waiting for finish())
  270. full_text = "".join(self.response_buffer)
  271. full_tool_input = "".join(self.all_tool_inputs)
  272. output_tokens = count_tokens(full_text) + count_tokens(full_tool_input)
  273. # Use "tool_use" stop_reason if any tool_use was emitted, otherwise "end_turn"
  274. stop_reason = "tool_use" if self.has_tool_use else "end_turn"
  275. yield build_message_stop(self.input_tokens, output_tokens, stop_reason)
  276. async def finish(self) -> AsyncGenerator[str, None]:
  277. """Send final events."""
  278. # Skip if response already ended (message_stop already sent)
  279. if self.response_ended:
  280. return
  281. # Flush any remaining think_buffer content
  282. if self.think_buffer:
  283. if self.in_think_block:
  284. # Emit remaining thinking content
  285. yield build_content_block_delta(
  286. self.content_block_index,
  287. self.think_buffer,
  288. delta_type="thinking_delta",
  289. field_name="thinking"
  290. )
  291. else:
  292. # Emit remaining text content
  293. if not self.content_block_start_sent:
  294. self.content_block_index += 1
  295. yield build_content_block_start(self.content_block_index, "text")
  296. self.content_block_start_sent = True
  297. self.content_block_started = True
  298. self.content_block_stop_sent = False
  299. self.response_buffer.append(self.think_buffer)
  300. yield build_content_block_delta(self.content_block_index, self.think_buffer)
  301. self.think_buffer = ""
  302. # Ensure last block is closed
  303. if self.content_block_started and not self.content_block_stop_sent:
  304. yield build_content_block_stop(self.content_block_index)
  305. self.content_block_stop_sent = True
  306. # Calculate output tokens (approximate)
  307. full_text = "".join(self.response_buffer)
  308. full_tool_input = "".join(self.all_tool_inputs)
  309. # Simple approximation: 4 chars per token
  310. # output_tokens = max(1, (len(full_text) + len(full_tool_input)) // 4)
  311. output_tokens = count_tokens(full_text) + count_tokens(full_tool_input)
  312. # Use "tool_use" stop_reason if any tool_use was emitted, otherwise "end_turn"
  313. stop_reason = "tool_use" if self.has_tool_use else "end_turn"
  314. yield build_message_stop(self.input_tokens, output_tokens, stop_reason)