Просмотр исходного кода

fix: optimize concurrency with SQLite WAL and global HTTP client, fix stream closed error

CassiopeiaCode 2 месяцев назад
Родитель
Сommit
61a15d53dd
2 измененных файлов с 122 добавлено и 53 удалено
  1. 55 22
      app.py
  2. 67 31
      replicate.py

+ 55 - 22
app.py

@@ -51,6 +51,39 @@ def _load_replicate_module():
 _replicate = _load_replicate_module()
 send_chat_request = _replicate.send_chat_request
 
+# ------------------------------------------------------------------------------
+# Global HTTP Client
+# ------------------------------------------------------------------------------
+
+GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
+
+def _get_proxies() -> Optional[Dict[str, str]]:
+    proxy = os.getenv("HTTP_PROXY", "").strip()
+    if proxy:
+        return {"http": proxy, "https": proxy}
+    return None
+
+async def _init_global_client():
+    global GLOBAL_CLIENT
+    proxies = _get_proxies()
+    mounts = None
+    if proxies:
+        proxy_url = proxies.get("https") or proxies.get("http")
+        if proxy_url:
+            mounts = {
+                "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+            }
+    # Increased limits for high concurrency
+    limits = httpx.Limits(max_keepalive_connections=100, max_connections=200)
+    GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=60.0, limits=limits)
+
+async def _close_global_client():
+    global GLOBAL_CLIENT
+    if GLOBAL_CLIENT:
+        await GLOBAL_CLIENT.aclose()
+        GLOBAL_CLIENT = None
+
 # ------------------------------------------------------------------------------
 # SQLite helpers
 # ------------------------------------------------------------------------------
@@ -58,6 +91,7 @@ send_chat_request = _replicate.send_chat_request
 async def _ensure_db():
     DB_PATH.parent.mkdir(parents=True, exist_ok=True)
     async with aiosqlite.connect(DB_PATH) as conn:
