Forráskód Böngészése

feat: Add 150K token limit check with proper format responses

- Check token count before processing requests in both Claude and OpenAI endpoints
- Return 200 status with formatted assistant message when limit exceeded
- Support both streaming and non-streaming responses
- Include actual token usage in response for transparency
- Error message: 'Context too long: X tokens exceeds the 150,000 token limit. Please compress your context and retry.'
CassiopeiaCode 3 hete
szülő
commit
01dcb6dc32
1 módosított fájl, 149 hozzáadás és 8 törlés
  1. 149 8
      app.py

+ 149 - 8
app.py

@@ -708,6 +708,95 @@ async def claude_messages(
     """
     Claude-compatible messages endpoint.
     """
+    # 0. Check token limit (150K tokens)
+    text_to_count = ""
+    if req.system:
+        if isinstance(req.system, str):
+            text_to_count += req.system
+        elif isinstance(req.system, list):
+            for item in req.system:
+                if isinstance(item, dict) and item.get("type") == "text":
+                    text_to_count += item.get("text", "")
+
+    for msg in req.messages:
+        if isinstance(msg.content, str):
+            text_to_count += msg.content
+        elif isinstance(msg.content, list):
+            for item in msg.content:
+                if isinstance(item, dict) and item.get("type") == "text":
+                    text_to_count += item.get("text", "")
+
+    if req.tools:
+        text_to_count += json.dumps([tool.model_dump() if hasattr(tool, 'model_dump') else tool for tool in req.tools], ensure_ascii=False)
+
+    input_tokens = count_tokens(text_to_count, apply_multiplier=True)
+    
+    # Return Claude-format response if token limit exceeded
+    if input_tokens > 150000:
+        error_message = f"Context too long: {input_tokens} tokens exceeds the 150,000 token limit. Please compress your context and retry."
+        
+        if req.stream:
+            # Streaming response
+            async def error_stream():
+                # message_start event
+                yield _sse_format({
+                    "type": "message_start",
+                    "message": {
+                        "id": f"msg_{uuid.uuid4()}",
+                        "type": "message",
+                        "role": "assistant",
+                        "model": req.model,
+                        "content": [],
+                        "stop_reason": None,
+                        "stop_sequence": None,
+                        "usage": {"input_tokens": input_tokens, "output_tokens": 0}
+                    }
+                })
+                # content_block_start event
+                yield _sse_format({
+                    "type": "content_block_start",
+                    "index": 0,
+                    "content_block": {"type": "text", "text": ""}
+                })
+                # content_block_delta event
+                yield _sse_format({
+                    "type": "content_block_delta",
+                    "index": 0,
+                    "delta": {"type": "text_delta", "text": error_message}
+                })
+                # content_block_stop event
+                yield _sse_format({
+                    "type": "content_block_stop",
+                    "index": 0
+                })
+                # message_delta event
+                yield _sse_format({
+                    "type": "message_delta",
+                    "delta": {"stop_reason": "end_turn", "stop_sequence": None},
+                    "usage": {"output_tokens": count_tokens(error_message)}
+                })
+                # message_stop event
+                yield _sse_format({"type": "message_stop"})
+            
+            return StreamingResponse(error_stream(), media_type="text/event-stream")
+        else:
+            # Non-streaming response
+            output_tokens = count_tokens(error_message)
+            response_body = {
+                "id": f"msg_{uuid.uuid4()}",
+                "type": "message",
+                "role": "assistant",
+                "model": req.model,
+                "content": [{"type": "text", "text": error_message}],
+                "stop_reason": "end_turn",
+                "stop_sequence": None,
+                "usage": {
+                    "input_tokens": input_tokens,
+                    "output_tokens": output_tokens
+                }
+            }
+            return JSONResponse(content=response_body, status_code=200)
+    
     # 1. Convert request
     # Always generate a new conversation_id like amq2api does
     # Using the same conversation_id can cause Amazon Q to return cached/stale data
@@ -997,6 +1086,64 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
     - messages will be converted into "{role}:\n{content}" and injected into template
     - account is chosen randomly among enabled accounts (API key is for authorization only)
     """
+    # Check token limit (150K tokens)
+    prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
+    prompt_tokens = count_tokens(prompt_text, apply_multiplier=True)
+    
+    # Return OpenAI-format response if token limit exceeded
+    if prompt_tokens > 150000:
+        error_message = f"Context too long: {prompt_tokens} tokens exceeds the 150,000 token limit. Please compress your context and retry."
+        
+        if req.stream:
+            # Streaming response
+            created = int(time.time())
+            stream_id = f"chatcmpl-{uuid.uuid4()}"
+            model_used = req.model or "unknown"
+            
+            async def error_stream():
+                # Send role first
+                yield _sse_format({
+                    "id": stream_id,
+                    "object": "chat.completion.chunk",
+                    "created": created,
+                    "model": model_used,
+                    "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
+                })
+                # Send error message
+                yield _sse_format({
+                    "id": stream_id,
+                    "object": "chat.completion.chunk",
+                    "created": created,
+                    "model": model_used,
+                    "choices": [{"index": 0, "delta": {"content": error_message}, "finish_reason": None}],
+                })
+                # Send stop and usage
+                completion_tokens = count_tokens(error_message)
+                yield _sse_format({
+                    "id": stream_id,
+                    "object": "chat.completion.chunk",
+                    "created": created,
+                    "model": model_used,
+                    "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
+                    "usage": {
+                        "prompt_tokens": prompt_tokens,
+                        "completion_tokens": completion_tokens,
+                        "total_tokens": prompt_tokens + completion_tokens,
+                    }
+                })
+                yield "data: [DONE]\n\n"
+            
+            return StreamingResponse(error_stream(), media_type="text/event-stream")
+        else:
+            # Non-streaming response
+            completion_tokens = count_tokens(error_message)
+            return JSONResponse(content=_openai_non_streaming_response(
+                error_message,
+                req.model,
+                prompt_tokens=prompt_tokens,
+                completion_tokens=completion_tokens
+            ), status_code=200)
+    
     model = map_model_name(req.model)
     do_stream = bool(req.stream)
 
@@ -1015,10 +1162,7 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
 
     if not do_stream:
         try:
-            # Calculate prompt tokens
-            prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
-            prompt_tokens = count_tokens(prompt_text)
-
+            # Token count already calculated above
             text, _, tracker = await _send_upstream(stream=False)
             await _update_stats(account["id"], bool(text))
             
@@ -1040,10 +1184,7 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
         
         it = None
         try:
-            # Calculate prompt tokens
-            prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
-            prompt_tokens = count_tokens(prompt_text)
-
+            # Token count already calculated above
             _, it, tracker = await _send_upstream(stream=True)
             assert it is not None