migrate.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """Database migration script for Cognio."""
  2. import sqlite3
  3. import sys
  4. from pathlib import Path
  5. def get_current_version(conn: sqlite3.Connection) -> int:
  6. """Get current database schema version."""
  7. try:
  8. cursor = conn.execute("SELECT version FROM schema_version ORDER BY version DESC LIMIT 1")
  9. result = cursor.fetchone()
  10. return result[0] if result else 0
  11. except sqlite3.OperationalError:
  12. return 0
  13. def init_version_table(conn: sqlite3.Connection) -> None:
  14. """Initialize schema version tracking table."""
  15. conn.execute(
  16. """
  17. CREATE TABLE IF NOT EXISTS schema_version (
  18. version INTEGER PRIMARY KEY,
  19. applied_at INTEGER NOT NULL,
  20. description TEXT
  21. )
  22. """
  23. )
  24. conn.commit()
  25. def migration_001_initial_schema(conn: sqlite3.Connection) -> None:
  26. """Initial database schema."""
  27. print(" Applying migration 001: Initial schema")
  28. conn.execute(
  29. """
  30. CREATE TABLE IF NOT EXISTS memories (
  31. id TEXT PRIMARY KEY,
  32. text TEXT NOT NULL,
  33. text_hash TEXT,
  34. embedding BLOB,
  35. project TEXT,
  36. tags TEXT,
  37. created_at INTEGER,
  38. updated_at INTEGER
  39. )
  40. """
  41. )
  42. conn.execute("CREATE INDEX IF NOT EXISTS idx_project ON memories(project)")
  43. conn.execute("CREATE INDEX IF NOT EXISTS idx_created ON memories(created_at)")
  44. conn.execute("CREATE INDEX IF NOT EXISTS idx_hash ON memories(text_hash)")
  45. conn.execute(
  46. "INSERT INTO schema_version (version, applied_at, description) VALUES (?, ?, ?)",
  47. (1, int(__import__("time").time()), "Initial schema"),
  48. )
  49. conn.commit()
  50. def migration_002_add_archived_flag(conn: sqlite3.Connection) -> None:
  51. """Add archived flag for soft delete."""
  52. print(" Applying migration 002: Add archived flag")
  53. # Check if column exists
  54. cursor = conn.execute("PRAGMA table_info(memories)")
  55. columns = [col[1] for col in cursor.fetchall()]
  56. if "archived" not in columns:
  57. conn.execute("ALTER TABLE memories ADD COLUMN archived INTEGER DEFAULT 0")
  58. conn.execute(
  59. "INSERT INTO schema_version (version, applied_at, description) VALUES (?, ?, ?)",
  60. (2, int(__import__("time").time()), "Add archived flag for soft delete"),
  61. )
  62. conn.commit()
  63. # Migration registry
  64. MIGRATIONS = {
  65. 1: migration_001_initial_schema,
  66. 2: migration_002_add_archived_flag,
  67. }
  68. def run_migrations(db_path: str) -> None:
  69. """Run pending migrations."""
  70. print(f"Running migrations on: {db_path}")
  71. # Ensure database directory exists
  72. Path(db_path).parent.mkdir(parents=True, exist_ok=True)
  73. # Connect to database
  74. conn = sqlite3.connect(db_path)
  75. try:
  76. # Initialize version tracking
  77. init_version_table(conn)
  78. # Get current version
  79. current_version = get_current_version(conn)
  80. print(f"Current schema version: {current_version}")
  81. # Get max version available
  82. max_version = max(MIGRATIONS.keys()) if MIGRATIONS else 0
  83. if current_version >= max_version:
  84. print("Database is up to date!")
  85. return
  86. # Apply pending migrations
  87. print(f"Migrating from version {current_version} to {max_version}")
  88. for version in range(current_version + 1, max_version + 1):
  89. if version in MIGRATIONS:
  90. migration_func = MIGRATIONS[version]
  91. migration_func(conn)
  92. print(f" Migration {version} completed")
  93. else:
  94. print(f" Warning: Migration {version} not found")
  95. print("\nAll migrations completed successfully!")
  96. print(f"Current version: {get_current_version(conn)}")
  97. except Exception as e:
  98. print(f"Error during migration: {e}")
  99. conn.rollback()
  100. sys.exit(1)
  101. finally:
  102. conn.close()
  103. if __name__ == "__main__":
  104. db_path = sys.argv[1] if len(sys.argv) > 1 else "./data/memory.db"
  105. run_migrations(db_path)