Ver Fonte

feat(db): Add PostgreSQL and MySQL support with SQLite fallback

- Create db.py abstraction layer supporting SQLite, PostgreSQL (asyncpg), and MySQL (aiomysql)
- Refactor app.py to use the new database module instead of direct SQLite calls
- Add DATABASE_URL environment variable for database selection
- Update all utility scripts to use the new async database interface
- Update requirements.txt with asyncpg and aiomysql dependencies
- Update README.md with new database configuration documentation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Zhi Yang há 2 meses atrás
pai
commit
1001430702
9 ficheiros alterados com 726 adições e 476 exclusões
  1. 5 0
      .env.example
  2. 11 4
      README.md
  3. 216 300
      app.py
  4. 371 0
      db.py
  5. 3 0
      requirements.txt
  6. 31 63
      scripts/account_stats.py
  7. 27 39
      scripts/delete_disabled_zero_success_accounts.py
  8. 40 36
      scripts/reset_accounts.py
  9. 22 34
      scripts/retry_failed_accounts.py

+ 5 - 0
.env.example

@@ -1,5 +1,10 @@
 # ============ 主服务配置 ============
 
+# 数据库连接URL(留空使用本地SQLite)
+# PostgreSQL: DATABASE_URL="postgres://user:password@host:5432/dbname?sslmode=require"
+# MySQL: DATABASE_URL="mysql://user:password@host:3306/dbname"
+DATABASE_URL=""
+
 # OpenAI 风格 API Key 白名单(仅用于授权,与账号无关)
 # 多个用逗号分隔,例如:OPENAI_KEYS="key1,key2,key3"
 # 留空则为开发模式,不校验 Authorization

+ 11 - 4
README.md

@@ -35,7 +35,7 @@
 ### 网络与安全
 - **HTTP 代理支持** - 可配置代理服务器,支持所有 HTTP 请求
 - **API Key 白名单** - 可选的访问控制,支持开发模式
-- **持久化存储** - SQLite 数据库存储账号信息
+- **持久化存储** - 支持 SQLite(默认)、PostgreSQL、MySQL 数据库
 
 ## 🚀 部署
 
@@ -90,6 +90,11 @@ cp .env.example .env
 **.env 配置说明:**
 
 ```bash
+# 数据库连接URL(留空使用本地SQLite)
+# PostgreSQL: DATABASE_URL="postgres://user:password@host:5432/dbname?sslmode=require"
+# MySQL: DATABASE_URL="mysql://user:password@host:3306/dbname"
+DATABASE_URL=""
+
 # OpenAI 风格 API Key 白名单(仅用于授权,与账号无关)
 # 多个用逗号分隔,例如:OPENAI_KEYS="key1,key2,key3"
 # 留空则为开发模式,不校验 Authorization
@@ -334,6 +339,7 @@ with client.messages.stream(
 ```
 v2/
 ├── app.py                          # FastAPI 主应用
+├── db.py                           # 数据库抽象层 (SQLite/PG/MySQL)
 ├── replicate.py                    # Amazon Q 请求复刻
 ├── auth_flow.py                    # 设备授权登录
 ├── claude_types.py                 # Claude API 类型定义
@@ -359,7 +365,7 @@ v2/
 ## 🛠️ 技术栈
 
 - **后端框架**: FastAPI + Python 3.11+
-- **数据库**: SQLite3 + aiosqlite
+- **数据库**: SQLite3 (aiosqlite) / PostgreSQL (asyncpg) / MySQL (aiomysql)
 - **HTTP 客户端**: httpx(支持异步和代理)
 - **Token 计数**: tiktoken
 - **前端**: 纯 HTML/CSS/JavaScript(无依赖)
@@ -371,6 +377,7 @@ v2/
 
 | 变量 | 说明 | 默认值 | 示例 |
 |------|------|--------|------|
+| `DATABASE_URL` | 数据库连接URL | 空(使用SQLite) | `"postgres://user:pass@host:5432/db"` |
 | `OPENAI_KEYS` | API Key 白名单(逗号分隔) | 空(开发模式) | `"key1,key2"` |
 | `MAX_ERROR_COUNT` | 错误次数阈值 | 100 | `50` |
 | `HTTP_PROXY` | HTTP代理地址 | 空 | `"http://127.0.0.1:7890"` |
