Просмотр исходного кода

Update v2: add Claude integration modules and modify existing files

CassiopeiaCode 2 месяцев назад
Родитель
Сommit
3b0a37742a
7 измененных файлов с 1112 добавлено и 32 удалено
  1. 2 0
      .gitignore
  2. 234 1
      app.py
  3. 393 0
      claude_converter.py
  4. 222 0
      claude_parser.py
  5. 164 0
      claude_stream.py
  6. 20 0
      claude_types.py
  7. 77 31
      replicate.py

+ 2 - 0
.gitignore

@@ -30,6 +30,8 @@ wheels/
 
 # Database
 *.sqlite3
+*.sqlite3-shm
+*.sqlite3-wal
 *.db
 
 # IDE

+ 234 - 1
app.py

@@ -51,6 +51,60 @@ def _load_replicate_module():
 _replicate = _load_replicate_module()
 send_chat_request = _replicate.send_chat_request
 
+# ------------------------------------------------------------------------------
+# Dynamic import of Claude modules
+# ------------------------------------------------------------------------------
+
+def _load_claude_modules():
+    # claude_types
+    spec_types = importlib.util.spec_from_file_location("v2_claude_types", str(BASE_DIR / "claude_types.py"))
+    mod_types = importlib.util.module_from_spec(spec_types)
+    spec_types.loader.exec_module(mod_types)
+    
+    # claude_converter
+    spec_conv = importlib.util.spec_from_file_location("v2_claude_converter", str(BASE_DIR / "claude_converter.py"))
+    mod_conv = importlib.util.module_from_spec(spec_conv)
+    # We need to inject claude_types into converter's namespace if it uses relative imports or expects them
+    # But since we used relative import in claude_converter.py (.claude_types), we need to be careful.
+    # Actually, since we are loading dynamically, relative imports might fail if not in sys.modules correctly.
+    # Let's patch sys.modules temporarily or just rely on file location.
+    # A simpler way for this single-file script style is to just load them.
+    # However, claude_converter does `from .claude_types import ...`
+    # To make that work, we should probably just use standard import if v2 is a package,
+    # but v2 is just a folder.
+    # Let's assume the user runs this with v2 in pythonpath or we just fix imports in the files.
+    # But I wrote `from .claude_types` in the file.
+    # Let's try to load it. If it fails, we might need to adjust.
+    # Actually, for simplicity in this `app.py` dynamic loading context,
+    # it is better if `claude_converter.py` used absolute import or we mock the package.
+    # BUT, let's try to just load them and see.
+    # To avoid relative import issues, I will inject the module into sys.modules
+    import sys
+    sys.modules["v2.claude_types"] = mod_types
+    
+    spec_conv.loader.exec_module(mod_conv)
+    
+    # claude_stream
+    spec_stream = importlib.util.spec_from_file_location("v2_claude_stream", str(BASE_DIR / "claude_stream.py"))
+    mod_stream = importlib.util.module_from_spec(spec_stream)
+    spec_stream.loader.exec_module(mod_stream)
+    
+    return mod_types, mod_conv, mod_stream
+
+try:
+    _claude_types, _claude_converter, _claude_stream = _load_claude_modules()
+    ClaudeRequest = _claude_types.ClaudeRequest
+    convert_claude_to_amazonq_request = _claude_converter.convert_claude_to_amazonq_request
+    ClaudeStreamHandler = _claude_stream.ClaudeStreamHandler
+except Exception as e:
+    print(f"Failed to load Claude modules: {e}")
+    traceback.print_exc()
+    # Define dummy classes to avoid NameError on startup if loading fails
+    class ClaudeRequest(BaseModel):
+        pass
+    convert_claude_to_amazonq_request = None
+    ClaudeStreamHandler = None
+
 # ------------------------------------------------------------------------------
 # Global HTTP Client
 # ------------------------------------------------------------------------------
