| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- import torch
- from pydantic import Field
- from transformers import AutoModel, AutoTokenizer
- class CodegeexChatModel:
- device: str = Field(description="device to load the model")
- tokenizer = Field(description="model's tokenizer")
- model = Field(description="Codegeex model")
- temperature: float = Field(description="temperature to use for the model.")
- def __init__(self, model_name_or_path):
- super().__init__()
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
- self.tokenizer = AutoTokenizer.from_pretrained(
- model_name_or_path, trust_remote_code=True
- )
- self.model = (
- AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
- .to(self.device)
- .eval()
- )
- print("Model has been initialized.")
- def chat(self, prompt, temperature=0.2, top_p=0.95):
- try:
- response, _ = self.model.chat(
- self.tokenizer,
- query=prompt,
- max_length=120000,
- temperature=temperature,
- top_p=top_p,
- )
- return response
- except Exception as e:
- return f"error: {e}"
- def stream_chat(self, prompt, temperature=0.2, top_p=0.95):
- try:
- for response, _ in self.model.stream_chat(
- self.tokenizer,
- query=prompt,
- max_length=120000,
- temperature=temperature,
- top_p=top_p,
- ):
- yield response
- except Exception as e:
- yield f"error: {e}"
|