@@ -478,8 +485,8 @@ server {
 
 1. **生产环境必须配置 `OPENAI_KEYS`**
 2. **使用 HTTPS 反向代理(Nginx + Let's Encrypt)**
-3. **定期备份 `data.sqlite3` 数据库**
-4. **限制数据库文件权限**(仅应用可读写)
+3. **定期备份数据库**(SQLite: `data.sqlite3`,或 PG/MySQL 数据库)
+4. **限制数据库访问权限**
 5. **配置防火墙规则,限制访问来源**
 
 ## 📄 许可证

+ 216 - 300
app.py

@@ -15,9 +15,10 @@ from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, Fil
 from pydantic import BaseModel
 from dotenv import load_dotenv
 import httpx
-import aiosqlite
 import tiktoken
 
+from db import get_database_backend, init_db, close_db, row_to_dict
+
 # ------------------------------------------------------------------------------
 # Tokenizer
 # ------------------------------------------------------------------------------
@@ -39,7 +40,6 @@ def count_tokens(text: str) -> int:
 # ------------------------------------------------------------------------------
 
 BASE_DIR = Path(__file__).resolve().parent
-DB_PATH = BASE_DIR / "data.sqlite3"
 
 load_dotenv(BASE_DIR / ".env")
 
@@ -170,63 +170,20 @@ async def _close_global_client():
         GLOBAL_CLIENT = None
 
 # ------------------------------------------------------------------------------
-# SQLite helpers
+# Database helpers
 # ------------------------------------------------------------------------------
 
-async def _ensure_db():
-    DB_PATH.parent.mkdir(parents=True, exist_ok=True)
-    async with aiosqlite.connect(DB_PATH) as conn:
-        await conn.execute("PRAGMA journal_mode=WAL;")
-        await conn.execute(
-            """
-            CREATE TABLE IF NOT EXISTS accounts (
-                id TEXT PRIMARY KEY,
-                label TEXT,
-                clientId TEXT,
-                clientSecret TEXT,
-                refreshToken TEXT,
-                accessToken TEXT,
-                other TEXT,
-                last_refresh_time TEXT,
-                last_refresh_status TEXT,
-                created_at TEXT,
-                updated_at TEXT
-            )
-            """
-        )
-        # add columns if missing
-        try:
-            async with conn.execute("PRAGMA table_info(accounts)") as cursor:
-                rows = await cursor.fetchall()
-                cols = [row[1] for row in rows]
-                if "enabled" not in cols:
-                    await conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
-                if "error_count" not in cols:
-                    await conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
-                if "success_count" not in cols:
-                    await conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
-        except Exception:
-            pass
-        await conn.commit()
+# Database backend instance (initialized on startup)
+_db = None
 
-def _conn() -> aiosqlite.Connection:
-    """Create a new database connection. Must be used with async with."""
-    return aiosqlite.connect(DB_PATH)
+async def _ensure_db():
+    """Initialize database backend."""
+    global _db
+    _db = await init_db()
 
-def _row_to_dict(r: aiosqlite.Row) -> Dict[str, Any]:
-    d = dict(r)
-    if d.get("other"):
-        try:
-            d["other"] = json.loads(d["other"])
-        except Exception:
-            pass
-    # normalize enabled to bool
-    if "enabled" in d and d["enabled"] is not None:
-        try:
-            d["enabled"] = bool(int(d["enabled"]))
-        except Exception:
-            d["enabled"] = bool(d["enabled"])
-    return d
+def _row_to_dict(r: Dict[str, Any]) -> Dict[str, Any]:
+    """Convert database row to dict with JSON parsing."""
+    return row_to_dict(r)
 
 # _ensure_db() will be called in startup event
 
@@ -239,31 +196,28 @@ async def _refresh_stale_tokens():
         try:
             await asyncio.sleep(300)  # 5 minutes
             now = time.time()
-            async with _conn() as conn:
-                conn.row_factory = aiosqlite.Row
-                async with conn.execute("SELECT id, last_refresh_time FROM accounts WHERE enabled=1") as cursor:
-                    rows = await cursor.fetchall()
-                    for row in rows:
-                        acc_id, last_refresh = row[0], row[1]
-                        should_refresh = False
-                        if not last_refresh or last_refresh == "never":
+            rows = await _db.fetchall("SELECT id, last_refresh_time FROM accounts WHERE enabled=1")
+            for row in rows:
+                acc_id, last_refresh = row['id'], row['last_refresh_time']
+                should_refresh = False
+                if not last_refresh or last_refresh == "never":
+                    should_refresh = True
+                else:
+                    try:
+                        last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
+                        if now - last_time > 1500:  # 25 minutes
                             should_refresh = True
-                        else:
-                            try:
-                                last_time = time.mktime(time.strptime(last_refresh, "%Y-%m-%dT%H:%M:%S"))
-                                if now - last_time > 1500:  # 25 minutes
-                                    should_refresh = True
-                            except Exception:
-                                # Malformed or unparsable timestamp; force refresh
-                                should_refresh = True
-
-                        if should_refresh:
-                            try:
-                                await refresh_access_token_in_db(acc_id)
-                            except Exception:
-                                traceback.print_exc()
-                                # Ignore per-account refresh failure; timestamp/status are recorded inside
-                                pass
+                    except Exception:
+                        # Malformed or unparsable timestamp; force refresh
+                        should_refresh = True
+
+                if should_refresh:
+                    try:
+                        await refresh_access_token_in_db(acc_id)
+                    except Exception:
+                        traceback.print_exc()
+                        # Ignore per-account refresh failure; timestamp/status are recorded inside
+                        pass
         except Exception:
             traceback.print_exc()
             pass
@@ -301,17 +255,13 @@ def _extract_bearer(token_header: Optional[str]) -> Optional[str]:
         return token_header.split(" ", 1)[1].strip()
     return token_header.strip()
 
-async def _list_enabled_accounts(conn: aiosqlite.Connection) -> List[Dict[str, Any]]:
-    conn.row_factory = aiosqlite.Row
-    async with conn.execute("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC") as cursor:
-        rows = await cursor.fetchall()
-        return [_row_to_dict(r) for r in rows]
+async def _list_enabled_accounts() -> List[Dict[str, Any]]:
+    rows = await _db.fetchall("SELECT * FROM accounts WHERE enabled=1 ORDER BY created_at DESC")
+    return [_row_to_dict(r) for r in rows]
 
-async def _list_disabled_accounts(conn: aiosqlite.Connection) -> List[Dict[str, Any]]:
-    conn.row_factory = aiosqlite.Row
-    async with conn.execute("SELECT * FROM accounts WHERE enabled=0 ORDER BY created_at DESC") as cursor:
-        rows = await cursor.fetchall()
-        return [_row_to_dict(r) for r in rows]
+async def _list_disabled_accounts() -> List[Dict[str, Any]]:
+    rows = await _db.fetchall("SELECT * FROM accounts WHERE enabled=0 ORDER BY created_at DESC")
+    return [_row_to_dict(r) for r in rows]
 
 async def verify_account(account: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
     """验证账号可用性"""
@@ -349,11 +299,10 @@ async def resolve_account_for_key(bearer_key: Optional[str]) -> Dict[str, Any]:
             raise HTTPException(status_code=401, detail="Invalid or missing API key")
 
     # Selection: random among enabled accounts
-    async with _conn() as conn:
-        candidates = await _list_enabled_accounts(conn)
-        if not candidates:
-            raise HTTPException(status_code=401, detail="No enabled account available")
-        return random.choice(candidates)
+    candidates = await _list_enabled_accounts()
+    if not candidates:
+        raise HTTPException(status_code=401, detail="No enabled account available")
+    return random.choice(candidates)
 
 # ------------------------------------------------------------------------------
 # Pydantic Schemas
@@ -406,114 +355,101 @@ def _oidc_headers() -> Dict[str, str]:
     }
 
 async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
-    async with _conn() as conn:
-        conn.row_factory = aiosqlite.Row
-        async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
-            row = await cursor.fetchone()
-            if not row:
-                raise HTTPException(status_code=404, detail="Account not found")
-            acc = _row_to_dict(row)
-
-        if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
-            raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
-
-        payload = {
-            "grantType": "refresh_token",
-            "clientId": acc["clientId"],
-            "clientSecret": acc["clientSecret"],
-            "refreshToken": acc["refreshToken"],
-        }
+    row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
+    if not row:
+        raise HTTPException(status_code=404, detail="Account not found")
+    acc = _row_to_dict(row)
+
+    if not acc.get("clientId") or not acc.get("clientSecret") or not acc.get("refreshToken"):
+        raise HTTPException(status_code=400, detail="Account missing clientId/clientSecret/refreshToken for refresh")
+
+    payload = {
+        "grantType": "refresh_token",
+        "clientId": acc["clientId"],
+        "clientSecret": acc["clientSecret"],
+        "refreshToken": acc["refreshToken"],
+    }
 
-        try:
-            # Use global client if available, else fallback (though global should be ready)
-            client = GLOBAL_CLIENT
-            if not client:
-                # Fallback for safety
-                async with httpx.AsyncClient(timeout=60.0) as temp_client:
-                    r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
-                    r.raise_for_status()
-                    data = r.json()
-            else:
-                r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
+    try:
+        # Use global client if available, else fallback (though global should be ready)
+        client = GLOBAL_CLIENT
+        if not client:
+            # Fallback for safety
+            async with httpx.AsyncClient(timeout=60.0) as temp_client:
+                r = await temp_client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
                 r.raise_for_status()
                 data = r.json()
+        else:
+            r = await client.post(TOKEN_URL, headers=_oidc_headers(), json=payload)
+            r.raise_for_status()
+            data = r.json()
 
-            new_access = data.get("accessToken")
-            new_refresh = data.get("refreshToken", acc.get("refreshToken"))
-            now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-            status = "success"
-        except httpx.HTTPError as e:
-            now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-            status = "failed"
-            await conn.execute(
-                """
-                UPDATE accounts
-                SET last_refresh_time=?, last_refresh_status=?, updated_at=?
-                WHERE id=?
-                """,
-                (now, status, now, account_id),
-            )
-            await conn.commit()
-            # 记录刷新失败次数
-            await _update_stats(account_id, False)
-            raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
-        except Exception as e:
-            # Ensure last_refresh_time is recorded even on unexpected errors
-            now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-            status = "failed"
-            await conn.execute(
-                """
-                UPDATE accounts
-                SET last_refresh_time=?, last_refresh_status=?, updated_at=?
-                WHERE id=?
-                """,
-                (now, status, now, account_id),
-            )
-            await conn.commit()
-            # 记录刷新失败次数
-            await _update_stats(account_id, False)
-            raise
-
-        await conn.execute(
+        new_access = data.get("accessToken")
+        new_refresh = data.get("refreshToken", acc.get("refreshToken"))
+        now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
+        status = "success"
+    except httpx.HTTPError as e:
+        now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
+        status = "failed"
+        await _db.execute(
+            """
+            UPDATE accounts
+            SET last_refresh_time=?, last_refresh_status=?, updated_at=?
+            WHERE id=?
+            """,
+            (now, status, now, account_id),
+        )
+        # 记录刷新失败次数
+        await _update_stats(account_id, False)
+        raise HTTPException(status_code=502, detail=f"Token refresh failed: {str(e)}")
+    except Exception as e:
+        # Ensure last_refresh_time is recorded even on unexpected errors
+        now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
+        status = "failed"
+        await _db.execute(
             """
             UPDATE accounts
-            SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
+            SET last_refresh_time=?, last_refresh_status=?, updated_at=?
             WHERE id=?
             """,
-            (new_access, new_refresh, now, status, now, account_id),
+            (now, status, now, account_id),
         )
-        await conn.commit()
+        # 记录刷新失败次数
+        await _update_stats(account_id, False)
+        raise
+
+    await _db.execute(
+        """
+        UPDATE accounts
+        SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
+        WHERE id=?
+        """,
+        (new_access, new_refresh, now, status, now, account_id),
+    )
 
-        async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
-            row2 = await cursor.fetchone()
-            return _row_to_dict(row2)
+    row2 = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
+    return _row_to_dict(row2)
 
 async def get_account(account_id: str) -> Dict[str, Any]:
-    async with _conn() as conn:
-        conn.row_factory = aiosqlite.Row
-        async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
-            row = await cursor.fetchone()
-            if not row:
-                raise HTTPException(status_code=404, detail="Account not found")
-            return _row_to_dict(row)
+    row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
+    if not row:
+        raise HTTPException(status_code=404, detail="Account not found")
+    return _row_to_dict(row)
 
 async def _update_stats(account_id: str, success: bool) -> None:
-    async with _conn() as conn:
-        if success:
-            await conn.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
-                        (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
-        else:
-            async with conn.execute("SELECT error_count FROM accounts WHERE id=?", (account_id,)) as cursor:
-                row = await cursor.fetchone()
-                if row:
-                    new_count = (row[0] or 0) + 1
-                    if new_count >= MAX_ERROR_COUNT:
-                        await conn.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
-                                   (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
-                    else:
-                        await conn.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
-                                   (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
-        await conn.commit()
+    if success:
+        await _db.execute("UPDATE accounts SET success_count=success_count+1, error_count=0, updated_at=? WHERE id=?",
+                    (time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
+    else:
+        row = await _db.fetchone("SELECT error_count FROM accounts WHERE id=?", (account_id,))
+        if row:
+            new_count = (row['error_count'] or 0) + 1
+            if new_count >= MAX_ERROR_COUNT:
+                await _db.execute("UPDATE accounts SET error_count=?, enabled=0, updated_at=? WHERE id=?",
+                           (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
+            else:
+                await _db.execute("UPDATE accounts SET error_count=?, updated_at=? WHERE id=?",
+                           (new_count, time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), account_id))
 
 # ------------------------------------------------------------------------------
 # Dependencies
@@ -914,32 +850,28 @@ async def _create_account_from_tokens(
 ) -> Dict[str, Any]:
     now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
     acc_id = str(uuid.uuid4())
-    async with _conn() as conn:
-        conn.row_factory = aiosqlite.Row
-        await conn.execute(
-            """
-            INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
-            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
-            """,
-            (
-                acc_id,
-                label,
-                client_id,
-                client_secret,
-                refresh_token,
-                access_token,
-                None,
-                now,
-                "success",
-                now,
-                now,
-                1 if enabled else 0,
-            ),
-        )
-        await conn.commit()
-        async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
-            row = await cursor.fetchone()
-            return _row_to_dict(row)
+    await _db.execute(
+        """
+        INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
+        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+        """,
+        (
+            acc_id,
+            label,
+            client_id,
+            client_secret,
+            refresh_token,
+            access_token,
+            None,
+            now,
+            "success",
+            now,
+            now,
+            1 if enabled else 0,
+        ),
+    )
+    row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
+    return _row_to_dict(row)
 
 # 管理控制台相关端点 - 仅在启用时注册
 if CONSOLE_ENABLED:
@@ -1056,32 +988,28 @@ if CONSOLE_ENABLED:
         acc_id = str(uuid.uuid4())
         other_str = json.dumps(body.other, ensure_ascii=False) if body.other is not None else None
         enabled_val = 1 if (body.enabled is None or body.enabled) else 0
-        async with _conn() as conn:
-            conn.row_factory = aiosqlite.Row
-            await conn.execute(
-                """
-                INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
-                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
-                """,
-                (
-                    acc_id,
-                    body.label,
-                    body.clientId,
-                    body.clientSecret,
-                    body.refreshToken,
-                    body.accessToken,
-                    other_str,
-                    None,
-                    "never",
-                    now,
-                    now,
-                    enabled_val,
-                ),
-            )
-            await conn.commit()
-            async with conn.execute("SELECT * FROM accounts WHERE id=?", (acc_id,)) as cursor:
-                row = await cursor.fetchone()
-                return _row_to_dict(row)
+        await _db.execute(
+            """
+            INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
+            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+            """,
+            (
+                acc_id,
+                body.label,
+                body.clientId,
+                body.clientSecret,
+                body.refreshToken,
+                body.accessToken,
+                other_str,
+                None,
+                "never",
+                now,
+                now,
+                enabled_val,
+            ),
+        )
+        row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
+        return _row_to_dict(row)
 
 
     async def _verify_and_enable_accounts(account_ids: List[str]):
@@ -1092,15 +1020,13 @@ if CONSOLE_ENABLED:
                 account = await get_account(acc_id)
                 verify_success, fail_reason = await verify_account(account)
                 now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-                
-                async with _conn() as conn:
-                    if verify_success:
-                        await conn.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, acc_id))
-                    elif fail_reason:
-                        other_dict = account.get("other", {})
-                        other_dict['failedReason'] = fail_reason
-                        await conn.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, acc_id))
-                    await conn.commit()
+
+                if verify_success:
+                    await _db.execute("UPDATE accounts SET enabled=1, updated_at=? WHERE id=?", (now, acc_id))
+                elif fail_reason:
+                    other_dict = account.get("other", {}) or {}
+                    other_dict['failedReason'] = fail_reason
+                    await _db.execute("UPDATE accounts SET other=?, updated_at=? WHERE id=?", (json.dumps(other_dict, ensure_ascii=False), now, acc_id))
             except Exception as e:
                 print(f"Error verifying account {acc_id}: {e}")
                 traceback.print_exc()
@@ -1112,36 +1038,34 @@ if CONSOLE_ENABLED:
         """
         now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
         new_account_ids = []
-        
-        async with _conn() as conn:
-            for i, account_data in enumerate(request.accounts):
-                acc_id = str(uuid.uuid4())
-                other_dict = account_data.other or {}
-                other_dict['source'] = 'feed'
-                other_str = json.dumps(other_dict, ensure_ascii=False)
-                
-                await conn.execute(
-                    """
-                    INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
-                    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
-                    """,
-                    (
-                        acc_id,
-                        account_data.label or f"批量账号 {i+1}",
-                        account_data.clientId,
-                        account_data.clientSecret,
-                        account_data.refreshToken,
-                        account_data.accessToken,
-                        other_str,
-                        None,
-                        "never",
-                        now,
-                        now,
-                        0,  # 初始为禁用状态
-                    ),
-                )
-                new_account_ids.append(acc_id)
-            await conn.commit()
+
+        for i, account_data in enumerate(request.accounts):
+            acc_id = str(uuid.uuid4())
+            other_dict = account_data.other or {}
+            other_dict['source'] = 'feed'
+            other_str = json.dumps(other_dict, ensure_ascii=False)
+
+            await _db.execute(
+                """
+                INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
+                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+                """,
+                (
+                    acc_id,
+                    account_data.label or f"批量账号 {i+1}",
+                    account_data.clientId,
+                    account_data.clientSecret,
+                    account_data.refreshToken,
+                    account_data.accessToken,
+                    other_str,
+                    None,
+                    "never",
+                    now,
+                    now,
+                    0,  # 初始为禁用状态
+                ),
+            )
+            new_account_ids.append(acc_id)
 
         # 启动后台任务进行验证,不阻塞当前请求
         if new_account_ids:
@@ -1155,11 +1079,8 @@ if CONSOLE_ENABLED:
 
     @app.get("/v2/accounts")
     async def list_accounts():
-        async with _conn() as conn:
-            conn.row_factory = aiosqlite.Row
-            async with conn.execute("SELECT * FROM accounts ORDER BY created_at DESC") as cursor:
-                rows = await cursor.fetchall()
-                return [_row_to_dict(r) for r in rows]
+        rows = await _db.fetchall("SELECT * FROM accounts ORDER BY created_at DESC")
+        return [_row_to_dict(r) for r in rows]
 
     @app.get("/v2/accounts/{account_id}")
     async def get_account_detail(account_id: str):
@@ -1167,12 +1088,10 @@ if CONSOLE_ENABLED:
 
     @app.delete("/v2/accounts/{account_id}")
     async def delete_account(account_id: str):
-        async with _conn() as conn:
-            cur = await conn.execute("DELETE FROM accounts WHERE id=?", (account_id,))
-            await conn.commit()
-            if cur.rowcount == 0:
-                raise HTTPException(status_code=404, detail="Account not found")
-            return {"deleted": account_id}
+        rowcount = await _db.execute("DELETE FROM accounts WHERE id=?", (account_id,))
+        if rowcount == 0:
+            raise HTTPException(status_code=404, detail="Account not found")
+        return {"deleted": account_id}
 
     @app.patch("/v2/accounts/{account_id}")
     async def update_account(account_id: str, body: AccountUpdate):
@@ -1201,15 +1120,11 @@ if CONSOLE_ENABLED:
         fields.append("updated_at=?"); values.append(now)
         values.append(account_id)
 
-        async with _conn() as conn:
-            conn.row_factory = aiosqlite.Row
-            cur = await conn.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", values)
-            await conn.commit()
-            if cur.rowcount == 0:
-                raise HTTPException(status_code=404, detail="Account not found")
-            async with conn.execute("SELECT * FROM accounts WHERE id=?", (account_id,)) as cursor:
-                row = await cursor.fetchone()
-                return _row_to_dict(row)
+        rowcount = await _db.execute(f"UPDATE accounts SET {', '.join(fields)} WHERE id=?", tuple(values))
+        if rowcount == 0:
+            raise HTTPException(status_code=404, detail="Account not found")
+        row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
+        return _row_to_dict(row)
 
     @app.post("/v2/accounts/{account_id}/refresh")
     async def manual_refresh(account_id: str):
@@ -1287,4 +1202,5 @@ async def startup_event():
 
 @app.on_event("shutdown")
 async def shutdown_event():
-    await _close_global_client()
+    await _close_global_client()
+    await close_db()

+ 371 - 0
db.py

@@ -0,0 +1,371 @@
+"""
+Database abstraction layer supporting SQLite, PostgreSQL, and MySQL.
+Backend selection is based on DATABASE_URL environment variable:
+- postgres://... or postgresql://... -> PostgreSQL
+- mysql://... -> MySQL
+- Not set -> SQLite (default)
+"""
+
+import os
+import json
+import time
+import asyncio
+from pathlib import Path
+from typing import Dict, List, Any, Optional, Tuple
+from abc import ABC, abstractmethod
+
+import aiosqlite
+
+# Optional imports for other backends
+try:
+    import asyncpg
+    HAS_ASYNCPG = True
+except ImportError:
+    HAS_ASYNCPG = False
+
+try:
+    import aiomysql
+    HAS_AIOMYSQL = True
+except ImportError:
+    HAS_AIOMYSQL = False
+
+
+class DatabaseBackend(ABC):
+    """Abstract base class for database backends."""
+
+    @abstractmethod
+    async def initialize(self) -> None:
+        """Initialize connection and ensure schema exists."""
+        pass
+
+    @abstractmethod
+    async def close(self) -> None:
+        """Close database connections."""
+        pass
+
+    @abstractmethod
+    async def execute(self, query: str, params: tuple = ()) -> int:
+        """Execute a query and return affected row count."""
+        pass
+
+    @abstractmethod
+    async def fetchone(self, query: str, params: tuple = ()) -> Optional[Dict[str, Any]]:
+        """Fetch a single row as dict."""
+        pass
+
+    @abstractmethod
+    async def fetchall(self, query: str, params: tuple = ()) -> List[Dict[str, Any]]:
+        """Fetch all rows as list of dicts."""
+        pass
+
+
+class SQLiteBackend(DatabaseBackend):
+    """SQLite database backend using aiosqlite."""
+
+    def __init__(self, db_path: Path):
+        self._db_path = db_path
+        self._initialized = False
+
+    async def initialize(self) -> None:
+        self._db_path.parent.mkdir(parents=True, exist_ok=True)
+        async with aiosqlite.connect(self._db_path) as conn:
+            await conn.execute("PRAGMA journal_mode=WAL;")
+            await conn.execute("""
+                CREATE TABLE IF NOT EXISTS accounts (
+                    id TEXT PRIMARY KEY,
+                    label TEXT,
+                    clientId TEXT,
+                    clientSecret TEXT,
+                    refreshToken TEXT,
+                    accessToken TEXT,
+                    other TEXT,
+                    last_refresh_time TEXT,
+                    last_refresh_status TEXT,
+                    created_at TEXT,
+                    updated_at TEXT,
+                    enabled INTEGER DEFAULT 1,
+                    error_count INTEGER DEFAULT 0,
+                    success_count INTEGER DEFAULT 0
+                )
+            """)
+            # Add columns if missing (migrations)
+            try:
+                async with conn.execute("PRAGMA table_info(accounts)") as cursor:
+                    rows = await cursor.fetchall()
+                    cols = [row[1] for row in rows]
+                    if "enabled" not in cols:
+                        await conn.execute("ALTER TABLE accounts ADD COLUMN enabled INTEGER DEFAULT 1")
+                    if "error_count" not in cols:
+                        await conn.execute("ALTER TABLE accounts ADD COLUMN error_count INTEGER DEFAULT 0")
+                    if "success_count" not in cols:
+                        await conn.execute("ALTER TABLE accounts ADD COLUMN success_count INTEGER DEFAULT 0")
+            except Exception:
+                pass
+            await conn.commit()
+        self._initialized = True
+
+    async def close(self) -> None:
+        pass  # SQLite connections are created per-operation
+
+    def _conn(self) -> aiosqlite.Connection:
+        return aiosqlite.connect(self._db_path)
+
+    async def execute(self, query: str, params: tuple = ()) -> int:
+        async with self._conn() as conn:
+            cursor = await conn.execute(query, params)
+            await conn.commit()
+            return cursor.rowcount
+
+    async def fetchone(self, query: str, params: tuple = ()) -> Optional[Dict[str, Any]]:
+        async with self._conn() as conn:
+            conn.row_factory = aiosqlite.Row
+            async with conn.execute(query, params) as cursor:
+                row = await cursor.fetchone()
+                return dict(row) if row else None
+
+    async def fetchall(self, query: str, params: tuple = ()) -> List[Dict[str, Any]]:
+        async with self._conn() as conn:
+            conn.row_factory = aiosqlite.Row
+            async with conn.execute(query, params) as cursor:
+                rows = await cursor.fetchall()
+                return [dict(row) for row in rows]
+
+
+class PostgresBackend(DatabaseBackend):
+    """PostgreSQL database backend using asyncpg."""
+
+    def __init__(self, dsn: str):
+        self._dsn = dsn
+        self._pool: Optional[asyncpg.pool.Pool] = None
+        self._initialized = False
+
+    async def initialize(self) -> None:
+        if not HAS_ASYNCPG:
+            raise ImportError("asyncpg is required for PostgreSQL support. Install with: pip install asyncpg")
+
+        self._pool = await asyncpg.create_pool(dsn=self._dsn, min_size=1, max_size=20)
+
+        async with self._pool.acquire() as conn:
+            await conn.execute("""
+                CREATE TABLE IF NOT EXISTS accounts (
+                    id TEXT PRIMARY KEY,
+                    label TEXT,
+                    "clientId" TEXT,
+                    "clientSecret" TEXT,
+                    "refreshToken" TEXT,
+                    "accessToken" TEXT,
+                    other TEXT,
+                    last_refresh_time TEXT,
+                    last_refresh_status TEXT,
+                    created_at TEXT,
+                    updated_at TEXT,
+                    enabled INTEGER DEFAULT 1,
+                    error_count INTEGER DEFAULT 0,
+                    success_count INTEGER DEFAULT 0
+                )
+            """)
+        self._initialized = True
+
+    async def close(self) -> None:
+        if self._pool:
+            await self._pool.close()
+            self._pool = None
+            self._initialized = False
+
+    def _convert_placeholders(self, query: str) -> str:
+        """Convert ? placeholders to $1, $2, etc."""
+        result = []
+        param_num = 0
+        i = 0
+        while i < len(query):
+            if query[i] == '?':
+                param_num += 1
+                result.append(f'${param_num}')
+            else:
+                result.append(query[i])
+            i += 1
+        return ''.join(result)
+
+    async def execute(self, query: str, params: tuple = ()) -> int:
+        pg_query = self._convert_placeholders(query)
+        async with self._pool.acquire() as conn:
+            result = await conn.execute(pg_query, *params)
+            # asyncpg returns string like "UPDATE 1"
+            try:
+                return int(result.split()[-1])
+            except (ValueError, IndexError):
+                return 0
+
+    async def fetchone(self, query: str, params: tuple = ()) -> Optional[Dict[str, Any]]:
+        pg_query = self._convert_placeholders(query)
+        async with self._pool.acquire() as conn:
+            row = await conn.fetchrow(pg_query, *params)
+            return dict(row) if row else None
+
+    async def fetchall(self, query: str, params: tuple = ()) -> List[Dict[str, Any]]:
+        pg_query = self._convert_placeholders(query)
+        async with self._pool.acquire() as conn:
+            rows = await conn.fetch(pg_query, *params)
+            return [dict(row) for row in rows]
+
+
+class MySQLBackend(DatabaseBackend):
+    """MySQL database backend using aiomysql."""
+
+    def __init__(self, dsn: str):
+        self._dsn = dsn
+        self._pool = None
+        self._initialized = False
+        self._config = self._parse_dsn(dsn)
+
+    def _parse_dsn(self, dsn: str) -> Dict[str, Any]:
+        """Parse MySQL DSN into connection parameters."""
+        # mysql://user:password@host:port/database
+        from urllib.parse import urlparse, parse_qs
+        parsed = urlparse(dsn)
+        config = {
+            'host': parsed.hostname or 'localhost',
+            'port': parsed.port or 3306,
+            'user': parsed.username or 'root',
+            'password': parsed.password or '',
+            'db': parsed.path.lstrip('/') if parsed.path else 'test',
+        }
+        # Handle SSL
+        query = parse_qs(parsed.query)
+        if 'ssl' in query or 'sslmode' in query or 'ssl-mode' in query:
+            config['ssl'] = True
+        return config
+
+    async def initialize(self) -> None:
+        if not HAS_AIOMYSQL:
+            raise ImportError("aiomysql is required for MySQL support. Install with: pip install aiomysql")
+
+        self._pool = await aiomysql.create_pool(
+            host=self._config['host'],
+            port=self._config['port'],
+            user=self._config['user'],
+            password=self._config['password'],
+            db=self._config['db'],
+            minsize=1,
+            maxsize=20,
+            autocommit=True
+        )
+
+        async with self._pool.acquire() as conn:
+            async with conn.cursor() as cur:
+                await cur.execute("""
+                    CREATE TABLE IF NOT EXISTS accounts (
+                        id VARCHAR(255) PRIMARY KEY,
+                        label TEXT,
+                        clientId TEXT,
+                        clientSecret TEXT,
+                        refreshToken TEXT,
+                        accessToken TEXT,
+                        other TEXT,
+                        last_refresh_time TEXT,
+                        last_refresh_status TEXT,
+                        created_at TEXT,
+                        updated_at TEXT,
+                        enabled INT DEFAULT 1,
+                        error_count INT DEFAULT 0,
+                        success_count INT DEFAULT 0
+                    )
+                """)
+        self._initialized = True
+
+    async def close(self) -> None:
+        if self._pool:
+            self._pool.close()
+            await self._pool.wait_closed()
+            self._pool = None
+            self._initialized = False
+
+    def _convert_placeholders(self, query: str) -> str:
+        """Convert ? placeholders to %s for MySQL."""
+        return query.replace('?', '%s')
+
+    async def execute(self, query: str, params: tuple = ()) -> int:
+        mysql_query = self._convert_placeholders(query)
+        async with self._pool.acquire() as conn:
+            async with conn.cursor() as cur:
+                await cur.execute(mysql_query, params)
+                return cur.rowcount
+
+    async def fetchone(self, query: str, params: tuple = ()) -> Optional[Dict[str, Any]]:
+        mysql_query = self._convert_placeholders(query)
+        async with self._pool.acquire() as conn:
+            async with conn.cursor(aiomysql.DictCursor) as cur:
+                await cur.execute(mysql_query, params)
+                return await cur.fetchone()
+
+    async def fetchall(self, query: str, params: tuple = ()) -> List[Dict[str, Any]]:
+        mysql_query = self._convert_placeholders(query)
+        async with self._pool.acquire() as conn:
+            async with conn.cursor(aiomysql.DictCursor) as cur:
+                await cur.execute(mysql_query, params)
+                return await cur.fetchall()
+
+
+# Global database instance
+_db: Optional[DatabaseBackend] = None
+
+
+def get_database_backend() -> DatabaseBackend:
+    """Get the configured database backend based on DATABASE_URL."""
+    global _db
+    if _db is not None:
+        return _db
+
+    database_url = os.getenv('DATABASE_URL', '').strip()
+
+    if database_url.startswith(('postgres://', 'postgresql://')):
+        # Fix common postgres:// to postgresql:// for asyncpg
+        dsn = database_url.replace('postgres://', 'postgresql://', 1) if database_url.startswith('postgres://') else database_url
+        _db = PostgresBackend(dsn)
+        print(f"[DB] Using PostgreSQL backend")
+    elif database_url.startswith('mysql://'):
+        _db = MySQLBackend(database_url)
+        print(f"[DB] Using MySQL backend")
+    else:
+        # Default to SQLite
+        base_dir = Path(__file__).resolve().parent
+        db_path = base_dir / "data.sqlite3"
+        _db = SQLiteBackend(db_path)
+        print(f"[DB] Using SQLite backend: {db_path}")
+
+    return _db
+
+
+async def init_db() -> DatabaseBackend:
+    """Initialize and return the database backend."""
+    db = get_database_backend()
+    await db.initialize()
+    return db
+
+
+async def close_db() -> None:
+    """Close the database backend."""
+    global _db
+    if _db:
+        await _db.close()
+        _db = None
+
+
+# Helper functions for common operations
+def row_to_dict(row: Dict[str, Any]) -> Dict[str, Any]:
+    """Convert a database row to dict with JSON parsing for 'other' field."""
+    if row is None:
+        return None
+    d = dict(row)
+    if d.get("other"):
+        try:
+            d["other"] = json.loads(d["other"])
+        except Exception:
+            pass
+    # normalize enabled to bool
+    if "enabled" in d and d["enabled"] is not None:
+        try:
+            d["enabled"] = bool(int(d["enabled"]))
+        except Exception:
+            d["enabled"] = bool(d["enabled"])
+    return d

+ 3 - 0
requirements.txt

@@ -4,4 +4,7 @@ pydantic==2.9.2
 python-dotenv==1.0.1
 httpx==0.28.1
 aiosqlite==0.20.0
+asyncpg>=0.30.0
+aiomysql>=0.2.0
+cryptography>=41.0.0
 tiktoken

+ 31 - 63
scripts/account_stats.py

@@ -1,73 +1,40 @@
 #!/usr/bin/env python3
-import sqlite3
 import sys
+import asyncio
 from pathlib import Path
-from datetime import datetime
 
-# --- 配置 ---
-BASE_DIR = Path(__file__).resolve().parent.parent
-DB_PATH = BASE_DIR / "data.sqlite3"
-# --- 配置结束 ---
+# Add parent directory to path for imports
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
 
-
-def get_db_connection():
-    """建立并返回数据库连接,如果数据库不存在则退出。"""
-    if not DB_PATH.exists():
-        print(f"错误: 数据库文件未找到: {DB_PATH}", file=sys.stderr)
-        sys.exit(1)
-    try:
-        conn = sqlite3.connect(DB_PATH)
-        conn.row_factory = sqlite3.Row
-        return conn
-    except sqlite3.Error as e:
-        print(f"数据库连接错误: {e}", file=sys.stderr)
-        sys.exit(1)
-
-
-def check_table_and_columns(conn: sqlite3.Connection):
-    """检查 'accounts' 表和必要的列是否存在。"""
-    try:
-        cursor = conn.execute("PRAGMA table_info(accounts)")
-        columns = [row['name'] for row in cursor.fetchall()]
-        if not columns:
-            print("错误: 'accounts' 表不存在。", file=sys.stderr)
-            sys.exit(1)
-
-        required_cols = ['id', 'label', 'enabled', 'success_count', 'error_count', 'last_refresh_status', 'last_refresh_time']
-        missing_cols = [col for col in required_cols if col not in columns]
-        if missing_cols:
-            print(f"错误: 'accounts' 表缺少以下列: {', '.join(missing_cols)}", file=sys.stderr)
-            sys.exit(1)
-    except sqlite3.Error as e:
-        print(f"检查表结构时出错: {e}", file=sys.stderr)
-        sys.exit(1)
+from db import init_db, close_db, row_to_dict
 
 
-def gather_stats():
+async def gather_stats():
     """连接数据库,查询并打印全面的账户统计信息。"""
-    conn = get_db_connection()
-    check_table_and_columns(conn)
+    db = await init_db()
 
     try:
-        accounts = conn.execute("SELECT * FROM accounts ORDER BY created_at DESC").fetchall()
-    except sqlite3.Error as e:
+        accounts = await db.fetchall("SELECT * FROM accounts ORDER BY created_at DESC")
+    except Exception as e:
         print(f"查询账户时出错: {e}", file=sys.stderr)
-        conn.close()
+        await close_db()
         sys.exit(1)
 
+    accounts = [row_to_dict(acc) for acc in accounts]
     total_accounts = len(accounts)
+
     if total_accounts == 0:
         print("数据库中没有找到任何账户。")
-        conn.close()
+        await close_db()
         return
 
     # --- 汇总统计 ---
-    enabled_accounts = [acc for acc in accounts if acc['enabled'] == 1]
-    disabled_accounts = [acc for acc in accounts if acc['enabled'] == 0]
-    refresh_failed_accounts = [acc for acc in accounts if acc['last_refresh_status'] == 'failed']
-    never_used_accounts = [acc for acc in accounts if acc['success_count'] == 0]
-    error_accounts = [acc for acc in accounts if acc['error_count'] > 0]
-    total_success_count = sum(acc['success_count'] for acc in accounts)
+    enabled_accounts = [acc for acc in accounts if acc.get('enabled')]
+    disabled_accounts = [acc for acc in accounts if not acc.get('enabled')]
+    refresh_failed_accounts = [acc for acc in accounts if acc.get('last_refresh_status') == 'failed']
+    never_used_accounts = [acc for acc in accounts if acc.get('success_count', 0) == 0]
+    error_accounts = [acc for acc in accounts if acc.get('error_count', 0) > 0]
+    total_success_count = sum(acc.get('success_count', 0) for acc in accounts)
 
     print("--- 账户统计摘要 ---")
     print(f"总账户数: {total_accounts}")
@@ -90,39 +57,40 @@ def gather_stats():
 
     for acc in accounts:
         # 状态 emoji
-        status_icon = "✅" if acc['enabled'] else "❌"
-        if acc['last_refresh_status'] == 'failed':
+        status_icon = "✅" if acc.get('enabled') else "❌"
+        if acc.get('last_refresh_status') == 'failed':
             status_icon = "⚠️"
 
         # 格式化输出
-        enabled_str = "是" if acc['enabled'] else "否"
-        label = acc['label'] if acc['label'] else "(无)"
-        
+        enabled_str = "是" if acc.get('enabled') else "否"
+        label = acc.get('label') or "(无)"
+
         # 截断过长的标签
         if len(label) > 15:
             label = label[:12] + "..."
 
-        last_refresh_time = acc['last_refresh_time'] if acc['last_refresh_time'] else "从未"
+        last_refresh_time = acc.get('last_refresh_time') or "从未"
+        last_refresh_status = acc.get('last_refresh_status') or "never"
 
         print("| {:<10s} | {:<8s} | {:<15s} | {:<5d} | {:<5d} | {:<12s} | {:<20s} |".format(
             status_icon,
             enabled_str,
             label,
-            acc['success_count'],
-            acc['error_count'],
-            acc['last_refresh_status'],
+            acc.get('success_count', 0),
+            acc.get('error_count', 0),
+            last_refresh_status,
             last_refresh_time
         ))
 
     print("-" * len(header))
-    conn.close()
+    await close_db()
 
 
 def main():
     """脚本主入口"""
-    gather_stats()
+    asyncio.run(gather_stats())
     sys.exit(0)
 
 
 if __name__ == "__main__":
-    main()
+    main()

+ 27 - 39
scripts/delete_disabled_zero_success_accounts.py

@@ -1,58 +1,46 @@
 #!/usr/bin/env python3
-import sqlite3
 import sys
+import asyncio
 from pathlib import Path
 
+# Add parent directory to path for imports
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
 
-BASE_DIR = Path(__file__).resolve().parent.parent
-DB_PATH = BASE_DIR / "data.sqlite3"
+from db import init_db, close_db
 
 
-def delete_disabled_accounts() -> int:
+async def delete_disabled_accounts() -> int:
     """
-    Delete accounts where enabled=0 AND success_count=0 from the SQLite database.
+    Delete accounts where enabled=0 AND success_count=0 from the database.
     Returns the number of rows deleted.
     """
-    if not DB_PATH.exists():
-        print(f"Database not found: {DB_PATH}")
-        return 0
+    db = await init_db()
 
     try:
-        with sqlite3.connect(DB_PATH) as conn:
-            conn.execute("PRAGMA foreign_keys = ON")
-            # Ensure table exists
-            tbl_cur = conn.execute(
-                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='accounts'"
-            )
-            if (tbl_cur.fetchone() or [0])[0] == 0:
-                print("Table 'accounts' not found in database.")
-                return 0
-
-            # Check column 'enabled' existence
-            cols = [row[1] for row in conn.execute("PRAGMA table_info(accounts)").fetchall()]
-            if "enabled" not in cols:
-                print("Column 'enabled' not found in 'accounts' table.")
-                return 0
-            if "success_count" not in cols:
-                print("Column 'success_count' not found in 'accounts' table.")
-                return 0
-
-            # Count first for clear reporting, then delete
-            count = (conn.execute("SELECT COUNT(*) FROM accounts WHERE enabled=0 AND success_count=0").fetchone() or [0])[0]
-            conn.execute("DELETE FROM accounts WHERE enabled=0 AND success_count=0")
-            conn.commit()
-            print(f"Deleted {count} disabled account(s) with zero success count.")
-            return int(count)
-    except sqlite3.Error as e:
-        print(f"SQLite error: {e}", file=sys.stderr)
+        # Count first for clear reporting
+        count_row = await db.fetchone("SELECT COUNT(*) as cnt FROM accounts WHERE enabled=0 AND success_count=0")
+        count = count_row['cnt'] if count_row else 0
+
+        if count > 0:
+            await db.execute("DELETE FROM accounts WHERE enabled=0 AND success_count=0")
+
+        print(f"Deleted {count} disabled account(s) with zero success count.")
+        return int(count)
+    except Exception as e:
+        print(f"Database error: {e}", file=sys.stderr)
         return 0
+    finally:
+        await close_db()
 
 
-def main() -> None:
-    deleted = delete_disabled_accounts()
-    # exit code 0 even if none deleted; non-zero only on sqlite error already handled
+async def main_async() -> None:
+    await delete_disabled_accounts()
     sys.exit(0)
 
 
+def main() -> None:
+    asyncio.run(main_async())
+
+
 if __name__ == "__main__":
-    main()
+    main()

+ 40 - 36
scripts/reset_accounts.py

@@ -3,87 +3,87 @@
 重新启用所有账号的脚本
 将所有账号的 enabled 设置为 1,保留错误和成功次数
 """
-import sqlite3
+import sys
+import asyncio
 from pathlib import Path
 from datetime import datetime
 
-DB_PATH = Path(__file__).resolve().parent.parent / "data.sqlite3"
+# Add parent directory to path for imports
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
 
+from db import init_db, close_db, row_to_dict
 
-def reset_all_accounts():
+
+async def reset_all_accounts():
     """重新启用所有账号(不重置错误和成功次数)"""
-    if not DB_PATH.exists():
-        print(f"错误: 数据库文件不存在: {DB_PATH}")
-        return
-    
-    conn = sqlite3.connect(DB_PATH)
-    cursor = conn.cursor()
-    
+    db = await init_db()
+
     try:
         # 获取当前禁用的账号数量
-        cursor.execute("SELECT COUNT(*) FROM accounts WHERE enabled=0")
-        disabled_count = cursor.fetchone()[0]
-        
+        disabled_row = await db.fetchone("SELECT COUNT(*) as cnt FROM accounts WHERE enabled=0")
+        disabled_count = disabled_row['cnt'] if disabled_row else 0
+
         # 获取总账号数量
-        cursor.execute("SELECT COUNT(*) FROM accounts")
-        total_count = cursor.fetchone()[0]
-        
+        total_row = await db.fetchone("SELECT COUNT(*) as cnt FROM accounts")
+        total_count = total_row['cnt'] if total_row else 0
+
         print(f"数据库中共有 {total_count} 个账号")
         print(f"其中 {disabled_count} 个账号已被禁用")
-        
+
         if disabled_count == 0:
             print("所有账号都已启用,无需操作")
+            await close_db()
             return
-        
+
         # 只重新启用账号,不重置错误和成功次数
         now = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")
-        cursor.execute("""
+        await db.execute("""
             UPDATE accounts
             SET enabled=1, updated_at=?
             WHERE enabled=0
         """, (now,))
-        conn.commit()
-        
+
         print(f"✓ 已重新启用 {disabled_count} 个账号")
         print("✓ 保留了所有账号的错误和成功次数")
-        
+
         # 显示更新后的状态
-        cursor.execute("""
+        rows = await db.fetchall("""
             SELECT id, label, enabled, error_count, success_count
             FROM accounts
             ORDER BY created_at DESC
         """)
-        rows = cursor.fetchall()
-        
+
         print("\n当前账号状态:")
         print("-" * 80)
         print(f"{'ID':<38} {'标签':<20} {'启用':<6} {'错误':<6} {'成功':<6}")
         print("-" * 80)
         for row in rows:
-            acc_id, label, enabled, error_count, success_count = row
-            label = label or "(无标签)"
+            acc_id = row['id']
+            label = row['label'] or "(无标签)"
+            enabled = row['enabled']
+            error_count = row['error_count'] or 0
+            success_count = row['success_count'] or 0
             enabled_str = "是" if enabled else "否"
-            print(f"{acc_id:<38} {label:<20} {enabled_str:<6} {error_count or 0:<6} {success_count or 0:<6}")
-    
+            print(f"{acc_id:<38} {label:<20} {enabled_str:<6} {error_count:<6} {success_count:<6}")
+
     finally:
-        cursor.close()
-        conn.close()
+        await close_db()
 
 
-def main():
+async def main_async():
     print("=" * 80)
     print("重新启用所有账号")
     print("=" * 80)
     print()
-    
+
     try:
-        reset_all_accounts()
+        await reset_all_accounts()
     except Exception as e:
         print(f"错误: {e}")
         import traceback
         traceback.print_exc()
         return 1
-    
+
     print()
     print("=" * 80)
     print("操作完成")
@@ -91,5 +91,9 @@ def main():
     return 0
 
 
+def main():
+    return asyncio.run(main_async())
+
+
 if __name__ == "__main__":
-    exit(main())
+    exit(main())

+ 22 - 34
scripts/retry_failed_accounts.py

@@ -3,20 +3,21 @@ import os
 import sys
 import time
 import uuid
-import json
 import asyncio
 import traceback
 from pathlib import Path
-from typing import Dict, Any, Optional, List
+from typing import Dict, Any, Optional
 
 import httpx
-import aiosqlite
 from dotenv import load_dotenv
 
+# Add parent directory to path for imports
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
+
+from db import init_db, close_db, row_to_dict
+
 # --- 配置 ---
-# 脚本将自动查找项目根目录下的 .env 文件并加载
 BASE_DIR = Path(__file__).resolve().parent.parent
-DB_PATH = BASE_DIR / "data.sqlite3"
 load_dotenv(BASE_DIR / ".env")
 # --- 配置结束 ---
 
@@ -44,7 +45,7 @@ def _oidc_headers() -> Dict[str, str]:
     }
 
 async def refresh_single_account_token(
-    conn: aiosqlite.Connection,
+    db,
     account: Dict[str, Any],
     client: httpx.AsyncClient
 ) -> bool:
@@ -77,7 +78,7 @@ async def refresh_single_account_token(
         now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
 
         # 刷新成功: 启用账号,重置错误计数,更新token
-        await conn.execute(
+        await db.execute(
             """
             UPDATE accounts
             SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?,
@@ -86,7 +87,6 @@ async def refresh_single_account_token(
             """,
             (new_access, new_refresh, now, "success", now, account_id),
         )
-        await conn.commit()
         print(f"  [✅] 账号 {label} 刷新成功并已重新启用。")
         return True
 
@@ -97,22 +97,20 @@ async def refresh_single_account_token(
             error_detail = e.response.json().get("error_description", str(e))
         except Exception:
             pass
-        
+
         now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-        await conn.execute(
+        await db.execute(
             "UPDATE accounts SET last_refresh_time=?, last_refresh_status=?, updated_at=? WHERE id=?",
             (now, "failed", now, account_id),
         )
-        await conn.commit()
         print(f"  [❌] 账号 {label} 刷新失败: {error_detail}")
         return False
     except Exception as e:
         now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
-        await conn.execute(
+        await db.execute(
             "UPDATE accounts SET last_refresh_time=?, last_refresh_status=?, updated_at=? WHERE id=?",
             (now, "failed", now, account_id),
         )
-        await conn.commit()
         print(f"  [❌] 账号 {label} 发生未知错误: {e}")
         traceback.print_exc()
         return False
@@ -120,26 +118,17 @@ async def refresh_single_account_token(
 
 async def main():
     """脚本主逻辑。"""
-    if not DB_PATH.exists():
-        print(f"错误: 数据库文件未找到: {DB_PATH}", file=sys.stderr)
-        sys.exit(1)
-
-    try:
-        conn = await aiosqlite.connect(DB_PATH)
-        conn.row_factory = aiosqlite.Row
-    except aiosqlite.Error as e:
-        print(f"数据库连接错误: {e}", file=sys.stderr)
-        sys.exit(1)
+    db = await init_db()
 
     # 查找目标账号
-    cursor = await conn.execute(
+    accounts_to_retry = await db.fetchall(
         "SELECT * FROM accounts WHERE enabled = 0 AND last_refresh_status = 'failed'"
     )
-    accounts_to_retry = await cursor.fetchall()
+    accounts_to_retry = [row_to_dict(acc) for acc in accounts_to_retry]
 
     if not accounts_to_retry:
         print("没有找到因刷新失败而被禁用的账号。")
-        await conn.close()
+        await close_db()
         return
 
     print(f"找到 {len(accounts_to_retry)} 个需要重试的账号...")
@@ -157,24 +146,23 @@ async def main():
                 "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
                 "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
             }
-    
+
     async with httpx.AsyncClient(mounts=mounts, timeout=60.0) as client:
-        for i, acc_row in enumerate(accounts_to_retry):
-            account = dict(acc_row)
+        for i, account in enumerate(accounts_to_retry):
             label = account.get("label") or account.get("id", "未知ID")[:8]
             print(f"\n--- ({i+1}/{len(accounts_to_retry)}) 正在处理账号: {label} ---")
-            
-            is_success = await refresh_single_account_token(conn, account, client)
+
+            is_success = await refresh_single_account_token(db, account, client)
             if is_success:
                 success_count += 1
             else:
                 failure_count += 1
-            
+
             # 在账号之间添加短暂延迟,避免请求过于集中
             if i < len(accounts_to_retry) - 1:
                 await asyncio.sleep(1)
 
-    await conn.close()
+    await close_db()
 
     print("\n--- 操作完成 ---")
     print(f"成功启用: {success_count} 个账号")
@@ -186,4 +174,4 @@ if __name__ == "__main__":
         asyncio.run(main())
     except KeyboardInterrupt:
         print("\n操作被用户中断。")
-        sys.exit(1)
+        sys.exit(1)