app.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436
  1. import os
  2. import json
  3. import traceback
  4. import uuid
  5. import time
  6. import asyncio
  7. import importlib.util
  8. import random
  9. import secrets
  10. import re
  11. from datetime import datetime, timedelta
  12. from pathlib import Path
  13. from typing import Dict, Optional, List, Any, AsyncGenerator, Tuple
  14. from fastapi import FastAPI, Depends, HTTPException, Header
  15. from fastapi.middleware.cors import CORSMiddleware
  16. from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, FileResponse
  17. from pydantic import BaseModel
  18. from dotenv import load_dotenv
  19. import httpx
  20. import tiktoken
  21. from db import init_db, close_db, row_to_dict
  22. # ------------------------------------------------------------------------------
  23. # Tokenizer
  24. # ------------------------------------------------------------------------------
  25. try:
  26. # cl100k_base is used by gpt-4, gpt-3.5-turbo, text-embedding-ada-002
  27. ENCODING = tiktoken.get_encoding("cl100k_base")
  28. except Exception:
  29. ENCODING = None
  30. def count_tokens(text: str, apply_multiplier: bool = False) -> int:
  31. """Counts tokens with tiktoken."""
  32. if not text or not ENCODING:
  33. return 0
  34. token_count = len(ENCODING.encode(text))
  35. if apply_multiplier:
  36. token_count = int(token_count * TOKEN_COUNT_MULTIPLIER)
  37. return token_count
  38. # ------------------------------------------------------------------------------
  39. # Bootstrap
  40. # ------------------------------------------------------------------------------
  41. BASE_DIR = Path(__file__).resolve().parent
  42. load_dotenv(BASE_DIR / ".env")
  43. app = FastAPI(title="v2 OpenAI-compatible Server (Amazon Q Backend)")
  44. # CORS for simple testing in browser
  45. app.add_middleware(
  46. CORSMiddleware,
  47. allow_origins=["*"],
  48. allow_methods=["*"],
  49. allow_headers=["*"],
  50. )
  51. # ------------------------------------------------------------------------------
  52. # Dynamic import of replicate.py to avoid package __init__ needs
  53. # ------------------------------------------------------------------------------
  54. def _load_replicate_module():
  55. mod_path = BASE_DIR / "replicate.py"
  56. spec = importlib.util.spec_from_file_location("v2_replicate", str(mod_path))
  57. module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
  58. assert spec is not None and spec.loader is not None
  59. spec.loader.exec_module(module) # type: ignore[attr-defined]
  60. return module
  61. _replicate = _load_replicate_module()
  62. send_chat_request = _replicate.send_chat_request
  63. # ------------------------------------------------------------------------------
  64. # Dynamic import of Claude modules
  65. # ------------------------------------------------------------------------------
  66. def _load_claude_modules():
  67. # claude_types
  68. spec_types = importlib.util.spec_from_file_location("v2_claude_types", str(BASE_DIR / "claude_types.py"))
  69. mod_types = importlib.util.module_from_spec(spec_types)
  70. spec_types.loader.exec_module(mod_types)
  71. # claude_converter
  72. spec_conv = importlib.util.spec_from_file_location("v2_claude_converter", str(BASE_DIR / "claude_converter.py"))
  73. mod_conv = importlib.util.module_from_spec(spec_conv)
  74. # We need to inject claude_types into converter's namespace if it uses relative imports or expects them
  75. # But since we used relative import in claude_converter.py (.claude_types), we need to be careful.
  76. # Actually, since we are loading dynamically, relative imports might fail if not in sys.modules correctly.
  77. # Let's patch sys.modules temporarily or just rely on file location.
  78. # A simpler way for this single-file script style is to just load them.
  79. # However, claude_converter does `from .claude_types import ...`
  80. # To make that work, we should probably just use standard import if v2 is a package,
  81. # but v2 is just a folder.
  82. # Let's assume the user runs this with v2 in pythonpath or we just fix imports in the files.
  83. # But I wrote `from .claude_types` in the file.
  84. # Let's try to load it. If it fails, we might need to adjust.
  85. # Actually, for simplicity in this `app.py` dynamic loading context,
  86. # it is better if `claude_converter.py` used absolute import or we mock the package.
  87. # BUT, let's try to just load them and see.
  88. # To avoid relative import issues, I will inject the module into sys.modules
  89. import sys
  90. sys.modules["v2.claude_types"] = mod_types
  91. spec_conv.loader.exec_module(mod_conv)
  92. # claude_stream
  93. spec_stream = importlib.util.spec_from_file_location("v2_claude_stream", str(BASE_DIR / "claude_stream.py"))
  94. mod_stream = importlib.util.module_from_spec(spec_stream)
  95. spec_stream.loader.exec_module(mod_stream)
  96. return mod_types, mod_conv, mod_stream
  97. try:
  98. _claude_types, _claude_converter, _claude_stream = _load_claude_modules()
  99. ClaudeRequest = _claude_types.ClaudeRequest
  100. convert_claude_to_amazonq_request = _claude_converter.convert_claude_to_amazonq_request
  101. map_model_name = _claude_converter.map_model_name
  102. ClaudeStreamHandler = _claude_stream.ClaudeStreamHandler
  103. except Exception as e:
  104. print(f"Failed to load Claude modules: {e}")
  105. traceback.print_exc()
  106. # Define dummy classes to avoid NameError on startup if loading fails
  107. class ClaudeRequest(BaseModel):
  108. pass
  109. convert_claude_to_amazonq_request = None
  110. map_model_name = lambda m: m # Pass through if module fails to load
  111. ClaudeStreamHandler = None
  112. # ------------------------------------------------------------------------------
  113. # Global HTTP Client
  114. # ------------------------------------------------------------------------------
  115. GLOBAL_CLIENT: Optional[httpx.AsyncClient] = None
  116. def _get_proxies() -> Optional[Dict[str, str]]:
  117. proxy = os.getenv("HTTP_PROXY", "").strip()
  118. if proxy:
  119. return {"http": proxy, "https": proxy}
  120. return None
  121. async def _init_global_client():
  122. global GLOBAL_CLIENT
  123. proxies = _get_proxies()
  124. mounts = None
  125. if proxies:
  126. proxy_url = proxies.get("https") or proxies.get("http")
  127. if proxy_url:
  128. mounts = {
  129. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  130. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  131. }
  132. # Increased limits for high concurrency with streaming
  133. # max_connections: 总连接数上限
  134. # max_keepalive_connections: 保持活跃的连接数
  135. # keepalive_expiry: 连接保持时间
  136. limits = httpx.Limits(
  137. max_keepalive_connections=500,
  138. max_connections=500, # 提高到500以支持更高并发
  139. keepalive_expiry=1.0 # 30秒后释放空闲连接
  140. )
  141. # 为流式响应设置更长的超时
  142. timeout = httpx.Timeout(
  143. connect=2.0, # 连接超时
  144. read=300.0, # 读取超时(流式响应需要更长时间)
  145. write=2.0, # 写入超时
  146. pool=1.0 # 从连接池获取连接的超时时间(关键!)
  147. )
  148. GLOBAL_CLIENT = httpx.AsyncClient(mounts=mounts, timeout=timeout, limits=limits)
  149. async def _close_global_client():
  150. global GLOBAL_CLIENT
  151. if GLOBAL_CLIENT:
  152. await GLOBAL_CLIENT.aclose()
  153. GLOBAL_CLIENT = None
  154. # ------------------------------------------------------------------------------
  155. # Database helpers
  156. # ------------------------------------------------------------------------------
  157. # Database backend instance (initialized on startup)
  158. _db = None
  159. async def _ensure_db():
  160. """Initialize database backend."""
  161. global _db
  162. _db = await init_db()
  163. def _row_to_dict(r: Dict[str, Any]) -> Dict[str, Any]:
  164. """Convert database row to dict with JSON parsing."""
  165. return row_to_dict(r)
  166. # _ensure_db() will be called in startup event
  167. # ------------------------------------------------------------------------------
  168. # Background token refresh thread
  169. # ------------------------------------------------------------------------------
  170. async def _refresh_stale_tokens():
  171. while True:
  172. try:
  173. await asyncio.sleep(300) # 5 minutes
  174. if _db is None:
  175. print("[Error] Database not initialized, skipping token refresh cycle.")
  176. continue
  177. now = time.time()
  178. if LAZY_ACCOUNT_POOL_ENABLED:
  179. limit = LAZY_ACCOUNT_POOL_SIZE + LAZY_ACCOUNT_POOL_REFRESH_OFFSET
  180. order_direction = "DESC" if LAZY_ACCOUNT_POOL_ORDER_DESC else "ASC"
  181. query = f"SELECT id, last_refresh_time FROM accounts WHERE enabled=1 ORDER BY {LAZY_ACCOUNT_POOL_ORDER_BY} {order_direction} LIMIT {limit}"
  182. rows = await _db.fetchall(query)
  183. else:
  184. rows = await _db.fetchall("SELECT id, last_refresh_time FROM accounts WHERE enabled=1")
  185. for row in rows:
  186. acc_id, last_refresh = row['id'], row['last_refresh_time']
  187. should_refresh = False
  188. if not last_refresh or last_refresh == "never":
  189. should_refresh = True
  190. else:
  191. try:
  192. last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
  193. if now - last_time > 1500: # 25 minutes
  194. should_refresh = True
  195. except Exception:
  196. # Malformed or unparsable timestamp; force refresh
  197. should_refresh = True
  198. if should_refresh:
  199. try:
  200. await refresh_access_token_in_db(acc_id)
  201. except Exception:
  202. traceback.print_exc()
  203. # Ignore per-account refresh failure; timestamp/status are recorded inside
  204. pass
  205. except Exception:
  206. traceback.print_exc()
  207. pass
  208. # ------------------------------------------------------------------------------
  209. # Env and API Key authorization (keys are independent of AWS accounts)
  210. # ------------------------------------------------------------------------------
  211. def _parse_allowed_keys_env() -> List[str]:
  212. """
  213. OPENAI_KEYS is a comma-separated whitelist of API keys for authorization only.
  214. Example: OPENAI_KEYS="key1,key2,key3"
  215. - When the list is non-empty, incoming Authorization: Bearer {key} must be one of them.
  216. - When empty or unset, authorization is effectively disabled (dev mode).
  217. """
  218. s = os.getenv("OPENAI_KEYS", "") or ""
  219. keys: List[str] = []
  220. for k in [x.strip() for x in s.split(",") if x.strip()]:
  221. keys.append(k)
  222. return keys
  223. ALLOWED_API_KEYS: List[str] = _parse_allowed_keys_env()
  224. MAX_ERROR_COUNT: int = int(os.getenv("MAX_ERROR_COUNT", "100"))
  225. TOKEN_COUNT_MULTIPLIER: float = float(os.getenv("TOKEN_COUNT_MULTIPLIER", "1.0"))
  226. # Lazy Account Pool settings
  227. LAZY_ACCOUNT_POOL_ENABLED: bool = os.getenv("LAZY_ACCOUNT_POOL_ENABLED", "false").lower() in ("true", "1", "yes")
  228. LAZY_ACCOUNT_POOL_SIZE: int = int(os.getenv("LAZY_ACCOUNT_POOL_SIZE", "20"))
  229. LAZY_ACCOUNT_POOL_REFRESH_OFFSET: int = int(os.getenv("LAZY_ACCOUNT_POOL_REFRESH_OFFSET", "10"))
  230. LAZY_ACCOUNT_POOL_ORDER_BY: str = os.getenv("LAZY_ACCOUNT_POOL_ORDER_BY", "created_at")
  231. LAZY_ACCOUNT_POOL_ORDER_DESC: bool = os.getenv("LAZY_ACCOUNT_POOL_ORDER_DESC", "false").lower() in ("true", "1", "yes")
  232. # Validate LAZY_ACCOUNT_POOL_ORDER_BY to prevent SQL injection
  233. if LAZY_ACCOUNT_POOL_ORDER_BY not in ["created_at", "id", "success_count"]:
  234. LAZY_ACCOUNT_POOL_ORDER_BY = "created_at"
  235. def _is_console_enabled() -> bool:
  236. """检查是否启用管理控制台"""
  237. console_env = os.getenv("ENABLE_CONSOLE", "true").strip().lower()
  238. return console_env not in ("false", "0", "no", "disabled")
  239. CONSOLE_ENABLED: bool = _is_console_enabled()
  240. # Admin authentication configuration
  241. ADMIN_PASSWORD: str = os.getenv("ADMIN_PASSWORD", "admin")
  242. def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
  243. if not token_header:
  244. return None
  245. if token_header.startswith("Bearer "):
  246. return token_header.split(" ", 1)[1].strip()
  247. return token_header.strip()
  248. async def _list_enabled_accounts(limit: Optional[int] = None) -> List[Dict[str, Any]]:
  249. if LAZY_ACCOUNT_POOL_ENABLED:
  250. order_direction = "DESC" if LAZY_ACCOUNT_POOL_ORDER_DESC else "ASC"
  251. query = f"SELECT * FROM accounts WHERE enabled=1 ORDER BY {LAZY_ACCOUNT_POOL_ORDER_BY} {order_direction}"
  252. if limit:
  253. query += f" LIMIT {limit}"
  254. rows = await _db.fetchall(query)
  255. else:
  256. query = "SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC"
  257. if limit:
  258. query += f" LIMIT {limit}"
  259. rows = await _db.fetchall(query)
  260. return [_row_to_dict(r) for r in rows]
  261. async def _list_disabled_accounts() -> List[Dict[str, Any]]:
  262. rows = await _db.fetchall("SELECT * FROM accounts WHERE enabled=0 ORDER BY created_at DESC")
  263. return [_row_to_dict(r) for r in rows]
  264. async def verify_account(account: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
  265. """验证账号可用性"""
  266. try:
  267. account = await refresh_access_token_in_db(account['id'])
  268. test_request = {
  269. "conversationState": {
  270. "currentMessage": {"userInputMessage": {"content": "hello"}},
  271. "chatTriggerType": "MANUAL"
  272. }
  273. }
  274. _, _, tracker, event_gen = await send_chat_request(
  275. access_token=account['accessToken'],
  276. messages=[],
  277. stream=True,
  278. raw_payload=test_request
  279. )
  280. if event_gen:
  281. async for _ in event_gen:
  282. break
  283. return True, None
  284. except Exception as e:
  285. if "AccessDenied" in str(e) or "403" in str(e):
  286. return False, "AccessDenied"
  287. return False, None
  288. async def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
  289. """
  290. Authorize request by OPENAI_KEYS (if configured), then select an AWS account.
  291. Selection strategy: random among all enabled accounts. Authorization key does NOT map to any account.
  292. """
  293. # Authorization
  294. if ALLOWED_API_KEYS:
  295. if not bearer_key or bearer_key not in ALLOWED_API_KEYS:
  296. raise HTTPException(status_code=401, detail="Invalid or missing API key")
  297. # Selection: random among enabled accounts
  298. if LAZY_ACCOUNT_POOL_ENABLED:
  299. candidates = await _list_enabled_accounts(limit=LAZY_ACCOUNT_POOL_SIZE)
  300. else:
  301. candidates = await _list_enabled_accounts()
  302. if not candidates:
  303. raise HTTPException(status_code=401, detail="No enabled account available")
  304. return random.choice(candidates)
  305. # ------------------------------------------------------------------------------
  306. # Pydantic Schemas
  307. # ------------------------------------------------------------------------------
  308. class AccountCreate(BaseModel):
  309. label: Optional[str] = None
  310. clientId: str
  311. clientSecret: str
  312. refreshToken: Optional[str] = None
  313. accessToken: Optional[str] = None
  314. other: Optional[Dict[str, Any]] = None
  315. enabled: Optional[bool] = True
  316. class BatchAccountCreate(BaseModel):
  317. accounts: List[AccountCreate]
  318. class AccountUpdate(BaseModel):
  319. label: Optional[str] = None
  320. clientId: Optional[str] = None
  321. clientSecret: Optional[str] = None
  322. refreshToken: Optional[str] = None
  323. accessToken: Optional[str] = None
  324. other: Optional[Dict[str, Any]] = None
  325. enabled: Optional[bool] = None
  326. class ChatMessage(BaseModel):
  327. role: str
  328. content: Any
  329. class ChatCompletionRequest(BaseModel):
  330. model: Optional[str] = None
  331. messages: List[ChatMessage]
  332. stream: Optional[bool] = False
  333. # ------------------------------------------------------------------------------
  334. # Token refresh (OIDC)
  335. # ------------------------------------------------------------------------------
  336. OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
  337. TOKEN_URL = f"{OIDC_BASE}/token"
  338. def _oidc_headers() -> Dict[str, str]:
  339. return {
  340. "content-type": "application/json",
  341. "user-agent": "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0",
  342. "x-amz-user-agent": "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI",
  343. "amz-sdk-request": "attempt=1; max=3",
  344. "amz-sdk-invocation-id": str(uuid.uuid4()),
  345. }
  346. async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
  347. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  348. if not row:
  349. raise HTTPException(status_code=404, detail="Account not found")
  350. acc = _row_to_dict(row)
  351. if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
  352. raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
  353. payload = {
  354. "grantType": "refresh_token",
  355. "clientId": acc["clientId"],
  356. "clientSecret": acc["clientSecret"],
  357. "refreshToken": acc["refreshToken"],
  358. }
  359. try:
  360. # Use global client if available, else fallback (though global should be ready)
  361. client = GLOBAL_CLIENT
  362. if not client:
  363. # Fallback for safety
  364. async with httpx.AsyncClient(timeout=60.0) as temp_client:
  365. r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
  366. r.raise_for_status()
  367. data = r.json()
  368. else:
  369. r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
  370. r.raise_for_status()
  371. data = r.json()
  372. new_access = data.get("accessToken")
  373. new_refresh = data.get("refreshToken", acc.get("refreshToken"))
  374. expires_in = data.get("expiresIn", 3600) # Default 1 hour if not provided
  375. expires_at = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(time.time() + expires_in))
  376. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  377. status = "success"
  378. except httpx.HTTPError as e:
  379. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  380. status = "failed"
  381. await _db.execute(
  382. """
  383. UPDATE accounts
  384. SET last_refresh_time=?, last_refresh_status=?, updated_at=?
  385. WHERE id=?
  386. """,
  387. (now, status, now, account_id),
  388. )
  389. # 记录刷新失败次数
  390. await _update_stats(account_id, False)
  391. raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
  392. except Exception as e:
  393. # Ensure last_refresh_time is recorded even on unexpected errors
  394. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  395. status = "failed"
  396. await _db.execute(
  397. """
  398. UPDATE accounts
  399. SET last_refresh_time=?, last_refresh_status=?, updated_at=?
  400. WHERE id=?
  401. """,
  402. (now, status, now, account_id),
  403. )
  404. # 记录刷新失败次数
  405. await _update_stats(account_id, False)
  406. raise
  407. await _db.execute(
  408. """
  409. UPDATE accounts
  410. SET accessToken=?, refreshToken=?, expires_at=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
  411. WHERE id=?
  412. """,
  413. (new_access, new_refresh, expires_at, now, status, now, account_id),
  414. )
  415. row2 = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  416. return _row_to_dict(row2)
  417. async def get_account(account_id: str) -> Dict[str, Any]:
  418. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  419. if not row:
  420. raise HTTPException(status_code=404, detail="Account not found")
  421. return _row_to_dict(row)
  422. async def _update_stats(account_id: str, success: bool) -> None:
  423. if success:
  424. await _db.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
  425. (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  426. else:
  427. row = await _db.fetchone("SELECT error_count FROM accounts WHERE id=?", (account_id,))
  428. if row:
  429. new_count = (row['error_count'] or 0) + 1
  430. if new_count >= MAX_ERROR_COUNT:
  431. await _db.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
  432. (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  433. else:
  434. await _db.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
  435. (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
  436. # ------------------------------------------------------------------------------
  437. # Dependencies
  438. # ------------------------------------------------------------------------------
  439. async def require_account(
  440. authorization: Optional[str] = Header(default=None),
  441. x_api_key: Optional[str] = Header(default=None)
  442. ) -> Dict[str, Any]:
  443. key = _extract_bearer(authorization) if authorization else x_api_key
  444. return await resolve_account_for_key(key)
  445. def verify_admin_password(authorization: Optional[str] = Header(None)) -> bool:
  446. """Verify admin password for console access"""
  447. if not authorization or not authorization.startswith("Bearer "):
  448. raise HTTPException(
  449. status_code=401,
  450. detail={"error": "Unauthorized access", "code": "UNAUTHORIZED"}
  451. )
  452. password = authorization[7:] # Remove "Bearer " prefix
  453. if password != ADMIN_PASSWORD:
  454. raise HTTPException(
  455. status_code=401,
  456. detail={"error": "Invalid password", "code": "INVALID_PASSWORD"}
  457. )
  458. return True
  459. # ------------------------------------------------------------------------------
  460. # OpenAI-compatible Chat endpoint
  461. # ------------------------------------------------------------------------------
  462. def _openai_non_streaming_response(
  463. text: str,
  464. model: Optional[str],
  465. prompt_tokens: int = 0,
  466. completion_tokens: int = 0,
  467. ) -> Dict[str, Any]:
  468. created = int(time.time())
  469. return {
  470. "id": f"chatcmpl-{uuid.uuid4()}",
  471. "object": "chat.completion",
  472. "created": created,
  473. "model": model or "unknown",
  474. "choices": [
  475. {
  476. "index": 0,
  477. "message": {
  478. "role": "assistant",
  479. "content": text,
  480. },
  481. "finish_reason": "stop",
  482. }
  483. ],
  484. "usage": {
  485. "prompt_tokens": prompt_tokens,
  486. "completion_tokens": completion_tokens,
  487. "total_tokens": prompt_tokens + completion_tokens,
  488. },
  489. }
  490. def _sse_format(obj: Dict[str, Any]) -> str:
  491. return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
  492. @app.post("/v1/messages")
  493. async def claude_messages(
  494. req: ClaudeRequest,
  495. account: Dict[str, Any] = Depends(require_account),
  496. x_conversation_id: Optional[str] = Header(default=None, alias="x-conversation-id")
  497. ):
  498. """
  499. Claude-compatible messages endpoint.
  500. """
  501. # 1. Convert request
  502. # Always generate a new conversation_id like amq2api does
  503. # Using the same conversation_id can cause Amazon Q to return cached/stale data
  504. try:
  505. aq_request = convert_claude_to_amazonq_request(req, conversation_id=None)
  506. except Exception as e:
  507. traceback.print_exc()
  508. raise HTTPException(status_code=400, detail=f"Request conversion failed: {str(e)}")
  509. # Post-process history to fix message ordering (prevents infinite loops)
  510. from message_processor import process_claude_history_for_amazonq
  511. conversation_state = aq_request.get("conversationState", {})
  512. history = conversation_state.get("history", [])
  513. if history:
  514. processed_history = process_claude_history_for_amazonq(history)
  515. aq_request["conversationState"]["history"] = processed_history
  516. # Remove duplicate tail userInputMessage that matches currentMessage content
  517. # This prevents the model from repeatedly responding to the same user message
  518. conversation_state = aq_request.get("conversationState", {})
  519. current_msg = conversation_state.get("currentMessage", {}).get("userInputMessage", {})
  520. current_content = (current_msg.get("content") or "").strip()
  521. history = conversation_state.get("history", [])
  522. if history and current_content:
  523. last = history[-1]
  524. if "userInputMessage" in last:
  525. last_content = (last["userInputMessage"].get("content") or "").strip()
  526. if last_content and last_content == current_content:
  527. # Remove duplicate tail userInputMessage
  528. history = history[:-1]
  529. aq_request["conversationState"]["history"] = history
  530. import logging
  531. logging.getLogger(__name__).info("Removed duplicate tail userInputMessage to prevent repeated response")
  532. conversation_state = aq_request.get("conversationState", {})
  533. conversation_id = conversation_state.get("conversationId")
  534. response_headers: Dict[str, str] = {}
  535. if conversation_id:
  536. response_headers["x-conversation-id"] = conversation_id
  537. # Always stream from upstream to get full event details
  538. event_iter = None
  539. try:
  540. access = account.get("accessToken")
  541. if not access:
  542. refreshed = await refresh_access_token_in_db(account["id"])
  543. access = refreshed.get("accessToken")
  544. # We call with stream=True to get the event iterator
  545. _, _, tracker, event_iter = await send_chat_request(
  546. access_token=access,
  547. messages=[],
  548. model=map_model_name(req.model),
  549. stream=True,
  550. client=GLOBAL_CLIENT,
  551. raw_payload=aq_request
  552. )
  553. if not event_iter:
  554. raise HTTPException(status_code=502, detail="No event stream returned")
  555. # Handler
  556. # Calculate input tokens
  557. text_to_count = ""
  558. if req.system:
  559. if isinstance(req.system, str):
  560. text_to_count += req.system
  561. elif isinstance(req.system, list):
  562. for item in req.system:
  563. if isinstance(item, dict) and item.get("type") == "text":
  564. text_to_count += item.get("text", "")
  565. for msg in req.messages:
  566. if isinstance(msg.content, str):
  567. text_to_count += msg.content
  568. elif isinstance(msg.content, list):
  569. for item in msg.content:
  570. if isinstance(item, dict) and item.get("type") == "text":
  571. text_to_count += item.get("text", "")
  572. input_tokens = count_tokens(text_to_count, apply_multiplier=True)
  573. handler = ClaudeStreamHandler(model=req.model, input_tokens=input_tokens, conversation_id=conversation_id)
  574. # Try to get the first event to ensure the connection is valid
  575. # This allows us to return proper HTTP error codes before starting the stream
  576. first_event = None
  577. try:
  578. first_event = await event_iter.__anext__()
  579. except StopAsyncIteration:
  580. raise HTTPException(status_code=502, detail="Empty response from upstream")
  581. except Exception as e:
  582. # If we get an error before the first event, we can still return proper status code
  583. err_msg = str(e)
  584. # Extract upstream status code from "Upstream error {code}: {message}"
  585. if err_msg.startswith("Upstream error "):
  586. match = re.match(r"Upstream error (\d+):", err_msg)
  587. if match:
  588. raise HTTPException(status_code=int(match.group(1)), detail=err_msg)
  589. raise HTTPException(status_code=502, detail=f"Upstream error: {err_msg}")
  590. async def event_generator():
  591. try:
  592. # Process the first event we already fetched
  593. if first_event:
  594. event_type, payload = first_event
  595. async for sse in handler.handle_event(event_type, payload):
  596. yield sse
  597. # Process remaining events
  598. async for event_type, payload in event_iter:
  599. async for sse in handler.handle_event(event_type, payload):
  600. yield sse
  601. async for sse in handler.finish():
  602. yield sse
  603. await _update_stats(account["id"], True)
  604. except GeneratorExit:
  605. # Client disconnected - update stats but don't re-raise
  606. await _update_stats(account["id"], tracker.has_content if tracker else False)
  607. except Exception:
  608. await _update_stats(account["id"], False)
  609. raise
  610. if req.stream:
  611. return StreamingResponse(
  612. event_generator(),
  613. media_type="text/event-stream",
  614. headers=response_headers or None
  615. )
  616. else:
  617. # Accumulate for non-streaming
  618. # This is a bit complex because we need to reconstruct the full response object
  619. # For now, let's just support streaming as it's the main use case for Claude Code
  620. # But to be nice, let's try to support non-streaming by consuming the generator
  621. content_blocks = []
  622. usage = {"input_tokens": 0, "output_tokens": 0}
  623. stop_reason = None
  624. # We need to parse the SSE strings back to objects... inefficient but works
  625. # Or we could refactor handler to yield objects.
  626. # For now, let's just raise error for non-streaming or implement basic text
  627. # Claude Code uses streaming.
  628. # Let's implement a basic accumulator from the SSE stream
  629. final_content = []
  630. async for sse_chunk in event_generator():
  631. data_str = None
  632. # Each chunk from the generator can have multiple lines ('event:', 'data:').
  633. # We need to find the 'data:' line.
  634. for line in sse_chunk.strip().split('\n'):
  635. if line.startswith("data:"):
  636. data_str = line[6:].strip()
  637. break
  638. if not data_str or data_str == "[DONE]":
  639. continue
  640. try:
  641. data = json.loads(data_str)
  642. dtype = data.get("type")
  643. if dtype == "content_block_start":
  644. idx = data.get("index", 0)
  645. while len(final_content) <= idx:
  646. final_content.append(None)
  647. final_content[idx] = data.get("content_block")
  648. elif dtype == "content_block_delta":
  649. idx = data.get("index", 0)
  650. delta = data.get("delta", {})
  651. if final_content[idx]:
  652. if delta.get("type") == "text_delta":
  653. final_content[idx]["text"] += delta.get("text", "")
  654. elif delta.get("type") == "thinking_delta":
  655. final_content[idx].setdefault("thinking", "")
  656. final_content[idx]["thinking"] += delta.get("thinking", "")
  657. elif delta.get("type") == "input_json_delta":
  658. if "partial_json" not in final_content[idx]:
  659. final_content[idx]["partial_json"] = ""
  660. final_content[idx]["partial_json"] += delta.get("partial_json", "")
  661. elif dtype == "content_block_stop":
  662. idx = data.get("index", 0)
  663. if final_content[idx] and final_content[idx].get("type") == "tool_use":
  664. if "partial_json" in final_content[idx]:
  665. try:
  666. final_content[idx]["input"] = json.loads(final_content[idx]["partial_json"])
  667. except json.JSONDecodeError:
  668. # Keep partial if invalid
  669. final_content[idx]["input"] = {"error": "invalid json", "partial": final_content[idx]["partial_json"]}
  670. del final_content[idx]["partial_json"]
  671. elif dtype == "message_delta":
  672. usage = data.get("usage", usage)
  673. stop_reason = data.get("delta", {}).get("stop_reason")
  674. except json.JSONDecodeError:
  675. # Ignore lines that are not valid JSON
  676. pass
  677. except Exception:
  678. # Broad exception to prevent accumulator from crashing on one bad event
  679. traceback.print_exc()
  680. pass
  681. # Final assembly
  682. final_content_cleaned = []
  683. for c in final_content:
  684. if c is not None:
  685. # Remove internal state like 'partial_json' before returning
  686. c.pop("partial_json", None)
  687. final_content_cleaned.append(c)
  688. response_body = {
  689. "id": f"msg_{uuid.uuid4()}",
  690. "type": "message",
  691. "role": "assistant",
  692. "model": req.model,
  693. "content": final_content_cleaned,
  694. "stop_reason": stop_reason,
  695. "stop_sequence": None,
  696. "usage": usage
  697. }
  698. if conversation_id:
  699. response_body["conversation_id"] = conversation_id
  700. response_body["conversationId"] = conversation_id
  701. return JSONResponse(content=response_body, headers=response_headers or None)
  702. except Exception as e:
  703. # Ensure event_iter (if created) is closed to release upstream connection
  704. try:
  705. if event_iter and hasattr(event_iter, "aclose"):
  706. await event_iter.aclose()
  707. except Exception:
  708. pass
  709. await _update_stats(account["id"], False)
  710. # Extract upstream status code from "Upstream error {code}: {message}"
  711. err_msg = str(e)
  712. if err_msg.startswith("Upstream error "):
  713. match = re.match(r"Upstream error (\d+):", err_msg)
  714. if match:
  715. raise HTTPException(status_code=int(match.group(1)), detail=err_msg)
  716. raise
  717. @app.post("/v1/messages/count_tokens")
  718. async def count_tokens_endpoint(req: ClaudeRequest):
  719. """
  720. Count tokens in a message without sending it.
  721. Compatible with Claude API's /v1/messages/count_tokens endpoint.
  722. Uses tiktoken for local token counting.
  723. """
  724. text_to_count = ""
  725. # Count system prompt tokens
  726. if req.system:
  727. if isinstance(req.system, str):
  728. text_to_count += req.system
  729. elif isinstance(req.system, list):
  730. for item in req.system:
  731. if isinstance(item, dict) and item.get("type") == "text":
  732. text_to_count += item.get("text", "")
  733. # Count message tokens
  734. for msg in req.messages:
  735. if isinstance(msg.content, str):
  736. text_to_count += msg.content
  737. elif isinstance(msg.content, list):
  738. for item in msg.content:
  739. if isinstance(item, dict) and item.get("type") == "text":
  740. text_to_count += item.get("text", "")
  741. # Count tool definition tokens if present
  742. if req.tools:
  743. text_to_count += json.dumps([tool.model_dump() if hasattr(tool, 'model_dump') else tool for tool in req.tools], ensure_ascii=False)
  744. input_tokens = count_tokens(text_to_count, apply_multiplier=True)
  745. return {"input_tokens": input_tokens}
  746. @app.post("/v1/chat/completions")
  747. async def chat_completions(req: ChatCompletionRequest, account: Dict[str, Any] = Depends(require_account)):
  748. """
  749. OpenAI-compatible chat endpoint.
  750. - stream default False
  751. - messages will be converted into "{role}:\n{content}" and injected into template
  752. - account is chosen randomly among enabled accounts (API key is for authorization only)
  753. """
  754. model = map_model_name(req.model)
  755. do_stream = bool(req.stream)
  756. async def _send_upstream(stream: bool) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], Any]:
  757. access = account.get("accessToken")
  758. if not access:
  759. refreshed = await refresh_access_token_in_db(account["id"])
  760. access = refreshed.get("accessToken")
  761. if not access:
  762. raise HTTPException(status_code=502, detail="Access token unavailable after refresh")
  763. # Note: send_chat_request signature changed, but we use keyword args so it should be fine if we don't pass raw_payload
  764. # But wait, the return signature changed too! It now returns 4 values.
  765. # We need to unpack 4 values.
  766. result = await send_chat_request(access, [m.model_dump() for m in req.messages], model=model, stream=stream, client=GLOBAL_CLIENT)
  767. return result[0], result[1], result[2] # Ignore the 4th value (event_stream) for OpenAI endpoint
  768. if not do_stream:
  769. try:
  770. # Calculate prompt tokens
  771. prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
  772. prompt_tokens = count_tokens(prompt_text)
  773. text, _, tracker = await _send_upstream(stream=False)
  774. await _update_stats(account["id"], bool(text))
  775. completion_tokens = count_tokens(text or "")
  776. return JSONResponse(content=_openai_non_streaming_response(
  777. text or "",
  778. model,
  779. prompt_tokens=prompt_tokens,
  780. completion_tokens=completion_tokens
  781. ))
  782. except Exception as e:
  783. await _update_stats(account["id"], False)
  784. raise
  785. else:
  786. created = int(time.time())
  787. stream_id = f"chatcmpl-{uuid.uuid4()}"
  788. model_used = model or "unknown"
  789. it = None
  790. try:
  791. # Calculate prompt tokens
  792. prompt_text = "".join([m.content for m in req.messages if isinstance(m.content, str)])
  793. prompt_tokens = count_tokens(prompt_text)
  794. _, it, tracker = await _send_upstream(stream=True)
  795. assert it is not None
  796. async def event_gen() -> AsyncGenerator[str, None]:
  797. completion_text = ""
  798. try:
  799. # Send role first
  800. yield _sse_format({
  801. "id": stream_id,
  802. "object": "chat.completion.chunk",
  803. "created": created,
  804. "model": model_used,
  805. "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
  806. })
  807. # Stream content
  808. async for piece in it:
  809. if piece:
  810. completion_text += piece
  811. yield _sse_format({
  812. "id": stream_id,
  813. "object": "chat.completion.chunk",
  814. "created": created,
  815. "model": model_used,
  816. "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
  817. })
  818. # Send stop and usage
  819. completion_tokens = count_tokens(completion_text)
  820. yield _sse_format({
  821. "id": stream_id,
  822. "object": "chat.completion.chunk",
  823. "created": created,
  824. "model": model_used,
  825. "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
  826. "usage": {
  827. "prompt_tokens": prompt_tokens,
  828. "completion_tokens": completion_tokens,
  829. "total_tokens": prompt_tokens + completion_tokens,
  830. }
  831. })
  832. yield "data: [DONE]\n\n"
  833. await _update_stats(account["id"], True)
  834. except GeneratorExit:
  835. # Client disconnected - update stats but don't re-raise
  836. await _update_stats(account["id"], tracker.has_content if tracker else False)
  837. except Exception:
  838. await _update_stats(account["id"], tracker.has_content if tracker else False)
  839. raise
  840. return StreamingResponse(event_gen(), media_type="text/event-stream")
  841. except Exception as e:
  842. # Ensure iterator (if created) is closed to release upstream connection
  843. try:
  844. if it and hasattr(it, "aclose"):
  845. await it.aclose()
  846. except Exception:
  847. pass
  848. await _update_stats(account["id"], False)
  849. raise
  850. # ------------------------------------------------------------------------------
  851. # Device Authorization (URL Login, 5-minute timeout)
  852. # ------------------------------------------------------------------------------
  853. # Dynamic import of auth_flow.py (device-code login helpers)
  854. def _load_auth_flow_module():
  855. mod_path = BASE_DIR / "auth_flow.py"
  856. spec = importlib.util.spec_from_file_location("v2_auth_flow", str(mod_path))
  857. module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
  858. assert spec is not None and spec.loader is not None
  859. spec.loader.exec_module(module) # type: ignore[attr-defined]
  860. return module
  861. _auth_flow = _load_auth_flow_module()
  862. register_client_min = _auth_flow.register_client_min
  863. device_authorize = _auth_flow.device_authorize
  864. poll_token_device_code = _auth_flow.poll_token_device_code
  865. # In-memory auth sessions (ephemeral)
  866. AUTH_SESSIONS: Dict[str, Dict[str, Any]] = {}
  867. class AuthStartBody(BaseModel):
  868. label: Optional[str] = None
  869. enabled: Optional[bool] = True
  870. class AdminLoginRequest(BaseModel):
  871. password: str
  872. class AdminLoginResponse(BaseModel):
  873. success: bool
  874. message: str
  875. async def _create_account_from_tokens(
  876. client_id: str,
  877. client_secret: str,
  878. access_token: str,
  879. refresh_token: Optional[str],
  880. label: Optional[str],
  881. enabled: bool,
  882. ) -> Dict[str, Any]:
  883. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  884. acc_id = str(uuid.uuid4())
  885. await _db.execute(
  886. """
  887. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
  888. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  889. """,
  890. (
  891. acc_id,
  892. label,
  893. client_id,
  894. client_secret,
  895. refresh_token,
  896. access_token,
  897. None,
  898. now,
  899. "success",
  900. now,
  901. now,
  902. 1 if enabled else 0,
  903. None, # expires_at - will be set on first refresh
  904. ),
  905. )
  906. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
  907. return _row_to_dict(row)
  908. # 管理控制台相关端点 - 仅在启用时注册
  909. if CONSOLE_ENABLED:
  910. # ------------------------------------------------------------------------------
  911. # Admin Authentication Endpoints
  912. # ------------------------------------------------------------------------------
  913. @app.post("/api/login", response_model=AdminLoginResponse)
  914. async def admin_login(request: AdminLoginRequest) -> AdminLoginResponse:
  915. """Admin login endpoint - password only"""
  916. if request.password == ADMIN_PASSWORD:
  917. return AdminLoginResponse(
  918. success=True,
  919. message="Login successful"
  920. )
  921. else:
  922. return AdminLoginResponse(
  923. success=False,
  924. message="Invalid password"
  925. )
  926. @app.get("/login", response_class=FileResponse)
  927. def login_page():
  928. """Serve the login page"""
  929. path = BASE_DIR / "frontend" / "login.html"
  930. if not path.exists():
  931. raise HTTPException(status_code=404, detail="frontend/login.html not found")
  932. return FileResponse(str(path))
  933. # ------------------------------------------------------------------------------
  934. # Device Authorization Endpoints
  935. # ------------------------------------------------------------------------------
  936. @app.post("/v2/auth/start")
  937. async def auth_start(body: AuthStartBody, _: bool = Depends(verify_admin_password)):
  938. """
  939. Start device authorization and return verification URL for user login.
  940. Session lifetime capped at 5 minutes on claim.
  941. """
  942. try:
  943. cid, csec = await register_client_min()
  944. dev = await device_authorize(cid, csec)
  945. except httpx.HTTPError as e:
  946. raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
  947. auth_id = str(uuid.uuid4())
  948. sess = {
  949. "clientId": cid,
  950. "clientSecret": csec,
  951. "deviceCode": dev.get("deviceCode"),
  952. "interval": int(dev.get("interval", 1)),
  953. "expiresIn": int(dev.get("expiresIn", 600)),
  954. "verificationUriComplete": dev.get("verificationUriComplete"),
  955. "userCode": dev.get("userCode"),
  956. "startTime": int(time.time()),
  957. "label": body.label,
  958. "enabled": True if body.enabled is None else bool(body.enabled),
  959. "status": "pending",
  960. "error": None,
  961. "accountId": None,
  962. }
  963. AUTH_SESSIONS[auth_id] = sess
  964. return {
  965. "authId": auth_id,
  966. "verificationUriComplete": sess["verificationUriComplete"],
  967. "userCode": sess["userCode"],
  968. "expiresIn": sess["expiresIn"],
  969. "interval": sess["interval"],
  970. }
  971. @app.get("/v2/auth/status/{auth_id}")
  972. async def auth_status(auth_id: str, _: bool = Depends(verify_admin_password)):
  973. sess = AUTH_SESSIONS.get(auth_id)
  974. if not sess:
  975. raise HTTPException(status_code=404, detail="Auth session not found")
  976. now_ts = int(time.time())
  977. deadline = sess["startTime"] + min(int(sess.get("expiresIn", 600)), 300)
  978. remaining = max(0, deadline - now_ts)
  979. return {
  980. "status": sess.get("status"),
  981. "remaining": remaining,
  982. "error": sess.get("error"),
  983. "accountId": sess.get("accountId"),
  984. }
  985. @app.post("/v2/auth/claim/{auth_id}")
  986. async def auth_claim(auth_id: str, _: bool = Depends(verify_admin_password)):
  987. """
  988. Block up to 5 minutes to exchange the device code for tokens after user completed login.
  989. On success, creates an enabled account and returns it.
  990. """
  991. sess = AUTH_SESSIONS.get(auth_id)
  992. if not sess:
  993. raise HTTPException(status_code=404, detail="Auth session not found")
  994. if sess.get("status") in ("completed", "timeout", "error"):
  995. return {
  996. "status": sess["status"],
  997. "accountId": sess.get("accountId"),
  998. "error": sess.get("error"),
  999. }
  1000. try:
  1001. toks = await poll_token_device_code(
  1002. sess["clientId"],
  1003. sess["clientSecret"],
  1004. sess["deviceCode"],
  1005. sess["interval"],
  1006. sess["expiresIn"],
  1007. max_timeout_sec=300, # 5 minutes
  1008. )
  1009. access_token = toks.get("accessToken")
  1010. refresh_token = toks.get("refreshToken")
  1011. if not access_token:
  1012. raise HTTPException(status_code=502, detail="No accessToken returned from OIDC")
  1013. acc = await _create_account_from_tokens(
  1014. sess["clientId"],
  1015. sess["clientSecret"],
  1016. access_token,
  1017. refresh_token,
  1018. sess.get("label"),
  1019. sess.get("enabled", True),
  1020. )
  1021. sess["status"] = "completed"
  1022. sess["accountId"] = acc["id"]
  1023. return {
  1024. "status": "completed",
  1025. "account": acc,
  1026. }
  1027. except TimeoutError:
  1028. sess["status"] = "timeout"
  1029. raise HTTPException(status_code=408, detail="Authorization timeout (5 minutes)")
  1030. except httpx.HTTPError as e:
  1031. sess["status"] = "error"
  1032. sess["error"] = str(e)
  1033. raise HTTPException(status_code=502, detail=f"OIDC error: {str(e)}")
  1034. # ------------------------------------------------------------------------------
  1035. # Accounts Management API
  1036. # ------------------------------------------------------------------------------
  1037. @app.post("/v2/accounts")
  1038. async def create_account(body: AccountCreate, _: bool = Depends(verify_admin_password)):
  1039. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1040. acc_id = str(uuid.uuid4())
  1041. other_str = json.dumps(body.other, ensure_ascii=False) if body.other is not None else None
  1042. enabled_val = 1 if (body.enabled is None or body.enabled) else 0
  1043. await _db.execute(
  1044. """
  1045. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
  1046. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  1047. """,
  1048. (
  1049. acc_id,
  1050. body.label,
  1051. body.clientId,
  1052. body.clientSecret,
  1053. body.refreshToken,
  1054. body.accessToken,
  1055. other_str,
  1056. None,
  1057. "never",
  1058. now,
  1059. now,
  1060. enabled_val,
  1061. None, # expires_at - will be set on first refresh
  1062. ),
  1063. )
  1064. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
  1065. return _row_to_dict(row)
  1066. async def _verify_and_enable_accounts(account_ids: List[str]):
  1067. """后台异步验证并启用账号"""
  1068. for acc_id in account_ids:
  1069. try:
  1070. # 必须先获取完整的账号信息
  1071. account = await get_account(acc_id)
  1072. verify_success, fail_reason = await verify_account(account)
  1073. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1074. if verify_success:
  1075. await _db.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, acc_id))
  1076. elif fail_reason:
  1077. other_dict = account.get("other", {}) or {}
  1078. other_dict['failedReason'] = fail_reason
  1079. await _db.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, acc_id))
  1080. except Exception as e:
  1081. print(f"Error verifying account {acc_id}: {e}")
  1082. traceback.print_exc()
  1083. @app.post("/v2/accounts/feed")
  1084. async def create_accounts_feed(request: BatchAccountCreate, _: bool = Depends(verify_admin_password)):
  1085. """
  1086. 统一的投喂接口,接收账号列表,立即存入并后台异步验证。
  1087. """
  1088. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1089. new_account_ids = []
  1090. for i, account_data in enumerate(request.accounts):
  1091. acc_id = str(uuid.uuid4())
  1092. other_dict = account_data.other or {}
  1093. other_dict['source'] = 'feed'
  1094. other_str = json.dumps(other_dict, ensure_ascii=False)
  1095. await _db.execute(
  1096. """
  1097. INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
  1098. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  1099. """,
  1100. (
  1101. acc_id,
  1102. account_data.label or f"批量账号 {i+1}",
  1103. account_data.clientId,
  1104. account_data.clientSecret,
  1105. account_data.refreshToken,
  1106. account_data.accessToken,
  1107. other_str,
  1108. None,
  1109. "never",
  1110. now,
  1111. now,
  1112. 0, # 初始为禁用状态
  1113. None, # expires_at - will be set on first refresh
  1114. ),
  1115. )
  1116. new_account_ids.append(acc_id)
  1117. # 启动后台任务进行验证,不阻塞当前请求
  1118. if new_account_ids:
  1119. asyncio.create_task(_verify_and_enable_accounts(new_account_ids))
  1120. return {
  1121. "status": "processing",
  1122. "message": f"{len(new_account_ids)} accounts received and are being verified in the background.",
  1123. "account_ids": new_account_ids
  1124. }
  1125. @app.get("/v2/accounts")
  1126. async def list_accounts(_: bool = Depends(verify_admin_password), enabled: Optional[bool] = None, sort_by: str = "created_at", sort_order: str = "desc"):
  1127. query = "SELECT * FROM accounts"
  1128. params = []
  1129. if enabled is not None:
  1130. query += " WHERE enabled=?"
  1131. params.append(1 if enabled else 0)
  1132. sort_field = "created_at" if sort_by not in ["created_at", "success_count"] else sort_by
  1133. order = "DESC" if sort_order.lower() == "desc" else "ASC"
  1134. query += f" ORDER BY {sort_field} {order}"
  1135. rows = await _db.fetchall(query, tuple(params) if params else ())
  1136. accounts = [_row_to_dict(r) for r in rows]
  1137. return {"accounts": accounts, "count": len(accounts)}
  1138. @app.get("/v2/accounts/{account_id}")
  1139. async def get_account_detail(account_id: str, _: bool = Depends(verify_admin_password)):
  1140. return await get_account(account_id)
  1141. @app.delete("/v2/accounts/{account_id}")
  1142. async def delete_account(account_id: str, _: bool = Depends(verify_admin_password)):
  1143. rowcount = await _db.execute("DELETE FROM accounts WHERE id=?", (account_id,))
  1144. if rowcount == 0:
  1145. raise HTTPException(status_code=404, detail="Account not found")
  1146. return {"deleted": account_id}
  1147. @app.patch("/v2/accounts/{account_id}")
  1148. async def update_account(account_id: str, body: AccountUpdate, _: bool = Depends(verify_admin_password)):
  1149. now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1150. fields = []
  1151. values: List[Any] = []
  1152. if body.label is not None:
  1153. fields.append("label=?"); values.append(body.label)
  1154. if body.clientId is not None:
  1155. fields.append("clientId=?"); values.append(body.clientId)
  1156. if body.clientSecret is not None:
  1157. fields.append("clientSecret=?"); values.append(body.clientSecret)
  1158. if body.refreshToken is not None:
  1159. fields.append("refreshToken=?"); values.append(body.refreshToken)
  1160. if body.accessToken is not None:
  1161. fields.append("accessToken=?"); values.append(body.accessToken)
  1162. if body.other is not None:
  1163. fields.append("other=?"); values.append(json.dumps(body.other, ensure_ascii=False))
  1164. if body.enabled is not None:
  1165. fields.append("enabled=?"); values.append(1 if body.enabled else 0)
  1166. if not fields:
  1167. return await get_account(account_id)
  1168. fields.append("updated_at=?"); values.append(now)
  1169. values.append(account_id)
  1170. rowcount = await _db.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", tuple(values))
  1171. if rowcount == 0:
  1172. raise HTTPException(status_code=404, detail="Account not found")
  1173. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  1174. return _row_to_dict(row)
  1175. @app.post("/v2/accounts/{account_id}/refresh")
  1176. async def manual_refresh(account_id: str, _: bool = Depends(verify_admin_password)):
  1177. return await refresh_access_token_in_db(account_id)
  1178. @app.post("/v2/chat/test")
  1179. async def admin_chat_test(req: ChatCompletionRequest, account_id: Optional[str] = None, _: bool = Depends(verify_admin_password)):
  1180. """Admin chat test - uses admin auth, selects account by id or random."""
  1181. if account_id:
  1182. row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
  1183. if not row:
  1184. raise HTTPException(status_code=404, detail="Account not found")
  1185. account = _row_to_dict(row)
  1186. # Check if token is expired or missing
  1187. expires_at = account.get("expires_at")
  1188. now_str = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1189. if not expires_at or expires_at <= now_str:
  1190. account = await refresh_access_token_in_db(account_id)
  1191. else:
  1192. candidates = await _list_enabled_accounts()
  1193. if not candidates:
  1194. raise HTTPException(status_code=503, detail="No enabled account available")
  1195. account = random.choice(candidates)
  1196. return await chat_completions(req, account)
  1197. # ------------------------------------------------------------------------------
  1198. # Simple Frontend (minimal dev test page; full UI in v2/frontend/index.html)
  1199. # ------------------------------------------------------------------------------
  1200. # Frontend inline HTML removed; serving ./frontend/index.html instead (see route below)
  1201. # Note: This route is NOT protected - the HTML file is served freely,
  1202. # but the frontend JavaScript checks authentication and redirects to /login if needed.
  1203. # All API endpoints remain protected.
  1204. @app.get("/", response_class=FileResponse)
  1205. def index():
  1206. path = BASE_DIR / "frontend" / "index.html"
  1207. if not path.exists():
  1208. raise HTTPException(status_code=404, detail="frontend/index.html not found")
  1209. return FileResponse(str(path))
  1210. # ------------------------------------------------------------------------------
  1211. # Health
  1212. # ------------------------------------------------------------------------------
  1213. @app.get("/healthz")
  1214. async def health():
  1215. return {"status": "ok"}
  1216. # ------------------------------------------------------------------------------
  1217. # Startup / Shutdown Events
  1218. # ------------------------------------------------------------------------------
  1219. # async def _verify_disabled_accounts_loop():
  1220. # """后台验证禁用账号任务"""
  1221. # while True:
  1222. # try:
  1223. # await asyncio.sleep(1800)
  1224. # async with _conn() as conn:
  1225. # accounts = await _list_disabled_accounts(conn)
  1226. # if accounts:
  1227. # for account in accounts:
  1228. # other = account.get('other')
  1229. # if other:
  1230. # try:
  1231. # other_dict = json.loads(other) if isinstance(other, str) else other
  1232. # if other_dict.get('failedReason') == 'AccessDenied':
  1233. # continue
  1234. # except:
  1235. # pass
  1236. # try:
  1237. # verify_success, fail_reason = await verify_account(account)
  1238. # now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
  1239. # if verify_success:
  1240. # await conn.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, account['id']))
  1241. # elif fail_reason:
  1242. # other_dict = {}
  1243. # if account.get('other'):
  1244. # try:
  1245. # other_dict = json.loads(account['other']) if isinstance(account['other'], str) else account['other']
  1246. # except:
  1247. # pass
  1248. # other_dict['failedReason'] = fail_reason
  1249. # await conn.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, account['id']))
  1250. # await conn.commit()
  1251. # except Exception:
  1252. # pass
  1253. # except Exception:
  1254. # pass
  1255. @app.on_event("startup")
  1256. async def startup_event():
  1257. """Initialize database and start background tasks on startup."""
  1258. await _init_global_client()
  1259. await _ensure_db()
  1260. asyncio.create_task(_refresh_stale_tokens())
  1261. # asyncio.create_task(_verify_disabled_accounts_loop())
  1262. @app.on_event("shutdown")
  1263. async def shutdown_event():
  1264. await _close_global_client()
  1265. await close_db()