app.py 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399
  1. import os
  2. import json
  3. import traceback
  4. import uuid
  5. import time
  6. import asyncio
  7. import importlib.util
  8. import random
  9. import secrets
  10. from datetime import datetime, timedelta
  11. from pathlib import Path
  12. from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
  13. from fastapi import FastAPI, Depends, HTTPException, Header
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, FileResponse
  16. from pydantic import BaseModel
  17. from dotenv import load_dotenv
  18. import httpx
  19. import tiktoken
  20. from db import init_db, close_db, row_to_dict
  21. # ------------------------------------------------------------------------------
  22. # Tokenizer
  23. # ------------------------------------------------------------------------------
  24. try:
  25. # cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
  26. ENCODING = tiktoken.get_encoding("cl100k_base")
  27. except Exception:
  28. ENCODING = None
  29. def count_tokens(text: str, apply_multiplier: bool = False) -> int:
  30. """Counts tokens with tiktoken."""
  31. if not text or not ENCODING:
  32. return 0
  33. token_count = len(ENCODING.encode(text))
  34. if apply_multiplier:
  35. token_count = int(token_count * TOKEN_COUNT_MULTIPLIER)
  36. return token_count
  37. # ------------------------------------------------------------------------------
  38. # Bootstrap
  39. # ------------------------------------------------------------------------------
  40. BASE_DIR = Path(__file__).resolve().parent
  41. load_dotenv(BASE_DIR / ".env")
  42. app = FastAPI(title="v2 OpenAI-compatible Server (Amazon Q Backend)")
  43. # CORS for simple testing in browser
  44. app.add_middleware(
  45. CORSMiddleware,
  46. allow_origins=["*"],
  47. allow_methods=["*"],
  48. allow_headers=["*"],
  49. )
  50. # ------------------------------------------------------------------------------
  51. # Dynamic import of replicate.py to avoid package __init__ needs
  52. # ------------------------------------------------------------------------------
  53. def _load_replicate_module():
  54. mod_path = BASE_DIR / "replicate.py"
  55. spec = importlib.util.spec_from_file_location("v2_replicate", str(mod_path))
  56. module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
  57. assert spec is not None and spec.loader is not None
  58. spec.loader.exec_module(module) # type: ignore[attr-defined]
  59. return module
  60. _replicate = _load_replicate_module()
  61. send_chat_request = _replicate.send_chat_request
  62. # ------------------------------------------------------------------------------
  63. # Dynamic import of Claude modules
  64. # ------------------------------------------------------------------------------
  65. def _load_claude_modules():
  66. # claude_types
  67. spec_types = importlib.util.spec_from_file_location("v2_claude_types", str(BASE_DIR / "claude_types.py"))
  68. mod_types = importlib.util.module_from_spec(spec_types)
  69. spec_types.loader.exec_module(mod_types)
  70. # claude_converter
  71. spec_conv = importlib.util.spec_from_file_location("v2_claude_converter", str(BASE_DIR / "claude_converter.py"))
  72. mod_conv = importlib.util.module_from_spec(spec_conv)
  73. # We need to inject claude_types into converter's namespace if it uses relative imports or expects them
  74. # But since we used relative import in claude_converter.py (.claude_types), we need to be careful.
  75. # Actually, since we are loading dynamically, relative imports might fail if not in sys.modules correctly.
  76. # Let's patch sys.modules temporarily or just rely on file location.
  77. # A simpler way for this single-file script style is to just load them.
  78. # However, claude_converter does `from .claude_types import ...`
  79. # To make that work, we should probably just use standard import if v2 is a package,
  80. # but v2 is just a folder.
  81. # Let's assume the user runs this with v2 in pythonpath or we just fix imports in the files.
  82. # But I wrote `from .claude_types` in the file.
  83. # Let's try to load it. If it fails, we might need to adjust.
  84. # Actually, for simplicity in this `app.py` dynamic loading context,
  85. # it is better if `claude_converter.py` used absolute import or we mock the package.
  86. # BUT, let's try to just load them and see.
  87. # To avoid relative import issues, I will inject the module into sys.modules
  88. import sys
  89. sys.modules["v2.claude_types"] = mod_types
  90. spec_conv.loader.exec_module(mod_conv)
  91. # claude_stream
  92. spec_stream = importlib.util.spec_from_file_location("v2_claude_stream", str(BASE_DIR / "claude_stream.py"))
  93. mod_stream = importlib.util.module_from_spec(spec_stream)
  94. spec_stream.loader.exec_module(mod_stream)
  95. return mod_types, mod_conv, mod_stream
  96. try:
  97. _claude_types, _claude_converter, _claude_stream = _load_claude_modules()
  98. ClaudeRequest = _claude_types.ClaudeRequest
  99. convert_claude_to_amazonq_request = _claude_converter.convert_claude_to_amazonq_request
  100. ClaudeStreamHandler = _claude_stream.ClaudeStreamHandler
  101. except Exception as e:
  102. print(f"Failed to load Claude modules: {e}")
  103. traceback.print_exc()
  104. # Define dummy classes to avoid NameError on startup if loading fails
  105. class ClaudeRequest(BaseModel):
  106. pass
  107. convert_claude_to_amazonq_request = None
  108. ClaudeStreamHandler = None
  109. # ------------------------------------------------------------------------------
  110. # Global HTTP Client
  111. # ------------------------------------------------------------------------------
  112. GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
  113. def _get_proxies() -> Optional[Dict[str, str]]:
  114. proxy = os.getenv("HTTP_PROXY", "").strip()
  115. if proxy:
  116. return {"http": proxy, "https": proxy}
  117. return None
  118. async def _init_global_client():
  119. global GLOBAL_CLIENT
  120. proxies = _get_proxies()
  121. mounts = None
  122. if proxies:
  123. proxy_url = proxies.get("https") or proxies.get("http")
  124. if proxy_url:
  125. mounts = {
  126. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  127. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  128. }
  129. # Increased limits for high concurrency with streaming
  130. # max_connections: 总连接数上限
  131. # max_keepalive_connections: 保持活跃的连接数
  132. # keepalive_expiry: 连接保持时间
  133. limits = httpx.Limits(
  134. max_keepalive_connections=60,
  135. max_connections=60, # 提高到500以支持更高并发
  136. keepalive_expiry=30.0 # 30秒后释放空闲连接
  137. )
  138. # 为流式响应设置更长的超时
  139. timeout = httpx.Timeout(
  140. connect=1.0, # 连接超时
  141. read=300.0, # 读取超时(流式响应需要更长时间)
  142. write=1.0, # 写入超时
  143. pool=1.0 # 从连接池获取连接的超时时间(关键!)
  144. )
  145. GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
  146. async def _close_global_client():
  147. global GLOBAL_CLIENT
  148. if GLOBAL_CLIENT:
  149. await GLOBAL_CLIENT.aclose()
  150. GLOBAL_CLIENT = None
  151. # ------------------------------------------------------------------------------
  152. # Database helpers
  153. # ------------------------------------------------------------------------------
  154. # Database backend instance (initialized on startup)
  155. _db = None
  156. async def _ensure_db():
  157. """Initialize database backend."""
  158. global _db
  159. _db = await init_db()
  160. def _row_to_dict(r: Dict[str, Any]) -> Dict[str, Any]:
  161. """Convert database row to dict with JSON parsing."""
  162. return row_to_dict(r)
  163. # _ensure_db() will be called in startup event
  164. # ------------------------------------------------------------------------------
  165. # Background token refresh thread
  166. # ------------------------------------------------------------------------------
  167. async def _refresh_stale_tokens():
  168. while True:
  169. try:
  170. await asyncio.sleep(300) # 5 minutes
  171. if _db is None:
  172. print("[Error] Database not initialized, skipping token refresh cycle.")
  173. continue
  174. now = time.time()
  175. if LAZY_ACCOUNT_POOL_ENABLED:
  176. limit = LAZY_ACCOUNT_POOL_SIZE + LAZY_ACCOUNT_POOL_REFRESH_OFFSET
  177. order_direction = "DESC" if LAZY_ACCOUNT_POOL_ORDER_DESC else "ASC"
  178. query = f"SELECT id, last_refresh_time FROM accounts WHERE enabled=1 ORDER BY {LAZY_ACCOUNT_POOL_ORDER_BY} {order_direction} LIMIT {limit}"
  179. rows = await _db.fetchall(query)
  180. else:
  181. rows = await _db.fetchall("SELECT id, last_refresh_time FROM accounts WHERE enabled=1")
  182. for row in rows:
  183. acc_id, last_refresh = row['id'], row['last_refresh_time']
  184. should_refresh = False
  185. if not last_refresh or last_refresh == "never":
  186. should_refresh = True
  187. else:
  188. try:
  189. last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
  190. if now - last_time > 1500: # 25 minutes
  191. should_refresh = True
  192. except Exception:
  193. # Malformed or unparsable timestamp; force refresh
  194. should_refresh = True
  195. if should_refresh:
  196. try:
  197. await refresh_access_token_in_db(acc_id)
  198. except Exception:
  199. traceback.print_exc()
  200. # Ignore per-account refresh failure; timestamp/status are recorded inside
  201. pass
  202. except Exception:
  203. traceback.print_exc()
  204. pass
  205. # ------------------------------------------------------------------------------
  206. # Env and API Key authorization (keys are independent of AWS accounts)
  207. # ------------------------------------------------------------------------------
  208. def _parse_allowed_keys_env() -> List[str]:
  209. """
  210. OPENAI_KEYS is a comma-separated whitelist of API keys for authorization only.
  211. Example: OPENAI_KEYS="key1,key2,key3"
  212. - When the list is non-empty, incoming Authorization: Bearer {key} must be one of them.
  213. - When empty or unset, authorization is effectively disabled (dev mode).
  214. """
  215. s = os.getenv("OPENAI_KEYS", "") or ""
  216. keys: List[str] = []
  217. for k in [x.strip() for x in s.split(",") if x.strip()]:
  218. keys.append(k)
  219. return keys
  220. ALLOWED_API_KEYS: List[str] = _parse_allowed_keys_env()
  221. MAX_ERROR_COUNT: int = int(os.getenv("MAX_ERROR_COUNT", "100"))
  222. TOKEN_COUNT_MULTIPLIER: float = float(os.getenv("TOKEN_COUNT_MULTIPLIER", "1.0"))
  223. # Lazy Account Pool settings
  224. LAZY_ACCOUNT_POOL_ENABLED: bool = os.getenv("LAZY_ACCOUNT_POOL_ENABLED", "false").lower() in ("true", "1", "yes")
  225. LAZY_ACCOUNT_POOL_SIZE: int = int(os.getenv("LAZY_ACCOUNT_POOL_SIZE", "20"))
  226. LAZY_ACCOUNT_POOL_REFRESH_OFFSET: int = int(os.getenv("LAZY_ACCOUNT_POOL_REFRESH_OFFSET", "10"))
  227. LAZY_ACCOUNT_POOL_ORDER_BY: str = os.getenv("LAZY_ACCOUNT_POOL_ORDER_BY", "created_at")
  228. LAZY_ACCOUNT_POOL_ORDER_DESC: bool = os.getenv("LAZY_ACCOUNT_POOL_ORDER_DESC", "false").lower() in ("true", "1", "yes")
  229. # Validate LAZY_ACCOUNT_POOL_ORDER_BY to prevent SQL injection
  230. if LAZY_ACCOUNT_POOL_ORDER_BY not in ["created_at", "id", "success_count"]:
  231. LAZY_ACCOUNT_POOL_ORDER_BY = "created_at"
  232. def _is_console_enabled() -> bool:
  233. """检查是否启用管理控制台"""
  234. console_env = os.getenv("ENABLE_CONSOLE", "true").strip().lower()
  235. return console_env not in ("false", "0", "no", "disabled")
  236. CONSOLE_ENABLED: bool = _is_console_enabled()
  237. # Admin authentication configuration
  238. ADMIN_PASSWORD: str = os.getenv("ADMIN_PASSWORD", "admin")
  239. def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
  240. if not token_header:
  241. return None
  242. if token_header.startswith("Bearer "):
  243. return token_header.split(" ", 1)[1].strip()
  244. return token_header.strip()
  245. async def _list_enabled_accounts(limit: Optional[int] = None) -> List[Dict[str, Any]]:
  246. if LAZY_ACCOUNT_POOL_ENABLED:
  247. order_direction = "DESC" if LAZY_ACCOUNT_POOL_ORDER_DESC else "ASC"
  248. query = f"SELECT * FROM accounts WHERE enabled=1 ORDER BY {LAZY_ACCOUNT_POOL_ORDER_BY} {order_direction}"
  249. if limit:
  250. query += f" LIMIT {limit}"
  251. rows = await _db.fetchall(query)
  252. else:
  253. query = "SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC"
  254. if limit:
  255. query += f" LIMIT {limit}"
  256. rows = await _db.fetchall(query)
  257. return [_row_to_dict(r) for r in rows]
  258. async def _list_disabled_accounts() -> List[Dict[str, Any]]:
  259. rows = await _db.fetchall("SELECT * FROM accounts WHERE enabled=0 ORDER BY created_at DESC")
  260. return [_row_to_dict(r) for r in rows]
  261. async def verify_account(account: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
  262. """验证账号可用性"""
  263. try:
  264. account = await refresh_access_token_in_db(account['id'])
  265. test_request = {
  266. "conversationState": {
  267. "currentMessage": {"userInputMessage": {"content": "hello"}},
  268. "chatTriggerType": "MANUAL"
  269. }
  270. }
  271. _, _, tracker, event_gen = await send_chat_request(
  272. access_token=account['accessToken'],
  273. messages=[],
  274. stream=True,
  275. raw_payload=test_request
  276. )
  277. if event_gen:
  278. async for _ in event_gen:
  279. break
  280. return True, None
  281. except Exception as e:
  282. if "AccessDenied" in str(e) or "403" in str(e):
  283. return False, "AccessDenied"
  284. return False, None
  285. async def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
  286. """
  287. Authorize request by OPENAI_KEYS (if configured), then select an AWS account.
  288. Selection strategy: random among all enabled accounts. Authorization key does NOT map to any account.
  289. """
  290. # Authorization
  291. if ALLOWED_API_KEYS:
  292. if not bearer_key or bearer_key not in ALLOWED_API_KEYS:
  293. raise HTTPException(status_code=401, detail="Invalid or missing API key")
  294. # Selection: random among enabled accounts
  295. if LAZY_ACCOUNT_POOL_ENABLED:
  296. candidates = await _list_enabled_accounts(limit=LAZY_ACCOUNT_POOL_SIZE)
  297. else:
  298. candidates = await _list_enabled_accounts()
  299. if not candidates:
  300. raise HTTPException(status_code=401, detail="No enabled account available")
  301. return random.choice(candidates)
  302. # ------------------------------------------------------------------------------
  303. # Pydantic Schemas
  304. # ------------------------------------------------------------------------------
  305. class AccountCreate(BaseModel):
  306. label: Optional[str] = None
  307. clientId: str
  308. clientSecret: str
  309. refreshToken: Optional[str] = None
  310. accessToken: Optional[str] = None
  311. other: Optional[Dict[str, Any]] = None
  312. enabled: Optional[bool] = True
  313. class BatchAccountCreate(BaseModel):
  314. accounts: List[AccountCreate]
  315. class AccountUpdate(BaseModel):
  316. label: Optional[str] = None
  317. clientId: Optional[str] = None
  318. clientSecret: Optional[str] = None
  319. refreshToken: Optional[str] = None
  320. accessToken: Optional[str] = None
  321. other: Optional[Dict[str, Any]] = None
  322. enabled: Optional[bool] = None
  323. class ChatMessage(BaseModel):
  324. role: str
  325. content: Any
  326. class ChatCompletionRequest(BaseModel):
  327. model: Optional[str] = None
  328. messages: List[ChatMessage]
  329. stream: Optional[bool] = False
  330. # ------------------------------------------------------------------------------
  331. # Token refresh (OIDC)
  332. # ------------------------------------------------------------------------------
  333. OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
  334. TOKEN_URL = f"{OIDC_BASE}/token"
  335. def _oidc_headers() -> Dict[str, str]:
  336. return {
  337. "content-type": "application/json",
  338. "user-agent": "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0",
  339. "x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI",
  340. "amz-sdk-request": "attempt=1; max=3",
  341. "amz-sdk-invocation-id": str(uuid.uuid4()),
  342. }
  343. async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
  344. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  345. if not row:
  346. raise HTTPException(status_code=404, detail="Account not found")
  347. acc = _row_to_dict(row)
  348. if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
  349. raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
  350. payload = {
  351. "grantType": "refresh_token",
  352. "clientId": acc["clientId"],
  353. "clientSecret": acc["clientSecret"],
  354. "refreshToken": acc["refreshToken"],
  355. }
  356. try:
  357. # Use global client if available, else fallback (though global should be ready)
  358. client = GLOBAL_CLIENT
  359. if not client:
  360. # Fallback for safety
  361. async with httpx.AsyncClient(timeout=60.0) as temp_client:
  362. r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
  363. r.raise_for_status()
  364. data = r.json()
  365. else:
  366. r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
  367. r.raise_for_status()
  368. data = r.json()
  369. new_access = data.get("accessToken")
  370. new_refresh = data.get("refreshToken", acc.get("refreshToken"))
  371. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  372. status = "success"
  373. except httpx.HTTPError as e:
  374. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  375. status = "failed"
  376. await _db.execute(
  377. """
  378. UPDATE accounts
  379. SET last_refresh_time=?, last_refresh_status=?, updated_at=?
  380. WHERE id=?
  381. """,
  382. (now, status, now, account_id),
  383. )
  384. # 记录刷新失败次数
  385. await _update_stats(account_id, False)
  386. raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
  387. except Exception as e:
  388. # Ensure last_refresh_time is recorded even on unexpected errors
  389. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  390. status = "failed"
  391. await _db.execute(
  392. """
  393. UPDATE accounts
  394. SET last_refresh_time=?, last_refresh_status=?, updated_at=?
  395. WHERE id=?
  396. """,
  397. (now, status, now, account_id),
  398. )
  399. # 记录刷新失败次数
  400. await _update_stats(account_id, False)
  401. raise
  402. await _db.execute(
  403. """
  404. UPDATE accounts
  405. SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
  406. WHERE id=?
  407. """,
  408. (new_access, new_refresh, now, status, now, account_id),
  409. )
  410. row2 = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  411. return _row_to_dict(row2)
  412. async def get_account(account_id: str) -> Dict[str, Any]:
  413. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  414. if not row:
  415. raise HTTPException(status_code=404, detail="Account not found")
  416. return _row_to_dict(row)
  417. async def _update_stats(account_id: str, success: bool) -> None:
  418. if success:
  419. await _db.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
  420. (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  421. else:
  422. row = await _db.fetchone("SELECT error_count FROM accounts WHERE id=?", (account_id,))
  423. if row:
  424. new_count = (row['error_count'] or 0) + 1
  425. if new_count >= MAX_ERROR_COUNT:
  426. await _db.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
  427. (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  428. else:
  429. await _db.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
  430. (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  431. # ------------------------------------------------------------------------------
  432. # Dependencies
  433. # ------------------------------------------------------------------------------
  434. async def require_account(
  435. authorization: Optional[str] = Header(default=None),
  436. x_api_key: Optional[str] = Header(default=None)
  437. ) -> Dict[str, Any]:
  438. key = _extract_bearer(authorization) if authorization else x_api_key
  439. return await resolve_account_for_key(key)
  440. def verify_admin_password(authorization: Optional[str] = Header(None)) -> bool:
  441. """Verify admin password for console access"""
  442. if not authorization or not authorization.startswith("Bearer "):
  443. raise HTTPException(
  444. status_code=401,
  445. detail={"error": "Unauthorized access", "code": "UNAUTHORIZED"}
  446. )
  447. password = authorization[7:] # Remove "Bearer " prefix
  448. if password != ADMIN_PASSWORD:
  449. raise HTTPException(
  450. status_code=401,
  451. detail={"error": "Invalid password", "code": "INVALID_PASSWORD"}
  452. )
  453. return True
  454. # ------------------------------------------------------------------------------
  455. # OpenAI-compatible Chat endpoint
  456. # ------------------------------------------------------------------------------
  457. def _openai_non_streaming_response(
  458. text: str,
  459. model: Optional[str],
  460. prompt_tokens: int = 0,
  461. completion_tokens: int = 0,
  462. ) -> Dict[str, Any]:
  463. created = int(time.time())
  464. return {
  465. "id": f"chatcmpl-{uuid.uuid4()}",
  466. "object": "chat.completion",
  467. "created": created,
  468. "model": model or "unknown",
  469. "choices": [
  470. {
  471. "index": 0,
  472. "message": {
  473. "role": "assistant",
  474. "content": text,
  475. },
  476. "finish_reason": "stop",
  477. }
  478. ],
  479. "usage": {
  480. "prompt_tokens": prompt_tokens,
  481. "completion_tokens": completion_tokens,
  482. "total_tokens": prompt_tokens + completion_tokens,
  483. },
  484. }
  485. def _sse_format(obj: Dict[str, Any]) -> str:
  486. return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
  487. @app.post("/v1/messages")
  488. async def claude_messages(req: ClaudeRequest, account: Dict[str, Any] = Depends(require_account)):
  489. """
  490. Claude-compatible messages endpoint.
  491. """
  492. # 1. Convert request
  493. try:
  494. aq_request = convert_claude_to_amazonq_request(req)
  495. except Exception as e:
  496. traceback.print_exc()
  497. raise HTTPException(status_code=400, detail=f"Request conversion failed: {str(e)}")
  498. # 2. Send upstream
  499. async def _send_upstream_raw() -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], Any, Optional[AsyncGenerator[Any, None]]]:
  500. access = account.get("accessToken")
  501. if not access:
  502. refreshed = await refresh_access_token_in_db(account["id"])
  503. access = refreshed.get("accessToken")
  504. if not access:
  505. raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
  506. # We use the modified send_chat_request which accepts raw_payload
  507. # and returns (text, text_stream, tracker, event_stream)
  508. return await send_chat_request(
  509. access_token=access,
  510. messages=[], # Not used when raw_payload is present
  511. model=req.model,
  512. stream=req.stream,
  513. client=GLOBAL_CLIENT,
  514. raw_payload=aq_request
  515. )
  516. try:
  517. _, _, tracker, event_stream = await _send_upstream_raw()
  518. if not req.stream:
  519. # Non-streaming: we need to consume the stream and build response
  520. # But wait, send_chat_request with stream=False returns text, but we need structured response
  521. # Actually, for Claude format, we might want to parse the events even for non-streaming
  522. # to get tool calls etc correctly.
  523. # However, our modified send_chat_request returns event_stream if raw_payload is used AND stream=True?
  524. # Let's check replicate.py modification.
  525. # If stream=False, it returns text. But text might not be enough for tool calls.
  526. # For simplicity, let's force stream=True internally and aggregate if req.stream is False.
  527. pass
  528. except Exception as e:
  529. await _update_stats(account["id"], False)
  530. raise
  531. # We always use streaming upstream to handle events properly
  532. try:
  533. # Force stream=True for upstream to get events
  534. # But wait, send_chat_request logic: if stream=True, returns event_stream
  535. # We need to call it with stream=True
  536. pass
  537. except:
  538. pass
  539. # Re-implementing logic to be cleaner
  540. # Always stream from upstream to get full event details
  541. event_iter = None
  542. first_event_received = False
  543. try:
  544. access = account.get("accessToken")
  545. if not access:
  546. refreshed = await refresh_access_token_in_db(account["id"])
  547. access = refreshed.get("accessToken")
  548. # We call with stream=True to get the event iterator
  549. _, _, tracker, event_iter = await send_chat_request(
  550. access_token=access,
  551. messages=[],
  552. model=req.model,
  553. stream=True,
  554. client=GLOBAL_CLIENT,
  555. raw_payload=aq_request
  556. )
  557. if not event_iter:
  558. raise HTTPException(status_code=502, detail="No event stream returned")
  559. # Handler
  560. # Calculate input tokens
  561. text_to_count = ""
  562. if req.system:
  563. if isinstance(req.system, str):
  564. text_to_count += req.system
  565. elif isinstance(req.system, list):
  566. for item in req.system:
  567. if isinstance(item, dict) and item.get("type") == "text":
  568. text_to_count += item.get("text", "")
  569. for msg in req.messages:
  570. if isinstance(msg.content, str):
  571. text_to_count += msg.content
  572. elif isinstance(msg.content, list):
  573. for item in msg.content:
  574. if isinstance(item, dict) and item.get("type") == "text":
  575. text_to_count += item.get("text", "")
  576. input_tokens = count_tokens(text_to_count, apply_multiplier=True)
  577. handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens)
  578. # Try to get the first event to ensure the connection is valid
  579. # This allows us to return proper HTTP error codes before starting the stream
  580. first_event = None
  581. try:
  582. first_event = await event_iter.__anext__()
  583. first_event_received = True
  584. except StopAsyncIteration:
  585. raise HTTPException(status_code=502, detail="Empty response from upstream")
  586. except Exception as e:
  587. # If we get an error before the first event, we can still return proper status code
  588. raise HTTPException(status_code=502, detail=f"Upstream error: {str(e)}")
  589. async def event_generator():
  590. try:
  591. # Process the first event we already fetched
  592. if first_event:
  593. event_type, payload = first_event
  594. async for sse in handler.handle_event(event_type, payload):
  595. yield sse
  596. # Process remaining events
  597. async for event_type, payload in event_iter:
  598. async for sse in handler.handle_event(event_type, payload):
  599. yield sse
  600. async for sse in handler.finish():
  601. yield sse
  602. await _update_stats(account["id"], True)
  603. except GeneratorExit:
  604. # Client disconnected - update stats but don't re-raise
  605. await _update_stats(account["id"], tracker.has_content if tracker else False)
  606. except Exception:
  607. await _update_stats(account["id"], False)
  608. raise
  609. if req.stream:
  610. return StreamingResponse(event_generator(), media_type="text/event-stream")
  611. else:
  612. # Accumulate for non-streaming
  613. # This is a bit complex because we need to reconstruct the full response object
  614. # For now, let's just support streaming as it's the main use case for Claude Code
  615. # But to be nice, let's try to support non-streaming by consuming the generator
  616. content_blocks = []
  617. usage = {"input_tokens": 0, "output_tokens": 0}
  618. stop_reason = None
  619. # We need to parse the SSE strings back to objects... inefficient but works
  620. # Or we could refactor handler to yield objects.
  621. # For now, let's just raise error for non-streaming or implement basic text
  622. # Claude Code uses streaming.
  623. # Let's implement a basic accumulator from the SSE stream
  624. final_content = []
  625. async for sse_chunk in event_generator():
  626. data_str = None
  627. # Each chunk from the generator can have multiple lines ('event:', 'data:').
  628. # We need to find the 'data:' line.
  629. for line in sse_chunk.strip().split('\n'):
  630. if line.startswith("data:"):
  631. data_str = line[6:].strip()
  632. break
  633. if not data_str or data_str == "[DONE]":
  634. continue
  635. try:
  636. data = json.loads(data_str)
  637. dtype = data.get("type")
  638. if dtype == "content_block_start":
  639. idx = data.get("index", 0)
  640. while len(final_content) <= idx:
  641. final_content.append(None)
  642. final_content[idx] = data.get("content_block")
  643. elif dtype == "content_block_delta":
  644. idx = data.get("index", 0)
  645. delta = data.get("delta", {})
  646. if final_content[idx]:
  647. if delta.get("type") == "text_delta":
  648. final_content[idx]["text"] += delta.get("text", "")
  649. elif delta.get("type") == "thinking_delta":
  650. final_content[idx].setdefault("thinking", "")
  651. final_content[idx]["thinking"] += delta.get("thinking", "")
  652. elif delta.get("type") == "input_json_delta":
  653. if "partial_json" not in final_content[idx]:
  654. final_content[idx]["partial_json"] = ""
  655. final_content[idx]["partial_json"] += delta.get("partial_json", "")
  656. elif dtype == "content_block_stop":
  657. idx = data.get("index", 0)
  658. if final_content[idx] and final_content[idx].get("type") == "tool_use":
  659. if "partial_json" in final_content[idx]:
  660. try:
  661. final_content[idx]["input"] = json.loads(final_content[idx]["partial_json"])
  662. except json.JSONDecodeError:
  663. # Keep partial if invalid
  664. final_content[idx]["input"] = {"error": "invalid json", "partial": final_content[idx]["partial_json"]}
  665. del final_content[idx]["partial_json"]
  666. elif dtype == "message_delta":
  667. usage = data.get("usage", usage)
  668. stop_reason = data.get("delta", {}).get("stop_reason")
  669. except json.JSONDecodeError:
  670. # Ignore lines that are not valid JSON
  671. pass
  672. except Exception:
  673. # Broad exception to prevent accumulator from crashing on one bad event
  674. traceback.print_exc()
  675. pass
  676. # Final assembly
  677. final_content_cleaned = []
  678. for c in final_content:
  679. if c is not None:
  680. # Remove internal state like 'partial_json' before returning
  681. c.pop("partial_json", None)
  682. final_content_cleaned.append(c)
  683. return JSONResponse(content={
  684. "id": f"msg_{uuid.uuid4()}",
  685. "type": "message",
  686. "role": "assistant",
  687. "model": req.model,
  688. "content": final_content_cleaned,
  689. "stop_reason": stop_reason,
  690. "stop_sequence": None,
  691. "usage": usage
  692. })
  693. except Exception as e:
  694. # Ensure event_iter (if created) is closed to release upstream connection
  695. try:
  696. if event_iter and hasattr(event_iter, "aclose"):
  697. await event_iter.aclose()
  698. except Exception:
  699. pass
  700. await _update_stats(account["id"], False)
  701. raise
  702. @app.post("/v1/messages/count_tokens")
  703. async def count_tokens_endpoint(req: ClaudeRequest):
  704. """
  705. Count tokens in a message without sending it.
  706. Compatible with Claude API's /v1/messages/count_tokens endpoint.
  707. Uses tiktoken for local token counting.
  708. """
  709. text_to_count = ""
  710. # Count system prompt tokens
  711. if req.system:
  712. if isinstance(req.system, str):
  713. text_to_count += req.system
  714. elif isinstance(req.system, list):
  715. for item in req.system:
  716. if isinstance(item, dict) and item.get("type") == "text":
  717. text_to_count += item.get("text", "")
  718. # Count message tokens
  719. for msg in req.messages:
  720. if isinstance(msg.content, str):
  721. text_to_count += msg.content
  722. elif isinstance(msg.content, list):
  723. for item in msg.content:
  724. if isinstance(item, dict) and item.get("type") == "text":
  725. text_to_count += item.get("text", "")
  726. # Count tool definition tokens if present
  727. if req.tools:
  728. text_to_count += json.dumps([tool.model_dump() if hasattr(tool, 'model_dump') else tool for tool in req.tools], ensure_ascii=False)
  729. input_tokens = count_tokens(text_to_count, apply_multiplier=True)
  730. return {"input_tokens": input_tokens}
  731. @app.post("/v1/chat/completions")
  732. async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
  733. """
  734. OpenAI-compatible chat endpoint.
  735. - stream default False
  736. - messages will be converted into "{role}:\n{content}" and injected into template
  737. - account is chosen randomly among enabled accounts (API key is for authorization only)
  738. """
  739. model = req.model
  740. do_stream = bool(req.stream)
  741. async def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], Any]:
  742. access = account.get("accessToken")
  743. if not access:
  744. refreshed = await refresh_access_token_in_db(account["id"])
  745. access = refreshed.get("accessToken")
  746. if not access:
  747. raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
  748. # Note: send_chat_request signature changed, but we use keyword args so it should be fine if we don't pass raw_payload
  749. # But wait, the return signature changed too! It now returns 4 values.
  750. # We need to unpack 4 values.
  751. result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
  752. return result[0], result[1], result[2] # Ignore the 4th value (event_stream) for OpenAI endpoint
  753. if not do_stream:
  754. try:
  755. # Calculate prompt tokens
  756. prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
  757. prompt_tokens = count_tokens(prompt_text)
  758. text, _, tracker = await _send_upstream(stream=False)
  759. await _update_stats(account["id"], bool(text))
  760. completion_tokens = count_tokens(text or "")
  761. return JSONResponse(content=_openai_non_streaming_response(
  762. text or "",
  763. model,
  764. prompt_tokens=prompt_tokens,
  765. completion_tokens=completion_tokens
  766. ))
  767. except Exception as e:
  768. await _update_stats(account["id"], False)
  769. raise
  770. else:
  771. created = int(time.time())
  772. stream_id = f"chatcmpl-{uuid.uuid4()}"
  773. model_used = model or "unknown"
  774. it = None
  775. try:
  776. # Calculate prompt tokens
  777. prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
  778. prompt_tokens = count_tokens(prompt_text)
  779. _, it, tracker = await _send_upstream(stream=True)
  780. assert it is not None
  781. async def event_gen() -> AsyncGenerator[str, None]:
  782. completion_text = ""
  783. try:
  784. # Send role first
  785. yield _sse_format({
  786. "id": stream_id,
  787. "object": "chat.completion.chunk",
  788. "created": created,
  789. "model": model_used,
  790. "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
  791. })
  792. # Stream content
  793. async for piece in it:
  794. if piece:
  795. completion_text += piece
  796. yield _sse_format({
  797. "id": stream_id,
  798. "object": "chat.completion.chunk",
  799. "created": created,
  800. "model": model_used,
  801. "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
  802. })
  803. # Send stop and usage
  804. completion_tokens = count_tokens(completion_text)
  805. yield _sse_format({
  806. "id": stream_id,
  807. "object": "chat.completion.chunk",
  808. "created": created,
  809. "model": model_used,
  810. "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
  811. "usage": {
  812. "prompt_tokens": prompt_tokens,
  813. "completion_tokens": completion_tokens,
  814. "total_tokens": prompt_tokens + completion_tokens,
  815. }
  816. })
  817. yield "data: [DONE]\n\n"
  818. await _update_stats(account["id"], True)
  819. except GeneratorExit:
  820. # Client disconnected - update stats but don't re-raise
  821. await _update_stats(account["id"], tracker.has_content if tracker else False)
  822. except Exception:
  823. await _update_stats(account["id"], tracker.has_content if tracker else False)
  824. raise
  825. return StreamingResponse(event_gen(), media_type="text/event-stream")
  826. except Exception as e:
  827. # Ensure iterator (if created) is closed to release upstream connection
  828. try:
  829. if it and hasattr(it, "aclose"):
  830. await it.aclose()
  831. except Exception:
  832. pass
  833. await _update_stats(account["id"], False)
  834. raise
  835. # ------------------------------------------------------------------------------
  836. # Device Authorization (URL Login, 5-minute timeout)
  837. # ------------------------------------------------------------------------------
  838. # Dynamic import of auth_flow.py (device-code login helpers)
  839. def _load_auth_flow_module():
  840. mod_path = BASE_DIR / "auth_flow.py"
  841. spec = importlib.util.spec_from_file_location("v2_auth_flow", str(mod_path))
  842. module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
  843. assert spec is not None and spec.loader is not None
  844. spec.loader.exec_module(module) # type: ignore[attr-defined]
  845. return module
  846. _auth_flow = _load_auth_flow_module()
  847. register_client_min = _auth_flow.register_client_min
  848. device_authorize = _auth_flow.device_authorize
  849. poll_token_device_code = _auth_flow.poll_token_device_code
  850. # In-memory auth sessions (ephemeral)
  851. AUTH_SESSIONS: Dict[str, Dict[str, Any]] = {}
  852. class AuthStartBody(BaseModel):
  853. label: Optional[str] = None
  854. enabled: Optional[bool] = True
  855. class AdminLoginRequest(BaseModel):
  856. password: str
  857. class AdminLoginResponse(BaseModel):
  858. success: bool
  859. message: str
  860. async def _create_account_from_tokens(
  861. client_id: str,
  862. client_secret: str,
  863. access_token: str,
  864. refresh_token: Optional[str],
  865. label: Optional[str],
  866. enabled: bool,
  867. ) -> Dict[str, Any]:
  868. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  869. acc_id = str(uuid.uuid4())
  870. await _db.execute(
  871. """
  872. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
  873. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  874. """,
  875. (
  876. acc_id,
  877. label,
  878. client_id,
  879. client_secret,
  880. refresh_token,
  881. access_token,
  882. None,
  883. now,
  884. "success",
  885. now,
  886. now,
  887. 1 if enabled else 0,
  888. ),
  889. )
  890. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
  891. return _row_to_dict(row)
  892. # 管理控制台相关端点 - 仅在启用时注册
  893. if CONSOLE_ENABLED:
  894. # ------------------------------------------------------------------------------
  895. # Admin Authentication Endpoints
  896. # ------------------------------------------------------------------------------
  897. @app.post("/api/login", response_model=AdminLoginResponse)
  898. async def admin_login(request: AdminLoginRequest) -> AdminLoginResponse:
  899. """Admin login endpoint - password only"""
  900. if request.password == ADMIN_PASSWORD:
  901. return AdminLoginResponse(
  902. success=True,
  903. message="Login successful"
  904. )
  905. else:
  906. return AdminLoginResponse(
  907. success=False,
  908. message="Invalid password"
  909. )
  910. @app.get("/login", response_class=FileResponse)
  911. def login_page():
  912. """Serve the login page"""
  913. path = BASE_DIR / "frontend" / "login.html"
  914. if not path.exists():
  915. raise HTTPException(status_code=404, detail="frontend/login.html not found")
  916. return FileResponse(str(path))
  917. # ------------------------------------------------------------------------------
  918. # Device Authorization Endpoints
  919. # ------------------------------------------------------------------------------
  920. @app.post("/v2/auth/start")
  921. async def auth_start(body: AuthStartBody, _: bool = Depends(verify_admin_password)):
  922. """
  923. Start device authorization and return verification URL for user login.
  924. Session lifetime capped at 5 minutes on claim.
  925. """
  926. try:
  927. cid, csec = await register_client_min()
  928. dev = await device_authorize(cid, csec)
  929. except httpx.HTTPError as e:
  930. raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
  931. auth_id = str(uuid.uuid4())
  932. sess = {
  933. "clientId": cid,
  934. "clientSecret": csec,
  935. "deviceCode": dev.get("deviceCode"),
  936. "interval": int(dev.get("interval", 1)),
  937. "expiresIn": int(dev.get("expiresIn", 600)),
  938. "verificationUriComplete": dev.get("verificationUriComplete"),
  939. "userCode": dev.get("userCode"),
  940. "startTime": int(time.time()),
  941. "label": body.label,
  942. "enabled": True if body.enabled is None else bool(body.enabled),
  943. "status": "pending",
  944. "error": None,
  945. "accountId": None,
  946. }
  947. AUTH_SESSIONS[auth_id] = sess
  948. return {
  949. "authId": auth_id,
  950. "verificationUriComplete": sess["verificationUriComplete"],
  951. "userCode": sess["userCode"],
  952. "expiresIn": sess["expiresIn"],
  953. "interval": sess["interval"],
  954. }
  955. @app.get("/v2/auth/status/{auth_id}")
  956. async def auth_status(auth_id: str, _: bool = Depends(verify_admin_password)):
  957. sess = AUTH_SESSIONS.get(auth_id)
  958. if not sess:
  959. raise HTTPException(status_code=404, detail="Auth session not found")
  960. now_ts = int(time.time())
  961. deadline = sess["startTime"] + min(int(sess.get("expiresIn", 600)), 300)
  962. remaining = max(0, deadline - now_ts)
  963. return {
  964. "status": sess.get("status"),
  965. "remaining": remaining,
  966. "error": sess.get("error"),
  967. "accountId": sess.get("accountId"),
  968. }
  969. @app.post("/v2/auth/claim/{auth_id}")
  970. async def auth_claim(auth_id: str, _: bool = Depends(verify_admin_password)):
  971. """
  972. Block up to 5 minutes to exchange the device code for tokens after user completed login.
  973. On success, creates an enabled account and returns it.
  974. """
  975. sess = AUTH_SESSIONS.get(auth_id)
  976. if not sess:
  977. raise HTTPException(status_code=404, detail="Auth session not found")
  978. if sess.get("status") in ("completed", "timeout", "error"):
  979. return {
  980. "status": sess["status"],
  981. "accountId": sess.get("accountId"),
  982. "error": sess.get("error"),
  983. }
  984. try:
  985. toks = await poll_token_device_code(
  986. sess["clientId"],
  987. sess["clientSecret"],
  988. sess["deviceCode"],
  989. sess["interval"],
  990. sess["expiresIn"],
  991. max_timeout_sec=300, # 5 minutes
  992. )
  993. access_token = toks.get("accessToken")
  994. refresh_token = toks.get("refreshToken")
  995. if not access_token:
  996. raise HTTPException(status_code=502, detail="No accessToken returned from OIDC")
  997. acc = await _create_account_from_tokens(
  998. sess["clientId"],
  999. sess["clientSecret"],
  1000. access_token,
  1001. refresh_token,
  1002. sess.get("label"),
  1003. sess.get("enabled", True),
  1004. )
  1005. sess["status"] = "completed"
  1006. sess["accountId"] = acc["id"]
  1007. return {
  1008. "status": "completed",
  1009. "account": acc,
  1010. }
  1011. except TimeoutError:
  1012. sess["status"] = "timeout"
  1013. raise HTTPException(status_code=408, detail="Authorization timeout (5 minutes)")
  1014. except httpx.HTTPError as e:
  1015. sess["status"] = "error"
  1016. sess["error"] = str(e)
  1017. raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
  1018. # ------------------------------------------------------------------------------
  1019. # Accounts Management API
  1020. # ------------------------------------------------------------------------------
  1021. @app.post("/v2/accounts")
  1022. async def create_account(body: AccountCreate, _: bool = Depends(verify_admin_password)):
  1023. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1024. acc_id = str(uuid.uuid4())
  1025. other_str = json.dumps(body.other, ensure_ascii=False) if body.other is not None else None
  1026. enabled_val = 1 if (body.enabled is None or body.enabled) else 0
  1027. await _db.execute(
  1028. """
  1029. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
  1030. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  1031. """,
  1032. (
  1033. acc_id,
  1034. body.label,
  1035. body.clientId,
  1036. body.clientSecret,
  1037. body.refreshToken,
  1038. body.accessToken,
  1039. other_str,
  1040. None,
  1041. "never",
  1042. now,
  1043. now,
  1044. enabled_val,
  1045. ),
  1046. )
  1047. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
  1048. return _row_to_dict(row)
  1049. async def _verify_and_enable_accounts(account_ids: List[str]):
  1050. """后台异步验证并启用账号"""
  1051. for acc_id in account_ids:
  1052. try:
  1053. # 必须先获取完整的账号信息
  1054. account = await get_account(acc_id)
  1055. verify_success, fail_reason = await verify_account(account)
  1056. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1057. if verify_success:
  1058. await _db.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, acc_id))
  1059. elif fail_reason:
  1060. other_dict = account.get("other", {}) or {}
  1061. other_dict['failedReason'] = fail_reason
  1062. await _db.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, acc_id))
  1063. except Exception as e:
  1064. print(f"Error verifying account {acc_id}: {e}")
  1065. traceback.print_exc()
  1066. @app.post("/v2/accounts/feed")
  1067. async def create_accounts_feed(request: BatchAccountCreate, _: bool = Depends(verify_admin_password)):
  1068. """
  1069. 统一的投喂接口,接收账号列表,立即存入并后台异步验证。
  1070. """
  1071. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1072. new_account_ids = []
  1073. for i, account_data in enumerate(request.accounts):
  1074. acc_id = str(uuid.uuid4())
  1075. other_dict = account_data.other or {}
  1076. other_dict['source'] = 'feed'
  1077. other_str = json.dumps(other_dict, ensure_ascii=False)
  1078. await _db.execute(
  1079. """
  1080. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
  1081. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  1082. """,
  1083. (
  1084. acc_id,
  1085. account_data.label or f"批量账号 {i+1}",
  1086. account_data.clientId,
  1087. account_data.clientSecret,
  1088. account_data.refreshToken,
  1089. account_data.accessToken,
  1090. other_str,
  1091. None,
  1092. "never",
  1093. now,
  1094. now,
  1095. 0, # 初始为禁用状态
  1096. ),
  1097. )
  1098. new_account_ids.append(acc_id)
  1099. # 启动后台任务进行验证,不阻塞当前请求
  1100. if new_account_ids:
  1101. asyncio.create_task(_verify_and_enable_accounts(new_account_ids))
  1102. return {
  1103. "status": "processing",
  1104. "message": f"{len(new_account_ids)} accounts received and are being verified in the background.",
  1105. "account_ids": new_account_ids
  1106. }
  1107. @app.get("/v2/accounts")
  1108. async def list_accounts(_: bool = Depends(verify_admin_password), enabled: Optional[bool] = None, sort_by: str = "created_at", sort_order: str = "desc"):
  1109. query = "SELECT * FROM accounts"
  1110. params = []
  1111. if enabled is not None:
  1112. query += " WHERE enabled=?"
  1113. params.append(1 if enabled else 0)
  1114. sort_field = "created_at" if sort_by not in ["created_at", "success_count"] else sort_by
  1115. order = "DESC" if sort_order.lower() == "desc" else "ASC"
  1116. query += f" ORDER BY {sort_field} {order}"
  1117. rows = await _db.fetchall(query, tuple(params) if params else ())
  1118. accounts = [_row_to_dict(r) for r in rows]
  1119. return {"accounts": accounts, "count": len(accounts)}
  1120. @app.get("/v2/accounts/{account_id}")
  1121. async def get_account_detail(account_id: str, _: bool = Depends(verify_admin_password)):
  1122. return await get_account(account_id)
  1123. @app.delete("/v2/accounts/{account_id}")
  1124. async def delete_account(account_id: str, _: bool = Depends(verify_admin_password)):
  1125. rowcount = await _db.execute("DELETE FROM accounts WHERE id=?", (account_id,))
  1126. if rowcount == 0:
  1127. raise HTTPException(status_code=404, detail="Account not found")
  1128. return {"deleted": account_id}
  1129. @app.patch("/v2/accounts/{account_id}")
  1130. async def update_account(account_id: str, body: AccountUpdate, _: bool = Depends(verify_admin_password)):
  1131. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1132. fields = []
  1133. values: List[Any] = []
  1134. if body.label is not None:
  1135. fields.append("label=?"); values.append(body.label)
  1136. if body.clientId is not None:
  1137. fields.append("clientId=?"); values.append(body.clientId)
  1138. if body.clientSecret is not None:
  1139. fields.append("clientSecret=?"); values.append(body.clientSecret)
  1140. if body.refreshToken is not None:
  1141. fields.append("refreshToken=?"); values.append(body.refreshToken)
  1142. if body.accessToken is not None:
  1143. fields.append("accessToken=?"); values.append(body.accessToken)
  1144. if body.other is not None:
  1145. fields.append("other=?"); values.append(json.dumps(body.other, ensure_ascii=False))
  1146. if body.enabled is not None:
  1147. fields.append("enabled=?"); values.append(1 if body.enabled else 0)
  1148. if not fields:
  1149. return await get_account(account_id)
  1150. fields.append("updated_at=?"); values.append(now)
  1151. values.append(account_id)
  1152. rowcount = await _db.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", tuple(values))
  1153. if rowcount == 0:
  1154. raise HTTPException(status_code=404, detail="Account not found")
  1155. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  1156. return _row_to_dict(row)
  1157. @app.post("/v2/accounts/{account_id}/refresh")
  1158. async def manual_refresh(account_id: str, _: bool = Depends(verify_admin_password)):
  1159. return await refresh_access_token_in_db(account_id)
  1160. # ------------------------------------------------------------------------------
  1161. # Simple Frontend (minimal dev test page; full UI in v2/frontend/index.html)
  1162. # ------------------------------------------------------------------------------
  1163. # Frontend inline HTML removed; serving ./frontend/index.html instead (see route below)
  1164. # Note: This route is NOT protected - the HTML file is served freely,
  1165. # but the frontend JavaScript checks authentication and redirects to /login if needed.
  1166. # All API endpoints remain protected.
  1167. @app.get("/", response_class=FileResponse)
  1168. def index():
  1169. path = BASE_DIR / "frontend" / "index.html"
  1170. if not path.exists():
  1171. raise HTTPException(status_code=404, detail="frontend/index.html not found")
  1172. return FileResponse(str(path))
  1173. # ------------------------------------------------------------------------------
  1174. # Health
  1175. # ------------------------------------------------------------------------------
  1176. @app.get("/healthz")
  1177. async def health():
  1178. return {"status": "ok"}
  1179. # ------------------------------------------------------------------------------
  1180. # Startup / Shutdown Events
  1181. # ------------------------------------------------------------------------------
  1182. # async def _verify_disabled_accounts_loop():
  1183. # """后台验证禁用账号任务"""
  1184. # while True:
  1185. # try:
  1186. # await asyncio.sleep(1800)
  1187. # async with _conn() as conn:
  1188. # accounts = await _list_disabled_accounts(conn)
  1189. # if accounts:
  1190. # for account in accounts:
  1191. # other = account.get('other')
  1192. # if other:
  1193. # try:
  1194. # other_dict = json.loads(other) if isinstance(other, str) else other
  1195. # if other_dict.get('failedReason') == 'AccessDenied':
  1196. # continue
  1197. # except:
  1198. # pass
  1199. # try:
  1200. # verify_success, fail_reason = await verify_account(account)
  1201. # now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1202. # if verify_success:
  1203. # await conn.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, account['id']))
  1204. # elif fail_reason:
  1205. # other_dict = {}
  1206. # if account.get('other'):
  1207. # try:
  1208. # other_dict = json.loads(account['other']) if isinstance(account['other'], str) else account['other']
  1209. # except:
  1210. # pass
  1211. # other_dict['failedReason'] = fail_reason
  1212. # await conn.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, account['id']))
  1213. # await conn.commit()
  1214. # except Exception:
  1215. # pass
  1216. # except Exception:
  1217. # pass
  1218. @app.on_event("startup")
  1219. async def startup_event():
  1220. """Initialize database and start background tasks on startup."""
  1221. await _init_global_client()
  1222. await _ensure_db()
  1223. asyncio.create_task(_refresh_stale_tokens())
  1224. # asyncio.create_task(_verify_disabled_accounts_loop())
  1225. @app.on_event("shutdown")
  1226. async def shutdown_event():
  1227. await _close_global_client()
  1228. await close_db()