codegeex.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. """
  2. coding : utf-8
  3. @Date : 2024/7/10
  4. @Author : Shaobo
  5. @Describe:
  6. """
  7. import torch
  8. from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
  9. from sseclient import Event
  10. from transformers import AutoTokenizer, AutoModel
  11. class CodegeexChatModel:
  12. def __init__(self, args):
  13. self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
  14. if args.bf16:
  15. self.model = AutoModel.from_pretrained(
  16. args.model_name_or_path,
  17. trust_remote_code=True,
  18. torch_dtype=torch.bfloat16,
  19. ).to(args.device).eval()
  20. else:
  21. self.model = AutoModel.from_pretrained(
  22. args.model_name_or_path,
  23. trust_remote_code=True
  24. ).to(args.device).eval()
  25. print("Model is initialized.")
  26. def stream_chat(self, request: ChatCompletionRequest):
  27. try:
  28. inputs = self.tokenizer.apply_chat_template(
  29. conversation=[msg.model_dump() for msg in request.messages],
  30. add_generation_prompt=True,
  31. return_tensors="pt",
  32. return_dict=True
  33. ).to(self.model.device)
  34. gen_configs = {
  35. "max_new_tokens": request.max_tokens,
  36. "temperature": request.temperature,
  37. "top_p": request.top_p,
  38. "repetition_penalty": request.presence_penalty,
  39. "do_sample": True if request.temperature else request.temperature,
  40. }
  41. length = 0
  42. for i, outputs in enumerate(self.model.stream_generate(**inputs, **gen_configs)):
  43. response = self.tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):-1])
  44. if not response or response[-1] == "�":
  45. continue
  46. resp = ChatCompletionStreamResponse()
  47. resp.choices[0].index = i
  48. resp.choices[0].delta.content = response[length:]
  49. event = Event(id=resp.id, data=resp.json(), event='message')
  50. yield event.dump()
  51. length = len(response)
  52. resp = ChatCompletionStreamResponse()
  53. resp.choices[0].finish_reason = 'stop'
  54. event = Event(id=resp.id, data=resp.json(), event='message')
  55. yield event.dump()
  56. except Exception as e:
  57. resp = ChatCompletionStreamResponse()
  58. resp.choices[0].finish_reason = 'stop'
  59. event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message')
  60. yield event.dump()
  61. def chat(self, request: ChatCompletionRequest):
  62. try:
  63. response, _ = self.model.chat(
  64. self.tokenizer,
  65. query=request.messages[-1].content,
  66. history=[msg.model_dump() for msg in request.messages[:-1]],
  67. max_new_tokens=request.max_tokens,
  68. temperature=request.temperature,
  69. top_p=request.top_p,
  70. repetition_penalty=request.presence_penalty
  71. )
  72. resp = ChatCompletionResponse()
  73. resp.choices[0].message.content = response
  74. resp.choices[0].finish_reason = 'stop'
  75. return resp.model_dump()
  76. except Exception as e:
  77. return f"请求报错,错误原因:{e}"