test_api.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. """Integration tests for API endpoints."""
  2. from collections.abc import Generator
  3. import pytest
  4. from httpx import ASGITransport, AsyncClient
  5. from src.config import settings
  6. from src.database import db
  7. from src.main import app
  8. @pytest.fixture(autouse=True)
  9. def setup_test_db() -> Generator[None, None, None]:
  10. """Setup test database before each test."""
  11. settings.api_key = None
  12. db.db_path = ":memory:"
  13. db.connect()
  14. yield
  15. db.close()
  16. @pytest.mark.asyncio
  17. async def test_root_endpoint() -> None:
  18. """Test root endpoint."""
  19. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  20. response = await client.get("/")
  21. assert response.status_code == 200
  22. data = response.json()
  23. assert data["name"] == "Cognio"
  24. @pytest.mark.asyncio
  25. async def test_health_check() -> None:
  26. """Test health check endpoint."""
  27. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  28. response = await client.get("/health")
  29. assert response.status_code == 200
  30. data = response.json()
  31. assert data["status"] == "ok"
  32. assert "db" in data
  33. assert "fts" in data
  34. assert "embedding_model" in data
  35. @pytest.mark.asyncio
  36. async def test_save_memory() -> None:
  37. """Test saving a memory."""
  38. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  39. response = await client.post(
  40. "/memory/save",
  41. json={"text": "Test memory", "project": "TEST", "tags": ["test", "example"]},
  42. )
  43. assert response.status_code == 200
  44. data = response.json()
  45. assert data["saved"] is True
  46. assert data["duplicate"] is False
  47. assert data["reason"] == "created"
  48. assert "id" in data
  49. @pytest.mark.asyncio
  50. async def test_save_duplicate_memory() -> None:
  51. """Test saving a duplicate memory."""
  52. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  53. memory_data = {"text": "Duplicate test", "project": "TEST", "tags": ["test"]}
  54. # Save first time
  55. response1 = await client.post("/memory/save", json=memory_data)
  56. assert response1.status_code == 200
  57. data1 = response1.json()
  58. assert data1["duplicate"] is False
  59. # Save second time (should be duplicate)
  60. response2 = await client.post("/memory/save", json=memory_data)
  61. assert response2.status_code == 200
  62. data2 = response2.json()
  63. assert data2["duplicate"] is True
  64. assert data2["reason"] == "duplicate"
  65. assert data2["id"] == data1["id"]
  66. @pytest.mark.asyncio
  67. async def test_search_memory() -> None:
  68. """Test searching memories."""
  69. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  70. # Save some test memories
  71. await client.post(
  72. "/memory/save",
  73. json={
  74. "text": "Python is a programming language",
  75. "project": "TEST",
  76. "tags": ["python"],
  77. },
  78. )
  79. await client.post(
  80. "/memory/save",
  81. json={
  82. "text": "JavaScript is used for web development",
  83. "project": "TEST",
  84. "tags": ["javascript"],
  85. },
  86. )
  87. # Search for Python
  88. response = await client.get(
  89. "/memory/search",
  90. params={"q": "programming language", "limit": 5, "threshold": 0.3},
  91. )
  92. assert response.status_code == 200
  93. data = response.json()
  94. assert "results" in data
  95. assert len(data["results"]) > 0
  96. assert data["results"][0]["score"] > 0.3
  97. @pytest.mark.asyncio
  98. async def test_list_memories() -> None:
  99. """Test listing memories."""
  100. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  101. # Save multiple memories
  102. for i in range(5):
  103. await client.post(
  104. "/memory/save",
  105. json={"text": f"Memory {i}", "project": "TEST", "tags": [f"tag{i}"]},
  106. )
  107. # List all memories
  108. response = await client.get("/memory/list", params={"page": 1, "limit": 10})
  109. assert response.status_code == 200
  110. data = response.json()
  111. assert "memories" in data
  112. assert len(data["memories"]) == 5
  113. assert data["total_items"] == 5
  114. @pytest.mark.asyncio
  115. async def test_get_memory() -> None:
  116. """Test getting a single memory by ID."""
  117. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  118. # Save a memory
  119. save_response = await client.post(
  120. "/memory/save",
  121. json={"text": "Test memory content", "project": "TEST", "tags": ["test"]},
  122. )
  123. memory_id = save_response.json()["id"]
  124. # Get the memory
  125. get_response = await client.get(f"/memory/{memory_id}")
  126. assert get_response.status_code == 200
  127. data = get_response.json()
  128. assert data["id"] == memory_id
  129. assert data["text"] == "Test memory content"
  130. assert data["project"] == "TEST"
  131. assert "test" in data["tags"]
  132. # Try to get non-existent memory
  133. get_response2 = await client.get("/memory/nonexistent-id")
  134. assert get_response2.status_code == 404
  135. @pytest.mark.asyncio
  136. async def test_delete_memory() -> None:
  137. """Test deleting a memory."""
  138. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  139. # Save a memory
  140. save_response = await client.post(
  141. "/memory/save", json={"text": "Memory to delete", "project": "TEST"}
  142. )
  143. memory_id = save_response.json()["id"]
  144. # Delete it
  145. delete_response = await client.delete(f"/memory/{memory_id}")
  146. assert delete_response.status_code == 200
  147. assert delete_response.json()["deleted"] is True
  148. # Try to delete again (should fail)
  149. delete_response2 = await client.delete(f"/memory/{memory_id}")
  150. assert delete_response2.status_code == 404
  151. @pytest.mark.asyncio
  152. async def test_get_stats() -> None:
  153. """Test getting statistics."""
  154. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  155. # Save some memories
  156. await client.post(
  157. "/memory/save", json={"text": "Memory 1", "project": "PROJECT_A", "tags": ["tag1"]}
  158. )
  159. await client.post(
  160. "/memory/save", json={"text": "Memory 2", "project": "PROJECT_B", "tags": ["tag2"]}
  161. )
  162. # Get stats
  163. response = await client.get("/memory/stats")
  164. assert response.status_code == 200
  165. data = response.json()
  166. assert data["total_memories"] == 2
  167. assert data["total_projects"] == 2
  168. assert "memories_by_project" in data
  169. assert data["memories_by_project"]["PROJECT_A"] == 1
  170. assert data["memories_by_project"]["PROJECT_B"] == 1