test_embedding_queue.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """Unit tests for EmbeddingQueue behavior (fast and deterministic)."""
  2. import asyncio
  3. import pytest
  4. from src.embedding_queue import EmbeddingQueue
  5. class DummyEmbeddingService:
  6. def __init__(self) -> None:
  7. self.cache: dict[str, list[float]] = {}
  8. self.calls: list[list[str]] = []
  9. def encode_batch(self, texts: list[str]) -> list[list[float]]:
  10. self.calls.append(list(texts))
  11. return [[float(len(t)), 0.0, 0.0] for t in texts]
  12. @pytest.mark.asyncio
  13. async def test_add_task_uses_cache(monkeypatch: pytest.MonkeyPatch) -> None:
  14. dummy = DummyEmbeddingService()
  15. dummy.cache["hash-1"] = [1.0, 2.0, 3.0]
  16. monkeypatch.setattr("src.embedding_queue.embedding_service", dummy)
  17. queue = EmbeddingQueue(batch_size=2, batch_timeout=1.0)
  18. result = await queue.add_task("text", "hash-1")
  19. assert result == [1.0, 2.0, 3.0]
  20. assert dummy.calls == []
  21. @pytest.mark.asyncio
  22. async def test_process_queue_batches_and_updates_cache(monkeypatch: pytest.MonkeyPatch) -> None:
  23. dummy = DummyEmbeddingService()
  24. monkeypatch.setattr("src.embedding_queue.embedding_service", dummy)
  25. queue = EmbeddingQueue(batch_size=2, batch_timeout=1.0)
  26. loop = asyncio.get_running_loop()
  27. fut1 = loop.create_future()
  28. fut2 = loop.create_future()
  29. await queue.queue.put(("one", "h1", fut1))
  30. await queue.queue.put(("two", "h2", fut2))
  31. await queue._process_queue()
  32. assert fut1.done() and fut2.done()
  33. assert fut1.result() == [3.0, 0.0, 0.0]
  34. assert fut2.result() == [3.0, 0.0, 0.0]
  35. assert "h1" in dummy.cache and "h2" in dummy.cache
  36. @pytest.mark.asyncio
  37. async def test_process_queue_error_propagates(monkeypatch: pytest.MonkeyPatch) -> None:
  38. class ErrorEmbeddingService:
  39. def __init__(self) -> None:
  40. self.cache: dict[str, list[float]] = {}
  41. def encode_batch(self, texts: list[str]) -> list[list[float]]: # pragma: no cover
  42. raise RuntimeError("embedding failure")
  43. error_service = ErrorEmbeddingService()
  44. monkeypatch.setattr("src.embedding_queue.embedding_service", error_service)
  45. queue = EmbeddingQueue(batch_size=1, batch_timeout=1.0)
  46. loop = asyncio.get_running_loop()
  47. fut = loop.create_future()
  48. await queue.queue.put(("text", "hash", fut))
  49. await queue._process_queue()
  50. assert fut.done()
  51. with pytest.raises(RuntimeError):
  52. fut.result()