Forráskód Böngészése

refactor: centralize database schema and auto-migration logic

- Define ACCOUNTS_COLUMNS as single source of truth for schema
- Auto-detect and add missing columns on startup for all backends
- Support SQLite, PostgreSQL, and MySQL migrations
- Log migration actions for debugging
CassiopeiaCode 1 hónapja
szülő
commit
ebb402b89e
1 módosított fájl, 120 hozzáadás és 74 törlés
  1. 120 74
      db.py

+ 120 - 74
db.py

@@ -11,11 +11,34 @@ import json
 import time
 import asyncio
 from pathlib import Path
-from typing import Dict, List, Any, Optional, Tuple
+from typing import Dict, List, Any, Optional, Tuple, Set
 from abc import ABC, abstractmethod
 
 import aiosqlite
 
+# Schema version for migrations
+SCHEMA_VERSION = 1
+
+# Define all columns that should exist in the accounts table
+# Format: (column_name, column_type_sqlite, column_type_postgres, column_type_mysql, default_value)
+ACCOUNTS_COLUMNS = [
+    ("id", "TEXT PRIMARY KEY", "TEXT PRIMARY KEY", "VARCHAR(255) PRIMARY KEY", None),
+    ("label", "TEXT", "TEXT", "TEXT", None),
+    ("clientId", "TEXT", "TEXT", "TEXT", None),
+    ("clientSecret", "TEXT", "TEXT", "TEXT", None),
+    ("refreshToken", "TEXT", "TEXT", "TEXT", None),
+    ("accessToken", "TEXT", "TEXT", "TEXT", None),
+    ("other", "TEXT", "TEXT", "TEXT", None),
+    ("last_refresh_time", "TEXT", "TEXT", "TEXT", None),
+    ("last_refresh_status", "TEXT", "TEXT", "TEXT", None),
+    ("created_at", "TEXT", "TEXT", "TEXT", None),
+    ("updated_at", "TEXT", "TEXT", "TEXT", None),
+    ("enabled", "INTEGER DEFAULT 1", "INTEGER DEFAULT 1", "INT DEFAULT 1", "1"),
+    ("error_count", "INTEGER DEFAULT 0", "INTEGER DEFAULT 0", "INT DEFAULT 0", "0"),
+    ("success_count", "INTEGER DEFAULT 0", "INTEGER DEFAULT 0", "INT DEFAULT 0", "0"),
+    ("expires_at", "TEXT", "TEXT", "TEXT", None),
+]
+
 # Optional imports for other backends
 try:
     import asyncpg
@@ -67,6 +90,31 @@ class SQLiteBackend(DatabaseBackend):
         self._initialized = False
         self._conn: Optional[aiosqlite.Connection] = None
 
+    async def _get_existing_columns(self) -> Set[str]:
+        """Get existing column names from accounts table."""
+        try:
+            async with self._conn.execute("PRAGMA table_info(accounts)") as cursor:
+                rows = await cursor.fetchall()
+                return {row[1] for row in rows}
+        except Exception:
+            return set()
+
+    async def _migrate_schema(self) -> None:
+        """Add missing columns to accounts table."""
+        existing_cols = await self._get_existing_columns()
+        if not existing_cols:
+            return  # Table doesn't exist yet, will be created fresh
+        
+        for col_name, col_type, _, _, _ in ACCOUNTS_COLUMNS:
+            if col_name not in existing_cols and "PRIMARY KEY" not in col_type:
+                # Extract just the type without DEFAULT clause for ALTER TABLE
+                base_type = col_type.split(" DEFAULT")[0].strip()
+                try:
+                    await self._conn.execute(f"ALTER TABLE accounts ADD COLUMN {col_name} {base_type}")
+                    print(f"[DB Migration] Added column: {col_name}")
+                except Exception as e:
+                    print(f"[DB Migration] Failed to add column {col_name}: {e}")
+
     async def initialize(self) -> None:
         if self._initialized:
             return