@@ -424,6 +478,181 @@ def _openai_non_streaming_response(text: str, model: Optional[str]) -> Dict[str,
 def _sse_format(obj: Dict[str, Any]) -> str:
     return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
 
[email protected]("/v1/messages")
+async def claude_messages(req: ClaudeRequest, account: Dict[str, Any] = Depends(require_account)):
+    """
+    Claude-compatible messages endpoint.
+    """
+    # 1. Convert request
+    try:
+        aq_request = convert_claude_to_amazonq_request(req)
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=400, detail=f"Request conversion failed: {str(e)}")
+
+    # 2. Send upstream
+    async def _send_upstream_raw() -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], Any, Optional[AsyncGenerator[Any, None]]]:
+        access = account.get("accessToken")
+        if not access:
+            refreshed = await refresh_access_token_in_db(account["id"])
+            access = refreshed.get("accessToken")
+            if not access:
+                raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
+        
+        # We use the modified send_chat_request which accepts raw_payload
+        # and returns (text, text_stream, tracker, event_stream)
+        return await send_chat_request(
+            access_token=access,
+            messages=[], # Not used when raw_payload is present
+            model=req.model,
+            stream=req.stream,
+            client=GLOBAL_CLIENT,
+            raw_payload=aq_request
+        )
+
+    try:
+        _, _, tracker, event_stream = await _send_upstream_raw()
+        
+        if not req.stream:
+            # Non-streaming: we need to consume the stream and build response
+            # But wait, send_chat_request with stream=False returns text, but we need structured response
+            # Actually, for Claude format, we might want to parse the events even for non-streaming
+            # to get tool calls etc correctly.
+            # However, our modified send_chat_request returns event_stream if raw_payload is used AND stream=True?
+            # Let's check replicate.py modification.
+            # If stream=False, it returns text. But text might not be enough for tool calls.
+            # For simplicity, let's force stream=True internally and aggregate if req.stream is False.
+            pass
+    except Exception as e:
+        await _update_stats(account["id"], False)
+        raise
+
+    # We always use streaming upstream to handle events properly
+    try:
+        # Force stream=True for upstream to get events
+        # But wait, send_chat_request logic: if stream=True, returns event_stream
+        # We need to call it with stream=True
+        pass
+    except:
+        pass
+        
+    # Re-implementing logic to be cleaner
+    
+    # Always stream from upstream to get full event details
+    try:
+        access = account.get("accessToken")
+        if not access:
+            refreshed = await refresh_access_token_in_db(account["id"])
+            access = refreshed.get("accessToken")
+        
+        # We call with stream=True to get the event iterator
+        _, _, tracker, event_iter = await send_chat_request(
+            access_token=access,
+            messages=[],
+            model=req.model,
+            stream=True,
+            client=GLOBAL_CLIENT,
+            raw_payload=aq_request
+        )
+        
+        if not event_iter:
+             raise HTTPException(status_code=502, detail="No event stream returned")
+
+        # Handler
+        # Estimate input tokens (simple count or 0)
+        # For now 0 or simple len
+        input_tokens = 0
+        handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens)
+
+        async def event_generator():
+            try:
+                async for event_type, payload in event_iter:
+                    async for sse in handler.handle_event(event_type, payload):
+                        yield sse
+                async for sse in handler.finish():
+                    yield sse
+                await _update_stats(account["id"], True)
+            except Exception:
+                await _update_stats(account["id"], False)
+                raise
+
+        if req.stream:
+            return StreamingResponse(event_generator(), media_type="text/event-stream")
+        else:
+            # Accumulate for non-streaming
+            # This is a bit complex because we need to reconstruct the full response object
+            # For now, let's just support streaming as it's the main use case for Claude Code
+            # But to be nice, let's try to support non-streaming by consuming the generator
+            
+            content_blocks = []
+            usage = {"input_tokens": 0, "output_tokens": 0}
+            stop_reason = None
+            
+            # We need to parse the SSE strings back to objects... inefficient but works
+            # Or we could refactor handler to yield objects.
+            # For now, let's just raise error for non-streaming or implement basic text
+            # Claude Code uses streaming.
+            
+            # Let's implement a basic accumulator from the SSE stream
+            final_content = []
+            
+            async for sse_line in event_generator():
+                if sse_line.startswith("data: "):
+                    data_str = sse_line[6:].strip()
+                    if data_str == "[DONE]": continue
+                    try:
+                        data = json.loads(data_str)
+                        dtype = data.get("type")
+                        if dtype == "content_block_start":
+                            idx = data.get("index", 0)
+                            while len(final_content) <= idx:
+                                final_content.append(None)
+                            final_content[idx] = data.get("content_block")
+                        elif dtype == "content_block_delta":
+                            idx = data.get("index", 0)
+                            delta = data.get("delta", {})
+                            if final_content[idx]:
+                                if delta.get("type") == "text_delta":
+                                    final_content[idx]["text"] += delta.get("text", "")
+                                elif delta.get("type") == "input_json_delta":
+                                    # We need to accumulate partial json
+                                    # But wait, content_block for tool_use has 'input' as dict?
+                                    # No, in start it is empty.
+                                    # We need to track partial json string
+                                    if "partial_json" not in final_content[idx]:
+                                        final_content[idx]["partial_json"] = ""
+                                    final_content[idx]["partial_json"] += delta.get("partial_json", "")
+                        elif dtype == "content_block_stop":
+                            idx = data.get("index", 0)
+                            # If tool use, parse json
+                            if final_content[idx] and final_content[idx]["type"] == "tool_use":
+                                if "partial_json" in final_content[idx]:
+                                    try:
+                                        final_content[idx]["input"] = json.loads(final_content[idx]["partial_json"])
+                                    except:
+                                        pass
+                                    del final_content[idx]["partial_json"]
+                        elif dtype == "message_delta":
+                            usage = data.get("usage", usage)
+                            stop_reason = data.get("delta", {}).get("stop_reason")
+                    except:
+                        pass
+            
+            return {
+                "id": f"msg_{uuid.uuid4()}",
+                "type": "message",
+                "role": "assistant",
+                "model": req.model,
+                "content": [c for c in final_content if c is not None],
+                "stop_reason": stop_reason,
+                "stop_sequence": None,
+                "usage": usage
+            }
+
+    except Exception as e:
+        await _update_stats(account["id"], False)
+        raise
+
 @app.post("/v1/chat/completions")
 async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
     """
@@ -442,7 +671,11 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
             access = refreshed.get("accessToken")
             if not access:
                 raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
-        return await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
+        # Note: send_chat_request signature changed, but we use keyword args so it should be fine if we don't pass raw_payload
+        # But wait, the return signature changed too! It now returns 4 values.
+        # We need to unpack 4 values.
+        result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
+        return result[0], result[1], result[2] # Ignore the 4th value (event_stream) for OpenAI endpoint
 
     if not do_stream:
         try:

+ 393 - 0
claude_converter.py

@@ -0,0 +1,393 @@
+import json
+import uuid
+import time
+import logging
+from datetime import datetime
+from typing import List, Dict, Any, Optional, Union
+
+try:
+    from .claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
+except ImportError:
+    # Fallback for dynamic loading where relative import might fail
+    # We assume claude_types is available in sys.modules or we can import it directly if in same dir
+    import sys
+    if "v2.claude_types" in sys.modules:
+        from v2.claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
+    else:
+        # Try absolute import assuming v2 is in path or current dir
+        try:
+            from claude_types import ClaudeRequest, ClaudeMessage, ClaudeTool
+        except ImportError:
+             # Last resort: if loaded via importlib in app.py, we might need to rely on app.py injecting it
+             # But app.py loads this module.
+             pass
+
+logger = logging.getLogger(__name__)
+
+def get_current_timestamp() -> str:
+    """Get current timestamp in Amazon Q format."""
+    now = datetime.now().astimezone()
+    weekday = now.strftime("%A")
+    iso_time = now.isoformat(timespec='milliseconds')
+    return f"{weekday}, {iso_time}"
+
+def map_model_name(claude_model: str) -> str:
+    """Map Claude model name to Amazon Q model ID."""
+    model_lower = claude_model.lower()
+    if model_lower.startswith("claude-sonnet-4.5") or model_lower.startswith("claude-sonnet-4-5"):
+        return "claude-sonnet-4.5"
+    return "claude-sonnet-4"
+
+def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
+    """Extract text from Claude content."""
+    if isinstance(content, str):
+        return content
+    elif isinstance(content, list):
+        parts = []
+        for block in content:
+            if isinstance(block, dict):
+                if block.get("type") == "text":
+                    parts.append(block.get("text", ""))
+        return "\n".join(parts)
+    return ""
+
+def extract_images_from_content(content: Union[str, List[Dict[str, Any]]]) -> Optional[List[Dict[str, Any]]]:
+    """Extract images from Claude content and convert to Amazon Q format."""
+    if not isinstance(content, list):
+        return None
+    
+    images = []
+    for block in content:
+        if isinstance(block, dict) and block.get("type") == "image":
+            source = block.get("source", {})
+            if source.get("type") == "base64":
+                media_type = source.get("media_type", "image/png")
+                fmt = media_type.split("/")[-1] if "/" in media_type else "png"
+                images.append({
+                    "format": fmt,
+                    "source": {
+                        "bytes": source.get("data", "")
+                    }
+                })
+    return images if images else None
+
+def convert_tool(tool: ClaudeTool) -> Dict[str, Any]:
+    """Convert Claude tool to Amazon Q tool."""
+    desc = tool.description or ""
+    if len(desc) > 10240:
+        desc = desc[:10100] + "\n\n...(Full description provided in TOOL DOCUMENTATION section)"
+    
+    return {
+        "toolSpecification": {
+            "name": tool.name,
+            "description": desc,
+            "inputSchema": {"json": tool.input_schema}
+        }
+    }
+
+def merge_user_messages(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
+    """Merge consecutive user messages."""
+    if not messages:
+        return {}
+    
+    all_contents = []
+    base_context = None
+    base_origin = None
+    base_model = None
+    
+    for msg in messages:
+        content = msg.get("content", "")
+        if base_context is None:
+            base_context = msg.get("userInputMessageContext", {})
+        if base_origin is None:
+            base_origin = msg.get("origin", "CLI")
+        if base_model is None:
+            base_model = msg.get("modelId")
+        
+        if content:
+            all_contents.append(content)
+            
+    return {
+        "content": "\n\n".join(all_contents),
+        "userInputMessageContext": base_context or {},
+        "origin": base_origin or "CLI",
+        "modelId": base_model
+    }
+
+def process_history(messages: List[ClaudeMessage]) -> List[Dict[str, Any]]:
+    """Process history messages to match Amazon Q format (alternating user/assistant)."""
+    history = []
+    seen_tool_use_ids = set()
+    
+    raw_history = []
+    
+    # First pass: convert individual messages
+    for msg in messages:
+        if msg.role == "user":
+            content = msg.content
+            text_content = ""
+            tool_results = None
+            images = extract_images_from_content(content)
+            
+            if isinstance(content, list):
+                text_parts = []
+                for block in content:
+                    if isinstance(block, dict):
+                        btype = block.get("type")
+                        if btype == "text":
+                            text_parts.append(block.get("text", ""))
+                        elif btype == "tool_result":
+                            if tool_results is None:
+                                tool_results = []
+                            
+                            tool_use_id = block.get("tool_use_id")
+                            raw_c = block.get("content", [])
+                            
+                            aq_content = []
+                            if isinstance(raw_c, str):
+                                aq_content = [{"text": raw_c}]
+                            elif isinstance(raw_c, list):
+                                for item in raw_c:
+                                    if isinstance(item, dict):
+                                        if item.get("type") == "text":
+                                            aq_content.append({"text": item.get("text", "")})
+                                        elif "text" in item:
+                                            aq_content.append({"text": item["text"]})
+                                    elif isinstance(item, str):
+                                        aq_content.append({"text": item})
+                            
+                            if not any(i.get("text", "").strip() for i in aq_content):
+                                aq_content = [{"text": "Tool use was cancelled by the user"}]
+                                
+                            # Merge if exists
+                            existing = next((r for r in tool_results if r["toolUseId"] == tool_use_id), None)
+                            if existing:
+                                existing["content"].extend(aq_content)
+                            else:
+                                tool_results.append({
+                                    "toolUseId": tool_use_id,
+                                    "content": aq_content,
+                                    "status": block.get("status", "success")
+                                })
+                text_content = "\n".join(text_parts)
+            else:
+                text_content = extract_text_from_content(content)
+            
+            user_ctx = {
+                "envState": {
+                    "operatingSystem": "macos",
+                    "currentWorkingDirectory": "/"
+                }
+            }
+            if tool_results:
+                user_ctx["toolResults"] = tool_results
+                
+            u_msg = {
+                "content": text_content,
+                "userInputMessageContext": user_ctx,
+                "origin": "CLI"
+            }
+            if images:
+                u_msg["images"] = images
+                
+            raw_history.append({"userInputMessage": u_msg})
+            
+        elif msg.role == "assistant":
+            content = msg.content
+            text_content = extract_text_from_content(content)
+            
+            entry = {
+                "assistantResponseMessage": {
+                    "messageId": str(uuid.uuid4()),
+                    "content": text_content
+                }
+            }
+            
+            if isinstance(content, list):
+                tool_uses = []
+                for block in content:
+                    if isinstance(block, dict) and block.get("type") == "tool_use":
+                        tid = block.get("id")
+                        if tid and tid not in seen_tool_use_ids:
+                            seen_tool_use_ids.add(tid)
+                            tool_uses.append({
+                                "toolUseId": tid,
+                                "name": block.get("name"),
+                                "input": block.get("input", {})
+                            })
+                if tool_uses:
+                    entry["assistantResponseMessage"]["toolUses"] = tool_uses
+            
+            raw_history.append(entry)
+
+    # Second pass: merge consecutive user messages
+    pending_user_msgs = []
+    for item in raw_history:
+        if "userInputMessage" in item:
+            pending_user_msgs.append(item["userInputMessage"])
+        elif "assistantResponseMessage" in item:
+            if pending_user_msgs:
+                merged = merge_user_messages(pending_user_msgs)
+                history.append({"userInputMessage": merged})
+                pending_user_msgs = []
+            history.append(item)
+            
+    if pending_user_msgs:
+        merged = merge_user_messages(pending_user_msgs)
+        history.append({"userInputMessage": merged})
+        
+    return history
+
+def convert_claude_to_amazonq_request(req: ClaudeRequest, conversation_id: Optional[str] = None) -> Dict[str, Any]:
+    """Convert ClaudeRequest to Amazon Q request body."""
+    if conversation_id is None:
+        conversation_id = str(uuid.uuid4())
+        
+    # 1. Tools
+    aq_tools = []
+    long_desc_tools = []
+    if req.tools:
+        for t in req.tools:
+            if t.description and len(t.description) > 10240:
+                long_desc_tools.append({"name": t.name, "full_description": t.description})
+            aq_tools.append(convert_tool(t))
+            
+    # 2. Current Message (last user message)
+    last_msg = req.messages[-1] if req.messages else None
+    prompt_content = ""
+    tool_results = None
+    has_tool_result = False
+    images = None
+    
+    if last_msg and last_msg.role == "user":
+        content = last_msg.content
+        images = extract_images_from_content(content)
+        
+        if isinstance(content, list):
+            text_parts = []
+            for block in content:
+                if isinstance(block, dict):
+                    btype = block.get("type")
+                    if btype == "text":
+                        text_parts.append(block.get("text", ""))
+                    elif btype == "tool_result":
+                        has_tool_result = True
+                        if tool_results is None:
+                            tool_results = []
+                        
+                        tid = block.get("tool_use_id")
+                        raw_c = block.get("content", [])
+                        
+                        aq_content = []
+                        if isinstance(raw_c, str):
+                            aq_content = [{"text": raw_c}]
+                        elif isinstance(raw_c, list):
+                            for item in raw_c:
+                                if isinstance(item, dict):
+                                    if item.get("type") == "text":
+                                        aq_content.append({"text": item.get("text", "")})
+                                    elif "text" in item:
+                                        aq_content.append({"text": item["text"]})
+                                elif isinstance(item, str):
+                                    aq_content.append({"text": item})
+                                    
+                        if not any(i.get("text", "").strip() for i in aq_content):
+                            aq_content = [{"text": "Tool use was cancelled by the user"}]
+                            
+                        existing = next((r for r in tool_results if r["toolUseId"] == tid), None)
+                        if existing:
+                            existing["content"].extend(aq_content)
+                        else:
+                            tool_results.append({
+                                "toolUseId": tid,
+                                "content": aq_content,
+                                "status": block.get("status", "success")
+                            })
+            prompt_content = "\n".join(text_parts)
+        else:
+            prompt_content = extract_text_from_content(content)
+            
+    # 3. Context
+    user_ctx = {
+        "envState": {
+            "operatingSystem": "macos",
+            "currentWorkingDirectory": "/"
+        }
+    }
+    if aq_tools:
+        user_ctx["tools"] = aq_tools
+    if tool_results:
+        user_ctx["toolResults"] = tool_results
+        
+    # 4. Format Content
+    formatted_content = ""
+    if has_tool_result and not prompt_content:
+        formatted_content = ""
+    else:
+        formatted_content = (
+            "--- CONTEXT ENTRY BEGIN ---\n"
+            f"Current time: {get_current_timestamp()}\n"
+            "--- CONTEXT ENTRY END ---\n\n"
+            "--- USER MESSAGE BEGIN ---\n"
+            f"{prompt_content}\n"
+            "--- USER MESSAGE END ---"
+        )
+        
+    if long_desc_tools:
+        docs = []
+        for info in long_desc_tools:
+            docs.append(f"Tool: {info['name']}\nFull Description:\n{info['full_description']}\n")
+        formatted_content = (
+            "--- TOOL DOCUMENTATION BEGIN ---\n"
+            f"{''.join(docs)}"
+            "--- TOOL DOCUMENTATION END ---\n\n"
+            f"{formatted_content}"
+        )
+        
+    if req.system and formatted_content:
+        sys_text = ""
+        if isinstance(req.system, str):
+            sys_text = req.system
+        elif isinstance(req.system, list):
+            parts = []
+            for b in req.system:
+                if isinstance(b, dict) and b.get("type") == "text":
+                    parts.append(b.get("text", ""))
+            sys_text = "\n".join(parts)
+            
+        if sys_text:
+            formatted_content = (
+                "--- SYSTEM PROMPT BEGIN ---\n"
+                f"{sys_text}\n"
+                "--- SYSTEM PROMPT END ---\n\n"
+                f"{formatted_content}"
+            )
+            
+    # 5. Model
+    model_id = map_model_name(req.model)
+    
+    # 6. User Input Message
+    user_input_msg = {
+        "content": formatted_content,
+        "userInputMessageContext": user_ctx,
+        "origin": "CLI",
+        "modelId": model_id
+    }
+    if images:
+        user_input_msg["images"] = images
+        
+    # 7. History
+    history_msgs = req.messages[:-1] if len(req.messages) > 1 else []
+    aq_history = process_history(history_msgs)
+    
+    # 8. Final Body
+    return {
+        "conversationState": {
+            "conversationId": conversation_id,
+            "history": aq_history,
+            "currentMessage": {
+                "userInputMessage": user_input_msg
+            },
+            "chatTriggerType": "MANUAL"
+        }
+    }

+ 222 - 0
claude_parser.py

@@ -0,0 +1,222 @@
+import json
+import struct
+import logging
+from typing import Optional, Dict, Any, AsyncIterator
+
+logger = logging.getLogger(__name__)
+
+class EventStreamParser:
+    """AWS Event Stream binary format parser (v2 style)."""
+    
+    @staticmethod
+    def parse_headers(headers_data: bytes) -> Dict[str, str]:
+        """Parse event stream headers."""
+        headers = {}
+        offset = 0
+        
+        while offset < len(headers_data):
+            if offset >= len(headers_data):
+                break
+            name_length = headers_data[offset]
+            offset += 1
+            
+            if offset + name_length > len(headers_data):
+                break
+            name = headers_data[offset:offset + name_length].decode('utf-8')
+            offset += name_length
+            
+            if offset >= len(headers_data):
+                break
+            value_type = headers_data[offset]
+            offset += 1
+            
+            if offset + 2 > len(headers_data):
+                break
+            value_length = struct.unpack('>H', headers_data[offset:offset + 2])[0]
+            offset += 2
+            
+            if offset + value_length > len(headers_data):
+                break
+            
+            if value_type == 7:
+                value = headers_data[offset:offset + value_length].decode('utf-8')
+            else:
+                value = headers_data[offset:offset + value_length]
+            
+            offset += value_length
+            headers[name] = value
+        
+        return headers
+    
+    @staticmethod
+    def parse_message(data: bytes) -> Optional[Dict[str, Any]]:
+        """Parse single Event Stream message."""
+        try:
+            if len(data) < 16:
+                return None
+            
+            total_length = struct.unpack('>I', data[0:4])[0]
+            headers_length = struct.unpack('>I', data[4:8])[0]
+            
+            if len(data) < total_length:
+                logger.warning(f"Incomplete message: expected {total_length} bytes, got {len(data)}")
+                return None
+            
+            headers_data = data[12:12 + headers_length]
+            headers = EventStreamParser.parse_headers(headers_data)
+            
+            payload_start = 12 + headers_length
+            payload_end = total_length - 4
+            payload_data = data[payload_start:payload_end]
+            
+            payload = None
+            if payload_data:
+                try:
+                    payload = json.loads(payload_data.decode('utf-8'))
+                except (json.JSONDecodeError, UnicodeDecodeError):
+                    payload = payload_data
+            
+            return {
+                'headers': headers,
+                'payload': payload,
+                'total_length': total_length
+            }
+        
+        except Exception as e:
+            logger.error(f"Failed to parse message: {e}", exc_info=True)
+            return None
+    
+    @staticmethod
+    async def parse_stream(byte_stream: AsyncIterator[bytes]) -> AsyncIterator[Dict[str, Any]]:
+        """Parse byte stream and extract events."""
+        buffer = bytearray()
+        
+        async for chunk in byte_stream:
+            buffer.extend(chunk)
+            
+            while len(buffer) >= 12:
+                try:
+                    total_length = struct.unpack('>I', buffer[0:4])[0]
+                except struct.error:
+                    break
+                
+                if len(buffer) < total_length:
+                    break
+                
+                message_data = bytes(buffer[:total_length])
+                buffer = buffer[total_length:]
+                
+                message = EventStreamParser.parse_message(message_data)
+                if message:
+                    yield message
+
+def extract_event_info(message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+    """Extract event information from parsed message."""
+    headers = message.get('headers', {})
+    payload = message.get('payload')
+    
+    event_type = headers.get(':event-type') or headers.get('event-type')
+    content_type = headers.get(':content-type') or headers.get('content-type')
+    message_type = headers.get(':message-type') or headers.get('message-type')
+    
+    return {
+        'event_type': event_type,
+        'content_type': content_type,
+        'message_type': message_type,
+        'payload': payload
+    }
+
+def _sse_format(event_type: str, data: Dict[str, Any]) -> str:
+    """Format SSE event."""
+    json_data = json.dumps(data, ensure_ascii=False)
+    return f"event: {event_type}\ndata: {json_data}\n\n"
+
+def build_message_start(conversation_id: str, model: str = "claude-sonnet-4.5", input_tokens: int = 0) -> str:
+    """Build message_start SSE event."""
+    data = {
+        "type": "message_start",
+        "message": {
+            "id": conversation_id,
+            "type": "message",
+            "role": "assistant",
+            "content": [],
+            "model": model,
+            "stop_reason": None,
+            "stop_sequence": None,
+            "usage": {"input_tokens": input_tokens, "output_tokens": 0}
+        }
+    }
+    return _sse_format("message_start", data)
+
+def build_content_block_start(index: int, block_type: str = "text") -> str:
+    """Build content_block_start SSE event."""
+    data = {
+        "type": "content_block_start",
+        "index": index,
+        "content_block": {"type": block_type, "text": ""} if block_type == "text" else {"type": block_type}
+    }
+    return _sse_format("content_block_start", data)
+
+def build_content_block_delta(index: int, text: str) -> str:
+    """Build content_block_delta SSE event (text)."""
+    data = {
+        "type": "content_block_delta",
+        "index": index,
+        "delta": {"type": "text_delta", "text": text}
+    }
+    return _sse_format("content_block_delta", data)
+
+def build_content_block_stop(index: int) -> str:
+    """Build content_block_stop SSE event."""
+    data = {
+        "type": "content_block_stop",
+        "index": index
+    }
+    return _sse_format("content_block_stop", data)
+
+def build_ping() -> str:
+    """Build ping SSE event."""
+    data = {"type": "ping"}
+    return _sse_format("ping", data)
+
+def build_message_stop(input_tokens: int, output_tokens: int, stop_reason: Optional[str] = None) -> str:
+    """Build message_delta and message_stop SSE events."""
+    delta_data = {
+        "type": "message_delta",
+        "delta": {"stop_reason": stop_reason or "end_turn", "stop_sequence": None},
+        "usage": {"output_tokens": output_tokens}
+    }
+    delta_event = _sse_format("message_delta", delta_data)
+    
+    stop_data = {
+        "type": "message_stop"
+    }
+    stop_event = _sse_format("message_stop", stop_data)
+    
+    return delta_event + stop_event
+
+def build_tool_use_start(index: int, tool_use_id: str, tool_name: str) -> str:
+    """Build tool_use content_block_start SSE event."""
+    data = {
+        "type": "content_block_start",
+        "index": index,
+        "content_block": {
+            "type": "tool_use",
+            "id": tool_use_id,
+            "name": tool_name,
+            "input": {}
+        }
+    }
+    return _sse_format("content_block_start", data)
+
+def build_tool_use_input_delta(index: int, input_json_delta: str) -> str:
+    """Build tool_use input_json_delta SSE event."""
+    data = {
+        "type": "content_block_delta",
+        "index": index,
+        "delta": {
+            "type": "input_json_delta",
+            "partial_json": input_json_delta
+        }
+    }
+    return _sse_format("content_block_delta", data)

+ 164 - 0
claude_stream.py

@@ -0,0 +1,164 @@
+import json
+import logging
+import importlib.util
+from pathlib import Path
+from typing import AsyncGenerator, Optional, Dict, Any, List, Set
+
+logger = logging.getLogger(__name__)
+
+def _load_claude_parser():
+    """Dynamically load claude_parser module."""
+    base_dir = Path(__file__).resolve().parent
+    spec = importlib.util.spec_from_file_location("v2_claude_parser", str(base_dir / "claude_parser.py"))
+    module = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(module)
+    return module
+
+try:
+    _parser = _load_claude_parser()
+    build_message_start = _parser.build_message_start
+    build_content_block_start = _parser.build_content_block_start
+    build_content_block_delta = _parser.build_content_block_delta
+    build_content_block_stop = _parser.build_content_block_stop
+    build_ping = _parser.build_ping
+    build_message_stop = _parser.build_message_stop
+    build_tool_use_start = _parser.build_tool_use_start
+    build_tool_use_input_delta = _parser.build_tool_use_input_delta
+except Exception as e:
+    logger.error(f"Failed to load claude_parser: {e}")
+    # Fallback definitions
+    def build_message_start(*args, **kwargs): return ""
+    def build_content_block_start(*args, **kwargs): return ""
+    def build_content_block_delta(*args, **kwargs): return ""
+    def build_content_block_stop(*args, **kwargs): return ""
+    def build_ping(*args, **kwargs): return ""
+    def build_message_stop(*args, **kwargs): return ""
+    def build_tool_use_start(*args, **kwargs): return ""
+    def build_tool_use_input_delta(*args, **kwargs): return ""
+
+class ClaudeStreamHandler:
+    def __init__(self, model: str, input_tokens: int = 0):
+        self.model = model
+        self.input_tokens = input_tokens
+        self.response_buffer: List[str] = []
+        self.content_block_index: int = -1
+        self.content_block_started: bool = False
+        self.content_block_start_sent: bool = False
+        self.content_block_stop_sent: bool = False
+        self.message_start_sent: bool = False
+        self.conversation_id: Optional[str] = None
+        
+        # Tool use state
+        self.current_tool_use: Optional[Dict[str, Any]] = None
+        self.tool_input_buffer: List[str] = []
+        self.tool_use_id: Optional[str] = None
+        self.tool_name: Optional[str] = None
+        self._processed_tool_use_ids: Set[str] = set()
+        self.all_tool_inputs: List[str] = []
+
+    async def handle_event(self, event_type: str, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
+        """Process a single Amazon Q event and yield Claude SSE events."""
+        
+        # 1. Message Start (initial-response)
+        if event_type == "initial-response":
+            if not self.message_start_sent:
+                conv_id = payload.get('conversationId', self.conversation_id or 'unknown')
+                self.conversation_id = conv_id
+                yield build_message_start(conv_id, self.model, self.input_tokens)
+                self.message_start_sent = True
+                yield build_ping()
+
+        # 2. Content Block Delta (assistantResponseEvent)
+        elif event_type == "assistantResponseEvent":
+            content = payload.get("content", "")
+            
+            # Close any open tool use block
+            if self.current_tool_use and not self.content_block_stop_sent:
+                yield build_content_block_stop(self.content_block_index)
+                self.content_block_stop_sent = True
+                self.current_tool_use = None
+
+            # Start content block if needed
+            if not self.content_block_start_sent:
+                self.content_block_index += 1
+                yield build_content_block_start(self.content_block_index, "text")
+                self.content_block_start_sent = True
+                self.content_block_started = True
+
+            # Send delta
+            if content:
+                self.response_buffer.append(content)
+                yield build_content_block_delta(self.content_block_index, content)
+
+        # 3. Tool Use (toolUseEvent)
+        elif event_type == "toolUseEvent":
+            tool_use_id = payload.get("toolUseId")
+            tool_name = payload.get("name")
+            tool_input = payload.get("input", {})
+            is_stop = payload.get("stop", False)
+
+            # Start new tool use
+            if tool_use_id and tool_name and not self.current_tool_use:
+                # Close previous text block if open
+                if self.content_block_start_sent and not self.content_block_stop_sent:
+                    yield build_content_block_stop(self.content_block_index)
+                    self.content_block_stop_sent = True
+
+                self._processed_tool_use_ids.add(tool_use_id)
+                self.content_block_index += 1
+                
+                yield build_tool_use_start(self.content_block_index, tool_use_id, tool_name)
+                
+                self.content_block_started = True
+                self.current_tool_use = {"toolUseId": tool_use_id, "name": tool_name}
+                self.tool_use_id = tool_use_id
+                self.tool_name = tool_name
+                self.tool_input_buffer = []
+                self.content_block_stop_sent = False
+                self.content_block_start_sent = True
+
+            # Accumulate input
+            if self.current_tool_use and tool_input:
+                fragment = ""
+                if isinstance(tool_input, str):
+                    fragment = tool_input
+                else:
+                    fragment = json.dumps(tool_input, ensure_ascii=False)
+                
+                self.tool_input_buffer.append(fragment)
+                yield build_tool_use_input_delta(self.content_block_index, fragment)
+
+            # Stop tool use
+            if is_stop and self.current_tool_use:
+                full_input = "".join(self.tool_input_buffer)
+                self.all_tool_inputs.append(full_input)
+                
+                yield build_content_block_stop(self.content_block_index)
+                self.content_block_stop_sent = True
+                self.content_block_started = False
+                self.current_tool_use = None
+                self.tool_use_id = None
+                self.tool_name = None
+                self.tool_input_buffer = []
+
+        # 4. Assistant Response End (assistantResponseEnd)
+        elif event_type == "assistantResponseEnd":
+            # Close any open block
+            if self.content_block_started and not self.content_block_stop_sent:
+                yield build_content_block_stop(self.content_block_index)
+                self.content_block_stop_sent = True
+
+    async def finish(self) -> AsyncGenerator[str, None]:
+        """Send final events."""
+        # Ensure last block is closed
+        if self.content_block_started and not self.content_block_stop_sent:
+            yield build_content_block_stop(self.content_block_index)
+            self.content_block_stop_sent = True
+
+        # Calculate output tokens (approximate)
+        full_text = "".join(self.response_buffer)
+        full_tool_input = "".join(self.all_tool_inputs)
+        # Simple approximation: 4 chars per token
+        output_tokens = max(1, (len(full_text) + len(full_tool_input)) // 4)
+
+        yield build_message_stop(self.input_tokens, output_tokens, "end_turn")

+ 20 - 0
claude_types.py

@@ -0,0 +1,20 @@
+from typing import List, Optional, Union, Dict, Any, Literal
+from pydantic import BaseModel, Field
+
+class ClaudeMessage(BaseModel):
+    role: str
+    content: Union[str, List[Dict[str, Any]]]
+
+class ClaudeTool(BaseModel):
+    name: str
+    description: Optional[str] = ""
+    input_schema: Dict[str, Any]
+
+class ClaudeRequest(BaseModel):
+    model: str
+    messages: List[ClaudeMessage]
+    max_tokens: int = 4096
+    temperature: Optional[float] = None
+    tools: Optional[List[ClaudeTool]] = None
+    stream: bool = False
+    system: Optional[Union[str, List[Dict[str, Any]]]] = None

+ 77 - 31
replicate.py

@@ -5,6 +5,24 @@ from pathlib import Path
 from typing import Dict, Optional, Tuple, Iterator, List, AsyncGenerator, Any
 import struct
 import httpx
+import importlib.util
+
+def _load_claude_parser():
+    """Dynamically load claude_parser module."""
+    base_dir = Path(__file__).resolve().parent
+    spec = importlib.util.spec_from_file_location("v2_claude_parser", str(base_dir / "claude_parser.py"))
+    module = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(module)
+    return module
+
+try:
+    _parser = _load_claude_parser()
+    EventStreamParser = _parser.EventStreamParser
+    extract_event_info = _parser.extract_event_info
+except Exception as e:
+    print(f"Warning: Failed to load claude_parser: {e}")
+    EventStreamParser = None
+    extract_event_info = None
 
 class StreamTracker:
     def __init__(self):
@@ -198,17 +216,28 @@ async def send_chat_request(
     model: Optional[str] = None,
     stream: bool = False,
     timeout: Tuple[int,int] = (15,300),
-    client: Optional[httpx.AsyncClient] = None
-) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker]:
+    client: Optional[httpx.AsyncClient] = None,
+    raw_payload: Optional[Dict[str, Any]] = None
+) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker, Optional[AsyncGenerator[Any, None]]]:
     url, headers_from_log, body_json = load_template()
     headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
-    try:
-        body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
-    except Exception:
-        pass
-    history_text = openai_messages_to_text(messages)
-    inject_history(body_json, history_text)
-    inject_model(body_json, model)
+    
+    if raw_payload:
+        # Use raw payload if provided (for Claude API)
+        body_json = raw_payload
+        # Ensure conversationId is set if missing
+        if "conversationState" in body_json and "conversationId" not in body_json["conversationState"]:
+             body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
+    else:
+        # Standard OpenAI-compatible logic
+        try:
+            body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
+        except Exception:
+            pass
+        history_text = openai_messages_to_text(messages)
+        inject_history(body_json, history_text)
+        inject_model(body_json, model)
+
     payload_str = json.dumps(body_json, ensure_ascii=False)
     headers = _merge_headers(headers_from_log, access_token)
     
@@ -246,37 +275,54 @@ async def send_chat_request(
         parser = AwsEventStreamParser()
         tracker = StreamTracker()
         
-        async def _iter_text() -> AsyncGenerator[str, None]:
+        async def _iter_events() -> AsyncGenerator[Any, None]:
             try:
-                async for chunk in resp.aiter_bytes():
-                    if not chunk:
-                        continue
-                    events = parser.feed(chunk)
-                    for _ev_headers, payload in events:
-                        parsed = _try_decode_event_payload(payload)
-                        if parsed is not None:
-                            text = _extract_text_from_event(parsed)
-                            if isinstance(text, str) and text:
-                                yield text
-                        else:
-                            try:
-                                txt = payload.decode("utf-8", errors="ignore")
-                                if txt:
-                                    yield txt
-                            except Exception:
-                                pass
+                if EventStreamParser and extract_event_info:
+                    # Use proper EventStreamParser
+                    async def byte_gen():
+                        async for chunk in resp.aiter_bytes():
+                            if chunk:
+                                yield chunk
+                    
+                    async for message in EventStreamParser.parse_stream(byte_gen()):
+                        event_info = extract_event_info(message)
+                        if event_info:
+                            event_type = event_info.get('event_type')
+                            payload = event_info.get('payload')
+                            if event_type and payload:
+                                yield (event_type, payload)
+                else:
+                    # Fallback to old parser
+                    async for chunk in resp.aiter_bytes():
+                        if not chunk:
+                            continue
+                        events = parser.feed(chunk)
+                        for ev_headers, payload in events:
+                            parsed = _try_decode_event_payload(payload)
+                            if parsed is not None:
+                                event_type = None
+                                if ":event-type" in ev_headers:
+                                    event_type = ev_headers[":event-type"]
+                                yield (event_type, parsed)
             except Exception:
-                # If we have already yielded content, suppress the error to allow partial success.
-                # If no content has been yielded yet (tracker.has_content is False), re-raise.
                 if not tracker.has_content:
                     raise
             finally:
                 await resp.aclose()
                 if local_client:
                     await client.aclose()
+
+        async def _iter_text() -> AsyncGenerator[str, None]:
+            async for event_type, parsed in _iter_events():
+                text = _extract_text_from_event(parsed)
+                if isinstance(text, str) and text:
+                    yield text
         
         if stream:
-            return None, tracker.track(_iter_text()), tracker
+            # If raw_payload is used, we might want the raw event stream
+            if raw_payload:
+                return None, None, tracker, _iter_events()
+            return None, tracker.track(_iter_text()), tracker, None
         else:
             buf = []
             try:
@@ -285,7 +331,7 @@ async def send_chat_request(
             finally:
                 # Ensure cleanup if not streamed
                 pass
-            return "".join(buf), None, tracker
+            return "".join(buf), None, tracker, None
 
     except Exception:
         if local_client and client: