|
|
@@ -15,9 +15,10 @@ from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, Fil
|
|
|
from pydantic import BaseModel
|
|
|
from dotenv import load_dotenv
|
|
|
import httpx
|
|
|
-import aiosqlite
|
|
|
import tiktoken
|
|
|
|
|
|
+from db import get_database_backend, init_db, close_db, row_to_dict
|
|
|
+
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Tokenizer
|
|
|
# ------------------------------------------------------------------------------
|
|
|
@@ -39,7 +40,6 @@ def count_tokens(text: str) -> int:
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
|
-DB_PATH = BASE_DIR / "data.sqlite3"
|
|
|
|
|
|
load_dotenv(BASE_DIR / ".env")
|
|
|
|
|
|
@@ -170,63 +170,20 @@ async def _close_global_client():
|
|
|
GLOBAL_CLIENT = None
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
-# SQLite helpers
|
|
|
+# Database helpers
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
-async def _ensure_db():
|
|
|
- DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
|
- async with aiosqlite.connect(DB_PATH) as conn:
|
|
|
- await conn.execute("PRAGMA journal_mode=WAL;")
|
|
|
- await conn.execute(
|
|
|
- """
|
|
|
- CREATE TABLE IF NOT EXISTS accounts (
|
|
|
- id TEXT PRIMARY KEY,
|
|
|
- label TEXT,
|
|
|
- clientId TEXT,
|
|
|
- clientSecret TEXT,
|
|
|
- refreshToken TEXT,
|
|
|
- accessToken TEXT,
|
|
|
- other TEXT,
|
|
|
- last_refresh_time TEXT,
|
|
|
- last_refresh_status TEXT,
|
|
|
- created_at TEXT,
|
|
|
- updated_at TEXT
|
|
|
- )
|
|
|
- """
|
|
|
- )
|
|
|
- # add columns if missing
|
|
|
- try:
|
|
|
- async with conn.execute("PRAGMA table_info(accounts)") as cursor:
|
|
|
- rows = await cursor.fetchall()
|
|
|
- cols = [row[1] for row in rows]
|
|
|
- if "enabled" not in cols:
|
|
|
- await conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
|
|
|
- if "error_count" not in cols:
|
|
|
- await conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
|
|
|
- if "success_count" not in cols:
|
|
|
- await conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
- await conn.commit()
|
|
|
+# Database backend instance (initialized on startup)
|
|
|
+_db = None
|
|
|
|
|
|
-def _conn() -> aiosqlite.Connection:
|
|
|
- """Create a new database connection. Must be used with async with."""
|
|
|
- return aiosqlite.connect(DB_PATH)
|
|
|
+async def _ensure_db():
|
|
|
+ """Initialize database backend."""
|
|
|
+ global _db
|
|
|
+ _db = await init_db()
|
|
|
|
|
|
-def _row_to_dict(r: aiosqlite.Row) -> Dict[str, Any]:
|
|
|
- d = dict(r)
|
|
|
- if d.get("other"):
|
|
|
- try:
|
|
|
- d["other"] = json.loads(d["other"])
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
- # normalize enabled to bool
|
|
|
- if "enabled" in d and d["enabled"] is not None:
|
|
|
- try:
|
|
|
- d["enabled"] = bool(int(d["enabled"]))
|
|
|
- except Exception:
|
|
|
- d["enabled"] = bool(d["enabled"])
|
|
|
- return d
|
|
|
+def _row_to_dict(r: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ """Convert database row to dict with JSON parsing."""
|
|
|
+ return row_to_dict(r)
|
|
|
|
|
|
# _ensure_db() will be called in startup event
|
|
|
|
|
|
@@ -239,31 +196,28 @@ async def _refresh_stale_tokens():
|
|
|
try:
|
|
|
await asyncio.sleep(300) # 5 minutes
|
|
|
now = time.time()
|
|
|
- async with _conn() as conn:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- async with conn.execute("SELECT id, last_refresh_time FROM accounts WHERE enabled=1") as cursor:
|
|
|
- rows = await cursor.fetchall()
|
|
|
- for row in rows:
|
|
|
- acc_id, last_refresh = row[0], row[1]
|
|
|
- should_refresh = False
|
|
|
- if not last_refresh or last_refresh == "never":
|
|
|
+ rows = await _db.fetchall("SELECT id, last_refresh_time FROM accounts WHERE enabled=1")
|
|
|
+ for row in rows:
|
|
|
+ acc_id, last_refresh = row['id'], row['last_refresh_time']
|
|
|
+ should_refresh = False
|
|
|
+ if not last_refresh or last_refresh == "never":
|
|
|
+ should_refresh = True
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
|
|
|
+ if now - last_time > 1500: # 25 minutes
|
|
|
should_refresh = True
|
|
|
- else:
|
|
|
- try:
|
|
|
- last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
|
|
|
- if now - last_time > 1500: # 25 minutes
|
|
|
- should_refresh = True
|
|
|
- except Exception:
|
|
|
- # Malformed or unparsable timestamp; force refresh
|
|
|
- should_refresh = True
|
|
|
-
|
|
|
- if should_refresh:
|
|
|
- try:
|
|
|
- await refresh_access_token_in_db(acc_id)
|
|
|
- except Exception:
|
|
|
- traceback.print_exc()
|
|
|
- # Ignore per-account refresh failure; timestamp/status are recorded inside
|
|
|
- pass
|
|
|
+ except Exception:
|
|
|
+ # Malformed or unparsable timestamp; force refresh
|
|
|
+ should_refresh = True
|
|
|
+
|
|
|
+ if should_refresh:
|
|
|
+ try:
|
|
|
+ await refresh_access_token_in_db(acc_id)
|
|
|
+ except Exception:
|
|
|
+ traceback.print_exc()
|
|
|
+ # Ignore per-account refresh failure; timestamp/status are recorded inside
|
|
|
+ pass
|
|
|
except Exception:
|
|
|
traceback.print_exc()
|
|
|
pass
|
|
|
@@ -301,17 +255,13 @@ def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
|
|
|
return token_header.split(" ", 1)[1].strip()
|
|
|
return token_header.strip()
|
|
|
|
|
|
-async def _list_enabled_accounts(conn: aiosqlite.Connection) -> List[Dict[str, Any]]:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC") as cursor:
|
|
|
- rows = await cursor.fetchall()
|
|
|
- return [_row_to_dict(r) for r in rows]
|
|
|
+async def _list_enabled_accounts() -> List[Dict[str, Any]]:
|
|
|
+ rows = await _db.fetchall("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC")
|
|
|
+ return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
-async def _list_disabled_accounts(conn: aiosqlite.Connection) -> List[Dict[str, Any]]:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE enabled=0 ORDER BY created_at DESC") as cursor:
|
|
|
- rows = await cursor.fetchall()
|
|
|
- return [_row_to_dict(r) for r in rows]
|
|
|
+async def _list_disabled_accounts() -> List[Dict[str, Any]]:
|
|
|
+ rows = await _db.fetchall("SELECT * FROM accounts WHERE enabled=0 ORDER BY created_at DESC")
|
|
|
+ return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
async def verify_account(account: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
|
|
|
"""验证账号可用性"""
|
|
|
@@ -349,11 +299,10 @@ async def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
|
|
|
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
|
|
|
|
|
# Selection: random among enabled accounts
|
|
|
- async with _conn() as conn:
|
|
|
- candidates = await _list_enabled_accounts(conn)
|
|
|
- if not candidates:
|
|
|
- raise HTTPException(status_code=401, detail="No enabled account available")
|
|
|
- return random.choice(candidates)
|
|
|
+ candidates = await _list_enabled_accounts()
|
|
|
+ if not candidates:
|
|
|
+ raise HTTPException(status_code=401, detail="No enabled account available")
|
|
|
+ return random.choice(candidates)
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Pydantic Schemas
|
|
|
@@ -406,114 +355,101 @@ def _oidc_headers() -> Dict[str, str]:
|
|
|
}
|
|
|
|
|
|
async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
|
|
|
- async with _conn() as conn:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
- row = await cursor.fetchone()
|
|
|
- if not row:
|
|
|
- raise HTTPException(status_code=404, detail="Account not found")
|
|
|
- acc = _row_to_dict(row)
|
|
|
-
|
|
|
- if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
|
|
|
- raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
|
|
|
-
|
|
|
- payload = {
|
|
|
- "grantType": "refresh_token",
|
|
|
- "clientId": acc["clientId"],
|
|
|
- "clientSecret": acc["clientSecret"],
|
|
|
- "refreshToken": acc["refreshToken"],
|
|
|
- }
|
|
|
+ row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
|
|
|
+ if not row:
|
|
|
+ raise HTTPException(status_code=404, detail="Account not found")
|
|
|
+ acc = _row_to_dict(row)
|
|
|
+
|
|
|
+ if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
|
|
|
+ raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "grantType": "refresh_token",
|
|
|
+ "clientId": acc["clientId"],
|
|
|
+ "clientSecret": acc["clientSecret"],
|
|
|
+ "refreshToken": acc["refreshToken"],
|
|
|
+ }
|
|
|
|
|
|
- try:
|
|
|
- # Use global client if available, else fallback (though global should be ready)
|
|
|
- client = GLOBAL_CLIENT
|
|
|
- if not client:
|
|
|
- # Fallback for safety
|
|
|
- async with httpx.AsyncClient(timeout=60.0) as temp_client:
|
|
|
- r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
|
|
- r.raise_for_status()
|
|
|
- data = r.json()
|
|
|
- else:
|
|
|
- r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
|
|
+ try:
|
|
|
+ # Use global client if available, else fallback (though global should be ready)
|
|
|
+ client = GLOBAL_CLIENT
|
|
|
+ if not client:
|
|
|
+ # Fallback for safety
|
|
|
+ async with httpx.AsyncClient(timeout=60.0) as temp_client:
|
|
|
+ r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
|
|
r.raise_for_status()
|
|
|
data = r.json()
|
|
|
+ else:
|
|
|
+ r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
|
|
|
+ r.raise_for_status()
|
|
|
+ data = r.json()
|
|
|
|
|
|
- new_access = data.get("accessToken")
|
|
|
- new_refresh = data.get("refreshToken", acc.get("refreshToken"))
|
|
|
- now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
- status = "success"
|
|
|
- except httpx.HTTPError as e:
|
|
|
- now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
- status = "failed"
|
|
|
- await conn.execute(
|
|
|
- """
|
|
|
- UPDATE accounts
|
|
|
- SET last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
- WHERE id=?
|
|
|
- """,
|
|
|
- (now, status, now, account_id),
|
|
|
- )
|
|
|
- await conn.commit()
|
|
|
- # 记录刷新失败次数
|
|
|
- await _update_stats(account_id, False)
|
|
|
- raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
|
|
|
- except Exception as e:
|
|
|
- # Ensure last_refresh_time is recorded even on unexpected errors
|
|
|
- now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
- status = "failed"
|
|
|
- await conn.execute(
|
|
|
- """
|
|
|
- UPDATE accounts
|
|
|
- SET last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
- WHERE id=?
|
|
|
- """,
|
|
|
- (now, status, now, account_id),
|
|
|
- )
|
|
|
- await conn.commit()
|
|
|
- # 记录刷新失败次数
|
|
|
- await _update_stats(account_id, False)
|
|
|
- raise
|
|
|
-
|
|
|
- await conn.execute(
|
|
|
+ new_access = data.get("accessToken")
|
|
|
+ new_refresh = data.get("refreshToken", acc.get("refreshToken"))
|
|
|
+ now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
+ status = "success"
|
|
|
+ except httpx.HTTPError as e:
|
|
|
+ now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
+ status = "failed"
|
|
|
+ await _db.execute(
|
|
|
+ """
|
|
|
+ UPDATE accounts
|
|
|
+ SET last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
+ WHERE id=?
|
|
|
+ """,
|
|
|
+ (now, status, now, account_id),
|
|
|
+ )
|
|
|
+ # 记录刷新失败次数
|
|
|
+ await _update_stats(account_id, False)
|
|
|
+ raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
|
|
|
+ except Exception as e:
|
|
|
+ # Ensure last_refresh_time is recorded even on unexpected errors
|
|
|
+ now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
+ status = "failed"
|
|
|
+ await _db.execute(
|
|
|
"""
|
|
|
UPDATE accounts
|
|
|
- SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
+ SET last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
WHERE id=?
|
|
|
""",
|
|
|
- (new_access, new_refresh, now, status, now, account_id),
|
|
|
+ (now, status, now, account_id),
|
|
|
)
|
|
|
- await conn.commit()
|
|
|
+ # 记录刷新失败次数
|
|
|
+ await _update_stats(account_id, False)
|
|
|
+ raise
|
|
|
+
|
|
|
+ await _db.execute(
|
|
|
+ """
|
|
|
+ UPDATE accounts
|
|
|
+ SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
|
|
|
+ WHERE id=?
|
|
|
+ """,
|
|
|
+ (new_access, new_refresh, now, status, now, account_id),
|
|
|
+ )
|
|
|
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
- row2 = await cursor.fetchone()
|
|
|
- return _row_to_dict(row2)
|
|
|
+ row2 = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
|
|
|
+ return _row_to_dict(row2)
|
|
|
|
|
|
async def get_account(account_id: str) -> Dict[str, Any]:
|
|
|
- async with _conn() as conn:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
- row = await cursor.fetchone()
|
|
|
- if not row:
|
|
|
- raise HTTPException(status_code=404, detail="Account not found")
|
|
|
- return _row_to_dict(row)
|
|
|
+ row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
|
|
|
+ if not row:
|
|
|
+ raise HTTPException(status_code=404, detail="Account not found")
|
|
|
+ return _row_to_dict(row)
|
|
|
|
|
|
async def _update_stats(account_id: str, success: bool) -> None:
|
|
|
- async with _conn() as conn:
|
|
|
- if success:
|
|
|
- await conn.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
|
|
|
- (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
- else:
|
|
|
- async with conn.execute("SELECT error_count FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
- row = await cursor.fetchone()
|
|
|
- if row:
|
|
|
- new_count = (row[0] or 0) + 1
|
|
|
- if new_count >= MAX_ERROR_COUNT:
|
|
|
- await conn.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
|
|
|
- (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
- else:
|
|
|
- await conn.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
|
|
|
- (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
- await conn.commit()
|
|
|
+ if success:
|
|
|
+ await _db.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
|
|
|
+ (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
+ else:
|
|
|
+ row = await _db.fetchone("SELECT error_count FROM accounts WHERE id=?", (account_id,))
|
|
|
+ if row:
|
|
|
+ new_count = (row['error_count'] or 0) + 1
|
|
|
+ if new_count >= MAX_ERROR_COUNT:
|
|
|
+ await _db.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
|
|
|
+ (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
+ else:
|
|
|
+ await _db.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
|
|
|
+ (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
# Dependencies
|
|
|
@@ -914,32 +850,28 @@ async def _create_account_from_tokens(
|
|
|
) -> Dict[str, Any]:
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
acc_id = str(uuid.uuid4())
|
|
|
- async with _conn() as conn:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- await conn.execute(
|
|
|
- """
|
|
|
- INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
- """,
|
|
|
- (
|
|
|
- acc_id,
|
|
|
- label,
|
|
|
- client_id,
|
|
|
- client_secret,
|
|
|
- refresh_token,
|
|
|
- access_token,
|
|
|
- None,
|
|
|
- now,
|
|
|
- "success",
|
|
|
- now,
|
|
|
- now,
|
|
|
- 1 if enabled else 0,
|
|
|
- ),
|
|
|
- )
|
|
|
- await conn.commit()
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
|
|
|
- row = await cursor.fetchone()
|
|
|
- return _row_to_dict(row)
|
|
|
+ await _db.execute(
|
|
|
+ """
|
|
|
+ INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
+ """,
|
|
|
+ (
|
|
|
+ acc_id,
|
|
|
+ label,
|
|
|
+ client_id,
|
|
|
+ client_secret,
|
|
|
+ refresh_token,
|
|
|
+ access_token,
|
|
|
+ None,
|
|
|
+ now,
|
|
|
+ "success",
|
|
|
+ now,
|
|
|
+ now,
|
|
|
+ 1 if enabled else 0,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
|
|
|
+ return _row_to_dict(row)
|
|
|
|
|
|
# 管理控制台相关端点 - 仅在启用时注册
|
|
|
if CONSOLE_ENABLED:
|
|
|
@@ -1056,32 +988,28 @@ if CONSOLE_ENABLED:
|
|
|
acc_id = str(uuid.uuid4())
|
|
|
other_str = json.dumps(body.other, ensure_ascii=False) if body.other is not None else None
|
|
|
enabled_val = 1 if (body.enabled is None or body.enabled) else 0
|
|
|
- async with _conn() as conn:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- await conn.execute(
|
|
|
- """
|
|
|
- INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
- """,
|
|
|
- (
|
|
|
- acc_id,
|
|
|
- body.label,
|
|
|
- body.clientId,
|
|
|
- body.clientSecret,
|
|
|
- body.refreshToken,
|
|
|
- body.accessToken,
|
|
|
- other_str,
|
|
|
- None,
|
|
|
- "never",
|
|
|
- now,
|
|
|
- now,
|
|
|
- enabled_val,
|
|
|
- ),
|
|
|
- )
|
|
|
- await conn.commit()
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
|
|
|
- row = await cursor.fetchone()
|
|
|
- return _row_to_dict(row)
|
|
|
+ await _db.execute(
|
|
|
+ """
|
|
|
+ INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
+ """,
|
|
|
+ (
|
|
|
+ acc_id,
|
|
|
+ body.label,
|
|
|
+ body.clientId,
|
|
|
+ body.clientSecret,
|
|
|
+ body.refreshToken,
|
|
|
+ body.accessToken,
|
|
|
+ other_str,
|
|
|
+ None,
|
|
|
+ "never",
|
|
|
+ now,
|
|
|
+ now,
|
|
|
+ enabled_val,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
|
|
|
+ return _row_to_dict(row)
|
|
|
|
|
|
|
|
|
async def _verify_and_enable_accounts(account_ids: List[str]):
|
|
|
@@ -1092,15 +1020,13 @@ if CONSOLE_ENABLED:
|
|
|
account = await get_account(acc_id)
|
|
|
verify_success, fail_reason = await verify_account(account)
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
-
|
|
|
- async with _conn() as conn:
|
|
|
- if verify_success:
|
|
|
- await conn.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, acc_id))
|
|
|
- elif fail_reason:
|
|
|
- other_dict = account.get("other", {})
|
|
|
- other_dict['failedReason'] = fail_reason
|
|
|
- await conn.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, acc_id))
|
|
|
- await conn.commit()
|
|
|
+
|
|
|
+ if verify_success:
|
|
|
+ await _db.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, acc_id))
|
|
|
+ elif fail_reason:
|
|
|
+ other_dict = account.get("other", {}) or {}
|
|
|
+ other_dict['failedReason'] = fail_reason
|
|
|
+ await _db.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, acc_id))
|
|
|
except Exception as e:
|
|
|
print(f"Error verifying account {acc_id}: {e}")
|
|
|
traceback.print_exc()
|
|
|
@@ -1112,36 +1038,34 @@ if CONSOLE_ENABLED:
|
|
|
"""
|
|
|
now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
|
|
|
new_account_ids = []
|
|
|
-
|
|
|
- async with _conn() as conn:
|
|
|
- for i, account_data in enumerate(request.accounts):
|
|
|
- acc_id = str(uuid.uuid4())
|
|
|
- other_dict = account_data.other or {}
|
|
|
- other_dict['source'] = 'feed'
|
|
|
- other_str = json.dumps(other_dict, ensure_ascii=False)
|
|
|
-
|
|
|
- await conn.execute(
|
|
|
- """
|
|
|
- INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
- """,
|
|
|
- (
|
|
|
- acc_id,
|
|
|
- account_data.label or f"批量账号 {i+1}",
|
|
|
- account_data.clientId,
|
|
|
- account_data.clientSecret,
|
|
|
- account_data.refreshToken,
|
|
|
- account_data.accessToken,
|
|
|
- other_str,
|
|
|
- None,
|
|
|
- "never",
|
|
|
- now,
|
|
|
- now,
|
|
|
- 0, # 初始为禁用状态
|
|
|
- ),
|
|
|
- )
|
|
|
- new_account_ids.append(acc_id)
|
|
|
- await conn.commit()
|
|
|
+
|
|
|
+ for i, account_data in enumerate(request.accounts):
|
|
|
+ acc_id = str(uuid.uuid4())
|
|
|
+ other_dict = account_data.other or {}
|
|
|
+ other_dict['source'] = 'feed'
|
|
|
+ other_str = json.dumps(other_dict, ensure_ascii=False)
|
|
|
+
|
|
|
+ await _db.execute(
|
|
|
+ """
|
|
|
+ INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
+ """,
|
|
|
+ (
|
|
|
+ acc_id,
|
|
|
+ account_data.label or f"批量账号 {i+1}",
|
|
|
+ account_data.clientId,
|
|
|
+ account_data.clientSecret,
|
|
|
+ account_data.refreshToken,
|
|
|
+ account_data.accessToken,
|
|
|
+ other_str,
|
|
|
+ None,
|
|
|
+ "never",
|
|
|
+ now,
|
|
|
+ now,
|
|
|
+ 0, # 初始为禁用状态
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ new_account_ids.append(acc_id)
|
|
|
|
|
|
# 启动后台任务进行验证,不阻塞当前请求
|
|
|
if new_account_ids:
|
|
|
@@ -1155,11 +1079,8 @@ if CONSOLE_ENABLED:
|
|
|
|
|
|
@app.get("/v2/accounts")
|
|
|
async def list_accounts():
|
|
|
- async with _conn() as conn:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- async with conn.execute("SELECT * FROM accounts ORDER BY created_at DESC") as cursor:
|
|
|
- rows = await cursor.fetchall()
|
|
|
- return [_row_to_dict(r) for r in rows]
|
|
|
+ rows = await _db.fetchall("SELECT * FROM accounts ORDER BY created_at DESC")
|
|
|
+ return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
@app.get("/v2/accounts/{account_id}")
|
|
|
async def get_account_detail(account_id: str):
|
|
|
@@ -1167,12 +1088,10 @@ if CONSOLE_ENABLED:
|
|
|
|
|
|
@app.delete("/v2/accounts/{account_id}")
|
|
|
async def delete_account(account_id: str):
|
|
|
- async with _conn() as conn:
|
|
|
- cur = await conn.execute("DELETE FROM accounts WHERE id=?", (account_id,))
|
|
|
- await conn.commit()
|
|
|
- if cur.rowcount == 0:
|
|
|
- raise HTTPException(status_code=404, detail="Account not found")
|
|
|
- return {"deleted": account_id}
|
|
|
+ rowcount = await _db.execute("DELETE FROM accounts WHERE id=?", (account_id,))
|
|
|
+ if rowcount == 0:
|
|
|
+ raise HTTPException(status_code=404, detail="Account not found")
|
|
|
+ return {"deleted": account_id}
|
|
|
|
|
|
@app.patch("/v2/accounts/{account_id}")
|
|
|
async def update_account(account_id: str, body: AccountUpdate):
|
|
|
@@ -1201,15 +1120,11 @@ if CONSOLE_ENABLED:
|
|
|
fields.append("updated_at=?"); values.append(now)
|
|
|
values.append(account_id)
|
|
|
|
|
|
- async with _conn() as conn:
|
|
|
- conn.row_factory = aiosqlite.Row
|
|
|
- cur = await conn.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", values)
|
|
|
- await conn.commit()
|
|
|
- if cur.rowcount == 0:
|
|
|
- raise HTTPException(status_code=404, detail="Account not found")
|
|
|
- async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
|
|
|
- row = await cursor.fetchone()
|
|
|
- return _row_to_dict(row)
|
|
|
+ rowcount = await _db.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", tuple(values))
|
|
|
+ if rowcount == 0:
|
|
|
+ raise HTTPException(status_code=404, detail="Account not found")
|
|
|
+ row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
|
|
|
+ return _row_to_dict(row)
|
|
|
|
|
|
@app.post("/v2/accounts/{account_id}/refresh")
|
|
|
async def manual_refresh(account_id: str):
|
|
|
@@ -1287,4 +1202,5 @@ async def startup_event():
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown_event():
|
|
|
- await _close_global_client()
|
|
|
+ await _close_global_client()
|
|
|
+ await close_db()
|