Forráskód Böngészése

feat: 完成异步高并发改造 - 替换requests为httpx,sqlite3为aiosqlite

CassiopeiaCode 2 hónapja
szülő
commit
fe324554d9
4 módosított fájl, 311 hozzáadás és 226 törlés
  1. 192 155
      app.py
  2. 62 30
      auth_flow.py
  3. 55 40
      replicate.py
  4. 2 1
      requirements.txt

+ 192 - 155
app.py

@@ -3,19 +3,19 @@ import json
 import traceback
 import uuid
 import time
-import sqlite3
+import asyncio
 import importlib.util
 import random
-import threading
 from pathlib import Path
-from typing import Dict, Optional, List, Any, Generator, Tuple
+from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
 
 from fastapi import FastAPI, Depends, HTTPException, Header
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, FileResponse
 from pydantic import BaseModel
 from dotenv import load_dotenv
-import requests
+import httpx
+import aiosqlite
 
 # ------------------------------------------------------------------------------
 # Bootstrap
@@ -55,10 +55,10 @@ send_chat_request = _replicate.send_chat_request
 # SQLite helpers
 # ------------------------------------------------------------------------------
 
-def _ensure_db():
+async def _ensure_db():
     DB_PATH.parent.mkdir(parents=True, exist_ok=True)
