|
|
@@ -3,19 +3,19 @@ import json
|
|
|
import traceback
|
|
|
import uuid
|
|
|
import time
|
|
|
-import sqlite3
|
|
|
+import asyncio
|
|
|
import importlib.util
|
|
|
import random
|
|
|
-import threading
|
|
|
from pathlib import Path
|
|
|
-from typing import Dict, Optional, List, Any, Generator, Tuple
|
|
|
+from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
|
|
|
|
|
|
from fastapi import FastAPI, Depends, HTTPException, Header
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, FileResponse
|
|
|
from pydantic import BaseModel
|
|
|
from dotenv import load_dotenv
|
|
|
-import requests
|
|
|
+import httpx
|
|
|
+import aiosqlite
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Bootstrap
|
|
|
@@ -55,10 +55,10 @@ send_chat_request = _replicate.send_chat_request
|
|
|
# SQLite helpers
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
-def _ensure_db():
|
|
|
+async def _ensure_db():
|
|
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
|
- with sqlite3.connect(DB_PATH) as conn:
|
|
|
- conn.execute(
|
|
|
+ async with aiosqlite.connect(DB_PATH) as conn:
|
|
|
+ await conn.execute(
|
|
|
"""
|
|
|
CREATE TABLE IF NOT EXISTS accounts (
|
|
|
id TEXT PRIMARY KEY,
|
|
|
@@ -77,23 +77,24 @@ def _ensure_db():
|
|
|
)
|
|
|
# add columns if missing
|
|
|
try:
|
|
|
- cols = [row[1] for row in conn.execute("PRAGMA table_info(accounts)").fetchall()]
|
|
|
- if "enabled" not in cols:
|
|
|
- conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
|
|
|
- if "error_count" not in cols:
|
|
|
- conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
|
|
|
- if "success_count" not in cols:
|
|
|
- conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
|
|
|
+ async with conn.execute("PRAGMA table_info(accounts)") as cursor:
|
|
|
+ rows = await cursor.fetchall()
|
|
|
+ cols = [row[1] for row in rows]
|
|
|
+ if "enabled" not in cols:
|
|
|
+ await conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
|
|
|
+ if "error_count" not in cols:
|
|
|
+ await conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
|
|
|
+ if "success_count" not in cols:
|
|
|
+ await conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
|
|
|
except Exception:
|
|
|
pass
|
|
|
- conn.commit()
|
|
|
+ await conn.commit()
|
|
|
|
|
|
-def _conn() -> sqlite3.Connection:
|
|
|
- conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
|
|
- conn.row_factory = sqlite3.Row
|
|
|
- return conn
|
|
|
+def _conn() -> aiosqlite.Connection:
|
|
|
+ """Create a new database connection. Must be used with async with."""
|
|
|
+ return aiosqlite.connect(DB_PATH)
|
|
|
|
|
|
-def _row_to_dict(r: sqlite3.Row) -> Dict[str, Any]:
|
|
|
+def _row_to_dict(r: aiosqlite.Row) -> Dict[str, Any]:
|
|
|
d = dict(r)
|
|
|
if d.get("other"):
|
|
|
try:
|
|
|
@@ -108,46 +109,46 @@ def _row_to_dict(r: sqlite3.Row) -> Dict[str, Any]:
|
|
|
d["enabled"] = bool(d["enabled"])
|
|
|
return d
|
|
|
|
|
|
-_ensure_db()
|
|
|
+# _ensure_db() will be called in startup event
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Background token refresh thread
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
-def _refresh_stale_tokens():
|
|
|
+async def _refresh_stale_tokens():
|
|
|
while True:
|
|
|
try:
|
|
|
- time.sleep(300) # 5 minutes
|
|
|
+ await asyncio.sleep(300) # 5 minutes
|
|
|
now = time.time()
|
|
|
- with _conn() as conn:
|
|
|
- rows = conn.execute("SELECT id, last_refresh_time FROM accounts WHERE enabled=1").fetchall()
|
|
|
- for row in rows:
|
|
|
- acc_id, last_refresh = row[0], row[1]
|
|
|
- should_refresh = False
|
|
|
- if not last_refresh or last_refresh == "never":
|
|
|
- should_refresh = True
|
|
|
- else:
|
|
|
- try:
|
|
|
- last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
|
|
|
- if now - last_time > 1500: # 25 minutes
|
|
|
- should_refresh = True
|
|
|
- except Exception:
|
|
|
- # Malformed or unparsable timestamp; force refresh
|
|
|
+ async with _conn() as conn:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ async with conn.execute("SELECT id, last_refresh_time FROM accounts WHERE enabled=1") as cursor:
|
|
|
+ rows = await cursor.fetchall()
|
|
|
+ for row in rows:
|
|
|
+ acc_id, last_refresh = row[0], row[1]
|
|
|
+ should_refresh = False
|
|
|
+ if not last_refresh or last_refresh == "never":
|
|
|
should_refresh = True
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
|
|
|
+ if now - last_time > 1500: # 25 minutes
|
|
|
+ should_refresh = True
|
|
|
+ except Exception:
|
|
|
+ # Malformed or unparsable timestamp; force refresh
|
|
|
+ should_refresh = True
|
|
|
|
|
|
- if should_refresh:
|
|
|
- try:
|
|
|
- refresh_access_token_in_db(acc_id)
|
|
|
- except Exception:
|
|
|
- traceback.print_exc()
|
|
|
- # Ignore per-account refresh failure; timestamp/status are recorded inside
|
|
|
- pass
|
|
|
+ if should_refresh:
|
|
|
+ try:
|
|
|
+ await refresh_access_token_in_db(acc_id)
|
|
|
+ except Exception:
|
|
|
+ traceback.print_exc()
|
|
|
+ # Ignore per-account refresh failure; timestamp/status are recorded inside
|
|
|
+ pass
|
|
|
except Exception:
|
|
|
traceback.print_exc()
|
|
|
pass
|
|
|
|
|
|
-threading.Thread(target=_refresh_stale_tokens, daemon=True).start()
|
|
|
-
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Env and API Key authorization (keys are independent of AWS accounts)
|
|
|
# ------------------------------------------------------------------------------
|
|
|
@@ -174,11 +175,13 @@ def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
|
|
|
return token_header.split(" ", 1)[1].strip()
|
|
|
return token_header.strip()
|
|
|
|
|
|
-def _list_enabled_accounts(conn: sqlite3.Connection) -> List[Dict[str, Any]]:
|
|
|
- rows = conn.execute("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC").fetchall()
|
|
|
- return [_row_to_dict(r) for r in rows]
|
|
|
+async def _list_enabled_accounts(conn: aiosqlite.Connection) -> List[Dict[str, Any]]:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ async with conn.execute("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC") as cursor:
|
|
|
+ rows = await cursor.fetchall()
|
|
|
+ return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
-def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
|
|
|
+async def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Authorize request by OPENAI_KEYS (if configured), then select an AWS account.
|
|
|
Selection strategy: random among all enabled accounts. Authorization key does NOT map to any account.
|
|
|
@@ -189,8 +192,8 @@ def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
|
|
|
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
|
|
|
|
|
# Selection: random among enabled accounts
|
|
|
- with _conn() as conn:
|
|
|
- candidates = _list_enabled_accounts(conn)
|
|
|
+ async with _conn() as conn:
|
|
|
+ candidates = await _list_enabled_accounts(conn)
|
|
|
if not candidates:
|
|
|
raise HTTPException(status_code=401, detail="No enabled account available")
|
|
|
return random.choice(candidates)
|
|
|
@@ -248,12 +251,14 @@ def _oidc_headers() -> Dict[str, str]:
|
|
|
"amz-sdk-invocation-id": str(uuid.uuid4()),
|
|
|
}
|
|
|
|
|
|
-def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
- with _conn() as conn:
|
|
|
- row = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
|
|
|
- if not row:
|
|
|
- raise HTTPException(status_code=404, detail="Account not found")
|
|
|
- acc = _row_to_dict(row)
|
|
|
+async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
+ async with _conn() as conn:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
+ row = await cursor.fetchone()
|
|
|
+ if not row:
|
|
|
+ raise HTTPException(status_code=404, detail="Account not found")
|
|
|
+ acc = _row_to_dict(row)
|
|
|
|
|
|
if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
|
|
|
raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
|
|
|
@@ -266,17 +271,27 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
- r = requests.post(TOKEN_URL, headers=_oidc_headers(), json=payload, timeout=(15, 60), proxies=_get_proxies())
|
|
|
- r.raise_for_status()
|
|
|
- data = r.json()
|
|
|
- new_access = data.get("accessToken")
|
|
|
- new_refresh = data.get("refreshToken", acc.get("refreshToken"))
|
|
|
- now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
- status = "success"
|
|
|
- except requests.RequestException as e:
|
|
|
+ proxies = _get_proxies()
|
|
|
+ mounts = None
|
|
|
+ if proxies:
|
|
|
+ proxy_url = proxies.get("https") or proxies.get("http")
|
|
|
+ if proxy_url:
|
|
|
+ mounts = {
|
|
|
+ "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
+ "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
|
|
|
+ }
|
|
|
+ async with httpx.AsyncClient(mounts=mounts, timeout=60.0) as client:
|
|
|
+ r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
|
|
+ r.raise_for_status()
|
|
|
+ data = r.json()
|
|
|
+ new_access = data.get("accessToken")
|
|
|
+ new_refresh = data.get("refreshToken", acc.get("refreshToken"))
|
|
|
+ now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
+ status = "success"
|
|
|
+ except httpx.HTTPError as e:
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
status = "failed"
|
|
|
- conn.execute(
|
|
|
+ await conn.execute(
|
|
|
"""
|
|
|
UPDATE accounts
|
|
|
SET last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
@@ -284,13 +299,13 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
""",
|
|
|
(now, status, now, account_id),
|
|
|
)
|
|
|
- conn.commit()
|
|
|
+ await conn.commit()
|
|
|
raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
|
|
|
except Exception as e:
|
|
|
# Ensure last_refresh_time is recorded even on unexpected errors
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
status = "failed"
|
|
|
- conn.execute(
|
|
|
+ await conn.execute(
|
|
|
"""
|
|
|
UPDATE accounts
|
|
|
SET last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
@@ -298,10 +313,10 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
""",
|
|
|
(now, status, now, account_id),
|
|
|
)
|
|
|
- conn.commit()
|
|
|
+ await conn.commit()
|
|
|
raise
|
|
|
|
|
|
- conn.execute(
|
|
|
+ await conn.execute(
|
|
|
"""
|
|
|
UPDATE accounts
|
|
|
SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
@@ -309,42 +324,46 @@ def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
""",
|
|
|
(new_access, new_refresh, now, status, now, account_id),
|
|
|
)
|
|
|
- conn.commit()
|
|
|
-
|
|
|
- row2 = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
|
|
|
- return _row_to_dict(row2)
|
|
|
-
|
|
|
-def get_account(account_id: str) -> Dict[str, Any]:
|
|
|
- with _conn() as conn:
|
|
|
- row = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
|
|
|
- if not row:
|
|
|
- raise HTTPException(status_code=404, detail="Account not found")
|
|
|
- return _row_to_dict(row)
|
|
|
-
|
|
|
-def _update_stats(account_id: str, success: bool) -> None:
|
|
|
- with _conn() as conn:
|
|
|
+ await conn.commit()
|
|
|
+
|
|
|
+ async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
+ row2 = await cursor.fetchone()
|
|
|
+ return _row_to_dict(row2)
|
|
|
+
|
|
|
+async def get_account(account_id: str) -> Dict[str, Any]:
|
|
|
+ async with _conn() as conn:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
+ row = await cursor.fetchone()
|
|
|
+ if not row:
|
|
|
+ raise HTTPException(status_code=404, detail="Account not found")
|
|
|
+ return _row_to_dict(row)
|
|
|
+
|
|
|
+async def _update_stats(account_id: str, success: bool) -> None:
|
|
|
+ async with _conn() as conn:
|
|
|
if success:
|
|
|
- conn.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
|
|
|
+ await conn.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
|
|
|
(time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
else:
|
|
|
- row = conn.execute("SELECT error_count FROM accounts WHERE id=?", (account_id,)).fetchone()
|
|
|
- if row:
|
|
|
- new_count = (row[0] or 0) + 1
|
|
|
- if new_count >= MAX_ERROR_COUNT:
|
|
|
- conn.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
|
|
|
- (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
- else:
|
|
|
- conn.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
|
|
|
- (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
- conn.commit()
|
|
|
+ async with conn.execute("SELECT error_count FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
+ row = await cursor.fetchone()
|
|
|
+ if row:
|
|
|
+ new_count = (row[0] or 0) + 1
|
|
|
+ if new_count >= MAX_ERROR_COUNT:
|
|
|
+ await conn.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
|
|
|
+ (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
+ else:
|
|
|
+ await conn.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
|
|
|
+ (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
+ await conn.commit()
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Dependencies
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
-def require_account(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
|
|
|
+async def require_account(authorization: Optional[str] = Header(default=None)) -> Dict[str, Any]:
|
|
|
bearer = _extract_bearer(authorization)
|
|
|
- return resolve_account_for_key(bearer)
|
|
|
+ return await resolve_account_for_key(bearer)
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# OpenAI-compatible Chat endpoint
|
|
|
@@ -378,7 +397,7 @@ def _sse_format(obj: Dict[str, Any]) -> str:
|
|
|
return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
|
-def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
|
|
|
+async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
|
|
|
"""
|
|
|
OpenAI-compatible chat endpoint.
|
|
|
- stream default False
|
|
|
@@ -388,22 +407,22 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
|
|
|
model = req.model
|
|
|
do_stream = bool(req.stream)
|
|
|
|
|
|
- def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[Generator[str, None, None]], Any]:
|
|
|
+ async def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], Any]:
|
|
|
access = account.get("accessToken")
|
|
|
if not access:
|
|
|
- refreshed = refresh_access_token_in_db(account["id"])
|
|
|
+ refreshed = await refresh_access_token_in_db(account["id"])
|
|
|
access = refreshed.get("accessToken")
|
|
|
if not access:
|
|
|
raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
|
|
|
- return send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream)
|
|
|
+ return await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream)
|
|
|
|
|
|
if not do_stream:
|
|
|
try:
|
|
|
- text, _, tracker = _send_upstream(stream=False)
|
|
|
- _update_stats(account["id"], bool(text))
|
|
|
+ text, _, tracker = await _send_upstream(stream=False)
|
|
|
+ await _update_stats(account["id"], bool(text))
|
|
|
return JSONResponse(content=_openai_non_streaming_response(text or "", model))
|
|
|
except Exception as e:
|
|
|
- _update_stats(account["id"], False)
|
|
|
+ await _update_stats(account["id"], False)
|
|
|
raise
|
|
|
else:
|
|
|
created = int(time.time())
|
|
|
@@ -411,14 +430,14 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
|
|
|
model_used = model or "unknown"
|
|
|
|
|
|
try:
|
|
|
- _, it, tracker = _send_upstream(stream=True)
|
|
|
+ _, it, tracker = await _send_upstream(stream=True)
|
|
|
assert it is not None
|
|
|
- first_piece = next(it, None)
|
|
|
+ first_piece = await it.__anext__()
|
|
|
if not first_piece:
|
|
|
- _update_stats(account["id"], False)
|
|
|
+ await _update_stats(account["id"], False)
|
|
|
raise HTTPException(status_code=502, detail="No content from upstream")
|
|
|
|
|
|
- def event_gen() -> Generator[str, None, None]:
|
|
|
+ async def event_gen() -> AsyncGenerator[str, None]:
|
|
|
try:
|
|
|
yield _sse_format({
|
|
|
"id": stream_id,
|
|
|
@@ -434,7 +453,7 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
|
|
|
"model": model_used,
|
|
|
"choices": [{"index": 0, "delta": {"content": first_piece}, "finish_reason": None}],
|
|
|
})
|
|
|
- for piece in it:
|
|
|
+ async for piece in it:
|
|
|
if piece:
|
|
|
yield _sse_format({
|
|
|
"id": stream_id,
|
|
|
@@ -451,14 +470,14 @@ def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depen
|
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
|
})
|
|
|
yield "data: [DONE]\n\n"
|
|
|
- _update_stats(account["id"], True)
|
|
|
+ await _update_stats(account["id"], True)
|
|
|
except Exception:
|
|
|
- _update_stats(account["id"], tracker.has_content if tracker else False)
|
|
|
+ await _update_stats(account["id"], tracker.has_content if tracker else False)
|
|
|
raise
|
|
|
|
|
|
return StreamingResponse(event_gen(), media_type="text/event-stream")
|
|
|
except Exception as e:
|
|
|
- _update_stats(account["id"], False)
|
|
|
+ await _update_stats(account["id"], False)
|
|
|
raise
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
@@ -486,7 +505,7 @@ class AuthStartBody(BaseModel):
|
|
|
label: Optional[str] = None
|
|
|
enabled: Optional[bool] = True
|
|
|
|
|
|
-def _create_account_from_tokens(
|
|
|
+async def _create_account_from_tokens(
|
|
|
client_id: str,
|
|
|
client_secret: str,
|
|
|
access_token: str,
|
|
|
@@ -496,8 +515,9 @@ def _create_account_from_tokens(
|
|
|
) -> Dict[str, Any]:
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
acc_id = str(uuid.uuid4())
|
|
|
- with _conn() as conn:
|
|
|
- conn.execute(
|
|
|
+ async with _conn() as conn:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ await conn.execute(
|
|
|
"""
|
|
|
INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
@@ -517,20 +537,21 @@ def _create_account_from_tokens(
|
|
|
1 if enabled else 0,
|
|
|
),
|
|
|
)
|
|
|
- conn.commit()
|
|
|
- row = conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)).fetchone()
|
|
|
- return _row_to_dict(row)
|
|
|
+ await conn.commit()
|
|
|
+ async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
|
|
|
+ row = await cursor.fetchone()
|
|
|
+ return _row_to_dict(row)
|
|
|
|
|
|
@app.post("/v2/auth/start")
|
|
|
-def auth_start(body: AuthStartBody):
|
|
|
+async def auth_start(body: AuthStartBody):
|
|
|
"""
|
|
|
Start device authorization and return verification URL for user login.
|
|
|
Session lifetime capped at 5 minutes on claim.
|
|
|
"""
|
|
|
try:
|
|
|
- cid, csec = register_client_min()
|
|
|
- dev = device_authorize(cid, csec)
|
|
|
- except requests.RequestException as e:
|
|
|
+ cid, csec = await register_client_min()
|
|
|
+ dev = await device_authorize(cid, csec)
|
|
|
+ except httpx.HTTPError as e:
|
|
|
raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
|
|
|
|
|
|
auth_id = str(uuid.uuid4())
|
|
|
@@ -559,7 +580,7 @@ def auth_start(body: AuthStartBody):
|
|
|
}
|
|
|
|
|
|
@app.get("/v2/auth/status/{auth_id}")
|
|
|
-def auth_status(auth_id: str):
|
|
|
+async def auth_status(auth_id: str):
|
|
|
sess = AUTH_SESSIONS.get(auth_id)
|
|
|
if not sess:
|
|
|
raise HTTPException(status_code=404, detail="Auth session not found")
|
|
|
@@ -574,7 +595,7 @@ def auth_status(auth_id: str):
|
|
|
}
|
|
|
|
|
|
@app.post("/v2/auth/claim/{auth_id}")
|
|
|
-def auth_claim(auth_id: str):
|
|
|
+async def auth_claim(auth_id: str):
|
|
|
"""
|
|
|
Block up to 5 minutes to exchange the device code for tokens after user completed login.
|
|
|
On success, creates an enabled account and returns it.
|
|
|
@@ -589,7 +610,7 @@ def auth_claim(auth_id: str):
|
|
|
"error": sess.get("error"),
|
|
|
}
|
|
|
try:
|
|
|
- toks = poll_token_device_code(
|
|
|
+ toks = await poll_token_device_code(
|
|
|
sess["clientId"],
|
|
|
sess["clientSecret"],
|
|
|
sess["deviceCode"],
|
|
|
@@ -602,7 +623,7 @@ def auth_claim(auth_id: str):
|
|
|
if not access_token:
|
|
|
raise HTTPException(status_code=502, detail="No accessToken returned from OIDC")
|
|
|
|
|
|
- acc = _create_account_from_tokens(
|
|
|
+ acc = await _create_account_from_tokens(
|
|
|
sess["clientId"],
|
|
|
sess["clientSecret"],
|
|
|
access_token,
|
|
|
@@ -619,7 +640,7 @@ def auth_claim(auth_id: str):
|
|
|
except TimeoutError:
|
|
|
sess["status"] = "timeout"
|
|
|
raise HTTPException(status_code=408, detail="Authorization timeout (5 minutes)")
|
|
|
- except requests.RequestException as e:
|
|
|
+ except httpx.HTTPError as e:
|
|
|
sess["status"] = "error"
|
|
|
sess["error"] = str(e)
|
|
|
raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
|
|
|
@@ -629,13 +650,14 @@ def auth_claim(auth_id: str):
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
@app.post("/v2/accounts")
|
|
|
-def create_account(body: AccountCreate):
|
|
|
+async def create_account(body: AccountCreate):
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
acc_id = str(uuid.uuid4())
|
|
|
other_str = json.dumps(body.other, ensure_ascii=False) if body.other is not None else None
|
|
|
enabled_val = 1 if (body.enabled is None or body.enabled) else 0
|
|
|
- with _conn() as conn:
|
|
|
- conn.execute(
|
|
|
+ async with _conn() as conn:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ await conn.execute(
|
|
|
"""
|
|
|
INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
@@ -655,31 +677,34 @@ def create_account(body: AccountCreate):
|
|
|
enabled_val,
|
|
|
),
|
|
|
)
|
|
|
- conn.commit()
|
|
|
- row = conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)).fetchone()
|
|
|
- return _row_to_dict(row)
|
|
|
+ await conn.commit()
|
|
|
+ async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
|
|
|
+ row = await cursor.fetchone()
|
|
|
+ return _row_to_dict(row)
|
|
|
|
|
|
@app.get("/v2/accounts")
|
|
|
-def list_accounts():
|
|
|
- with _conn() as conn:
|
|
|
- rows = conn.execute("SELECT * FROM accounts ORDER BY created_at DESC").fetchall()
|
|
|
- return [_row_to_dict(r) for r in rows]
|
|
|
+async def list_accounts():
|
|
|
+ async with _conn() as conn:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ async with conn.execute("SELECT * FROM accounts ORDER BY created_at DESC") as cursor:
|
|
|
+ rows = await cursor.fetchall()
|
|
|
+ return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
@app.get("/v2/accounts/{account_id}")
|
|
|
-def get_account_detail(account_id: str):
|
|
|
- return get_account(account_id)
|
|
|
+async def get_account_detail(account_id: str):
|
|
|
+ return await get_account(account_id)
|
|
|
|
|
|
@app.delete("/v2/accounts/{account_id}")
|
|
|
-def delete_account(account_id: str):
|
|
|
- with _conn() as conn:
|
|
|
- cur = conn.execute("DELETE FROM accounts WHERE id=?", (account_id,))
|
|
|
- conn.commit()
|
|
|
+async def delete_account(account_id: str):
|
|
|
+ async with _conn() as conn:
|
|
|
+ cur = await conn.execute("DELETE FROM accounts WHERE id=?", (account_id,))
|
|
|
+ await conn.commit()
|
|
|
if cur.rowcount == 0:
|
|
|
raise HTTPException(status_code=404, detail="Account not found")
|
|
|
return {"deleted": account_id}
|
|
|
|
|
|
@app.patch("/v2/accounts/{account_id}")
|
|
|
-def update_account(account_id: str, body: AccountUpdate):
|
|
|
+async def update_account(account_id: str, body: AccountUpdate):
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
fields = []
|
|
|
values: List[Any] = []
|
|
|
@@ -700,22 +725,24 @@ def update_account(account_id: str, body: AccountUpdate):
|
|
|
fields.append("enabled=?"); values.append(1 if body.enabled else 0)
|
|
|
|
|
|
if not fields:
|
|
|
- return get_account(account_id)
|
|
|
+ return await get_account(account_id)
|
|
|
|
|
|
fields.append("updated_at=?"); values.append(now)
|
|
|
values.append(account_id)
|
|
|
|
|
|
- with _conn() as conn:
|
|
|
- cur = conn.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", values)
|
|
|
- conn.commit()
|
|
|
+ async with _conn() as conn:
|
|
|
+ conn.row_factory = aiosqlite.Row
|
|
|
+ cur = await conn.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", values)
|
|
|
+ await conn.commit()
|
|
|
if cur.rowcount == 0:
|
|
|
raise HTTPException(status_code=404, detail="Account not found")
|
|
|
- row = conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)).fetchone()
|
|
|
- return _row_to_dict(row)
|
|
|
+ async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
+ row = await cursor.fetchone()
|
|
|
+ return _row_to_dict(row)
|
|
|
|
|
|
@app.post("/v2/accounts/{account_id}/refresh")
|
|
|
-def manual_refresh(account_id: str):
|
|
|
- return refresh_access_token_in_db(account_id)
|
|
|
+async def manual_refresh(account_id: str):
|
|
|
+ return await refresh_access_token_in_db(account_id)
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Simple Frontend (minimal dev test page; full UI in v2/frontend/index.html)
|
|
|
@@ -735,5 +762,15 @@ def index():
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
@app.get("/healthz")
|
|
|
-def health():
|
|
|
- return {"status": "ok"}
|
|
|
+async def health():
|
|
|
+ return {"status": "ok"}
|
|
|
+
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+# Startup / Shutdown Events
|
|
|
+# ------------------------------------------------------------------------------
|
|
|
+
|
|
|
[email protected]_event("startup")
|
|
|
+async def startup_event():
|
|
|
+ """Initialize database and start background tasks on startup."""
|
|
|
+ await _ensure_db()
|
|
|
+ asyncio.create_task(_refresh_stale_tokens())
|