@@ -79,40 +127,20 @@ class SQLiteBackend(DatabaseBackend):
         await self._conn.execute("PRAGMA cache_size = -65536; -- 64MB")
         await self._conn.execute("PRAGMA temp_store = MEMORY;")
         
-        await self._conn.execute("""
-            CREATE TABLE IF NOT EXISTS accounts (
-                id TEXT PRIMARY KEY,
-                label TEXT,
-                clientId TEXT,
-                clientSecret TEXT,
-                refreshToken TEXT,
-                accessToken TEXT,
-                other TEXT,
-                last_refresh_time TEXT,
-                last_refresh_status TEXT,
-                created_at TEXT,
-                updated_at TEXT,
-                enabled INTEGER DEFAULT 1,
-                error_count INTEGER DEFAULT 0,
-                success_count INTEGER DEFAULT 0,
-                expires_at TEXT
-            )
+        # Build CREATE TABLE statement from schema definition
+        columns_sql = ", ".join([f"{col[0]} {col[1]}" for col in ACCOUNTS_COLUMNS])
+        await self._conn.execute(f"""
+            CREATE TABLE IF NOT EXISTS accounts ({columns_sql})
         """)
         
+        # Run migrations for existing tables
+        await self._migrate_schema()
+        
         # Create indexes for performance
         await self._conn.execute("CREATE INDEX IF NOT EXISTS idx_accounts_enabled ON accounts (enabled);")
         await self._conn.execute("CREATE INDEX IF NOT EXISTS idx_accounts_created_at ON accounts (created_at);")
         await self._conn.execute("CREATE INDEX IF NOT EXISTS idx_accounts_success_count ON accounts (success_count);")
 
-        # Add expires_at column if missing (migration)
-        try:
-            async with self._conn.execute("PRAGMA table_info(accounts)") as cursor:
-                cols = {row[1] for row in await cursor.fetchall()}
-                if "expires_at" not in cols:
-                    await self._conn.execute("ALTER TABLE accounts ADD COLUMN expires_at TEXT")
-        except Exception:
-            pass
-
         await self._conn.commit()
         self._initialized = True
 
@@ -148,6 +176,32 @@ class PostgresBackend(DatabaseBackend):
         self._pool: "Optional[asyncpg.pool.Pool]" = None
         self._initialized = False
 
+    async def _get_existing_columns(self, conn) -> Set[str]:
+        """Get existing column names from accounts table."""
+        try:
+            rows = await conn.fetch("""
+                SELECT column_name FROM information_schema.columns 
+                WHERE table_name = 'accounts'
+            """)
+            return {row['column_name'] for row in rows}
+        except Exception:
+            return set()
+
+    async def _migrate_schema(self, conn) -> None:
+        """Add missing columns to accounts table."""
+        existing_cols = await self._get_existing_columns(conn)
+        if not existing_cols:
+            return  # Table doesn't exist yet
+        
+        for col_name, _, col_type, _, _ in ACCOUNTS_COLUMNS:
+            if col_name not in existing_cols and "PRIMARY KEY" not in col_type:
+                base_type = col_type.split(" DEFAULT")[0].strip()
+                try:
+                    await conn.execute(f"ALTER TABLE accounts ADD COLUMN IF NOT EXISTS {col_name} {base_type}")
+                    print(f"[DB Migration] Added column: {col_name}")
+                except Exception as e:
+                    print(f"[DB Migration] Failed to add column {col_name}: {e}")
+
     async def initialize(self) -> None:
         if not HAS_ASYNCPG:
             raise ImportError("asyncpg is required for PostgreSQL support. Install with: pip install asyncpg")
@@ -155,30 +209,13 @@ class PostgresBackend(DatabaseBackend):
         self._pool = await asyncpg.create_pool(dsn=self._dsn, min_size=1, max_size=20)
 
         async with self._pool.acquire() as conn:
-            await conn.execute("""
-                CREATE TABLE IF NOT EXISTS accounts (
-                    id TEXT PRIMARY KEY,
-                    label TEXT,
-                    clientId TEXT,
-                    clientSecret TEXT,
-                    refreshToken TEXT,
-                    accessToken TEXT,
-                    other TEXT,
-                    last_refresh_time TEXT,
-                    last_refresh_status TEXT,
-                    created_at TEXT,
-                    updated_at TEXT,
-                    enabled INTEGER DEFAULT 1,
-                    error_count INTEGER DEFAULT 0,
-                    success_count INTEGER DEFAULT 0,
-                    expires_at TEXT
-                )
+            # Build CREATE TABLE statement from schema definition
+            columns_sql = ", ".join([f"{col[0]} {col[2]}" for col in ACCOUNTS_COLUMNS])
+            await conn.execute(f"""
+                CREATE TABLE IF NOT EXISTS accounts ({columns_sql})
             """)
-            # Add column if missing (migration)
-            try:
-                await conn.execute("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS expires_at TEXT")
-            except Exception:
-                pass
+            # Run migrations
+            await self._migrate_schema(conn)
         self._initialized = True
 
     async def close(self) -> None:
@@ -251,6 +288,32 @@ class MySQLBackend(DatabaseBackend):
             config['ssl'] = True
         return config
 
+    async def _get_existing_columns(self, cur) -> Set[str]:
+        """Get existing column names from accounts table."""
+        try:
+            await cur.execute(f"DESCRIBE accounts")
+            rows = await cur.fetchall()
+            return {row[0] if isinstance(row, tuple) else row['Field'] for row in rows}
+        except Exception:
+            return set()
+
+    async def _migrate_schema(self, cur) -> None:
+        """Add missing columns to accounts table."""
+        existing_cols = await self._get_existing_columns(cur)
+        if not existing_cols:
+            return  # Table doesn't exist yet
+        
+        for col_name, _, _, col_type, _ in ACCOUNTS_COLUMNS:
+            if col_name not in existing_cols and "PRIMARY KEY" not in col_type:
+                base_type = col_type.split(" DEFAULT")[0].strip()
+                try:
+                    await cur.execute(f"ALTER TABLE accounts ADD COLUMN {col_name} {base_type}")
+                    print(f"[DB Migration] Added column: {col_name}")
+                except Exception as e:
+                    # Column might already exist
+                    if "Duplicate column" not in str(e):
+                        print(f"[DB Migration] Failed to add column {col_name}: {e}")
+
     async def initialize(self) -> None:
         if not HAS_AIOMYSQL:
             raise ImportError("aiomysql is required for MySQL support. Install with: pip install aiomysql")
@@ -268,30 +331,13 @@ class MySQLBackend(DatabaseBackend):
 
         async with self._pool.acquire() as conn:
             async with conn.cursor() as cur:
-                await cur.execute("""
-                    CREATE TABLE IF NOT EXISTS accounts (
-                        id VARCHAR(255) PRIMARY KEY,
-                        label TEXT,
-                        clientId TEXT,
-                        clientSecret TEXT,
-                        refreshToken TEXT,
-                        accessToken TEXT,
-                        other TEXT,
-                        last_refresh_time TEXT,
-                        last_refresh_status TEXT,
-                        created_at TEXT,
-                        updated_at TEXT,
-                        enabled INT DEFAULT 1,
-                        error_count INT DEFAULT 0,
-                        success_count INT DEFAULT 0,
-                        expires_at TEXT
-                    )
+                # Build CREATE TABLE statement from schema definition
+                columns_sql = ", ".join([f"{col[0]} {col[3]}" for col in ACCOUNTS_COLUMNS])
+                await cur.execute(f"""
+                    CREATE TABLE IF NOT EXISTS accounts ({columns_sql})
                 """)
-                # Add column if missing (migration)
-                try:
-                    await cur.execute("ALTER TABLE accounts ADD COLUMN expires_at TEXT")
-                except Exception:
-                    pass  # Column already exists
+                # Run migrations
+                await self._migrate_schema(cur)
         self._initialized = True
 
     async def close(self) -> None: