embedding.py 889 B

1234567891011121314151617181920212223242526
  1. import os
  2. from langchain.schema.embeddings import Embeddings
  3. from zhipuai import ZhipuAI
  4. class GLMEmbeddings(Embeddings):
  5. def __init__(self):
  6. self.client = ZhipuAI(api_key=os.getenv("Zhipu_API_KEY"))
  7. self.embedding_size = 1024
  8. def embed_query(self, text: str) -> list[float]:
  9. return self.embed_documents([text])[0]
  10. def embed_documents(self, texts: list[str]) -> list[list[float]]:
  11. return self._get_len_safe_embeddings(texts)
  12. def _get_len_safe_embeddings(self, texts: list[str]) -> list[list[float]]:
  13. try:
  14. # 获取embedding响应
  15. response = self.client.embeddings.create(model="embedding-2", input=texts)
  16. data = [item.embedding for item in response.data]
  17. return data
  18. except Exception as e:
  19. print(f"Fail to get embeddings, caused by {e}")
  20. return []