test_embeddings.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. """Unit tests for embedding service."""
  2. import pytest
  3. from src.embeddings import EmbeddingService
  4. from src.utils import generate_text_hash
  5. @pytest.fixture
  6. def embedding_service() -> EmbeddingService:
  7. """Create embedding service instance."""
  8. return EmbeddingService()
  9. def test_encode_single_text(embedding_service: EmbeddingService) -> None:
  10. """Test encoding a single text."""
  11. text = "This is a test sentence for embedding"
  12. text_hash = generate_text_hash(text)
  13. embedding = embedding_service.encode(text, text_hash)
  14. assert isinstance(embedding, list)
  15. assert len(embedding) == embedding_service.embedding_dim
  16. assert all(isinstance(x, float) for x in embedding)
  17. def test_encode_batch(embedding_service: EmbeddingService) -> None:
  18. """Test encoding multiple texts."""
  19. texts = ["First text", "Second text", "Third text"]
  20. embeddings = embedding_service.encode_batch(texts)
  21. assert len(embeddings) == 3
  22. assert all(len(emb) == embedding_service.embedding_dim for emb in embeddings)
  23. def test_cosine_similarity(embedding_service: EmbeddingService) -> None:
  24. """Test cosine similarity calculation."""
  25. text1 = "The cat sat on the mat"
  26. text2 = "A cat was sitting on a mat"
  27. text3 = "Python programming language"
  28. hash1 = generate_text_hash(text1)
  29. hash2 = generate_text_hash(text2)
  30. hash3 = generate_text_hash(text3)
  31. emb1 = embedding_service.encode(text1, hash1)
  32. emb2 = embedding_service.encode(text2, hash2)
  33. emb3 = embedding_service.encode(text3, hash3)
  34. # Similar sentences should have high similarity
  35. sim_similar = embedding_service.cosine_similarity(emb1, emb2)
  36. assert sim_similar > 0.7
  37. # Dissimilar sentences should have lower similarity
  38. sim_different = embedding_service.cosine_similarity(emb1, emb3)
  39. assert sim_different < 0.5
  40. # Same sentence should have similarity close to 1.0
  41. sim_identical = embedding_service.cosine_similarity(emb1, emb1)
  42. assert sim_identical > 0.99
  43. def test_embedding_cache(embedding_service: EmbeddingService) -> None:
  44. """Test the embedding cache."""
  45. text = "This is a test sentence for the embedding cache."
  46. text_hash = generate_text_hash(text)
  47. # First call should generate and cache the embedding
  48. embedding1 = embedding_service.encode(text, text_hash)
  49. # Second call should return the cached embedding
  50. embedding2 = embedding_service.encode(text, text_hash)
  51. assert embedding1 == embedding2
  52. # Save and load the cache
  53. embedding_service.save_cache()
  54. embedding_service.load_cache()
  55. # Third call should still return the cached embedding
  56. embedding3 = embedding_service.encode(text, text_hash)
  57. assert embedding1 == embedding3