|
|
@@ -51,6 +51,39 @@ def _load_replicate_module():
|
|
|
_replicate = _load_replicate_module()
|
|
|
send_chat_request = _replicate.send_chat_request
|
|
|
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+# Global HTTP Client
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+
|
|
|
+GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
|
|
|
+
|
|
|
+def _get_proxies() -> Optional[Dict[str, str]]:
|
|
|
+ proxy = os.getenv("HTTP_PROXY", "").strip()
|
|
|
+ if proxy:
|
|
|
+ return {"http": proxy, "https": proxy}
|
|
|
+ return None
|
|
|
+
|
|
|
+async def _init_global_client():
|
|
|
+ global GLOBAL_CLIENT
|
|
|
+ proxies = _get_proxies()
|
|
|
+ mounts = None
|
|
|
+ if proxies:
|
|
|
+ proxy_url = proxies.get("https") or proxies.get("http")
|
|
|
+ if proxy_url:
|
|
|
+ mounts = {
|
|
|
+ "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
+ "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
+ }
|
|
|
+ # Increased limits for high concurrency
|
|
|
+ limits = httpx.Limits(max_keepalive_connections=100, max_connections=200)
|
|
|
+ GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=60.0, limits=limits)
|
|
|
+
|
|
|
+async def _close_global_client():
|
|
|
+ global GLOBAL_CLIENT
|
|
|
+ if GLOBAL_CLIENT:
|
|
|
+ await GLOBAL_CLIENT.aclose()
|
|
|
+ GLOBAL_CLIENT = None
|
|
|
+
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# SQLite helpers
|
|
|
# ------------------------------------------------------------------------------
|
|
|
@@ -58,6 +91,7 @@ send_chat_request = _replicate.send_chat_request
|
|
|
async def _ensure_db():
|
|
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
|
async with aiosqlite.connect(DB_PATH) as conn:
|
|
|
+ await conn.execute("PRAGMA journal_mode=WAL;")
|
|
|
await conn.execute(
|
|
|
"""
|
|
|
CREATE TABLE IF NOT EXISTS accounts (
|
|
|
@@ -236,12 +270,6 @@ class ChatCompletionRequest(BaseModel):
|
|
|
OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
|
|
|
TOKEN_URL = f"{OIDC_BASE}/token"
|
|
|
|
|
|
-def _get_proxies() -> Optional[Dict[str, str]]:
|
|
|
- proxy = os.getenv("HTTP_PROXY", "").strip()
|
|
|
- if proxy:
|
|
|
- return {"http": proxy, "https": proxy}
|
|
|
- return None
|
|
|
-
|
|
|
def _oidc_headers() -> Dict[str, str]:
|
|
|
return {
|
|
|
"content-type": "application/json",
|
|
|
@@ -271,23 +299,23 @@ async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
- proxies = _get_proxies()
|
|
|
- mounts = None
|
|
|
- if proxies:
|
|
|
- proxy_url = proxies.get("https") or proxies.get("http")
|
|
|
- if proxy_url:
|
|
|
- mounts = {
|
|
|
- "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
- "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
- }
|
|
|
- async with httpx.AsyncClient(mounts=mounts, timeout=60.0) as client:
|
|
|
+ # Use global client if available, else fallback (though global should be ready)
|
|
|
+ client = GLOBAL_CLIENT
|
|
|
+ if not client:
|
|
|
+ # Fallback for safety
|
|
|
+ async with httpx.AsyncClient(timeout=60.0) as temp_client:
|
|
|
+ r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
|
|
+ r.raise_for_status()
|
|
|
+ data = r.json()
|
|
|
+ else:
|
|
|
r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
|
|
r.raise_for_status()
|
|
|
data = r.json()
|
|
|
- new_access = data.get("accessToken")
|
|
|
- new_refresh = data.get("refreshToken", acc.get("refreshToken"))
|
|
|
- now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
- status = "success"
|
|
|
+
|
|
|
+ new_access = data.get("accessToken")
|
|
|
+ new_refresh = data.get("refreshToken", acc.get("refreshToken"))
|
|
|
+ now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
+ status = "success"
|
|
|
except httpx.HTTPError as e:
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
status = "failed"
|
|
|
@@ -414,7 +442,7 @@ async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] =
|
|
|
access = refreshed.get("accessToken")
|
|
|
if not access:
|
|
|
raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
|
|
|
- return await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream)
|
|
|
+ return await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
|
|
|
|
|
|
if not do_stream:
|
|
|
try:
|
|
|
@@ -772,5 +800,10 @@ async def health():
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
"""Initialize database and start background tasks on startup."""
|
|
|
+ await _init_global_client()
|
|
|
await _ensure_db()
|
|
|
- asyncio.create_task(_refresh_stale_tokens())
|
|
|
+ asyncio.create_task(_refresh_stale_tokens())
|
|
|
+
|
|
|
[email protected]_event("shutdown")
|
|
|
+async def shutdown_event():
|
|
|
+ await _close_global_client()
|