+        await conn.execute("PRAGMA journal_mode=WAL;")
         await conn.execute(
             """
             CREATE TABLE IF NOT EXISTS accounts (
@@ -236,12 +270,6 @@ class ChatCompletionRequest(BaseModel):
 OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
 TOKEN_URL = f"{OIDC_BASE}/token"
 
-def _get_proxies() -> Optional[Dict[str, str]]:
-    proxy = os.getenv("HTTP_PROXY", "").strip()
-    if proxy:
-        return {"http": proxy, "https": proxy}
-    return None
-
 def _oidc_headers() -> Dict[str, str]:
     return {
         "content-type": "application/json",
@@ -271,23 +299,23 @@ async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
         }
 
         try:
-            proxies = _get_proxies()
-            mounts = None
-            if proxies:
-                proxy_url = proxies.get("https") or proxies.get("http")
-                if proxy_url:
-                    mounts = {
-                        "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
-                        "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
-                    }
-            async with httpx.AsyncClient(mounts=mounts, timeout=60.0) as client:
+            # Use global client if available, else fallback (though global should be ready)
+            client = GLOBAL_CLIENT
+            if not client:
+                # Fallback for safety
+                async with httpx.AsyncClient(timeout=60.0) as temp_client:
+                    r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
+                    r.raise_for_status()
+                    data = r.json()
+            else:
                 r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
                 r.raise_for_status()
                 data = r.json()
-                new_access = data.get("accessToken")
-                new_refresh = data.get("refreshToken", acc.get("refreshToken"))
-                now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-                status = "success"
+
+            new_access = data.get("accessToken")
+            new_refresh = data.get("refreshToken", acc.get("refreshToken"))
+            now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
+            status = "success"
         except httpx.HTTPError as e:
             now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
             status = "failed"
@@ -414,7 +442,7 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
             access = refreshed.get("accessToken")
             if not access:
                 raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
-        return await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream)
+        return await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
 
     if not do_stream:
         try:
@@ -772,5 +800,10 @@ async def health():
 @app.on_event("startup")
 async def startup_event():
     """Initialize database and start background tasks on startup."""
+    await _init_global_client()
     await _ensure_db()
-    asyncio.create_task(_refresh_stale_tokens())
+    asyncio.create_task(_refresh_stale_tokens())
+
[email protected]_event("shutdown")
+async def shutdown_event():
+    await _close_global_client()

+ 67 - 31
replicate.py

@@ -192,7 +192,14 @@ def inject_model(body_json: Dict[str, Any], model: Optional[str]) -> None:
     except Exception:
         pass
 
-async def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model: Optional[str] = None, stream: bool = False, timeout: Tuple[int,int] = (15,300)) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker]:
+async def send_chat_request(
+    access_token: str,
+    messages: List[Dict[str, Any]],
+    model: Optional[str] = None,
+    stream: bool = False,
+    timeout: Tuple[int,int] = (15,300),
+    client: Optional[httpx.AsyncClient] = None
+) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker]:
     url, headers_from_log, body_json = load_template()
     headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
     try:
@@ -205,31 +212,42 @@ async def send_chat_request(access_token: str, messages: List[Dict[str, Any]], m
     payload_str = json.dumps(body_json, ensure_ascii=False)
     headers = _merge_headers(headers_from_log, access_token)
     
-    # Build mounts with proxy if available
-    proxies = _get_proxies()
-    mounts = None
-    if proxies:
-        proxy_url = proxies.get("https") or proxies.get("http")
-        if proxy_url:
-            mounts = {
-                "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
-                "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
-            }
+    local_client = False
+    if client is None:
+        local_client = True
+        proxies = _get_proxies()
+        mounts = None
+        if proxies:
+            proxy_url = proxies.get("https") or proxies.get("http")
+            if proxy_url:
+                mounts = {
+                    "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                    "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                }
+        client = httpx.AsyncClient(mounts=mounts, timeout=httpx.Timeout(timeout[0], read=timeout[1]))
     
-    async with httpx.AsyncClient(mounts=mounts, timeout=httpx.Timeout(timeout[0], read=timeout[1])) as client:
-        async with client.stream("POST", url, headers=headers, content=payload_str) as resp:
-            if resp.status_code >= 400:
-                try:
-                    err = await resp.aread()
-                    err = err.decode("utf-8", errors="ignore")
-                except Exception:
-                    err = f"HTTP {resp.status_code}"
-                raise httpx.HTTPError(f"Upstream error {resp.status_code}: {err}")
-            
-            parser = AwsEventStreamParser()
-            tracker = StreamTracker()
-            
-            async def _iter_text() -> AsyncGenerator[str, None]:
+    # Use manual request sending to control stream lifetime
+    req = client.build_request("POST", url, headers=headers, content=payload_str)
+    
+    try:
+        resp = await client.send(req, stream=True)
+        
+        if resp.status_code >= 400:
+            try:
+                await resp.read()
+                err = resp.text
+            except Exception:
+                err = f"HTTP {resp.status_code}"
+            await resp.aclose()
+            if local_client:
+                await client.aclose()
+            raise httpx.HTTPError(f"Upstream error {resp.status_code}: {err}")
+        
+        parser = AwsEventStreamParser()
+        tracker = StreamTracker()
+        
+        async def _iter_text() -> AsyncGenerator[str, None]:
+            try:
                 async for chunk in resp.aiter_bytes():
                     if not chunk:
                         continue
@@ -247,11 +265,29 @@ async def send_chat_request(access_token: str, messages: List[Dict[str, Any]], m
                                     yield txt
                             except Exception:
                                 pass
-            
-            if stream:
-                return None, tracker.track(_iter_text()), tracker
-            else:
-                buf = []
+            except Exception:
+                # If we have already yielded content, suppress the error to allow partial success.
+                # If no content has been yielded yet (tracker.has_content is False), re-raise.
+                if not tracker.has_content:
+                    raise
+            finally:
+                await resp.aclose()
+                if local_client:
+                    await client.aclose()
+        
+        if stream:
+            return None, tracker.track(_iter_text()), tracker
+        else:
+            buf = []
+            try:
                 async for t in tracker.track(_iter_text()):
                     buf.append(t)
-                return "".join(buf), None, tracker
+            finally:
+                # Ensure cleanup if not streamed
+                pass
+            return "".join(buf), None, tracker
+
+    except Exception:
+        if local_client and client:
+            await client.aclose()
+        raise