Jelajahi Sumber

fix(db): 修复多数据库后端的健壮性和一致性问题

根据代码审查的建议,本次提交修复了 PostgreSQL 后端的一个严重 Bug,并对数据库抽象层、应用和管理脚本进行了一系列健壮性优化。

主要变更如下:

- fix(db): 修正了 PostgreSQL 的建表语句。移除了 clientId 等列名周围的双引号,避免了因大小写敏感导致的 column does not exist 错误。

- refactor(db): 改进了类型注解。将 asyncpg.pool.Pool 的类型提示改为字符串前向引用,以防止在未安装 asyncpg 依赖时启动失败。同时修正了 row_to_dict 的返回类型注解为 Optional。

- refactor(app): 增强了后台任务的健壮性。在 _refresh_stale_tokens 循环任务中添加了对数据库连接 _db 是否已初始化的检查。

- refactor(scripts): 统一了脚本的环境配置。为所有管理脚本添加了 load_dotenv 逻辑,确保它们能和主应用一样正确加载 .env 文件中的 DATABASE_URL。

- chore: 清理了代码中未使用的导入,包括 app.py 和 scripts/reset_accounts.py。
CassiopeiaCode 2 bulan lalu
induk
melakukan
3fbf751a03

+ 4 - 1
app.py

@@ -17,7 +17,7 @@ from dotenv import load_dotenv
 import httpx
 import tiktoken
 
-from db import get_database_backend, init_db, close_db, row_to_dict
+from db import init_db, close_db, row_to_dict
 
 # ------------------------------------------------------------------------------
 # Tokenizer
@@ -195,6 +195,9 @@ async def _refresh_stale_tokens():
     while True:
         try:
             await asyncio.sleep(300)  # 5 minutes
+            if _db is None:
+                print("[Error] Database not initialized, skipping token refresh cycle.")
+                continue
             now = time.time()
             rows = await _db.fetchall("SELECT id, last_refresh_time FROM accounts WHERE enabled=1")
             for row in rows:

+ 6 - 6
db.py

@@ -136,7 +136,7 @@ class PostgresBackend(DatabaseBackend):
 
     def __init__(self, dsn: str):
         self._dsn = dsn
-        self._pool: Optional[asyncpg.pool.Pool] = None
+        self._pool: "Optional[asyncpg.pool.Pool]" = None
         self._initialized = False
 
     async def initialize(self) -> None:
@@ -150,10 +150,10 @@ class PostgresBackend(DatabaseBackend):
                 CREATE TABLE IF NOT EXISTS accounts (
                     id TEXT PRIMARY KEY,
                     label TEXT,
-                    "clientId" TEXT,
-                    "clientSecret" TEXT,
-                    "refreshToken" TEXT,
-                    "accessToken" TEXT,
+                    clientId TEXT,
+                    clientSecret TEXT,
+                    refreshToken TEXT,
+                    accessToken TEXT,
                     other TEXT,
                     last_refresh_time TEXT,
                     last_refresh_status TEXT,
@@ -352,7 +352,7 @@ async def close_db() -> None:
 
 
 # Helper functions for common operations
-def row_to_dict(row: Dict[str, Any]) -> Dict[str, Any]:
+def row_to_dict(row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
     """Convert a database row to dict with JSON parsing for 'other' field."""
     if row is None:
         return None

+ 5 - 0
scripts/account_stats.py

@@ -2,6 +2,11 @@
 import sys
 import asyncio
 from pathlib import Path
+from dotenv import load_dotenv
+
+# Load .env file from parent directory
+BASE_DIR = Path(__file__).resolve().parent.parent
+load_dotenv(BASE_DIR / ".env")
 
 # Add parent directory to path for imports
 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

+ 5 - 0
scripts/delete_disabled_zero_success_accounts.py

@@ -2,6 +2,11 @@
 import sys
 import asyncio
 from pathlib import Path
+from dotenv import load_dotenv
+
+# Load .env file from parent directory
+BASE_DIR = Path(__file__).resolve().parent.parent
+load_dotenv(BASE_DIR / ".env")
 
 # Add parent directory to path for imports
 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

+ 6 - 1
scripts/reset_accounts.py

@@ -7,11 +7,16 @@ import sys
 import asyncio
 from pathlib import Path
 from datetime import datetime
+from dotenv import load_dotenv
+
+# Load .env file from parent directory
+BASE_DIR = Path(__file__).resolve().parent.parent
+load_dotenv(BASE_DIR / ".env")
 
 # Add parent directory to path for imports
 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
 
-from db import init_db, close_db, row_to_dict
+from db import init_db, close_db
 
 
 async def reset_all_accounts():