|
|
@@ -16,6 +16,23 @@ from pydantic import BaseModel
|
|
|
from dotenv import load_dotenv
|
|
|
import httpx
|
|
|
import aiosqlite
|
|
|
+import tiktoken
|
|
|
+
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+# Tokenizer
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+
|
|
|
+try:
|
|
|
+ # cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
|
|
|
+ ENCODING = tiktoken.get_encoding("cl100k_base")
|
|
|
+except Exception:
|
|
|
+ ENCODING = None
|
|
|
+
|
|
|
+def count_tokens(text: str) -> int:
|
|
|
+ """Counts tokens with tiktoken."""
|
|
|
+ if not text or not ENCODING:
|
|
|
+ return 0
|
|
|
+ return len(ENCODING.encode(text))
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Bootstrap
|
|
|
@@ -476,7 +493,12 @@ async def require_account(authorization: Optional[str] = Header(default=None)) -
|
|
|
# OpenAI-compatible Chat endpoint
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
-def _openai_non_streaming_response(text: str, model: Optional[str]) -> Dict[str, Any]:
|
|
|
+def _openai_non_streaming_response(
|
|
|
+ text: str,
|
|
|
+ model: Optional[str],
|
|
|
+ prompt_tokens: int = 0,
|
|
|
+ completion_tokens: int = 0,
|
|
|
+) -> Dict[str, Any]:
|
|
|
created = int(time.time())
|
|
|
return {
|
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
|
@@ -494,9 +516,9 @@ def _openai_non_streaming_response(text: str, model: Optional[str]) -> Dict[str,
|
|
|
}
|
|
|
],
|
|
|
"usage": {
|
|
|
- "prompt_tokens": None,
|
|
|
- "completion_tokens": None,
|
|
|
- "total_tokens": None,
|
|
|
+ "prompt_tokens": prompt_tokens,
|
|
|
+ "completion_tokens": completion_tokens,
|
|
|
+ "total_tokens": prompt_tokens + completion_tokens,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
@@ -587,7 +609,25 @@ async def claude_messages(req: ClaudeRequest, account: Dict[str, Any] = Depends(
|
|
|
# Handler
|
|
|
# Estimate input tokens (simple count or 0)
|
|
|
# For now 0 or simple len
|
|
|
- input_tokens = 0
|
|
|
+ # Calculate input tokens
|
|
|
+ text_to_count = ""
|
|
|
+ if req.system:
|
|
|
+ if isinstance(req.system, str):
|
|
|
+ text_to_count += req.system
|
|
|
+ elif isinstance(req.system, list):
|
|
|
+ for item in req.system:
|
|
|
+ if isinstance(item, dict) and item.get("type") == "text":
|
|
|
+ text_to_count += item.get("text", "")
|
|
|
+
|
|
|
+ for msg in req.messages:
|
|
|
+ if isinstance(msg.content, str):
|
|
|
+ text_to_count += msg.content
|
|
|
+ elif isinstance(msg.content, list):
|
|
|
+ for item in msg.content:
|
|
|
+ if isinstance(item, dict) and item.get("type") == "text":
|
|
|
+ text_to_count += item.get("text", "")
|
|
|
+
|
|
|
+ input_tokens = count_tokens(text_to_count)
|
|
|
handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens)
|
|
|
|
|
|
async def event_generator():
|
|
|
@@ -714,9 +754,21 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
|
|
|
|
|
|
if not do_stream:
|
|
|
try:
|
|
|
+ # Calculate prompt tokens
|
|
|
+ prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
|
|
|
+ prompt_tokens = count_tokens(prompt_text)
|
|
|
+
|
|
|
text, _, tracker = await _send_upstream(stream=False)
|
|
|
await _update_stats(account["id"], bool(text))
|
|
|
- return JSONResponse(content=_openai_non_streaming_response(text or "", model))
|
|
|
+
|
|
|
+ completion_tokens = count_tokens(text or "")
|
|
|
+
|
|
|
+ return JSONResponse(content=_openai_non_streaming_response(
|
|
|
+ text or "",
|
|
|
+ model,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ completion_tokens=completion_tokens
|
|
|
+ ))
|
|
|
except Exception as e:
|
|
|
await _update_stats(account["id"], False)
|
|
|
raise
|
|
|
@@ -727,15 +779,17 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
|
|
|
|
|
|
it = None
|
|
|
try:
|
|
|
+ # Calculate prompt tokens
|
|
|
+ prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
|
|
|
+ prompt_tokens = count_tokens(prompt_text)
|
|
|
+
|
|
|
_, it, tracker = await _send_upstream(stream=True)
|
|
|
assert it is not None
|
|
|
- first_piece = await it.__anext__()
|
|
|
- if not first_piece:
|
|
|
- await _update_stats(account["id"], False)
|
|
|
- raise HTTPException(status_code=502, detail="No content from upstream")
|
|
|
|
|
|
async def event_gen() -> AsyncGenerator[str, None]:
|
|
|
+ completion_text = ""
|
|
|
try:
|
|
|
+ # Send role first
|
|
|
yield _sse_format({
|
|
|
"id": stream_id,
|
|
|
"object": "chat.completion.chunk",
|
|
|
@@ -743,15 +797,11 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
|
|
|
"model": model_used,
|
|
|
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
|
|
})
|
|
|
- yield _sse_format({
|
|
|
- "id": stream_id,
|
|
|
- "object": "chat.completion.chunk",
|
|
|
- "created": created,
|
|
|
- "model": model_used,
|
|
|
- "choices": [{"index": 0, "delta": {"content": first_piece}, "finish_reason": None}],
|
|
|
- })
|
|
|
+
|
|
|
+ # Stream content
|
|
|
async for piece in it:
|
|
|
if piece:
|
|
|
+ completion_text += piece
|
|
|
yield _sse_format({
|
|
|
"id": stream_id,
|
|
|
"object": "chat.completion.chunk",
|
|
|
@@ -759,13 +809,22 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
|
|
|
"model": model_used,
|
|
|
"choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
|
|
|
})
|
|
|
+
|
|
|
+ # Send stop and usage
|
|
|
+ completion_tokens = count_tokens(completion_text)
|
|
|
yield _sse_format({
|
|
|
"id": stream_id,
|
|
|
"object": "chat.completion.chunk",
|
|
|
"created": created,
|
|
|
"model": model_used,
|
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
|
+ "usage": {
|
|
|
+ "prompt_tokens": prompt_tokens,
|
|
|
+ "completion_tokens": completion_tokens,
|
|
|
+ "total_tokens": prompt_tokens + completion_tokens,
|
|
|
+ }
|
|
|
})
|
|
|
+
|
|
|
yield "data: [DONE]\n\n"
|
|
|
await _update_stats(account["id"], True)
|
|
|
except GeneratorExit:
|