app.py 63 KB


  1. import os
  2. import json
  3. import traceback
  4. import uuid
  5. import time
  6. import asyncio
  7. import importlib.util
  8. import random
  9. import secrets
  10. import re
  11. import logging
  12. from datetime import datetime, timedelta
  13. from pathlib import Path
  14. from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
  15. from fastapi import FastAPI, Depends, HTTPException, Header
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, FileResponse
  18. from pydantic import BaseModel
  19. from dotenv import load_dotenv
  20. import httpx
  21. import tiktoken
  22. from db import init_db, close_db, row_to_dict
  23. # ------------------------------------------------------------------------------
  24. # Logger
  25. # ------------------------------------------------------------------------------
  26. logging.basicConfig(
  27. level=logging.INFO,
  28. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  29. )
  30. logger = logging.getLogger(__name__)
  31. # ------------------------------------------------------------------------------
  32. # Tokenizer
  33. # ------------------------------------------------------------------------------
  34. try:
  35. # cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
  36. ENCODING = tiktoken.get_encoding("cl100k_base")
  37. except Exception:
  38. ENCODING = None
  39. def count_tokens(text: str, apply_multiplier: bool = False) -> int:
  40. """Counts tokens with tiktoken."""
  41. if not text or not ENCODING:
  42. return 0
  43. token_count = len(ENCODING.encode(text))
  44. if apply_multiplier:
  45. token_count = int(token_count * TOKEN_COUNT_MULTIPLIER)
  46. return token_count
  47. # ------------------------------------------------------------------------------
  48. # Bootstrap
  49. # ------------------------------------------------------------------------------
  50. BASE_DIR = Path(__file__).resolve().parent
  51. load_dotenv(BASE_DIR / ".env")
  52. app = FastAPI(title="v2 OpenAI-compatible Server (Amazon Q Backend)")
  53. # CORS for simple testing in browser
  54. app.add_middleware(
  55. CORSMiddleware,
  56. allow_origins=["*"],
  57. allow_methods=["*"],
  58. allow_headers=["*"],
  59. )
  60. # ------------------------------------------------------------------------------
  61. # Dynamic import of replicate.py to avoid package __init__ needs
  62. # ------------------------------------------------------------------------------
  63. def _load_replicate_module():
  64. mod_path = BASE_DIR / "replicate.py"
  65. spec = importlib.util.spec_from_file_location("v2_replicate", str(mod_path))
  66. module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
  67. assert spec is not None and spec.loader is not None
  68. spec.loader.exec_module(module) # type: ignore[attr-defined]
  69. return module
  70. _replicate = _load_replicate_module()
  71. send_chat_request = _replicate.send_chat_request
  72. # ------------------------------------------------------------------------------
  73. # Dynamic import of Claude modules
  74. # ------------------------------------------------------------------------------
  75. def _load_claude_modules():
  76. # claude_types
  77. spec_types = importlib.util.spec_from_file_location("v2_claude_types", str(BASE_DIR / "claude_types.py"))
  78. mod_types = importlib.util.module_from_spec(spec_types)
  79. spec_types.loader.exec_module(mod_types)
  80. # claude_converter
  81. spec_conv = importlib.util.spec_from_file_location("v2_claude_converter", str(BASE_DIR / "claude_converter.py"))
  82. mod_conv = importlib.util.module_from_spec(spec_conv)
  83. # We need to inject claude_types into converter's namespace if it uses relative imports or expects them
  84. # But since we used relative import in claude_converter.py (.claude_types), we need to be careful.
  85. # Actually, since we are loading dynamically, relative imports might fail if not in sys.modules correctly.
  86. # Let's patch sys.modules temporarily or just rely on file location.
  87. # A simpler way for this single-file script style is to just load them.
  88. # However, claude_converter does `from .claude_types import ...`
  89. # To make that work, we should probably just use standard import if v2 is a package,
  90. # but v2 is just a folder.
  91. # Let's assume the user runs this with v2 in pythonpath or we just fix imports in the files.
  92. # But I wrote `from .claude_types` in the file.
  93. # Let's try to load it. If it fails, we might need to adjust.
  94. # Actually, for simplicity in this `app.py` dynamic loading context,
  95. # it is better if `claude_converter.py` used absolute import or we mock the package.
  96. # BUT, let's try to just load them and see.
  97. # To avoid relative import issues, I will inject the module into sys.modules
  98. import sys
  99. sys.modules["v2.claude_types"] = mod_types
  100. spec_conv.loader.exec_module(mod_conv)
  101. # claude_stream
  102. spec_stream = importlib.util.spec_from_file_location("v2_claude_stream", str(BASE_DIR / "claude_stream.py"))
  103. mod_stream = importlib.util.module_from_spec(spec_stream)
  104. spec_stream.loader.exec_module(mod_stream)
  105. return mod_types, mod_conv, mod_stream
  106. try:
  107. _claude_types, _claude_converter, _claude_stream = _load_claude_modules()
  108. ClaudeRequest = _claude_types.ClaudeRequest
  109. convert_claude_to_amazonq_request = _claude_converter.convert_claude_to_amazonq_request
  110. map_model_name = _claude_converter.map_model_name
  111. ClaudeStreamHandler = _claude_stream.ClaudeStreamHandler
  112. except Exception as e:
  113. print(f"Failed to load Claude modules: {e}")
  114. traceback.print_exc()
  115. # Define dummy classes to avoid NameError on startup if loading fails
  116. class ClaudeRequest(BaseModel):
  117. pass
  118. convert_claude_to_amazonq_request = None
  119. map_model_name = lambda m: m # Pass through if module fails to load
  120. ClaudeStreamHandler = None
  121. # ------------------------------------------------------------------------------
  122. # Global HTTP Client
  123. # ------------------------------------------------------------------------------
  124. GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
  125. def _get_proxies() -> Optional[Dict[str, str]]:
  126. proxy = os.getenv("HTTP_PROXY", "").strip()
  127. if proxy:
  128. return {"http": proxy, "https": proxy}
  129. return None
  130. async def _init_global_client():
  131. global GLOBAL_CLIENT
  132. proxies = _get_proxies()
  133. mounts = None
  134. if proxies:
  135. proxy_url = proxies.get("https") or proxies.get("http")
  136. if proxy_url:
  137. mounts = {
  138. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  139. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  140. }
  141. # Increased limits for high concurrency with streaming
  142. # max_connections: 总连接数上限
  143. # max_keepalive_connections: 保持活跃的连接数
  144. # keepalive_expiry: 连接保持时间
  145. limits = httpx.Limits(
  146. max_keepalive_connections=500,
  147. max_connections=500, # 提高到500以支持更高并发
  148. keepalive_expiry=1.0 # 30秒后释放空闲连接
  149. )
  150. # 为流式响应设置更长的超时
  151. timeout = httpx.Timeout(
  152. connect=2.0, # 连接超时
  153. read=300.0, # 读取超时(流式响应需要更长时间)
  154. write=2.0, # 写入超时
  155. pool=1.0 # 从连接池获取连接的超时时间(关键!)
  156. )
  157. GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
  158. def get_global_client() -> Optional[httpx.AsyncClient]:
  159. """获取当前的全局客户端实例(动态获取,确保总是最新的)"""
  160. return GLOBAL_CLIENT
  161. async def _close_global_client():
  162. global GLOBAL_CLIENT
  163. if GLOBAL_CLIENT:
  164. await GLOBAL_CLIENT.aclose()
  165. GLOBAL_CLIENT = None
  166. async def _recycle_global_client():
  167. """定期回收并重建全局HTTP客户端,避免死连接累积
  168. 策略:先创建新客户端,等待2分钟后再关闭旧客户端,确保平滑过渡
  169. """
  170. while True:
  171. try:
  172. await asyncio.sleep(60) # 每1分钟回收一次
  173. logger.info("[连接回收] 开始回收全局HTTP客户端...")
  174. # 保存旧客户端引用
  175. global GLOBAL_CLIENT
  176. old_client = GLOBAL_CLIENT
  177. # 创建新客户端
  178. proxies = _get_proxies()
  179. mounts = None
  180. if proxies:
  181. proxy_url = proxies.get("https") or proxies.get("http")
  182. if proxy_url:
  183. mounts = {
  184. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  185. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  186. }
  187. limits = httpx.Limits(
  188. max_keepalive_connections=500,
  189. max_connections=500,
  190. keepalive_expiry=1.0
  191. )
  192. timeout = httpx.Timeout(
  193. connect=2.0,
  194. read=300.0,
  195. write=2.0,
  196. pool=1.0
  197. )
  198. # 替换为新客户端
  199. GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
  200. logger.info("[连接回收] 新客户端已创建,等待120秒后关闭旧客户端...")
  201. # 等待2分钟,让正在使用旧客户端的请求完成
  202. await asyncio.sleep(120)
  203. # 关闭旧客户端
  204. if old_client:
  205. try:
  206. await old_client.aclose()
  207. logger.info("[连接回收] 旧客户端已关闭,回收完成")
  208. except Exception as e:
  209. logger.warning(f"[连接回收] 关闭旧客户端时出错: {e}")
  210. except Exception as e:
  211. logger.error(f"[连接回收] 回收失败: {e}")
  212. traceback.print_exc()
  213. # 确保客户端可用
  214. try:
  215. if GLOBAL_CLIENT is None:
  216. await _init_global_client()
  217. except Exception:
  218. pass
  219. # ------------------------------------------------------------------------------
  220. # Database helpers
  221. # ------------------------------------------------------------------------------
  222. # Database backend instance (initialized on startup)
  223. _db = None
  224. async def _ensure_db():
  225. """Initialize database backend."""
  226. global _db
  227. _db = await init_db()
  228. def _row_to_dict(r: Dict[str, Any]) -> Dict[str, Any]:
  229. """Convert database row to dict with JSON parsing."""
  230. return row_to_dict(r)
  231. # _ensure_db() will be called in startup event
  232. # ------------------------------------------------------------------------------
  233. # Background token refresh thread
  234. # ------------------------------------------------------------------------------
  235. async def _refresh_stale_tokens():
  236. while True:
  237. try:
  238. await asyncio.sleep(300) # 5 minutes
  239. if _db is None:
  240. print("[Error] Database not initialized, skipping token refresh cycle.")
  241. continue
  242. now = time.time()
  243. if LAZY_ACCOUNT_POOL_ENABLED:
  244. limit = LAZY_ACCOUNT_POOL_SIZE + LAZY_ACCOUNT_POOL_REFRESH_OFFSET
  245. order_direction = "DESC" if LAZY_ACCOUNT_POOL_ORDER_DESC else "ASC"
  246. query = f"SELECT id, last_refresh_time FROM accounts WHERE enabled=1 ORDER BY {LAZY_ACCOUNT_POOL_ORDER_BY} {order_direction} LIMIT {limit}"
  247. rows = await _db.fetchall(query)
  248. else:
  249. rows = await _db.fetchall("SELECT id, last_refresh_time FROM accounts WHERE enabled=1")
  250. for row in rows:
  251. acc_id, last_refresh = row['id'], row['last_refresh_time']
  252. should_refresh = False
  253. if not last_refresh or last_refresh == "never":
  254. should_refresh = True
  255. else:
  256. try:
  257. last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
  258. if now - last_time > 1500: # 25 minutes
  259. should_refresh = True
  260. except Exception:
  261. # Malformed or unparsable timestamp; force refresh
  262. should_refresh = True
  263. if should_refresh:
  264. try:
  265. await refresh_access_token_in_db(acc_id)
  266. except Exception:
  267. traceback.print_exc()
  268. # Ignore per-account refresh failure; timestamp/status are recorded inside
  269. pass
  270. except Exception:
  271. traceback.print_exc()
  272. pass
  273. # ------------------------------------------------------------------------------
  274. # Env and API Key authorization (keys are independent of AWS accounts)
  275. # ------------------------------------------------------------------------------
  276. def _parse_allowed_keys_env() -> List[str]:
  277. """
  278. OPENAI_KEYS is a comma-separated whitelist of API keys for authorization only.
  279. Example: OPENAI_KEYS="key1,key2,key3"
  280. - When the list is non-empty, incoming Authorization: Bearer {key} must be one of them.
  281. - When empty or unset, authorization is effectively disabled (dev mode).
  282. """
  283. s = os.getenv("OPENAI_KEYS", "") or ""
  284. keys: List[str] = []
  285. for k in [x.strip() for x in s.split(",") if x.strip()]:
  286. keys.append(k)
  287. return keys
  288. ALLOWED_API_KEYS: List[str] = _parse_allowed_keys_env()
  289. MAX_ERROR_COUNT: int = int(os.getenv("MAX_ERROR_COUNT", "100"))
  290. TOKEN_COUNT_MULTIPLIER: float = float(os.getenv("TOKEN_COUNT_MULTIPLIER", "1.0"))
  291. # Lazy Account Pool settings
  292. LAZY_ACCOUNT_POOL_ENABLED: bool = os.getenv("LAZY_ACCOUNT_POOL_ENABLED", "false").lower() in ("true", "1", "yes")
  293. LAZY_ACCOUNT_POOL_SIZE: int = int(os.getenv("LAZY_ACCOUNT_POOL_SIZE", "20"))
  294. LAZY_ACCOUNT_POOL_REFRESH_OFFSET: int = int(os.getenv("LAZY_ACCOUNT_POOL_REFRESH_OFFSET", "10"))
  295. LAZY_ACCOUNT_POOL_ORDER_BY: str = os.getenv("LAZY_ACCOUNT_POOL_ORDER_BY", "created_at")
  296. LAZY_ACCOUNT_POOL_ORDER_DESC: bool = os.getenv("LAZY_ACCOUNT_POOL_ORDER_DESC", "false").lower() in ("true", "1", "yes")
  297. # Validate LAZY_ACCOUNT_POOL_ORDER_BY to prevent SQL injection
  298. if LAZY_ACCOUNT_POOL_ORDER_BY not in ["created_at", "id", "success_count"]:
  299. LAZY_ACCOUNT_POOL_ORDER_BY = "created_at"
  300. def _is_console_enabled() -> bool:
  301. """检查是否启用管理控制台"""
  302. console_env = os.getenv("ENABLE_CONSOLE", "true").strip().lower()
  303. return console_env not in ("false", "0", "no", "disabled")
  304. CONSOLE_ENABLED: bool = _is_console_enabled()
  305. # Admin authentication configuration
  306. ADMIN_PASSWORD: str = os.getenv("ADMIN_PASSWORD", "admin")
  307. def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
  308. if not token_header:
  309. return None
  310. if token_header.startswith("Bearer "):
  311. return token_header.split(" ", 1)[1].strip()
  312. return token_header.strip()
  313. async def _list_enabled_accounts(limit: Optional[int] = None) -> List[Dict[str, Any]]:
  314. if LAZY_ACCOUNT_POOL_ENABLED:
  315. order_direction = "DESC" if LAZY_ACCOUNT_POOL_ORDER_DESC else "ASC"
  316. query = f"SELECT * FROM accounts WHERE enabled=1 ORDER BY {LAZY_ACCOUNT_POOL_ORDER_BY} {order_direction}"
  317. if limit:
  318. query += f" LIMIT {limit}"
  319. rows = await _db.fetchall(query)
  320. else:
  321. query = "SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC"
  322. if limit:
  323. query += f" LIMIT {limit}"
  324. rows = await _db.fetchall(query)
  325. return [_row_to_dict(r) for r in rows]
  326. async def _list_disabled_accounts() -> List[Dict[str, Any]]:
  327. rows = await _db.fetchall("SELECT * FROM accounts WHERE enabled=0 ORDER BY created_at DESC")
  328. return [_row_to_dict(r) for r in rows]
  329. async def verify_account(account: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
  330. """验证账号可用性"""
  331. try:
  332. account = await refresh_access_token_in_db(account['id'])
  333. test_request = {
  334. "conversationState": {
  335. "currentMessage": {"userInputMessage": {"content": "hello"}},
  336. "chatTriggerType": "MANUAL"
  337. }
  338. }
  339. _, _, tracker, event_gen = await send_chat_request(
  340. access_token=account['accessToken'],
  341. messages=[],
  342. stream=True,
  343. raw_payload=test_request
  344. )
  345. if event_gen:
  346. async for _ in event_gen:
  347. break
  348. return True, None
  349. except Exception as e:
  350. if "AccessDenied" in str(e) or "403" in str(e):
  351. return False, "AccessDenied"
  352. return False, None
  353. async def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
  354. """
  355. Authorize request by OPENAI_KEYS (if configured), then select an AWS account.
  356. Selection strategy: random among all enabled accounts. Authorization key does NOT map to any account.
  357. """
  358. # Authorization
  359. if ALLOWED_API_KEYS:
  360. if not bearer_key or bearer_key not in ALLOWED_API_KEYS:
  361. raise HTTPException(status_code=401, detail="Invalid or missing API key")
  362. # Selection: random among enabled accounts
  363. if LAZY_ACCOUNT_POOL_ENABLED:
  364. candidates = await _list_enabled_accounts(limit=LAZY_ACCOUNT_POOL_SIZE)
  365. else:
  366. candidates = await _list_enabled_accounts()
  367. if not candidates:
  368. raise HTTPException(status_code=401, detail="No enabled account available")
  369. return random.choice(candidates)
  370. # ------------------------------------------------------------------------------
  371. # Pydantic Schemas
  372. # ------------------------------------------------------------------------------
  373. class AccountCreate(BaseModel):
  374. label: Optional[str] = None
  375. clientId: str
  376. clientSecret: str
  377. refreshToken: Optional[str] = None
  378. accessToken: Optional[str] = None
  379. other: Optional[Dict[str, Any]] = None
  380. enabled: Optional[bool] = True
  381. class BatchAccountCreate(BaseModel):
  382. accounts: List[AccountCreate]
  383. class AccountUpdate(BaseModel):
  384. label: Optional[str] = None
  385. clientId: Optional[str] = None
  386. clientSecret: Optional[str] = None
  387. refreshToken: Optional[str] = None
  388. accessToken: Optional[str] = None
  389. other: Optional[Dict[str, Any]] = None
  390. enabled: Optional[bool] = None
  391. class ChatMessage(BaseModel):
  392. role: str
  393. content: Any
  394. class ChatCompletionRequest(BaseModel):
  395. model: Optional[str] = None
  396. messages: List[ChatMessage]
  397. stream: Optional[bool] = False
  398. # ------------------------------------------------------------------------------
  399. # Token refresh (OIDC)
  400. # ------------------------------------------------------------------------------
  401. OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
  402. TOKEN_URL = f"{OIDC_BASE}/token"
  403. def _oidc_headers() -> Dict[str, str]:
  404. return {
  405. "content-type": "application/json",
  406. "user-agent": "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0",
  407. "x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI",
  408. "amz-sdk-request": "attempt=1; max=3",
  409. "amz-sdk-invocation-id": str(uuid.uuid4()),
  410. }
  411. async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
  412. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  413. if not row:
  414. raise HTTPException(status_code=404, detail="Account not found")
  415. acc = _row_to_dict(row)
  416. if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
  417. raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
  418. payload = {
  419. "grantType": "refresh_token",
  420. "clientId": acc["clientId"],
  421. "clientSecret": acc["clientSecret"],
  422. "refreshToken": acc["refreshToken"],
  423. }
  424. try:
  425. # Use global client if available, else fallback (though global should be ready)
  426. client = get_global_client()
  427. if not client:
  428. # Fallback for safety
  429. async with httpx.AsyncClient(timeout=60.0) as temp_client:
  430. r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
  431. r.raise_for_status()
  432. data = r.json()
  433. else:
  434. r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
  435. r.raise_for_status()
  436. data = r.json()
  437. new_access = data.get("accessToken")
  438. new_refresh = data.get("refreshToken", acc.get("refreshToken"))
  439. expires_in = data.get("expiresIn", 3600) # Default 1 hour if not provided
  440. expires_at = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(time.time() + expires_in))
  441. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  442. status = "success"
  443. except httpx.HTTPError as e:
  444. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  445. status = "failed"
  446. await _db.execute(
  447. """
  448. UPDATE accounts
  449. SET last_refresh_time=?, last_refresh_status=?, updated_at=?
  450. WHERE id=?
  451. """,
  452. (now, status, now, account_id),
  453. )
  454. # 记录刷新失败次数
  455. await _update_stats(account_id, False)
  456. raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
  457. except Exception as e:
  458. # Ensure last_refresh_time is recorded even on unexpected errors
  459. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  460. status = "failed"
  461. await _db.execute(
  462. """
  463. UPDATE accounts
  464. SET last_refresh_time=?, last_refresh_status=?, updated_at=?
  465. WHERE id=?
  466. """,
  467. (now, status, now, account_id),
  468. )
  469. # 记录刷新失败次数
  470. await _update_stats(account_id, False)
  471. raise
  472. await _db.execute(
  473. """
  474. UPDATE accounts
  475. SET accessToken=?, refreshToken=?, expires_at=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
  476. WHERE id=?
  477. """,
  478. (new_access, new_refresh, expires_at, now, status, now, account_id),
  479. )
  480. row2 = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  481. return _row_to_dict(row2)
  482. async def get_account(account_id: str) -> Dict[str, Any]:
  483. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  484. if not row:
  485. raise HTTPException(status_code=404, detail="Account not found")
  486. return _row_to_dict(row)
  487. async def _update_stats(account_id: str, success: bool) -> None:
  488. if success:
  489. await _db.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
  490. (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  491. else:
  492. row = await _db.fetchone("SELECT error_count FROM accounts WHERE id=?", (account_id,))
  493. if row:
  494. new_count = (row['error_count'] or 0) + 1
  495. if new_count >= MAX_ERROR_COUNT:
  496. await _db.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
  497. (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  498. else:
  499. await _db.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
  500. (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  501. # ------------------------------------------------------------------------------
  502. # Dependencies
  503. # ------------------------------------------------------------------------------
  504. async def require_account(
  505. authorization: Optional[str] = Header(default=None),
  506. x_api_key: Optional[str] = Header(default=None)
  507. ) -> Dict[str, Any]:
  508. key = _extract_bearer(authorization) if authorization else x_api_key
  509. return await resolve_account_for_key(key)
  510. def verify_admin_password(authorization: Optional[str] = Header(None)) -> bool:
  511. """Verify admin password for console access"""
  512. if not authorization or not authorization.startswith("Bearer "):
  513. raise HTTPException(
  514. status_code=401,
  515. detail={"error": "Unauthorized access", "code": "UNAUTHORIZED"}
  516. )
  517. password = authorization[7:] # Remove "Bearer " prefix
  518. if password != ADMIN_PASSWORD:
  519. raise HTTPException(
  520. status_code=401,
  521. detail={"error": "Invalid password", "code": "INVALID_PASSWORD"}
  522. )
  523. return True
  524. # ------------------------------------------------------------------------------
  525. # OpenAI-compatible Chat endpoint
  526. # ------------------------------------------------------------------------------
  527. def _openai_non_streaming_response(
  528. text: str,
  529. model: Optional[str],
  530. prompt_tokens: int = 0,
  531. completion_tokens: int = 0,
  532. ) -> Dict[str, Any]:
  533. created = int(time.time())
  534. return {
  535. "id": f"chatcmpl-{uuid.uuid4()}",
  536. "object": "chat.completion",
  537. "created": created,
  538. "model": model or "unknown",
  539. "choices": [
  540. {
  541. "index": 0,
  542. "message": {
  543. "role": "assistant",
  544. "content": text,
  545. },
  546. "finish_reason": "stop",
  547. }
  548. ],
  549. "usage": {
  550. "prompt_tokens": prompt_tokens,
  551. "completion_tokens": completion_tokens,
  552. "total_tokens": prompt_tokens + completion_tokens,
  553. },
  554. }
  555. def _sse_format(obj: Dict[str, Any]) -> str:
  556. return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
  557. @app.post("/v1/messages")
  558. async def claude_messages(
  559. req: ClaudeRequest,
  560. account: Dict[str, Any] = Depends(require_account),
  561. x_conversation_id: Optional[str] = Header(default=None, alias="x-conversation-id")
  562. ):
  563. """
  564. Claude-compatible messages endpoint.
  565. """
  566. # 1. Convert request
  567. # Always generate a new conversation_id like amq2api does
  568. # Using the same conversation_id can cause Amazon Q to return cached/stale data
  569. try:
  570. aq_request = convert_claude_to_amazonq_request(req, conversation_id=None)
  571. except Exception as e:
  572. traceback.print_exc()
  573. raise HTTPException(status_code=400, detail=f"Request conversion failed: {str(e)}")
  574. # Post-process history to fix message ordering (prevents infinite loops)
  575. from message_processor import process_claude_history_for_amazonq
  576. conversation_state = aq_request.get("conversationState", {})
  577. history = conversation_state.get("history", [])
  578. if history:
  579. processed_history = process_claude_history_for_amazonq(history)
  580. aq_request["conversationState"]["history"] = processed_history
  581. # Remove duplicate tail userInputMessage that matches currentMessage content
  582. # This prevents the model from repeatedly responding to the same user message
  583. conversation_state = aq_request.get("conversationState", {})
  584. current_msg = conversation_state.get("currentMessage", {}).get("userInputMessage", {})
  585. current_content = (current_msg.get("content") or "").strip()
  586. history = conversation_state.get("history", [])
  587. if history and current_content:
  588. last = history[-1]
  589. if "userInputMessage" in last:
  590. last_content = (last["userInputMessage"].get("content") or "").strip()
  591. if last_content and last_content == current_content:
  592. # Remove duplicate tail userInputMessage
  593. history = history[:-1]
  594. aq_request["conversationState"]["history"] = history
  595. import logging
  596. logging.getLogger(__name__).info("Removed duplicate tail userInputMessage to prevent repeated response")
  597. conversation_state = aq_request.get("conversationState", {})
  598. conversation_id = conversation_state.get("conversationId")
  599. response_headers: Dict[str, str] = {}
  600. if conversation_id:
  601. response_headers["x-conversation-id"] = conversation_id
  602. # Always stream from upstream to get full event details
  603. event_iter = None
  604. try:
  605. access = account.get("accessToken")
  606. if not access:
  607. refreshed = await refresh_access_token_in_db(account["id"])
  608. access = refreshed.get("accessToken")
  609. # We call with stream=True to get the event iterator
  610. _, _, tracker, event_iter = await send_chat_request(
  611. access_token=access,
  612. messages=[],
  613. model=map_model_name(req.model),
  614. stream=True,
  615. client=get_global_client(),
  616. raw_payload=aq_request
  617. )
  618. if not event_iter:
  619. raise HTTPException(status_code=502, detail="No event stream returned")
  620. # Handler
  621. # Calculate input tokens
  622. text_to_count = ""
  623. if req.system:
  624. if isinstance(req.system, str):
  625. text_to_count += req.system
  626. elif isinstance(req.system, list):
  627. for item in req.system:
  628. if isinstance(item, dict) and item.get("type") == "text":
  629. text_to_count += item.get("text", "")
  630. for msg in req.messages:
  631. if isinstance(msg.content, str):
  632. text_to_count += msg.content
  633. elif isinstance(msg.content, list):
  634. for item in msg.content:
  635. if isinstance(item, dict) and item.get("type") == "text":
  636. text_to_count += item.get("text", "")
  637. input_tokens = count_tokens(text_to_count, apply_multiplier=True)
  638. handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens, conversation_id=conversation_id)
  639. # Try to get the first event to ensure the connection is valid
  640. # This allows us to return proper HTTP error codes before starting the stream
  641. first_event = None
  642. try:
  643. first_event = await event_iter.__anext__()
  644. except StopAsyncIteration:
  645. raise HTTPException(status_code=502, detail="Empty response from upstream")
  646. except Exception as e:
  647. # If we get an error before the first event, we can still return proper status code
  648. err_msg = str(e)
  649. # Extract upstream status code from "Upstream error {code}: {message}"
  650. if err_msg.startswith("Upstream error "):
  651. match = re.match(r"Upstream error (\d+):", err_msg)
  652. if match:
  653. raise HTTPException(status_code=int(match.group(1)), detail=err_msg)
  654. raise HTTPException(status_code=502, detail=f"Upstream error: {err_msg}")
  655. async def event_generator():
  656. try:
  657. # Process the first event we already fetched
  658. if first_event:
  659. event_type, payload = first_event
  660. async for sse in handler.handle_event(event_type, payload):
  661. yield sse
  662. # Process remaining events
  663. async for event_type, payload in event_iter:
  664. async for sse in handler.handle_event(event_type, payload):
  665. yield sse
  666. async for sse in handler.finish():
  667. yield sse
  668. await _update_stats(account["id"], True)
  669. except GeneratorExit:
  670. # Client disconnected - update stats but don't re-raise
  671. await _update_stats(account["id"], tracker.has_content if tracker else False)
  672. except Exception:
  673. await _update_stats(account["id"], False)
  674. raise
  675. if req.stream:
  676. return StreamingResponse(
  677. event_generator(),
  678. media_type="text/event-stream",
  679. headers=response_headers or None
  680. )
  681. else:
  682. # Accumulate for non-streaming
  683. # This is a bit complex because we need to reconstruct the full response object
  684. # For now, let's just support streaming as it's the main use case for Claude Code
  685. # But to be nice, let's try to support non-streaming by consuming the generator
  686. content_blocks = []
  687. usage = {"input_tokens": 0, "output_tokens": 0}
  688. stop_reason = None
  689. # We need to parse the SSE strings back to objects... inefficient but works
  690. # Or we could refactor handler to yield objects.
  691. # For now, let's just raise error for non-streaming or implement basic text
  692. # Claude Code uses streaming.
  693. # Let's implement a basic accumulator from the SSE stream
  694. final_content = []
  695. async for sse_chunk in event_generator():
  696. data_str = None
  697. # Each chunk from the generator can have multiple lines ('event:', 'data:').
  698. # We need to find the 'data:' line.
  699. for line in sse_chunk.strip().split('\n'):
  700. if line.startswith("data:"):
  701. data_str = line[6:].strip()
  702. break
  703. if not data_str or data_str == "[DONE]":
  704. continue
  705. try:
  706. data = json.loads(data_str)
  707. dtype = data.get("type")
  708. if dtype == "content_block_start":
  709. idx = data.get("index", 0)
  710. while len(final_content) <= idx:
  711. final_content.append(None)
  712. final_content[idx] = data.get("content_block")
  713. elif dtype == "content_block_delta":
  714. idx = data.get("index", 0)
  715. delta = data.get("delta", {})
  716. if final_content[idx]:
  717. if delta.get("type") == "text_delta":
  718. final_content[idx]["text"] += delta.get("text", "")
  719. elif delta.get("type") == "thinking_delta":
  720. final_content[idx].setdefault("thinking", "")
  721. final_content[idx]["thinking"] += delta.get("thinking", "")
  722. elif delta.get("type") == "input_json_delta":
  723. if "partial_json" not in final_content[idx]:
  724. final_content[idx]["partial_json"] = ""
  725. final_content[idx]["partial_json"] += delta.get("partial_json", "")
  726. elif dtype == "content_block_stop":
  727. idx = data.get("index", 0)
  728. if final_content[idx] and final_content[idx].get("type") == "tool_use":
  729. if "partial_json" in final_content[idx]:
  730. try:
  731. final_content[idx]["input"] = json.loads(final_content[idx]["partial_json"])
  732. except json.JSONDecodeError:
  733. # Keep partial if invalid
  734. final_content[idx]["input"] = {"error": "invalid json", "partial": final_content[idx]["partial_json"]}
  735. del final_content[idx]["partial_json"]
  736. elif dtype == "message_delta":
  737. usage = data.get("usage", usage)
  738. stop_reason = data.get("delta", {}).get("stop_reason")
  739. except json.JSONDecodeError:
  740. # Ignore lines that are not valid JSON
  741. pass
  742. except Exception:
  743. # Broad exception to prevent accumulator from crashing on one bad event
  744. traceback.print_exc()
  745. pass
  746. # Final assembly
  747. final_content_cleaned = []
  748. for c in final_content:
  749. if c is not None:
  750. # Remove internal state like 'partial_json' before returning
  751. c.pop("partial_json", None)
  752. final_content_cleaned.append(c)
  753. response_body = {
  754. "id": f"msg_{uuid.uuid4()}",
  755. "type": "message",
  756. "role": "assistant",
  757. "model": req.model,
  758. "content": final_content_cleaned,
  759. "stop_reason": stop_reason,
  760. "stop_sequence": None,
  761. "usage": usage
  762. }
  763. if conversation_id:
  764. response_body["conversation_id"] = conversation_id
  765. response_body["conversationId"] = conversation_id
  766. return JSONResponse(content=response_body, headers=response_headers or None)
  767. except Exception as e:
  768. # Ensure event_iter (if created) is closed to release upstream connection
  769. try:
  770. if event_iter and hasattr(event_iter, "aclose"):
  771. await event_iter.aclose()
  772. except Exception:
  773. pass
  774. await _update_stats(account["id"], False)
  775. # Extract upstream status code from "Upstream error {code}: {message}"
  776. err_msg = str(e)
  777. if err_msg.startswith("Upstream error "):
  778. match = re.match(r"Upstream error (\d+):", err_msg)
  779. if match:
  780. raise HTTPException(status_code=int(match.group(1)), detail=err_msg)
  781. raise
  782. @app.post("/v1/messages/count_tokens")
  783. async def count_tokens_endpoint(req: ClaudeRequest):
  784. """
  785. Count tokens in a message without sending it.
  786. Compatible with Claude API's /v1/messages/count_tokens endpoint.
  787. Uses tiktoken for local token counting.
  788. """
  789. text_to_count = ""
  790. # Count system prompt tokens
  791. if req.system:
  792. if isinstance(req.system, str):
  793. text_to_count += req.system
  794. elif isinstance(req.system, list):
  795. for item in req.system:
  796. if isinstance(item, dict) and item.get("type") == "text":
  797. text_to_count += item.get("text", "")
  798. # Count message tokens
  799. for msg in req.messages:
  800. if isinstance(msg.content, str):
  801. text_to_count += msg.content
  802. elif isinstance(msg.content, list):
  803. for item in msg.content:
  804. if isinstance(item, dict) and item.get("type") == "text":
  805. text_to_count += item.get("text", "")
  806. # Count tool definition tokens if present
  807. if req.tools:
  808. text_to_count += json.dumps([tool.model_dump() if hasattr(tool, 'model_dump') else tool for tool in req.tools], ensure_ascii=False)
  809. input_tokens = count_tokens(text_to_count, apply_multiplier=True)
  810. return {"input_tokens": input_tokens}
  811. @app.post("/v1/chat/completions")
  812. async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
  813. """
  814. OpenAI-compatible chat endpoint.
  815. - stream default False
  816. - messages will be converted into "{role}:\n{content}" and injected into template
  817. - account is chosen randomly among enabled accounts (API key is for authorization only)
  818. """
  819. model = map_model_name(req.model)
  820. do_stream = bool(req.stream)
  821. async def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], Any]:
  822. access = account.get("accessToken")
  823. if not access:
  824. refreshed = await refresh_access_token_in_db(account["id"])
  825. access = refreshed.get("accessToken")
  826. if not access:
  827. raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
  828. # Note: send_chat_request signature changed, but we use keyword args so it should be fine if we don't pass raw_payload
  829. # But wait, the return signature changed too! It now returns 4 values.
  830. # We need to unpack 4 values.
  831. result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=get_global_client())
  832. return result[0], result[1], result[2] # Ignore the 4th value (event_stream) for OpenAI endpoint
  833. if not do_stream:
  834. try:
  835. # Calculate prompt tokens
  836. prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
  837. prompt_tokens = count_tokens(prompt_text)
  838. text, _, tracker = await _send_upstream(stream=False)
  839. await _update_stats(account["id"], bool(text))
  840. completion_tokens = count_tokens(text or "")
  841. return JSONResponse(content=_openai_non_streaming_response(
  842. text or "",
  843. model,
  844. prompt_tokens=prompt_tokens,
  845. completion_tokens=completion_tokens
  846. ))
  847. except Exception as e:
  848. await _update_stats(account["id"], False)
  849. raise
  850. else:
  851. created = int(time.time())
  852. stream_id = f"chatcmpl-{uuid.uuid4()}"
  853. model_used = model or "unknown"
  854. it = None
  855. try:
  856. # Calculate prompt tokens
  857. prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
  858. prompt_tokens = count_tokens(prompt_text)
  859. _, it, tracker = await _send_upstream(stream=True)
  860. assert it is not None
  861. async def event_gen() -> AsyncGenerator[str, None]:
  862. completion_text = ""
  863. try:
  864. # Send role first
  865. yield _sse_format({
  866. "id": stream_id,
  867. "object": "chat.completion.chunk",
  868. "created": created,
  869. "model": model_used,
  870. "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
  871. })
  872. # Stream content
  873. async for piece in it:
  874. if piece:
  875. completion_text += piece
  876. yield _sse_format({
  877. "id": stream_id,
  878. "object": "chat.completion.chunk",
  879. "created": created,
  880. "model": model_used,
  881. "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
  882. })
  883. # Send stop and usage
  884. completion_tokens = count_tokens(completion_text)
  885. yield _sse_format({
  886. "id": stream_id,
  887. "object": "chat.completion.chunk",
  888. "created": created,
  889. "model": model_used,
  890. "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
  891. "usage": {
  892. "prompt_tokens": prompt_tokens,
  893. "completion_tokens": completion_tokens,
  894. "total_tokens": prompt_tokens + completion_tokens,
  895. }
  896. })
  897. yield "data: [DONE]\n\n"
  898. await _update_stats(account["id"], True)
  899. except GeneratorExit:
  900. # Client disconnected - update stats but don't re-raise
  901. await _update_stats(account["id"], tracker.has_content if tracker else False)
  902. except Exception:
  903. await _update_stats(account["id"], tracker.has_content if tracker else False)
  904. raise
  905. return StreamingResponse(event_gen(), media_type="text/event-stream")
  906. except Exception as e:
  907. # Ensure iterator (if created) is closed to release upstream connection
  908. try:
  909. if it and hasattr(it, "aclose"):
  910. await it.aclose()
  911. except Exception:
  912. pass
  913. await _update_stats(account["id"], False)
  914. raise
  915. # ------------------------------------------------------------------------------
  916. # Device Authorization (URL Login, 5-minute timeout)
  917. # ------------------------------------------------------------------------------
  918. # Dynamic import of auth_flow.py (device-code login helpers)
  919. def _load_auth_flow_module():
  920. mod_path = BASE_DIR / "auth_flow.py"
  921. spec = importlib.util.spec_from_file_location("v2_auth_flow", str(mod_path))
  922. module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
  923. assert spec is not None and spec.loader is not None
  924. spec.loader.exec_module(module) # type: ignore[attr-defined]
  925. return module
  926. _auth_flow = _load_auth_flow_module()
  927. register_client_min = _auth_flow.register_client_min
  928. device_authorize = _auth_flow.device_authorize
  929. poll_token_device_code = _auth_flow.poll_token_device_code
  930. # In-memory auth sessions (ephemeral)
  931. AUTH_SESSIONS: Dict[str, Dict[str, Any]] = {}
  932. class AuthStartBody(BaseModel):
  933. label: Optional[str] = None
  934. enabled: Optional[bool] = True
  935. class AdminLoginRequest(BaseModel):
  936. password: str
  937. class AdminLoginResponse(BaseModel):
  938. success: bool
  939. message: str
  940. async def _create_account_from_tokens(
  941. client_id: str,
  942. client_secret: str,
  943. access_token: str,
  944. refresh_token: Optional[str],
  945. label: Optional[str],
  946. enabled: bool,
  947. ) -> Dict[str, Any]:
  948. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  949. acc_id = str(uuid.uuid4())
  950. await _db.execute(
  951. """
  952. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
  953. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  954. """,
  955. (
  956. acc_id,
  957. label,
  958. client_id,
  959. client_secret,
  960. refresh_token,
  961. access_token,
  962. None,
  963. now,
  964. "success",
  965. now,
  966. now,
  967. 1 if enabled else 0,
  968. None, # expires_at - will be set on first refresh
  969. ),
  970. )
  971. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
  972. return _row_to_dict(row)
  973. # 管理控制台相关端点 - 仅在启用时注册
  974. if CONSOLE_ENABLED:
  975. # ------------------------------------------------------------------------------
  976. # Admin Authentication Endpoints
  977. # ------------------------------------------------------------------------------
  978. @app.post("/api/login", response_model=AdminLoginResponse)
  979. async def admin_login(request: AdminLoginRequest) -> AdminLoginResponse:
  980. """Admin login endpoint - password only"""
  981. if request.password == ADMIN_PASSWORD:
  982. return AdminLoginResponse(
  983. success=True,
  984. message="Login successful"
  985. )
  986. else:
  987. return AdminLoginResponse(
  988. success=False,
  989. message="Invalid password"
  990. )
  991. @app.get("/login", response_class=FileResponse)
  992. def login_page():
  993. """Serve the login page"""
  994. path = BASE_DIR / "frontend" / "login.html"
  995. if not path.exists():
  996. raise HTTPException(status_code=404, detail="frontend/login.html not found")
  997. return FileResponse(str(path))
  998. # ------------------------------------------------------------------------------
  999. # Device Authorization Endpoints
  1000. # ------------------------------------------------------------------------------
  1001. @app.post("/v2/auth/start")
  1002. async def auth_start(body: AuthStartBody, _: bool = Depends(verify_admin_password)):
  1003. """
  1004. Start device authorization and return verification URL for user login.
  1005. Session lifetime capped at 5 minutes on claim.
  1006. """
  1007. try:
  1008. cid, csec = await register_client_min()
  1009. dev = await device_authorize(cid, csec)
  1010. except httpx.HTTPError as e:
  1011. raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
  1012. auth_id = str(uuid.uuid4())
  1013. sess = {
  1014. "clientId": cid,
  1015. "clientSecret": csec,
  1016. "deviceCode": dev.get("deviceCode"),
  1017. "interval": int(dev.get("interval", 1)),
  1018. "expiresIn": int(dev.get("expiresIn", 600)),
  1019. "verificationUriComplete": dev.get("verificationUriComplete"),
  1020. "userCode": dev.get("userCode"),
  1021. "startTime": int(time.time()),
  1022. "label": body.label,
  1023. "enabled": True if body.enabled is None else bool(body.enabled),
  1024. "status": "pending",
  1025. "error": None,
  1026. "accountId": None,
  1027. }
  1028. AUTH_SESSIONS[auth_id] = sess
  1029. return {
  1030. "authId": auth_id,
  1031. "verificationUriComplete": sess["verificationUriComplete"],
  1032. "userCode": sess["userCode"],
  1033. "expiresIn": sess["expiresIn"],
  1034. "interval": sess["interval"],
  1035. }
  1036. @app.get("/v2/auth/status/{auth_id}")
  1037. async def auth_status(auth_id: str, _: bool = Depends(verify_admin_password)):
  1038. sess = AUTH_SESSIONS.get(auth_id)
  1039. if not sess:
  1040. raise HTTPException(status_code=404, detail="Auth session not found")
  1041. now_ts = int(time.time())
  1042. deadline = sess["startTime"] + min(int(sess.get("expiresIn", 600)), 300)
  1043. remaining = max(0, deadline - now_ts)
  1044. return {
  1045. "status": sess.get("status"),
  1046. "remaining": remaining,
  1047. "error": sess.get("error"),
  1048. "accountId": sess.get("accountId"),
  1049. }
  1050. @app.post("/v2/auth/claim/{auth_id}")
  1051. async def auth_claim(auth_id: str, _: bool = Depends(verify_admin_password)):
  1052. """
  1053. Block up to 5 minutes to exchange the device code for tokens after user completed login.
  1054. On success, creates an enabled account and returns it.
  1055. """
  1056. sess = AUTH_SESSIONS.get(auth_id)
  1057. if not sess:
  1058. raise HTTPException(status_code=404, detail="Auth session not found")
  1059. if sess.get("status") in ("completed", "timeout", "error"):
  1060. return {
  1061. "status": sess["status"],
  1062. "accountId": sess.get("accountId"),
  1063. "error": sess.get("error"),
  1064. }
  1065. try:
  1066. toks = await poll_token_device_code(
  1067. sess["clientId"],
  1068. sess["clientSecret"],
  1069. sess["deviceCode"],
  1070. sess["interval"],
  1071. sess["expiresIn"],
  1072. max_timeout_sec=300, # 5 minutes
  1073. )
  1074. access_token = toks.get("accessToken")
  1075. refresh_token = toks.get("refreshToken")
  1076. if not access_token:
  1077. raise HTTPException(status_code=502, detail="No accessToken returned from OIDC")
  1078. acc = await _create_account_from_tokens(
  1079. sess["clientId"],
  1080. sess["clientSecret"],
  1081. access_token,
  1082. refresh_token,
  1083. sess.get("label"),
  1084. sess.get("enabled", True),
  1085. )
  1086. sess["status"] = "completed"
  1087. sess["accountId"] = acc["id"]
  1088. return {
  1089. "status": "completed",
  1090. "account": acc,
  1091. }
  1092. except TimeoutError:
  1093. sess["status"] = "timeout"
  1094. raise HTTPException(status_code=408, detail="Authorization timeout (5 minutes)")
  1095. except httpx.HTTPError as e:
  1096. sess["status"] = "error"
  1097. sess["error"] = str(e)
  1098. raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
  1099. # ------------------------------------------------------------------------------
  1100. # Accounts Management API
  1101. # ------------------------------------------------------------------------------
  1102. @app.post("/v2/accounts")
  1103. async def create_account(body: AccountCreate, _: bool = Depends(verify_admin_password)):
  1104. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1105. acc_id = str(uuid.uuid4())
  1106. other_str = json.dumps(body.other, ensure_ascii=False) if body.other is not None else None
  1107. enabled_val = 1 if (body.enabled is None or body.enabled) else 0
  1108. await _db.execute(
  1109. """
  1110. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
  1111. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  1112. """,
  1113. (
  1114. acc_id,
  1115. body.label,
  1116. body.clientId,
  1117. body.clientSecret,
  1118. body.refreshToken,
  1119. body.accessToken,
  1120. other_str,
  1121. None,
  1122. "never",
  1123. now,
  1124. now,
  1125. enabled_val,
  1126. None, # expires_at - will be set on first refresh
  1127. ),
  1128. )
  1129. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
  1130. return _row_to_dict(row)
  1131. async def _verify_and_enable_accounts(account_ids: List[str]):
  1132. """后台异步验证并启用账号"""
  1133. for acc_id in account_ids:
  1134. try:
  1135. # 必须先获取完整的账号信息
  1136. account = await get_account(acc_id)
  1137. verify_success, fail_reason = await verify_account(account)
  1138. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1139. if verify_success:
  1140. await _db.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, acc_id))
  1141. elif fail_reason:
  1142. other_dict = account.get("other", {}) or {}
  1143. other_dict['failedReason'] = fail_reason
  1144. await _db.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, acc_id))
  1145. except Exception as e:
  1146. print(f"Error verifying account {acc_id}: {e}")
  1147. traceback.print_exc()
  1148. @app.post("/v2/accounts/feed")
  1149. async def create_accounts_feed(request: BatchAccountCreate, _: bool = Depends(verify_admin_password)):
  1150. """
  1151. 统一的投喂接口,接收账号列表,立即存入并后台异步验证。
  1152. """
  1153. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1154. new_account_ids = []
  1155. for i, account_data in enumerate(request.accounts):
  1156. acc_id = str(uuid.uuid4())
  1157. other_dict = account_data.other or {}
  1158. other_dict['source'] = 'feed'
  1159. other_str = json.dumps(other_dict, ensure_ascii=False)
  1160. await _db.execute(
  1161. """
  1162. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
  1163. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  1164. """,
  1165. (
  1166. acc_id,
  1167. account_data.label or f"批量账号 {i+1}",
  1168. account_data.clientId,
  1169. account_data.clientSecret,
  1170. account_data.refreshToken,
  1171. account_data.accessToken,
  1172. other_str,
  1173. None,
  1174. "never",
  1175. now,
  1176. now,
  1177. 0, # 初始为禁用状态
  1178. None, # expires_at - will be set on first refresh
  1179. ),
  1180. )
  1181. new_account_ids.append(acc_id)
  1182. # 启动后台任务进行验证,不阻塞当前请求
  1183. if new_account_ids:
  1184. asyncio.create_task(_verify_and_enable_accounts(new_account_ids))
  1185. return {
  1186. "status": "processing",
  1187. "message": f"{len(new_account_ids)} accounts received and are being verified in the background.",
  1188. "account_ids": new_account_ids
  1189. }
  1190. @app.get("/v2/accounts")
  1191. async def list_accounts(_: bool = Depends(verify_admin_password), enabled: Optional[bool] = None, sort_by: str = "created_at", sort_order: str = "desc"):
  1192. query = "SELECT * FROM accounts"
  1193. params = []
  1194. if enabled is not None:
  1195. query += " WHERE enabled=?"
  1196. params.append(1 if enabled else 0)
  1197. sort_field = "created_at" if sort_by not in ["created_at", "success_count"] else sort_by
  1198. order = "DESC" if sort_order.lower() == "desc" else "ASC"
  1199. query += f" ORDER BY {sort_field} {order}"
  1200. rows = await _db.fetchall(query, tuple(params) if params else ())
  1201. accounts = [_row_to_dict(r) for r in rows]
  1202. return {"accounts": accounts, "count": len(accounts)}
  1203. @app.get("/v2/accounts/{account_id}")
  1204. async def get_account_detail(account_id: str, _: bool = Depends(verify_admin_password)):
  1205. return await get_account(account_id)
  1206. @app.delete("/v2/accounts/{account_id}")
  1207. async def delete_account(account_id: str, _: bool = Depends(verify_admin_password)):
  1208. rowcount = await _db.execute("DELETE FROM accounts WHERE id=?", (account_id,))
  1209. if rowcount == 0:
  1210. raise HTTPException(status_code=404, detail="Account not found")
  1211. return {"deleted": account_id}
  1212. @app.patch("/v2/accounts/{account_id}")
  1213. async def update_account(account_id: str, body: AccountUpdate, _: bool = Depends(verify_admin_password)):
  1214. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1215. fields = []
  1216. values: List[Any] = []
  1217. if body.label is not None:
  1218. fields.append("label=?"); values.append(body.label)
  1219. if body.clientId is not None:
  1220. fields.append("clientId=?"); values.append(body.clientId)
  1221. if body.clientSecret is not None:
  1222. fields.append("clientSecret=?"); values.append(body.clientSecret)
  1223. if body.refreshToken is not None:
  1224. fields.append("refreshToken=?"); values.append(body.refreshToken)
  1225. if body.accessToken is not None:
  1226. fields.append("accessToken=?"); values.append(body.accessToken)
  1227. if body.other is not None:
  1228. fields.append("other=?"); values.append(json.dumps(body.other, ensure_ascii=False))
  1229. if body.enabled is not None:
  1230. fields.append("enabled=?"); values.append(1 if body.enabled else 0)
  1231. if not fields:
  1232. return await get_account(account_id)
  1233. fields.append("updated_at=?"); values.append(now)
  1234. values.append(account_id)
  1235. rowcount = await _db.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", tuple(values))
  1236. if rowcount == 0:
  1237. raise HTTPException(status_code=404, detail="Account not found")
  1238. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  1239. return _row_to_dict(row)
  1240. @app.post("/v2/accounts/{account_id}/refresh")
  1241. async def manual_refresh(account_id: str, _: bool = Depends(verify_admin_password)):
  1242. return await refresh_access_token_in_db(account_id)
  1243. @app.post("/v2/chat/test")
  1244. async def admin_chat_test(req: ChatCompletionRequest, account_id: Optional[str] = None, _: bool = Depends(verify_admin_password)):
  1245. """Admin chat test - uses admin auth, selects account by id or random."""
  1246. if account_id:
  1247. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  1248. if not row:
  1249. raise HTTPException(status_code=404, detail="Account not found")
  1250. account = _row_to_dict(row)
  1251. # Check if token is expired or missing
  1252. expires_at = account.get("expires_at")
  1253. now_str = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1254. if not expires_at or expires_at <= now_str:
  1255. account = await refresh_access_token_in_db(account_id)
  1256. else:
  1257. candidates = await _list_enabled_accounts()
  1258. if not candidates:
  1259. raise HTTPException(status_code=503, detail="No enabled account available")
  1260. account = random.choice(candidates)
  1261. return await chat_completions(req, account)
  1262. # ------------------------------------------------------------------------------
  1263. # Simple Frontend (minimal dev test page; full UI in v2/frontend/index.html)
  1264. # ------------------------------------------------------------------------------
  1265. # Frontend inline HTML removed; serving ./frontend/index.html instead (see route below)
  1266. # Note: This route is NOT protected - the HTML file is served freely,
  1267. # but the frontend JavaScript checks authentication and redirects to /login if needed.
  1268. # All API endpoints remain protected.
  1269. @app.get("/", response_class=FileResponse)
  1270. def index():
  1271. path = BASE_DIR / "frontend" / "index.html"
  1272. if not path.exists():
  1273. raise HTTPException(status_code=404, detail="frontend/index.html not found")
  1274. return FileResponse(str(path))
  1275. # ------------------------------------------------------------------------------
  1276. # Health
  1277. # ------------------------------------------------------------------------------
  1278. @app.get("/healthz")
  1279. async def health():
  1280. return {"status": "ok"}
  1281. # ------------------------------------------------------------------------------
  1282. # Startup / Shutdown Events
  1283. # ------------------------------------------------------------------------------
  1284. # async def _verify_disabled_accounts_loop():
  1285. # """后台验证禁用账号任务"""
  1286. # while True:
  1287. # try:
  1288. # await asyncio.sleep(1800)
  1289. # async with _conn() as conn:
  1290. # accounts = await _list_disabled_accounts(conn)
  1291. # if accounts:
  1292. # for account in accounts:
  1293. # other = account.get('other')
  1294. # if other:
  1295. # try:
  1296. # other_dict = json.loads(other) if isinstance(other, str) else other
  1297. # if other_dict.get('failedReason') == 'AccessDenied':
  1298. # continue
  1299. # except:
  1300. # pass
  1301. # try:
  1302. # verify_success, fail_reason = await verify_account(account)
  1303. # now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1304. # if verify_success:
  1305. # await conn.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, account['id']))
  1306. # elif fail_reason:
  1307. # other_dict = {}
  1308. # if account.get('other'):
  1309. # try:
  1310. # other_dict = json.loads(account['other']) if isinstance(account['other'], str) else account['other']
  1311. # except:
  1312. # pass
  1313. # other_dict['failedReason'] = fail_reason
  1314. # await conn.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, account['id']))
  1315. # await conn.commit()
  1316. # except Exception:
  1317. # pass
  1318. # except Exception:
  1319. # pass
  1320. @app.on_event("startup")
  1321. async def startup_event():
  1322. """Initialize database and start background tasks on startup."""
  1323. await _init_global_client()
  1324. await _ensure_db()
  1325. asyncio.create_task(_refresh_stale_tokens())
  1326. asyncio.create_task(_recycle_global_client()) # 启动连接回收任务
  1327. # asyncio.create_task(_verify_disabled_accounts_loop())
  1328. @app.on_event("shutdown")
  1329. async def shutdown_event():
  1330. await _close_global_client()
  1331. await close_db()