-    with sqlite3.connect(DB_PATH) as conn:
-        conn.execute(
+    async with aiosqlite.connect(DB_PATH) as conn:
+        await conn.execute(
             """
             CREATE TABLE IF NOT EXISTS accounts (
                 id TEXT PRIMARY KEY,
@@ -77,23 +77,24 @@ def _ensure_db():
         )
         # add columns if missing
         try:
-            cols = [row[1] for row in conn.execute("PRAGMA table_info(accounts)").fetchall()]
-            if "enabled" not in cols:
-                conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
-            if "error_count" not in cols:
-                conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
-            if "success_count" not in cols:
-                conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
+            async with conn.execute("PRAGMA table_info(accounts)") as cursor:
+                rows = await cursor.fetchall()
+                cols = [row[1] for row in rows]
+                if "enabled" not in cols:
+                    await conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
+                if "error_count" not in cols:
+                    await conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
+                if "success_count" not in cols:
+                    await conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
         except Exception:
             pass
-        conn.commit()
+        await conn.commit()
 
-def _conn() -> sqlite3.Connection:
-    conn = sqlite3.connect(DB_PATH, check_same_thread=False)
-    conn.row_factory = sqlite3.Row
-    return conn
+def _conn() -> aiosqlite.Connection:
+    """Create a new database connection. Must be used with async with."""
+    return aiosqlite.connect(DB_PATH)
 
-def _row_to_dict(r: sqlite3.Row) -> Dict[str, Any]:
+def _row_to_dict(r: aiosqlite.Row) -> Dict[str, Any]:
     d = dict(r)
     if d.get("other"):
         try:
@@ -108,46 +109,46 @@ def _row_to_dict(r: sqlite3.Row) -> Dict[str, Any]:
             d["enabled"] = bool(d["enabled"])
     return d
 
-_ensure_db()
+# _ensure_db() will be called in startup event
 
 # ------------------------------------------------------------------------------
 # Background token refresh thread
 # ------------------------------------------------------------------------------
 
-def _refresh_stale_tokens():
+async def _refresh_stale_tokens():
     while True:
         try:
-            time.sleep(300)  # 5 minutes
+            await asyncio.sleep(300)  # 5 minutes
             now = time.time()
-            with _conn() as conn:
-                rows = conn.execute("SELECT id, last_refresh_time FROM accounts WHERE enabled=1").fetchall()
-                for row in rows:
-                    acc_id, last_refresh = row[0], row[1]
-                    should_refresh = False
-                    if not last_refresh or last_refresh == "never":
-                        should_refresh = True
-                    else:
-                        try:
-                            last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
-                            if now - last_time > 1500:  # 25 minutes
-                                should_refresh = True
-                        except Exception:
-                            # Malformed or unparsable timestamp; force refresh
+            async with _conn() as conn:
+                conn.row_factory = aiosqlite.Row
+                async with conn.execute("SELECT id, last_refresh_time FROM accounts WHERE enabled=1") as cursor:
+                    rows = await cursor.fetchall()
+                    for row in rows:
+                        acc_id, last_refresh = row[0], row[1]
+                        should_refresh = False
+                        if not last_refresh or last_refresh == "never":
                             should_refresh = True
+                        else:
+                            try:
+                                last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
+                                if now - last_time > 1500:  # 25 minutes
+                                    should_refresh = True
+                            except Exception:
+                                # Malformed or unparsable timestamp; force refresh
+                                should_refresh = True
 
-                    if should_refresh:
-                        try:
-                            refresh_access_token_in_db(acc_id)
-                        except Exception:
-                            traceback.print_exc()
-                            # Ignore per-account refresh failure; timestamp/status are recorded inside
-                            pass
+                        if should_refresh:
+                            try:
+                                await refresh_access_token_in_db(acc_id)
+                            except Exception:
+                                traceback.print_exc()
+                                # Ignore per-account refresh failure; timestamp/status are recorded inside
+                                pass
         except Exception:
             traceback.print_exc()
             pass
 
-threading.Thread(target=_refresh_stale_tokens, daemon=True).start()
-
 # ------------------------------------------------------------------------------
 # Env and API Key authorization (keys are independent of AWS accounts)
 # ------------------------------------------------------------------------------
@@ -174,11 +175,13 @@ def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
         return token_header.split(" ", 1)[1].strip()
     return token_header.strip()
 
-def _list_enabled_accounts(conn: sqlite3.Connection) -> List[Dict[str, Any]]:
-    rows = conn.execute("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC").fetchall()
-    return [_row_to_dict(r) for r in rows]
+async def _list_enabled_accounts(conn: aiosqlite.Connection) -> List[Dict[str, Any]]:
+    conn.row_factory = aiosqlite.Row
+    async with conn.execute("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC") as cursor:
+        rows = await cursor.fetchall()
+        return [_row_to_dict(r) for r in rows]
 
-def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
+async def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
     """
     Authorize request by OPENAI_KEYS (if configured), then select an AWS account.
     Selection strategy: random among all enabled accounts. Authorization key does NOT map to any account.
@@ -189,8 +192,8 @@ def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
             raise HTTPException(status_code=401, detail="Invalid or missing API key")
 
     # Selection: random among enabled accounts
-    with _conn() as conn:
-        candidates = _list_enabled_accounts(conn)
+    async with _conn() as conn:
+        candidates = await _list_enabled_accounts(conn)
         if not candidates:
             raise HTTPException(status_code=401, detail="No enabled account available")
         return random.choice(candidates)
@@ -248,12 +251,14 @@ def _oidc_headers() -> Dict[str, str]:
         "amz-sdk-invocation-id": str(uuid.uuid4()),
     }
 
-def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
-    with _conn() as conn:
-        row = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
-        if not row:
-            raise HTTPException(status_code=404, detail="Account not found")
-        acc = _row_to_dict(row)
+async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
+    async with _conn() as conn:
+        conn.row_factory = aiosqlite.Row
+        async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
+            row = await cursor.fetchone()
+            if not row:
+                raise HTTPException(status_code=404, detail="Account not found")
+            acc = _row_to_dict(row)
 
         if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
             raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
@@ -266,17 +271,27 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
         }
 
         try:
