Răsfoiți Sursa

fix: Implement accurate token counting using tiktoken

CassiopeiaCode 2 luni în urmă
părinte
comite
f073e032d7
3 a modificat fișierele cu 101 adăugiri și 19 ștergeri
  1. 76 17
      app.py
  2. 23 1
      claude_stream.py
  3. 2 1
      requirements.txt

+ 76 - 17
app.py

@@ -16,6 +16,23 @@ from pydantic import BaseModel
 from dotenv import load_dotenv
 import httpx
 import aiosqlite
+import tiktoken
+
+# ------------------------------------------------------------------------------
+# Tokenizer
+# ------------------------------------------------------------------------------
+
+try:
+    # cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
+    ENCODING = tiktoken.get_encoding("cl100k_base")
+except Exception:
+    ENCODING = None
+
+def count_tokens(text: str) -> int:
+    """Counts tokens with tiktoken."""
+    if not text or not ENCODING:
+        return 0
+    return len(ENCODING.encode(text))
 
 # ------------------------------------------------------------------------------
 # Bootstrap
@@ -476,7 +493,12 @@ async def require_account(authorization: Optional[str] = Header(default=None)) -
 # OpenAI-compatible Chat endpoint
 # ------------------------------------------------------------------------------
 
-def _openai_non_streaming_response(text: str, model: Optional[str]) -> Dict[str, Any]:
+def _openai_non_streaming_response(
+    text: str,
+    model: Optional[str],
+    prompt_tokens: int = 0,
+    completion_tokens: int = 0,
+) -> Dict[str, Any]:
     created = int(time.time())
     return {
         "id": f"chatcmpl-{uuid.uuid4()}",
@@ -494,9 +516,9 @@ def _openai_non_streaming_response(text: str, model: Optional[str]) -> Dict[str,
             }
         ],
         "usage": {
-            "prompt_tokens": None,
-            "completion_tokens": None,
-            "total_tokens": None,
+            "prompt_tokens": prompt_tokens,
+            "completion_tokens": completion_tokens,
+            "total_tokens": prompt_tokens + completion_tokens,
         },
     }
 
@@ -587,7 +609,25 @@ async def claude_messages(req: ClaudeRequest, account: Dict[str, Any] = Depends(
         # Handler
         # Estimate input tokens (simple count or 0)
         # For now 0 or simple len
-        input_tokens = 0
+        # Calculate input tokens
+        text_to_count = ""
+        if req.system:
+            if isinstance(req.system, str):
+                text_to_count += req.system
+            elif isinstance(req.system, list):
+                for item in req.system:
+                    if isinstance(item, dict) and item.get("type") == "text":
+                        text_to_count += item.get("text", "")
+        
+        for msg in req.messages:
+            if isinstance(msg.content, str):
+                text_to_count += msg.content
+            elif isinstance(msg.content, list):
+                for item in msg.content:
+                    if isinstance(item, dict) and item.get("type") == "text":
+                        text_to_count += item.get("text", "")
+
+        input_tokens = count_tokens(text_to_count)
         handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens)
 
         async def event_generator():
@@ -714,9 +754,21 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
 
     if not do_stream:
         try:
+            # Calculate prompt tokens
+            prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
+            prompt_tokens = count_tokens(prompt_text)
+
             text, _, tracker = await _send_upstream(stream=False)
             await _update_stats(account["id"], bool(text))
-            return JSONResponse(content=_openai_non_streaming_response(text or "", model))
+            
+            completion_tokens = count_tokens(text or "")
+            
+            return JSONResponse(content=_openai_non_streaming_response(
+                text or "",
+                model,
+                prompt_tokens=prompt_tokens,
+                completion_tokens=completion_tokens
+            ))
         except Exception as e:
             await _update_stats(account["id"], False)
             raise
@@ -727,15 +779,17 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
         
         it = None
         try:
+            # Calculate prompt tokens
+            prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
+            prompt_tokens = count_tokens(prompt_text)
+
             _, it, tracker = await _send_upstream(stream=True)
             assert it is not None
-            first_piece = await it.__anext__()
-            if not first_piece:
-                await _update_stats(account["id"], False)
-                raise HTTPException(status_code=502, detail="No content from upstream")
             
             async def event_gen() -> AsyncGenerator[str, None]:
+                completion_text = ""
                 try:
+                    # Send role first
                     yield _sse_format({
                         "id": stream_id,
                         "object": "chat.completion.chunk",
@@ -743,15 +797,11 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
                         "model": model_used,
                         "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
                     })
-                    yield _sse_format({
-                        "id": stream_id,
-                        "object": "chat.completion.chunk",
-                        "created": created,
-                        "model": model_used,
-                        "choices": [{"index": 0, "delta": {"content": first_piece}, "finish_reason": None}],
-                    })
+                    
+                    # Stream content
                     async for piece in it:
                         if piece:
+                            completion_text += piece
                             yield _sse_format({
                                 "id": stream_id,
                                 "object": "chat.completion.chunk",
@@ -759,13 +809,22 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
                                 "model": model_used,
                                 "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
                             })
+                    
+                    # Send stop and usage
+                    completion_tokens = count_tokens(completion_text)
                     yield _sse_format({
                         "id": stream_id,
                         "object": "chat.completion.chunk",
                         "created": created,
                         "model": model_used,
                         "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
+                        "usage": {
+                            "prompt_tokens": prompt_tokens,
+                            "completion_tokens": completion_tokens,
+                            "total_tokens": prompt_tokens + completion_tokens,
+                        }
                     })
+                    
                     yield "data: [DONE]\n\n"
                     await _update_stats(account["id"], True)
                 except GeneratorExit:

+ 23 - 1
claude_stream.py

@@ -3,9 +3,30 @@ import logging
 import importlib.util
 from pathlib import Path
 from typing import AsyncGenerator, Optional, Dict, Any, List, Set
+import tiktoken
 
 logger = logging.getLogger(__name__)
 
+# ------------------------------------------------------------------------------
+# Tokenizer
+# ------------------------------------------------------------------------------
+
+try:
+    # cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
+    ENCODING = tiktoken.get_encoding("cl100k_base")
+except Exception:
+    ENCODING = None
+
+def count_tokens(text: str) -> int:
+    """Counts tokens with tiktoken."""
+    if not text or not ENCODING:
+        return 0
+    return len(ENCODING.encode(text))
+
+# ------------------------------------------------------------------------------
+# Dynamic Loader
+# ------------------------------------------------------------------------------
+
 def _load_claude_parser():
     """Dynamically load claude_parser module."""
     base_dir = Path(__file__).resolve().parent
@@ -159,6 +180,7 @@ class ClaudeStreamHandler:
         full_text = "".join(self.response_buffer)
         full_tool_input = "".join(self.all_tool_inputs)
         # Simple approximation: 4 chars per token
-        output_tokens = max(1, (len(full_text) + len(full_tool_input)) // 4)
+        # output_tokens = max(1, (len(full_text) + len(full_tool_input)) // 4)
+        output_tokens = count_tokens(full_text) + count_tokens(full_tool_input)
 
         yield build_message_stop(self.input_tokens, output_tokens, "end_turn")

+ 2 - 1
requirements.txt

@@ -3,4 +3,5 @@ uvicorn[standard]==0.32.0
 pydantic==2.9.2
 python-dotenv==1.0.1
 httpx==0.28.1
-aiosqlite==0.20.0
+aiosqlite==0.20.0
+tiktoken