Просмотр исходного кода

feat: Add Engram hashed N-gram retrieval index

- Add engram.py with tokenizer compression, multi-head hashing, and bucket generation
- Add SQLite engram_index table with bucket/memory_id/hits schema
- Wire Engram updates into save/delete/bulk-delete paths
- Integrate Engram candidate retrieval in hybrid and semantic-only search
- Add backfill on startup and minimal tests
- Configurable via engram_enabled, ngram_sizes, heads, buckets, limits
0xReLogic 4 дней назад
Родитель
Сommit
c056e78bd0
6 измененных файлов с 351 добавлено и 14 удалено
  1. 9 0
      src/config.py
  2. 171 7
      src/database.py
  3. 96 0
      src/engram.py
  4. 1 0
      src/main.py
  5. 54 7
      src/memory.py
  6. 20 0
      tests/test_engram.py

+ 9 - 0
src/config.py

@@ -29,6 +29,15 @@ class Settings(BaseSettings):
     hybrid_mode: str = "candidate"  # candidate | rerank
     hybrid_rerank_topk: int = 100
 
+    # Engram O(1) retrieval (hashed N-gram index)
+    engram_enabled: bool = True
+    engram_ngram_sizes: str = "2,3"
+    engram_num_heads: int = 4
+    engram_num_buckets: int = 1000003
+    engram_candidate_limit: int = 200
+    engram_min_hits: int = 2
+    engram_query_bucket_limit: int = 500
+
     # Performance
     max_text_length: int = 10000
     batch_size: int = 32

+ 171 - 7
src/database.py

@@ -7,6 +7,7 @@ from pathlib import Path
 from typing import Any
 
 from .config import settings
+from .engram import engram_index
 from .models import Memory
 from .utils import get_timestamp
 
@@ -132,6 +133,27 @@ class Database:
             logger.warning(f"FTS5 not available or initialization failed: {e}")
             self.fts_ready = False
 
+        # Engram hashed n-gram index (best-effort)
+        try:
+            cursor.execute(
+                """
+                CREATE TABLE IF NOT EXISTS engram_index (
+                    bucket INTEGER NOT NULL,
+                    memory_id TEXT NOT NULL,
+                    hits INTEGER NOT NULL,
+                    PRIMARY KEY (bucket, memory_id)
+                )
+                """
+            )
+            cursor.execute(
+                "CREATE INDEX IF NOT EXISTS idx_engram_bucket ON engram_index(bucket)"
+            )
+            cursor.execute(
+                "CREATE INDEX IF NOT EXISTS idx_engram_memory ON engram_index(memory_id)"
+            )
+        except sqlite3.OperationalError as e:
+            logger.warning(f"Engram index init failed: {e}")
+
         self.conn.commit()
         logger.info("Database schema initialized")
 
@@ -139,6 +161,120 @@ class Database:
         """Return whether FTS is ready for use."""
         return self.fts_ready
 