-            r = requests.post(TOKEN_URL, headers=_oidc_headers(), json=payload, timeout=(15, 60), proxies=_get_proxies())
-            r.raise_for_status()
-            data = r.json()
-            new_access = data.get("accessToken")
-            new_refresh = data.get("refreshToken", acc.get("refreshToken"))
-            now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-            status = "success"
-        except requests.RequestException as e:
+            proxies = _get_proxies()
+            mounts = None
+            if proxies:
+                proxy_url = proxies.get("https") or proxies.get("http")
+                if proxy_url:
+                    mounts = {
+                        "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                        "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                    }
+            async with httpx.AsyncClient(mounts=mounts, timeout=60.0) as client:
+                r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
+                r.raise_for_status()
+                data = r.json()
+                new_access = data.get("accessToken")
+                new_refresh = data.get("refreshToken", acc.get("refreshToken"))
+                now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
+                status = "success"
+        except httpx.HTTPError as e:
             now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
             status = "failed"
-            conn.execute(
+            await conn.execute(
                 """
                 UPDATE accounts
                 SET last_refresh_time=?, last_refresh_status=?, updated_at=?
@@ -284,13 +299,13 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
                 """,
                 (now, status, now, account_id),
             )
-            conn.commit()
+            await conn.commit()
             raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
         except Exception as e:
             # Ensure last_refresh_time is recorded even on unexpected errors
             now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
             status = "failed"
-            conn.execute(
+            await conn.execute(
                 """
                 UPDATE accounts
                 SET last_refresh_time=?, last_refresh_status=?, updated_at=?
@@ -298,10 +313,10 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
                 """,
                 (now, status, now, account_id),
             )
-            conn.commit()
+            await conn.commit()
             raise
 
-        conn.execute(
+        await conn.execute(
             """
             UPDATE accounts
             SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
@@ -309,42 +324,46 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
             """,
             (new_access, new_refresh, now, status, now, account_id),
         )
-        conn.commit()
-
-        row2 = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
-        return _row_to_dict(row2)
-
-def get_account(account_id: str) -> Dict[str, Any]:
-    with _conn() as conn:
-        row = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
-        if not row:
-            raise HTTPException(status_code=404, detail="Account not found")
-        return _row_to_dict(row)
-
-def _update_stats(account_id: str, success: bool) -> None:
-    with _conn() as conn:
+        await conn.commit()
+
+        async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
+            row2 = await cursor.fetchone()
+            return _row_to_dict(row2)
+
+async def get_account(account_id: str) -> Dict[str, Any]:
+    async with _conn() as conn:
+        conn.row_factory = aiosqlite.Row
+        async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
+            row = await cursor.fetchone()
+            if not row:
+                raise HTTPException(status_code=404, detail="Account not found")
+            return _row_to_dict(row)
+
+async def _update_stats(account_id: str, success: bool) -> None:
+    async with _conn() as conn:
         if success:
-            conn.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
+            await conn.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
                         (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
         else:
-            row = conn.execute("SELECT error_count FROM accounts WHERE id=?", (account_id,)).fetchone()
-            if row:
-                new_count = (row[0] or 0) + 1
-                if new_count >= MAX_ERROR_COUNT:
-                    conn.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
-                               (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
-                else:
-                    conn.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
-                               (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
-        conn.commit()
+            async with conn.execute("SELECT error_count FROM accounts WHERE id=?", (account_id,)) as cursor:
+                row = await cursor.fetchone()
+                if row:
+                    new_count = (row[0] or 0) + 1
+                    if new_count >= MAX_ERROR_COUNT:
+                        await conn.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
+                                   (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
+                    else:
+                        await conn.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
+                                   (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
+        await conn.commit()
 
 # ------------------------------------------------------------------------------
 # Dependencies
 # ------------------------------------------------------------------------------
 
-def require_account(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
+async def require_account(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
     bearer = _extract_bearer(authorization)
-    return resolve_account_for_key(bearer)
+    return await resolve_account_for_key(bearer)
 
 # ------------------------------------------------------------------------------
 # OpenAI-compatible Chat endpoint
@@ -378,7 +397,7 @@ def _sse_format(obj: Dict[str, Any]) -> str:
     return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
 
 @app.post("/v1/chat/completions")
-def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
+async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
     """
     OpenAI-compatible chat endpoint.
     - stream default False
@@ -388,22 +407,22 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
     model = req.model
     do_stream = bool(req.stream)
 
-    def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[Generator[str, None, None]], Any]:
+    async def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], Any]:
         access = account.get("accessToken")
         if not access:
-            refreshed = refresh_access_token_in_db(account["id"])
+            refreshed = await refresh_access_token_in_db(account["id"])
             access = refreshed.get("accessToken")
             if not access:
                 raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
-        return send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream)
+        return await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream)
 
     if not do_stream:
         try:
-            text, _, tracker = _send_upstream(stream=False)
-            _update_stats(account["id"], bool(text))
+            text, _, tracker = await _send_upstream(stream=False)
+            await _update_stats(account["id"], bool(text))
             return JSONResponse(content=_openai_non_streaming_response(text or "", model))
         except Exception as e:
-            _update_stats(account["id"], False)
+            await _update_stats(account["id"], False)
             raise
     else:
         created = int(time.time())
@@ -411,14 +430,14 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
         model_used = model or "unknown"
         
         try:
-            _, it, tracker = _send_upstream(stream=True)
+            _, it, tracker = await _send_upstream(stream=True)
             assert it is not None
-            first_piece = next(it, None)
+            first_piece = await it.__anext__()
             if not first_piece:
-                _update_stats(account["id"], False)
+                await _update_stats(account["id"], False)
                 raise HTTPException(status_code=502, detail="No content from upstream")
             
-            def event_gen() -> Generator[str, None, None]:
+            async def event_gen() -> AsyncGenerator[str, None]:
                 try:
                     yield _sse_format({
                         "id": stream_id,
@@ -434,7 +453,7 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
                         "model": model_used,
                         "choices": [{"index": 0, "delta": {"content": first_piece}, "finish_reason": None}],
                     })
-                    for piece in it:
+                    async for piece in it:
                         if piece:
                             yield _sse_format({
                                 "id": stream_id,
@@ -451,14 +470,14 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
                         "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
                     })
                     yield "data: [DONE]\n\n"
-                    _update_stats(account["id"], True)
+                    await _update_stats(account["id"], True)
                 except Exception:
-                    _update_stats(account["id"], tracker.has_content if tracker else False)
+                    await _update_stats(account["id"], tracker.has_content if tracker else False)
                     raise
             
             return StreamingResponse(event_gen(), media_type="text/event-stream")
         except Exception as e:
-            _update_stats(account["id"], False)
+            await _update_stats(account["id"], False)
             raise
 
 # ------------------------------------------------------------------------------
@@ -486,7 +505,7 @@ class AuthStartBody(BaseModel):
     label: Optional[str] = None
     enabled: Optional[bool] = True
 
-def _create_account_from_tokens(
+async def _create_account_from_tokens(
     client_id: str,
     client_secret: str,
     access_token: str,
@@ -496,8 +515,9 @@ def _create_account_from_tokens(
 ) -> Dict[str, Any]:
     now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
     acc_id = str(uuid.uuid4())
-    with _conn() as conn:
-        conn.execute(
+    async with _conn() as conn:
+        conn.row_factory = aiosqlite.Row
+        await conn.execute(
             """
             INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
@@ -517,20 +537,21 @@ def _create_account_from_tokens(
                 1 if enabled else 0,
             ),
         )
-        conn.commit()
-        row = conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)).fetchone()
-        return _row_to_dict(row)
+        await conn.commit()
+        async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
+            row = await cursor.fetchone()
+            return _row_to_dict(row)
 
 @app.post("/v2/auth/start")
-def auth_start(body: AuthStartBody):
+async def auth_start(body: AuthStartBody):
     """
     Start device authorization and return verification URL for user login.
     Session lifetime capped at 5 minutes on claim.
     """
     try:
-        cid, csec = register_client_min()
-        dev = device_authorize(cid, csec)
-    except requests.RequestException as e:
+        cid, csec = await register_client_min()
+        dev = await device_authorize(cid, csec)
+    except httpx.HTTPError as e:
         raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
 
     auth_id = str(uuid.uuid4())
@@ -559,7 +580,7 @@ def auth_start(body: AuthStartBody):
     }
 
 @app.get("/v2/auth/status/{auth_id}")
-def auth_status(auth_id: str):
+async def auth_status(auth_id: str):
     sess = AUTH_SESSIONS.get(auth_id)
     if not sess:
         raise HTTPException(status_code=404, detail="Auth session not found")
@@ -574,7 +595,7 @@ def auth_status(auth_id: str):
     }
 
 @app.post("/v2/auth/claim/{auth_id}")
-def auth_claim(auth_id: str):
+async def auth_claim(auth_id: str):
     """
     Block up to 5 minutes to exchange the device code for tokens after user completed login.
     On success, creates an enabled account and returns it.
@@ -589,7 +610,7 @@ def auth_claim(auth_id: str):
             "error": sess.get("error"),
         }
     try:
-        toks = poll_token_device_code(
+        toks = await poll_token_device_code(
             sess["clientId"],
             sess["clientSecret"],
             sess["deviceCode"],
@@ -602,7 +623,7 @@ def auth_claim(auth_id: str):
         if not access_token:
             raise HTTPException(status_code=502, detail="No accessToken returned from OIDC")
 
-        acc = _create_account_from_tokens(
+        acc = await _create_account_from_tokens(
             sess["clientId"],
             sess["clientSecret"],
             access_token,
@@ -619,7 +640,7 @@ def auth_claim(auth_id: str):
     except TimeoutError:
         sess["status"] = "timeout"
         raise HTTPException(status_code=408, detail="Authorization timeout (5 minutes)")
-    except requests.RequestException as e:
+    except httpx.HTTPError as e:
         sess["status"] = "error"
         sess["error"] = str(e)
         raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
@@ -629,13 +650,14 @@ def auth_claim(auth_id: str):
 # ------------------------------------------------------------------------------
 
 @app.post("/v2/accounts")
-def create_account(body: AccountCreate):
+async def create_account(body: AccountCreate):
     now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
     acc_id = str(uuid.uuid4())
     other_str = json.dumps(body.other, ensure_ascii=False) if body.other is not None else None
     enabled_val = 1 if (body.enabled is None or body.enabled) else 0
-    with _conn() as conn:
-        conn.execute(
+    async with _conn() as conn:
+        conn.row_factory = aiosqlite.Row
+        await conn.execute(
             """
             INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
@@ -655,31 +677,34 @@ def create_account(body: AccountCreate):
                 enabled_val,
             ),
         )
-        conn.commit()
-        row = conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)).fetchone()
-        return _row_to_dict(row)
+        await conn.commit()
+        async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
+            row = await cursor.fetchone()
+            return _row_to_dict(row)
 
 @app.get("/v2/accounts")
-def list_accounts():
-    with _conn() as conn:
-        rows = conn.execute("SELECT * FROM accounts ORDER BY created_at DESC").fetchall()
-        return [_row_to_dict(r) for r in rows]
+async def list_accounts():
+    async with _conn() as conn:
+        conn.row_factory = aiosqlite.Row
+        async with conn.execute("SELECT * FROM accounts ORDER BY created_at DESC") as cursor:
+            rows = await cursor.fetchall()
+            return [_row_to_dict(r) for r in rows]
 
 @app.get("/v2/accounts/{account_id}")
-def get_account_detail(account_id: str):
-    return get_account(account_id)
+async def get_account_detail(account_id: str):
+    return await get_account(account_id)
 
 @app.delete("/v2/accounts/{account_id}")
-def delete_account(account_id: str):
-    with _conn() as conn:
-        cur = conn.execute("DELETE FROM accounts WHERE id=?", (account_id,))
-        conn.commit()
+async def delete_account(account_id: str):
+    async with _conn() as conn:
+        cur = await conn.execute("DELETE FROM accounts WHERE id=?", (account_id,))
+        await conn.commit()
         if cur.rowcount == 0:
             raise HTTPException(status_code=404, detail="Account not found")
         return {"deleted": account_id}
 
 @app.patch("/v2/accounts/{account_id}")
-def update_account(account_id: str, body: AccountUpdate):
+async def update_account(account_id: str, body: AccountUpdate):
     now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
     fields = []
     values: List[Any] = []
@@ -700,22 +725,24 @@ def update_account(account_id: str, body: AccountUpdate):
         fields.append("enabled=?"); values.append(1 if body.enabled else 0)
 
     if not fields:
-        return get_account(account_id)
+        return await get_account(account_id)
 
     fields.append("updated_at=?"); values.append(now)
     values.append(account_id)
 
-    with _conn() as conn:
-        cur = conn.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", values)
-        conn.commit()
+    async with _conn() as conn:
+        conn.row_factory = aiosqlite.Row
+        cur = await conn.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", values)
+        await conn.commit()
         if cur.rowcount == 0:
             raise HTTPException(status_code=404, detail="Account not found")
-        row = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
-        return _row_to_dict(row)
+        async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
+            row = await cursor.fetchone()
+            return _row_to_dict(row)
 
 @app.post("/v2/accounts/{account_id}/refresh")
-def manual_refresh(account_id: str):
-    return refresh_access_token_in_db(account_id)
+async def manual_refresh(account_id: str):
+    return await refresh_access_token_in_db(account_id)
 
 # ------------------------------------------------------------------------------
 # Simple Frontend (minimal dev test page; full UI in v2/frontend/index.html)
@@ -735,5 +762,15 @@ def index():
 # ------------------------------------------------------------------------------
 
 @app.get("/healthz")
-def health():
-    return {"status": "ok"}
+async def health():
+    return {"status": "ok"}
+
+# ------------------------------------------------------------------------------
+# Startup / Shutdown Events
+# ------------------------------------------------------------------------------
+
[email protected]_event("startup")
+async def startup_event():
+    """Initialize database and start background tasks on startup."""
+    await _ensure_db()
+    asyncio.create_task(_refresh_stale_tokens())

+ 62 - 30
auth_flow.py

@@ -2,9 +2,10 @@ import json
 import time
 import uuid
 import os
+import asyncio
 from typing import Dict, Tuple, Optional
 
-import requests
+import httpx
 
 def _get_proxies() -> Optional[Dict[str, str]]:
     proxy = os.getenv("HTTP_PROXY", "").strip()
@@ -34,15 +35,15 @@ def make_headers() -> Dict[str, str]:
     }
 
 
