embedding.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import os
  2. from llama_index.core.base.embeddings.base import BaseEmbedding
  3. from pydantic import Field
  4. from zhipuai import ZhipuAI
  5. class GLMEmbeddings(BaseEmbedding):
  6. client = Field(description="embedding model client")
  7. embedding_size: float = Field(description="embedding size")
  8. def __init__(self):
  9. super().__init__(model_name='GLM', embed_batch_size=64)
  10. self.client = ZhipuAI(api_key=os.getenv("Zhipu_API_KEY"))
  11. self.embedding_size = 1024
  12. def _get_query_embedding(self, query: str) -> list[float]:
  13. return self._get_text_embeddings([query])[0]
  14. def _get_text_embedding(self, text: str) -> list[float]:
  15. return self._get_text_embeddings([text])[0]
  16. def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
  17. return self._get_len_safe_embeddings(texts)
  18. async def _aget_query_embedding(self, query: str) -> list[float]:
  19. return self._get_query_embedding(query)
  20. def _get_len_safe_embeddings(self, texts: list[str]) -> list[list[float]]:
  21. try:
  22. # 获取embedding响应
  23. response = self.client.embeddings.create(model="embedding-2", input=texts)
  24. data = [item.embedding for item in response.data]
  25. return data
  26. except Exception as e:
  27. print(f"Fail to get embeddings, caused by {e}")
  28. return []