+    def backfill_engram_index(self) -> None:
+        """Backfill Engram index for existing memories if empty."""
+        if not settings.engram_enabled:
+            return
+        if self.conn is None:
+            raise RuntimeError(_DB_NOT_CONNECTED_ERROR)
+
+        cursor = self.conn.cursor()
+        try:
+            cursor.execute("SELECT COUNT(*) AS count FROM engram_index")
+            count = cursor.fetchone()["count"]
+        except sqlite3.OperationalError as e:
+            logger.warning(f"Engram backfill skipped: {e}")
+            return
+
+        if count and count > 0:
+            return
+
+        logger.info("Engram backfill: rebuilding index for existing memories...")
+        cursor.execute("SELECT id, text FROM memories WHERE archived = 0")
+        rows = cursor.fetchall()
+        for row in rows:
+            memory_id = row["id"]
+            text = row["text"] or ""
+            bucket_counts = engram_index.bucket_counts(text)
+            for bucket, hits in bucket_counts.items():
+                cursor.execute(
+                    """
+                    INSERT INTO engram_index (bucket, memory_id, hits)
+                    VALUES (?, ?, ?)
+                    ON CONFLICT(bucket, memory_id) DO UPDATE SET hits = excluded.hits
+                    """,
+                    (int(bucket), memory_id, int(hits)),
+                )
+        self.conn.commit()
+        logger.info("Engram backfill: complete (entries=%s)", len(rows))
+
+    def upsert_engram_index(self, memory_id: str, text: str) -> None:
+        """Upsert Engram hashed n-gram buckets for a memory."""
+        if not settings.engram_enabled or not text:
+            return
+        if self.conn is None:
+            raise RuntimeError(_DB_NOT_CONNECTED_ERROR)
+
+        bucket_counts = engram_index.bucket_counts(text)
+        if not bucket_counts:
+            return
+
+        cursor = self.conn.cursor()
+        for bucket, hits in bucket_counts.items():
+            cursor.execute(
+                """
+                INSERT INTO engram_index (bucket, memory_id, hits)
+                VALUES (?, ?, ?)
+                ON CONFLICT(bucket, memory_id) DO UPDATE SET hits = excluded.hits
+                """,
+                (int(bucket), memory_id, int(hits)),
+            )
+        self.conn.commit()
+
+    def delete_engram_for_ids(self, memory_ids: list[str]) -> None:
+        """Remove Engram buckets for memory IDs."""
+        if not settings.engram_enabled or not memory_ids:
+            return
+        if self.conn is None:
+            raise RuntimeError(_DB_NOT_CONNECTED_ERROR)
+
+        placeholders = ",".join(["?"] * len(memory_ids))
+        self.execute(
+            f"DELETE FROM engram_index WHERE memory_id IN ({placeholders})",
+            tuple(memory_ids),
+        )
+        self.commit()
+
+    def engram_search_candidates(
+        self, query: str, project: str | None = None, limit: int | None = None
+    ) -> list[tuple[str, int]]:
+        """Return candidate memory IDs using Engram hashed n-gram lookup."""
+        if not settings.engram_enabled:
+            return []
+        if self.conn is None:
+            raise RuntimeError(_DB_NOT_CONNECTED_ERROR)
+
+        buckets = engram_index.buckets_for_query(query)
+        if not buckets:
+            return []
+
+        limit = int(limit or getattr(settings, "engram_candidate_limit", 200))
+        min_hits = int(getattr(settings, "engram_min_hits", 2))
+
+        placeholders = ",".join(["?"] * len(buckets))
+        sql = (
+            "SELECT e.memory_id AS id, SUM(e.hits) AS hits "
+            "FROM engram_index e "
+            "JOIN memories m ON m.id = e.memory_id "
+            "WHERE m.archived = 0 AND e.bucket IN ("
+            + placeholders
+            + ")"
+        )
+        params: list[Any] = [*buckets]
+        if project:
+            sql += " AND m.project = ?"
+            params.append(project)
+        sql += " GROUP BY e.memory_id HAVING SUM(e.hits) >= ? ORDER BY hits DESC LIMIT ?"
+        params.extend([min_hits, limit])
+
+        try:
+            cursor = self.execute(sql, tuple(params))
+            rows = cursor.fetchall()
+            return [(row["id"], int(row["hits"])) for row in rows]
+        except sqlite3.OperationalError as e:
+            logger.warning(f"Engram search failed: {e}")
+            return []
+
     def close(self) -> None:
         """Close database connection."""
         if self.conn:
@@ -184,6 +320,11 @@ class Database:
         )
         self.commit()
 
+        try:
+            self.upsert_engram_index(memory.id, memory.text)
+        except Exception as e:
+            logger.warning(f"Engram index update failed: {e}")
+
     def get_memory_by_id(self, memory_id: str) -> Memory | None:
         """Retrieve a memory by ID."""
         cursor = self.execute("SELECT * FROM memories WHERE id = ?", (memory_id,))
@@ -256,8 +397,14 @@ class Database:
     def delete_memory(self, memory_id: str) -> bool:
         """Delete a memory by ID (hard delete)."""
         cursor = self.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
+        deleted = cursor.rowcount > 0
+        if deleted:
+            try:
+                self.delete_engram_for_ids([memory_id])
+            except Exception as e:
+                logger.warning(f"Engram index cleanup failed: {e}")
         self.commit()
-        return cursor.rowcount > 0
+        return deleted
 
     def update_embedding(self, memory_id: str, embedding: list[float]) -> bool:
         """Update embedding vector for a memory and touch updated_at."""
@@ -275,24 +422,41 @@ class Database:
         cursor = self.execute(
             "UPDATE memories SET archived = 1 WHERE id = ? AND archived = 0", (memory_id,)
         )
