test_memory.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. """Unit tests for memory service."""
  2. from collections.abc import Generator
  3. from datetime import datetime, timedelta
  4. import pytest
  5. from src.database import db
  6. from src.memory import memory_service
  7. from src.models import SaveMemoryRequest
  8. @pytest.fixture(autouse=True)
  9. def setup_test_db() -> Generator[None, None, None]:
  10. """Setup test database before each test."""
  11. db.db_path = ":memory:"
  12. db.connect()
  13. yield
  14. db.close()
  15. def test_save_memory() -> None:
  16. """Test saving a memory."""
  17. request = SaveMemoryRequest(
  18. text="Python is a programming language", project="TEST", tags=["python", "test"]
  19. )
  20. memory_id, is_duplicate, reason = memory_service.save_memory(request)
  21. assert memory_id is not None
  22. assert is_duplicate is False
  23. assert reason == "created"
  24. def test_save_duplicate_memory() -> None:
  25. """Test duplicate detection."""
  26. request = SaveMemoryRequest(text="Duplicate text", project="TEST", tags=["test"])
  27. # Save first time
  28. id1, dup1, _ = memory_service.save_memory(request)
  29. assert dup1 is False
  30. # Save again (should be duplicate)
  31. id2, dup2, reason2 = memory_service.save_memory(request)
  32. assert dup2 is True
  33. assert reason2 == "duplicate"
  34. assert id1 == id2
  35. def test_search_memory() -> None:
  36. """Test semantic search."""
  37. # Save some memories
  38. memory_service.save_memory(
  39. SaveMemoryRequest(text="Python is great for AI", project="AI", tags=["python"])
  40. )
  41. memory_service.save_memory(
  42. SaveMemoryRequest(text="JavaScript for web development", project="WEB", tags=["js"])
  43. )
  44. memory_service.save_memory(
  45. SaveMemoryRequest(text="Machine learning with Python", project="AI", tags=["python", "ml"])
  46. )
  47. # Search for Python-related
  48. results = memory_service.search_memory(query="Python programming", limit=5, threshold=0.3)
  49. assert len(results) > 0
  50. # Should find Python-related memories
  51. assert any("Python" in r.text for r in results)
  52. def test_search_with_project_filter() -> None:
  53. """Test search with project filtering."""
  54. memory_service.save_memory(
  55. SaveMemoryRequest(text="AI project memory", project="AI", tags=["ai"])
  56. )
  57. memory_service.save_memory(
  58. SaveMemoryRequest(text="Web project memory", project="WEB", tags=["web"])
  59. )
  60. results = memory_service.search_memory(
  61. query="project memory", project="AI", limit=5, threshold=0.3
  62. )
  63. assert len(results) > 0
  64. assert all(r.project == "AI" for r in results)
  65. def test_search_with_date_range() -> None:
  66. """Test search with date filtering."""
  67. # Save a memory
  68. memory_service.save_memory(
  69. SaveMemoryRequest(text="Recent memory", project="TEST", tags=["test"])
  70. )
  71. # Search with date range
  72. now = datetime.now()
  73. yesterday = (now - timedelta(days=1)).isoformat()
  74. tomorrow = (now + timedelta(days=1)).isoformat()
  75. results = memory_service.search_memory(
  76. query="memory", after_date=yesterday, before_date=tomorrow, threshold=0.1
  77. )
  78. assert len(results) > 0
  79. def test_list_memories() -> None:
  80. """Test listing memories."""
  81. # Save multiple memories
  82. for i in range(5):
  83. memory_service.save_memory(
  84. SaveMemoryRequest(text=f"Memory {i}", project="TEST", tags=[f"tag{i}"])
  85. )
  86. memories, _ = memory_service.list_memories(page=1, limit=10)
  87. assert len(memories) == 5
  88. def test_list_with_project_filter() -> None:
  89. """Test listing with project filter."""
  90. memory_service.save_memory(SaveMemoryRequest(text="AI memory", project="AI", tags=["ai"]))
  91. memory_service.save_memory(SaveMemoryRequest(text="WEB memory", project="WEB", tags=["web"]))
  92. memories, total = memory_service.list_memories(project="AI", page=1, limit=10)
  93. assert total == 1
  94. assert memories[0].project == "AI"
  95. def test_list_with_relevance_sort() -> None:
  96. """Test listing with relevance sorting."""
  97. memory_service.save_memory(
  98. SaveMemoryRequest(text="Python programming language", project="CODE", tags=["python"])
  99. )
  100. memory_service.save_memory(
  101. SaveMemoryRequest(text="JavaScript is cool", project="CODE", tags=["js"])
  102. )
  103. memories, _ = memory_service.list_memories(
  104. page=1, limit=10, sort="relevance", search_query="Python"
  105. )
  106. assert len(memories) > 0
  107. # First result should be most relevant (already checked by len > 0)
  108. assert memories[0].score is not None
  109. def test_delete_memory() -> None:
  110. """Test deleting a memory."""
  111. request = SaveMemoryRequest(text="Memory to delete", project="TEST", tags=["test"])
  112. memory_id, _, _ = memory_service.save_memory(request)
  113. # Delete it
  114. deleted = memory_service.delete_memory(memory_id)
  115. assert deleted is True
  116. # Try to delete again (should fail)
  117. deleted_again = memory_service.delete_memory(memory_id)
  118. assert deleted_again is False
  119. def test_bulk_delete() -> None:
  120. """Test bulk deletion."""
  121. # Save memories in different projects
  122. memory_service.save_memory(
  123. SaveMemoryRequest(text="Memory 1", project="PROJECT_A", tags=["test"])
  124. )
  125. memory_service.save_memory(
  126. SaveMemoryRequest(text="Memory 2", project="PROJECT_A", tags=["test"])
  127. )
  128. memory_service.save_memory(
  129. SaveMemoryRequest(text="Memory 3", project="PROJECT_B", tags=["test"])
  130. )
  131. # Bulk delete PROJECT_A
  132. count = memory_service.bulk_delete(project="PROJECT_A")
  133. assert count == 2
  134. # Check remaining
  135. memories, total = memory_service.list_memories(page=1, limit=10)
  136. assert total == 1
  137. assert memories[0].project == "PROJECT_B"
  138. def test_get_stats() -> None:
  139. """Test statistics retrieval."""
  140. # Save some test data
  141. memory_service.save_memory(
  142. SaveMemoryRequest(text="Memory 1", project="PROJECT_A", tags=["tag1", "tag2"])
  143. )
  144. memory_service.save_memory(
  145. SaveMemoryRequest(text="Memory 2", project="PROJECT_B", tags=["tag1"])
  146. )
  147. stats = memory_service.get_stats()
  148. assert stats["total_memories"] == 2
  149. assert stats["total_projects"] == 2
  150. assert "PROJECT_A" in stats["by_project"]
  151. assert "PROJECT_B" in stats["by_project"]
  152. assert "tag1" in stats["top_tags"]
  153. def test_export_json() -> None:
  154. """Test JSON export."""
  155. memory_service.save_memory(
  156. SaveMemoryRequest(text="Export test", project="TEST", tags=["export"])
  157. )
  158. data = memory_service.export_memories(format="json")
  159. assert isinstance(data, dict)
  160. assert "memories" in data
  161. assert len(data["memories"]) > 0
  162. def test_export_markdown() -> None:
  163. """Test Markdown export."""
  164. memory_service.save_memory(
  165. SaveMemoryRequest(text="Export test", project="TEST", tags=["export"])
  166. )
  167. data = memory_service.export_memories(format="markdown")
  168. assert isinstance(data, str)
  169. assert "Export test" in data
  170. assert "TEST" in data
  171. def test_search_memory_minimal_truncation() -> None:
  172. """Test search with minimal payload and truncation."""
  173. long_text = "This is a long memory used to test minimal payload truncation behavior."
  174. memory_service.save_memory(SaveMemoryRequest(text=long_text, project="TEST", tags=["minimal"]))
  175. results = memory_service.search_memory(
  176. query="minimal payload",
  177. limit=1,
  178. threshold=0.0,
  179. minimal=True,
  180. max_chars_per_item=10,
  181. )
  182. assert len(results) == 1
  183. # Should be truncated and end with ellipsis character
  184. assert len(results[0].text) == 11
  185. assert results[0].text.endswith("…")
  186. def test_reembed_mismatched(monkeypatch: pytest.MonkeyPatch) -> None:
  187. """Test reembed_mismatched scans and re-embeds only mismatched items."""
  188. from src.models import Memory
  189. mem1 = Memory(
  190. id="1",
  191. text="needs embed",
  192. summary=None,
  193. text_hash="h1",
  194. embedding=None,
  195. project="TEST",
  196. tags=["t1"],
  197. created_at=0,
  198. updated_at=0,
  199. )
  200. mem2 = Memory(
  201. id="2",
  202. text="already ok",
  203. summary=None,
  204. text_hash="h2",
  205. embedding=[0.1, 0.2, 0.3],
  206. project="TEST",
  207. tags=["t2"],
  208. created_at=0,
  209. updated_at=0,
  210. )
  211. calls: dict[str, list] = {"updates": []}
  212. def fake_list_memories(limit: int = 500, offset: int = 0, **_: object) -> list[Memory]:
  213. return [mem1, mem2] if offset == 0 else []
  214. def fake_encode_batch(texts: list[str]) -> list[list[float]]:
  215. return [[1.0, 2.0, 3.0] for _ in texts]
  216. def fake_update_embedding(mem_id: str, emb: list[float]) -> bool:
  217. calls["updates"].append((mem_id, emb))
  218. return True
  219. monkeypatch.setattr("src.memory.db.list_memories", fake_list_memories)
  220. monkeypatch.setattr("src.memory.embedding_service.embedding_dim", 3)
  221. monkeypatch.setattr("src.memory.embedding_service.encode_batch", fake_encode_batch)
  222. monkeypatch.setattr("src.memory.db.update_embedding", fake_update_embedding)
  223. stats = memory_service.reembed_mismatched(page_size=10)
  224. assert stats["scanned"] == 2
  225. assert stats["reembedded"] == 1
  226. assert calls["updates"][0][0] == "1"