Kaynağa Gözat

实现连接强制回收机制避免死连接

- 添加get_global_client()动态获取客户端实例
- 实现_recycle_global_client()后台任务,每3分钟回收一次
- 采用平滑过渡策略:先创建新客户端,等待60秒后再关闭旧客户端
- 所有使用点改为动态获取,确保新请求使用新客户端
- 添加logging模块支持回收日志记录
CassiopeiaCode 4 hafta önce
ebeveyn
işleme
8e0c4873bd
1 değiştirilmiş dosya ile 80 ekleme ve 4 silme
  1. 80 4
      app.py

+ 80 - 4
app.py

@@ -8,6 +8,7 @@ import importlib.util
 import random
 import secrets
 import re
+import logging
 from datetime import datetime, timedelta
 from pathlib import Path
 from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
@@ -22,6 +23,16 @@ import tiktoken
 
 from db import init_db, close_db, row_to_dict
 
+# ------------------------------------------------------------------------------
+# Logger
+# ------------------------------------------------------------------------------
+
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
 # ------------------------------------------------------------------------------
 # Tokenizer
 # ------------------------------------------------------------------------------
@@ -171,12 +182,76 @@ async def _init_global_client():
     )
     GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
 
-async def _close_global_client():
+def get_global_client() -> Optional[httpx.AsyncClient]:
+    """获取当前的全局客户端实例(动态获取,确保总是最新的)"""
+    return GLOBAL_CLIENT
     global GLOBAL_CLIENT
     if GLOBAL_CLIENT:
         await GLOBAL_CLIENT.aclose()
         GLOBAL_CLIENT = None
 
+async def _recycle_global_client():
+    """定期回收并重建全局HTTP客户端,避免死连接累积
+    
+    策略:先创建新客户端,等待1分钟后再关闭旧客户端,确保平滑过渡
+    """
+    while True:
+        try:
+            await asyncio.sleep(180)  # 每3分钟回收一次
+            logger.info("[连接回收] 开始回收全局HTTP客户端...")
+            
+            # 保存旧客户端引用
+            global GLOBAL_CLIENT
+            old_client = GLOBAL_CLIENT
+            
+            # 创建新客户端
+            proxies = _get_proxies()
+            mounts = None
+            if proxies:
+                proxy_url = proxies.get("https") or proxies.get("http")
+                if proxy_url:
+                    mounts = {
+                        "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                        "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                    }
+            
+            limits = httpx.Limits(
+                max_keepalive_connections=500,
+                max_connections=500,
+                keepalive_expiry=1.0
+            )
+            timeout = httpx.Timeout(
+                connect=2.0,
+                read=300.0,
+                write=2.0,
+                pool=1.0
+            )
+            
+            # 替换为新客户端
+            GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
+            logger.info("[连接回收] 新客户端已创建,等待60秒后关闭旧客户端...")
+            
+            # 等待1分钟,让正在使用旧客户端的请求完成
+            await asyncio.sleep(60)
+            
+            # 关闭旧客户端
+            if old_client:
+                try:
+                    await old_client.aclose()
+                    logger.info("[连接回收] 旧客户端已关闭,回收完成")
+                except Exception as e:
+                    logger.warning(f"[连接回收] 关闭旧客户端时出错: {e}")
+            
+        except Exception as e:
+            logger.error(f"[连接回收] 回收失败: {e}")
+            traceback.print_exc()
+            # 确保客户端可用
+            try:
+                if GLOBAL_CLIENT is None:
+                    await _init_global_client()
+            except Exception:
+                pass
+
 # ------------------------------------------------------------------------------
 # Database helpers
 # ------------------------------------------------------------------------------
@@ -420,7 +495,7 @@ async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
 
     try:
         # Use global client if available, else fallback (though global should be ready)
-        client = GLOBAL_CLIENT
+        client = get_global_client()
         if not client:
             # Fallback for safety
             async with httpx.AsyncClient(timeout=60.0) as temp_client:
@@ -630,7 +705,7 @@ async def claude_messages(
             messages=[],
             model=map_model_name(req.model),
             stream=True,
-            client=GLOBAL_CLIENT,
+            client=get_global_client(),
             raw_payload=aq_request
         )
 
@@ -877,7 +952,7 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
         # Note: send_chat_request signature changed, but we use keyword args so it should be fine if we don't pass raw_payload
         # But wait, the return signature changed too! It now returns 4 values.
         # We need to unpack 4 values.
-        result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
+        result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=get_global_client())
         return result[0], result[1], result[2] # Ignore the 4th value (event_stream) for OpenAI endpoint
 
     if not do_stream:
@@ -1428,6 +1503,7 @@ async def startup_event():
     await _init_global_client()
     await _ensure_db()
     asyncio.create_task(_refresh_stale_tokens())
+    asyncio.create_task(_recycle_global_client())  # 启动连接回收任务
     # asyncio.create_task(_verify_disabled_accounts_loop())
 
 @app.on_event("shutdown")