-def post_json(url: str, payload: Dict) -> requests.Response:
+async def post_json(client: httpx.AsyncClient, url: str, payload: Dict) -> httpx.Response:
     # Keep JSON order and mimic body closely to v1
     payload_str = json.dumps(payload, ensure_ascii=False)
     headers = make_headers()
-    resp = requests.post(url, headers=headers, data=payload_str, timeout=(15, 60), proxies=_get_proxies())
+    resp = await client.post(url, headers=headers, content=payload_str, timeout=httpx.Timeout(15.0, read=60.0))
     return resp
 
 
-def register_client_min() -> Tuple[str, str]:
+async def register_client_min() -> Tuple[str, str]:
     """
     Register an OIDC client (minimal) and return (clientId, clientSecret).
     """
@@ -55,13 +56,23 @@ def register_client_min() -> Tuple[str, str]:
             "codewhisperer:conversations",
         ],
     }
-    r = post_json(REGISTER_URL, payload)
-    r.raise_for_status()
-    data = r.json()
-    return data["clientId"], data["clientSecret"]
+    proxies = _get_proxies()
+    mounts = None
+    if proxies:
+        proxy_url = proxies.get("https") or proxies.get("http")
+        if proxy_url:
+            mounts = {
+                "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+            }
+    async with httpx.AsyncClient(mounts=mounts) as client:
+        r = await post_json(client, REGISTER_URL, payload)
+        r.raise_for_status()
+        data = r.json()
+        return data["clientId"], data["clientSecret"]
 
 
-def device_authorize(client_id: str, client_secret: str) -> Dict:
+async def device_authorize(client_id: str, client_secret: str) -> Dict:
     """
     Start device authorization. Returns dict that includes:
     - deviceCode
@@ -75,12 +86,22 @@ def device_authorize(client_id: str, client_secret: str) -> Dict:
         "clientSecret": client_secret,
         "startUrl": START_URL,
     }
-    r = post_json(DEVICE_AUTH_URL, payload)
-    r.raise_for_status()
-    return r.json()
+    proxies = _get_proxies()
+    mounts = None
+    if proxies:
+        proxy_url = proxies.get("https") or proxies.get("http")
+        if proxy_url:
+            mounts = {
+                "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+            }
+    async with httpx.AsyncClient(mounts=mounts) as client:
+        r = await post_json(client, DEVICE_AUTH_URL, payload)
+        r.raise_for_status()
+        return r.json()
 
 
-def poll_token_device_code(
+async def poll_token_device_code(
     client_id: str,
     client_secret: str,
     device_code: str,
@@ -94,7 +115,7 @@ def poll_token_device_code(
     Returns token dict with at least 'accessToken' and optionally 'refreshToken'.
     Raises:
       - TimeoutError on timeout
-      - requests.HTTPError for non-recoverable HTTP errors
+      - httpx.HTTPError for non-recoverable HTTP errors
     """
     payload = {
         "clientId": client_id,
@@ -111,22 +132,33 @@ def poll_token_device_code(
     # Ensure interval sane
     poll_interval = max(1, int(interval or 1))
 
-    while time.time() < deadline:
-        r = post_json(TOKEN_URL, payload)
-        if r.status_code == 200:
-            return r.json()
-        if r.status_code == 400:
-            # Expect AuthorizationPendingException early on
-            try:
-                err = r.json()
-            except Exception:
-                err = {"error": r.text}
-            if str(err.get("error")) == "authorization_pending":
-                time.sleep(poll_interval)
-                continue
-            # Other 4xx are errors
+    proxies = _get_proxies()
+    mounts = None
+    if proxies:
+        proxy_url = proxies.get("https") or proxies.get("http")
+        if proxy_url:
+            mounts = {
+                "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+            }
+    
+    async with httpx.AsyncClient(mounts=mounts) as client:
+        while time.time() < deadline:
+            r = await post_json(client, TOKEN_URL, payload)
+            if r.status_code == 200:
+                return r.json()
+            if r.status_code == 400:
+                # Expect AuthorizationPendingException early on
+                try:
+                    err = r.json()
+                except Exception:
+                    err = {"error": r.text}
+                if str(err.get("error")) == "authorization_pending":
+                    await asyncio.sleep(poll_interval)
+                    continue
+                # Other 4xx are errors
+                r.raise_for_status()
+            # Non-200, non-400
             r.raise_for_status()
-        # Non-200, non-400
-        r.raise_for_status()
 
     raise TimeoutError("Device authorization expired before approval (timeout reached)")

+ 55 - 40
replicate.py

@@ -2,16 +2,16 @@ import json
 import uuid
 import os
 from pathlib import Path
-from typing import Dict, Optional, Tuple, Iterator, List, Generator, Any
+from typing import Dict, Optional, Tuple, Iterator, List, AsyncGenerator, Any
 import struct
-import requests
+import httpx
 
 class StreamTracker:
     def __init__(self):
         self.has_content = False
     
-    def track(self, gen: Generator[str, None, None]) -> Generator[str, None, None]:
-        for item in gen:
+    async def track(self, gen: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
+        async for item in gen:
             if item:
                 self.has_content = True
             yield item
@@ -192,7 +192,7 @@ def inject_model(body_json: Dict[str, Any], model: Optional[str]) -> None:
     except Exception:
         pass
 
-def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model: Optional[str] = None, stream: bool = False, timeout: Tuple[int,int] = (15,300)) -> Tuple[Optional[str], Optional[Generator[str, None, None]], bool]:
+async def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model: Optional[str] = None, stream: bool = False, timeout: Tuple[int,int] = (15,300)) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker]:
     url, headers_from_log, body_json = load_template()
     headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
     try:
@@ -204,39 +204,54 @@ def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model:
     inject_model(body_json, model)
     payload_str = json.dumps(body_json, ensure_ascii=False)
     headers = _merge_headers(headers_from_log, access_token)
-    session = requests.Session()
+    
+    # Build mounts with proxy if available
     proxies = _get_proxies()
-    resp = session.post(url, headers=headers, data=payload_str, stream=True, timeout=timeout, proxies=proxies)
-    if resp.status_code >= 400:
-        try:
-            err = resp.text
-        except Exception:
-            err = f"HTTP {resp.status_code}"
-        raise requests.HTTPError(f"Upstream error {resp.status_code}: {err}", response=resp)
-    parser = AwsEventStreamParser()
-    tracker = StreamTracker()
-    def _iter_text() -> Generator[str, None, None]:
-        for chunk in resp.iter_content(chunk_size=None):
-            if not chunk:
-                continue
-            events = parser.feed(chunk)
-            for _ev_headers, payload in events:
-                parsed = _try_decode_event_payload(payload)
-                if parsed is not None:
-                    text = _extract_text_from_event(parsed)
-                    if isinstance(text, str) and text:
-                        yield text
-                else:
-                    try:
-                        txt = payload.decode("utf-8", errors="ignore")
-                        if txt:
-                            yield txt
-                    except Exception:
-                        pass
-    if stream:
-        return None, tracker.track(_iter_text()), tracker
-    else:
-        buf = []
-        for t in tracker.track(_iter_text()):
-            buf.append(t)
-        return "".join(buf), None, tracker
+    mounts = None
+    if proxies:
+        proxy_url = proxies.get("https") or proxies.get("http")
+        if proxy_url:
+            mounts = {
+                "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+                "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
+            }
+    
+    async with httpx.AsyncClient(mounts=mounts, timeout=httpx.Timeout(timeout[0], read=timeout[1])) as client:
+        async with client.stream("POST", url, headers=headers, content=payload_str) as resp:
+            if resp.status_code >= 400:
+                try:
+                    err = await resp.aread()
+                    err = err.decode("utf-8", errors="ignore")
+                except Exception:
+                    err = f"HTTP {resp.status_code}"
+                raise httpx.HTTPError(f"Upstream error {resp.status_code}: {err}")
+            
+            parser = AwsEventStreamParser()
+            tracker = StreamTracker()
+            
+            async def _iter_text() -> AsyncGenerator[str, None]:
+                async for chunk in resp.aiter_bytes():
+                    if not chunk:
+                        continue
+                    events = parser.feed(chunk)
+                    for _ev_headers, payload in events:
+                        parsed = _try_decode_event_payload(payload)
+                        if parsed is not None:
+                            text = _extract_text_from_event(parsed)
+                            if isinstance(text, str) and text:
+                                yield text
+                        else:
+                            try:
+                                txt = payload.decode("utf-8", errors="ignore")
+                                if txt:
+                                    yield txt
+                            except Exception:
+                                pass
+            
+            if stream:
+                return None, tracker.track(_iter_text()), tracker
+            else:
+                buf = []
+                async for t in tracker.track(_iter_text()):
+                    buf.append(t)
+                return "".join(buf), None, tracker

+ 2 - 1
requirements.txt

@@ -2,4 +2,5 @@ fastapi==0.115.5
 uvicorn[standard]==0.32.0
 pydantic==2.9.2
 python-dotenv==1.0.1
-requests==2.32.3
+httpx==0.28.1
+aiosqlite==0.20.0