claude_parser.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import json
  2. import struct
  3. import logging
  4. from typing import Optional, Dict, Any, AsyncIterator
  5. logger = logging.getLogger(__name__)
  6. class EventStreamParser:
  7. """AWS Event Stream binary format parser (v2 style)."""
  8. @staticmethod
  9. def parse_headers(headers_data: bytes) -> Dict[str, str]:
  10. """Parse event stream headers."""
  11. headers = {}
  12. offset = 0
  13. while offset < len(headers_data):
  14. if offset >= len(headers_data):
  15. break
  16. name_length = headers_data[offset]
  17. offset += 1
  18. if offset + name_length > len(headers_data):
  19. break
  20. name = headers_data[offset:offset + name_length].decode('utf-8')
  21. offset += name_length
  22. if offset >= len(headers_data):
  23. break
  24. value_type = headers_data[offset]
  25. offset += 1
  26. if offset + 2 > len(headers_data):
  27. break
  28. value_length = struct.unpack('>H', headers_data[offset:offset + 2])[0]
  29. offset += 2
  30. if offset + value_length > len(headers_data):
  31. break
  32. if value_type == 7:
  33. value = headers_data[offset:offset + value_length].decode('utf-8')
  34. else:
  35. value = headers_data[offset:offset + value_length]
  36. offset += value_length
  37. headers[name] = value
  38. return headers
  39. @staticmethod
  40. def parse_message(data: bytes) -> Optional[Dict[str, Any]]:
  41. """Parse single Event Stream message."""
  42. try:
  43. if len(data) < 16:
  44. return None
  45. total_length = struct.unpack('>I', data[0:4])[0]
  46. headers_length = struct.unpack('>I', data[4:8])[0]
  47. if len(data) < total_length:
  48. logger.warning(f"Incomplete message: expected {total_length} bytes, got {len(data)}")
  49. return None
  50. headers_data = data[12:12 + headers_length]
  51. headers = EventStreamParser.parse_headers(headers_data)
  52. payload_start = 12 + headers_length
  53. payload_end = total_length - 4
  54. payload_data = data[payload_start:payload_end]
  55. payload = None
  56. if payload_data:
  57. try:
  58. payload = json.loads(payload_data.decode('utf-8'))
  59. except (json.JSONDecodeError, UnicodeDecodeError):
  60. payload = payload_data
  61. return {
  62. 'headers': headers,
  63. 'payload': payload,
  64. 'total_length': total_length
  65. }
  66. except Exception as e:
  67. logger.error(f"Failed to parse message: {e}", exc_info=True)
  68. return None
  69. @staticmethod
  70. async def parse_stream(byte_stream: AsyncIterator[bytes]) -> AsyncIterator[Dict[str, Any]]:
  71. """Parse byte stream and extract events."""
  72. buffer = bytearray()
  73. async for chunk in byte_stream:
  74. buffer.extend(chunk)
  75. while len(buffer) >= 12:
  76. try:
  77. total_length = struct.unpack('>I', buffer[0:4])[0]
  78. except struct.error:
  79. break
  80. if len(buffer) < total_length:
  81. break
  82. message_data = bytes(buffer[:total_length])
  83. buffer = buffer[total_length:]
  84. message = EventStreamParser.parse_message(message_data)
  85. if message:
  86. yield message
  87. def extract_event_info(message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
  88. """Extract event information from parsed message."""
  89. headers = message.get('headers', {})
  90. payload = message.get('payload')
  91. event_type = headers.get(':event-type') or headers.get('event-type')
  92. content_type = headers.get(':content-type') or headers.get('content-type')
  93. message_type = headers.get(':message-type') or headers.get('message-type')
  94. return {
  95. 'event_type': event_type,
  96. 'content_type': content_type,
  97. 'message_type': message_type,
  98. 'payload': payload
  99. }
  100. def _sse_format(event_type: str, data: Dict[str, Any]) -> str:
  101. """Format SSE event."""
  102. json_data = json.dumps(data, ensure_ascii=False)
  103. return f"event: {event_type}\ndata: {json_data}\n\n"
  104. def build_message_start(conversation_id: str, model: str = "claude-sonnet-4.5", input_tokens: int = 0) -> str:
  105. """Build message_start SSE event."""
  106. data = {
  107. "type": "message_start",
  108. "message": {
  109. "id": conversation_id,
  110. "type": "message",
  111. "role": "assistant",
  112. "content": [],
  113. "model": model,
  114. "stop_reason": None,
  115. "stop_sequence": None,
  116. "usage": {"input_tokens": input_tokens, "output_tokens": 0}
  117. }
  118. }
  119. return _sse_format("message_start", data)
  120. def build_content_block_start(index: int, block_type: str = "text") -> str:
  121. """Build content_block_start SSE event."""
  122. if block_type == "text":
  123. block_payload = {"type": "text", "text": ""}
  124. elif block_type == "thinking":
  125. block_payload = {"type": "thinking", "thinking": ""}
  126. else:
  127. block_payload = {"type": block_type}
  128. data = {
  129. "type": "content_block_start",
  130. "index": index,
  131. "content_block": block_payload
  132. }
  133. return _sse_format("content_block_start", data)
  134. def build_content_block_delta(index: int, text: str, delta_type: str = "text_delta", field_name: str = "text") -> str:
  135. """Build content_block_delta SSE event."""
  136. delta = {"type": delta_type}
  137. if field_name:
  138. delta[field_name] = text
  139. data = {
  140. "type": "content_block_delta",
  141. "index": index,
  142. "delta": delta
  143. }
  144. return _sse_format("content_block_delta", data)
  145. def build_content_block_stop(index: int) -> str:
  146. """Build content_block_stop SSE event."""
  147. data = {
  148. "type": "content_block_stop",
  149. "index": index
  150. }
  151. return _sse_format("content_block_stop", data)
  152. def build_ping() -> str:
  153. """Build ping SSE event."""
  154. data = {"type": "ping"}
  155. return _sse_format("ping", data)
  156. def build_message_stop(input_tokens: int, output_tokens: int, stop_reason: Optional[str] = None) -> str:
  157. """Build message_delta and message_stop SSE events."""
  158. delta_data = {
  159. "type": "message_delta",
  160. "delta": {"stop_reason": stop_reason or "end_turn", "stop_sequence": None},
  161. "usage": {"output_tokens": output_tokens}
  162. }
  163. delta_event = _sse_format("message_delta", delta_data)
  164. stop_data = {
  165. "type": "message_stop"
  166. }
  167. stop_event = _sse_format("message_stop", stop_data)
  168. return delta_event + stop_event
  169. def build_tool_use_start(index: int, tool_use_id: str, tool_name: str) -> str:
  170. """Build tool_use content_block_start SSE event."""
  171. data = {
  172. "type": "content_block_start",
  173. "index": index,
  174. "content_block": {
  175. "type": "tool_use",
  176. "id": tool_use_id,
  177. "name": tool_name,
  178. "input": {}
  179. }
  180. }
  181. return _sse_format("content_block_start", data)
  182. def build_tool_use_input_delta(index: int, input_json_delta: str) -> str:
  183. """Build tool_use input_json_delta SSE event."""
  184. data = {
  185. "type": "content_block_delta",
  186. "index": index,
  187. "delta": {
  188. "type": "input_json_delta",
  189. "partial_json": input_json_delta
  190. }
  191. }
  192. return _sse_format("content_block_delta", data)