+        archived = cursor.rowcount > 0
+        if archived:
+            try:
+                self.delete_engram_for_ids([memory_id])
+            except Exception as e:
+                logger.warning(f"Engram index cleanup failed: {e}")
         self.commit()
-        return cursor.rowcount > 0
+        return archived
 
     def bulk_delete(self, project: str | None = None, before_timestamp: int | None = None) -> int:
         """Bulk delete memories (hard delete)."""
-        query = "DELETE FROM memories WHERE 1=1"
+        base = "FROM memories WHERE 1=1"
         params: list[Any] = []
 
         if project:
-            query += _PROJECT_FILTER_SQL
+            base += _PROJECT_FILTER_SQL
             params.append(project)
 
-        if before_timestamp:
-            query += " AND created_at < ?"
+        if before_timestamp is not None:
+            base += " AND created_at < ?"
             params.append(before_timestamp)
 
-        cursor = self.execute(query, tuple(params))
+        ids: list[str] = []
+        if settings.engram_enabled:
+            cursor = self.execute(f"SELECT id {base}", tuple(params))
+            ids = [row["id"] for row in cursor.fetchall()]
+
+        cursor = self.execute(f"DELETE {base}", tuple(params))
         self.commit()
+
+        if ids:
+            try:
+                self.delete_engram_for_ids(ids)
+            except Exception as e:
+                logger.warning(f"Engram index cleanup failed: {e}")
         return cursor.rowcount
 
     def get_memories_by_ids(

+ 96 - 0
src/engram.py

@@ -0,0 +1,96 @@
+"""Engram: hashed N-gram retrieval index for Cognio."""
+
+from __future__ import annotations
+
+import hashlib
+import logging
+import re
+import unicodedata
+from collections import Counter
+from typing import Iterable
+
+from .config import settings
+
+logger = logging.getLogger(__name__)
+
+
+class EngramIndex:
+    """Hashed N-gram index for O(1)-style candidate retrieval."""
+
+    def _normalize_text(self, text: str) -> str:
+        normalized = unicodedata.normalize("NFKC", text)
+        normalized = normalized.lower()
+        normalized = re.sub(r"\s+", " ", normalized).strip()
+        return normalized
+
+    def _tokenize(self, text: str) -> list[str]:
+        normalized = self._normalize_text(text)
+        if not normalized:
+            return []
+        return re.findall(r"[a-z0-9]+", normalized)
+
+    def _parse_ngram_sizes(self) -> list[int]:
+        raw = getattr(settings, "engram_ngram_sizes", "2,3")
+        sizes: list[int] = []
+        if isinstance(raw, str):
+            for part in raw.split(","):
+                part = part.strip()
+                if not part:
+                    continue
+                try:
+                    sizes.append(int(part))
+                except ValueError:
+                    continue
+        elif isinstance(raw, Iterable):
+            for item in raw:
+                try:
+                    sizes.append(int(item))
+                except (TypeError, ValueError):
+                    continue
+        sizes = sorted({s for s in sizes if s > 0})
+        return sizes or [2, 3]
+
+    def _num_heads(self) -> int:
+        return max(1, int(getattr(settings, "engram_num_heads", 4)))
+
+    def _num_buckets(self) -> int:
+        return max(1024, int(getattr(settings, "engram_num_buckets", 1000003)))
+
+    def _bucket_limit(self) -> int:
+        return max(0, int(getattr(settings, "engram_query_bucket_limit", 500)))
+
+    def _hash_ngram(self, ngram: list[str], head: int, num_buckets: int) -> int:
+        key = f"{head}|{' '.join(ngram)}"
+        digest = hashlib.blake2b(key.encode("utf-8"), digest_size=8).digest()
+        return int.from_bytes(digest, "big") % num_buckets
+
+    def buckets_for_text(self, text: str) -> list[int]:
+        tokens = self._tokenize(text)
+        if not tokens:
+            return []
+        ngram_sizes = self._parse_ngram_sizes()
+        num_heads = self._num_heads()
+        num_buckets = self._num_buckets()
+
+        buckets: list[int] = []
+        for idx in range(len(tokens)):
+            for n in ngram_sizes:
+                if idx + 1 < n:
+                    continue
+                ngram = tokens[idx - n + 1 : idx + 1]
+                for head in range(num_heads):
+                    buckets.append(self._hash_ngram(ngram, head, num_buckets))
+        return buckets
+
+    def bucket_counts(self, text: str) -> Counter[int]:
+        return Counter(self.buckets_for_text(text))
+
+    def buckets_for_query(self, query: str) -> list[int]:
+        buckets = list(dict.fromkeys(self.buckets_for_text(query)))
+        limit = self._bucket_limit()
+        if limit and len(buckets) > limit:
+            buckets = buckets[:limit]
+        return buckets
+
+
+engram_index = EngramIndex()

+ 1 - 0
src/main.py

@@ -53,6 +53,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
     logger.info("Starting Cognio server...")
     settings.ensure_db_dir()
     db.connect()
+    db.backfill_engram_index()
     embedding_service.load_model()
 
     # Optionally trigger background re-embedding for mismatched dimensions

+ 54 - 7
src/memory.py

@@ -237,6 +237,28 @@ class MemoryService:
             # 1) Get FTS candidates (id, bm25 rank), lower rank is better
             candidates = db.fts_search_candidates(query=query, project=project, limit=100)
 
+            # Engram hashed n-gram candidates (merge with FTS)
+            engram_candidates: list[tuple[str, int]] = []
+            if settings.engram_enabled:
+                try:
+                    engram_candidates = db.engram_search_candidates(
+                        query=query,
+                        project=project,
+                        limit=getattr(settings, "engram_candidate_limit", 200),
+                    )
+                except Exception as e:
+                    logger.warning(f"Engram candidate lookup failed: {e}")
+
+            if engram_candidates:
+                merged: dict[str, float] = {mid: float(rank) for mid, rank in candidates}
+                for mid, hits in engram_candidates:
+                    rank = 1.0 / max(float(hits), 1.0)
+                    if mid in merged:
+                        merged[mid] = min(merged[mid], rank)
+                    else:
+                        merged[mid] = rank
+                candidates = sorted(merged.items(), key=lambda x: x[1])
+
             after_ts: int | None = None
             before_ts: int | None = None
             if after_date:
@@ -456,16 +478,41 @@ class MemoryService:
             except ValueError:
                 before_ts = None
 
-        all_memories = db.get_all_memories(
-            project=project,
-            tags=tags,
-            after_timestamp=after_ts,
-            before_timestamp=before_ts,
-        )
+        base_memories: list[Memory] | None = None
+        if settings.engram_enabled:
+            try:
+                engram_candidates = db.engram_search_candidates(
+                    query=query,
+                    project=project,
+                    limit=getattr(settings, "engram_candidate_limit", 200),
+                )
+            except Exception as e:
+                logger.warning(f"Engram candidate lookup failed: {e}")
+                engram_candidates = []
+
+            if engram_candidates:
+                candidate_ids = [mid for mid, _ in engram_candidates]
+                base_memories = db.get_memories_by_ids(
+                    ids=candidate_ids,
+                    project=project,
+                    tags=tags,
+                    after_timestamp=after_ts,
+                    before_timestamp=before_ts,
+                )
+                if not base_memories:
+                    base_memories = None
+
+        if base_memories is None:
+            base_memories = db.get_all_memories(
+                project=project,
+                tags=tags,
+                after_timestamp=after_ts,
+                before_timestamp=before_ts,
+            )
 
         emb_dim = embedding_service.embedding_dim
         mems_with_emb = [
-            m for m in all_memories if (m.embedding is not None and len(m.embedding) == emb_dim)
+            m for m in base_memories if (m.embedding is not None and len(m.embedding) == emb_dim)
         ]
         if not mems_with_emb:
             return []

+ 20 - 0
tests/test_engram.py

@@ -0,0 +1,20 @@
+from src.config import settings
+from src.engram import EngramIndex
+
+
+def test_engram_buckets_for_text():
+    index = EngramIndex()
+    buckets = index.buckets_for_text("Hello world from Cognio")
+    assert buckets
+    assert all(isinstance(bucket, int) for bucket in buckets)
+
+
+def test_engram_bucket_limit():
+    index = EngramIndex()
+    original_limit = settings.engram_query_bucket_limit
+    try:
+        settings.engram_query_bucket_limit = 1
+        buckets = index.buckets_for_query("hello world from cognio")
+        assert len(buckets) <= 1
+    finally:
+        settings.engram_query_bucket_limit = original_limit