codegeex4.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import torch
  2. from pydantic import Field
  3. from transformers import AutoModel, AutoTokenizer
  4. class CodegeexChatModel:
  5. device: str = Field(description="device to load the model")
  6. tokenizer = Field(description="model's tokenizer")
  7. model = Field(description="Codegeex model")
  8. temperature: float = Field(description="temperature to use for the model.")
  9. def __init__(self, model_name_or_path):
  10. super().__init__()
  11. self.device = "cuda" if torch.cuda.is_available() else "cpu"
  12. self.tokenizer = AutoTokenizer.from_pretrained(
  13. model_name_or_path, trust_remote_code=True
  14. )
  15. self.model = (
  16. AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
  17. .to(self.device)
  18. .eval()
  19. )
  20. print("Model has been initialized.")
  21. def chat(self, prompt, temperature=0.2, top_p=0.95):
  22. try:
  23. response, _ = self.model.chat(
  24. self.tokenizer,
  25. query=prompt,
  26. max_length=120000,
  27. temperature=temperature,
  28. top_p=top_p,
  29. )
  30. return response
  31. except Exception as e:
  32. return f"error: {e}"
  33. def stream_chat(self, prompt, temperature=0.2, top_p=0.95):
  34. try:
  35. for response, _ in self.model.stream_chat(
  36. self.tokenizer,
  37. query=prompt,
  38. max_length=120000,
  39. temperature=temperature,
  40. top_p=top_p,
  41. ):
  42. yield response
  43. except Exception as e:
  44. yield f"error: {e}"