|
|
@@ -8,6 +8,7 @@ import importlib.util
|
|
|
import random
|
|
|
import secrets
|
|
|
import re
|
|
|
+import logging
|
|
|
from datetime import datetime, timedelta
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
|
|
|
@@ -22,6 +23,16 @@ import tiktoken
|
|
|
|
|
|
from db import init_db, close_db, row_to_dict
|
|
|
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+# Logger
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+
|
|
|
+logging.basicConfig(
|
|
|
+ level=logging.INFO,
|
|
|
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
+)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Tokenizer
|
|
|
# ------------------------------------------------------------------------------
|
|
|
@@ -171,12 +182,76 @@ async def _init_global_client():
|
|
|
)
|
|
|
GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
|
|
|
|
|
|
-async def _close_global_client():
|
|
|
+def get_global_client() -> Optional[httpx.AsyncClient]:
|
|
|
+ """获取当前的全局客户端实例(动态获取,确保总是最新的)"""
|
|
|
+ return GLOBAL_CLIENT
|
|
|
global GLOBAL_CLIENT
|
|
|
if GLOBAL_CLIENT:
|
|
|
await GLOBAL_CLIENT.aclose()
|
|
|
GLOBAL_CLIENT = None
|
|
|
|
|
|
+async def _recycle_global_client():
|
|
|
+ """定期回收并重建全局HTTP客户端,避免死连接累积
|
|
|
+
|
|
|
+ 策略:先创建新客户端,等待1分钟后再关闭旧客户端,确保平滑过渡
|
|
|
+ """
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ await asyncio.sleep(180) # 每3分钟回收一次
|
|
|
+ logger.info("[连接回收] 开始回收全局HTTP客户端...")
|
|
|
+
|
|
|
+ # 保存旧客户端引用
|
|
|
+ global GLOBAL_CLIENT
|
|
|
+ old_client = GLOBAL_CLIENT
|
|
|
+
|
|
|
+ # 创建新客户端
|
|
|
+ proxies = _get_proxies()
|
|
|
+ mounts = None
|
|
|
+ if proxies:
|
|
|
+ proxy_url = proxies.get("https") or proxies.get("http")
|
|
|
+ if proxy_url:
|
|
|
+ mounts = {
|
|
|
+ "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
+ "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
+ }
|
|
|
+
|
|
|
+ limits = httpx.Limits(
|
|
|
+ max_keepalive_connections=500,
|
|
|
+ max_connections=500,
|
|
|
+ keepalive_expiry=1.0
|
|
|
+ )
|
|
|
+ timeout = httpx.Timeout(
|
|
|
+ connect=2.0,
|
|
|
+ read=300.0,
|
|
|
+ write=2.0,
|
|
|
+ pool=1.0
|
|
|
+ )
|
|
|
+
|
|
|
+ # 替换为新客户端
|
|
|
+ GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
|
|
|
+ logger.info("[连接回收] 新客户端已创建,等待60秒后关闭旧客户端...")
|
|
|
+
|
|
|
+ # 等待1分钟,让正在使用旧客户端的请求完成
|
|
|
+ await asyncio.sleep(60)
|
|
|
+
|
|
|
+ # 关闭旧客户端
|
|
|
+ if old_client:
|
|
|
+ try:
|
|
|
+ await old_client.aclose()
|
|
|
+ logger.info("[连接回收] 旧客户端已关闭,回收完成")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"[连接回收] 关闭旧客户端时出错: {e}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[连接回收] 回收失败: {e}")
|
|
|
+ traceback.print_exc()
|
|
|
+ # 确保客户端可用
|
|
|
+ try:
|
|
|
+ if GLOBAL_CLIENT is None:
|
|
|
+ await _init_global_client()
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Database helpers
|
|
|
# ------------------------------------------------------------------------------
|
|
|
@@ -420,7 +495,7 @@ async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
|
|
|
try:
|
|
|
# Use global client if available, else fallback (though global should be ready)
|
|
|
- client = GLOBAL_CLIENT
|
|
|
+ client = get_global_client()
|
|
|
if not client:
|
|
|
# Fallback for safety
|
|
|
async with httpx.AsyncClient(timeout=60.0) as temp_client:
|
|
|
@@ -630,7 +705,7 @@ async def claude_messages(
|
|
|
messages=[],
|
|
|
model=map_model_name(req.model),
|
|
|
stream=True,
|
|
|
- client=GLOBAL_CLIENT,
|
|
|
+ client=get_global_client(),
|
|
|
raw_payload=aq_request
|
|
|
)
|
|
|
|
|
|
@@ -877,7 +952,7 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
|
|
|
# Note: send_chat_request signature changed, but we use keyword args so it should be fine if we don't pass raw_payload
|
|
|
# But wait, the return signature changed too! It now returns 4 values.
|
|
|
# We need to unpack 4 values.
|
|
|
- result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
|
|
|
+ result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=get_global_client())
|
|
|
return result[0], result[1], result[2] # Ignore the 4th value (event_stream) for OpenAI endpoint
|
|
|
|
|
|
if not do_stream:
|
|
|
@@ -1428,6 +1503,7 @@ async def startup_event():
|
|
|
await _init_global_client()
|
|
|
await _ensure_db()
|
|
|
asyncio.create_task(_refresh_stale_tokens())
|
|
|
+ asyncio.create_task(_recycle_global_client()) # 启动连接回收任务
|
|
|
# asyncio.create_task(_verify_disabled_accounts_loop())
|
|
|
|
|
|
@app.on_event("shutdown")
|