|
|
@@ -29,16 +29,21 @@ class CodegeexChatModel:
|
|
|
|
|
|
def stream_chat(self, request: ChatCompletionRequest):
|
|
|
try:
|
|
|
+ inputs = self.tokenizer.apply_chat_template(
|
|
|
+ conversation=[msg.model_dump() for msg in request.messages],
|
|
|
+ add_generation_prompt=True,
|
|
|
+ return_tensors="pt",
|
|
|
+ return_dict=True
|
|
|
+ ).to(self.model.device)
|
|
|
+ gen_configs = {
|
|
|
+ "max_new_tokens": request.max_tokens,
|
|
|
+ "temperature": request.temperature,
|
|
|
+ "top_p": request.top_p,
|
|
|
+ "repetition_penalty": request.presence_penalty
|
|
|
+ }
|
|
|
length = 0
|
|
|
- for i, (response, _) in enumerate(self.model.stream_chat(
|
|
|
- self.tokenizer,
|
|
|
- query=request.messages[-1].content,
|
|
|
- history=[msg.model_dump() for msg in request.messages[:-1]],
|
|
|
- max_new_tokens=request.max_tokens,
|
|
|
- temperature=request.temperature,
|
|
|
- top_p=request.top_p,
|
|
|
- repetition_penalty=request.presence_penalty
|
|
|
- )):
|
|
|
+ for i, outputs in enumerate(self.model.stream_generate(**inputs, **gen_configs)):
|
|
|
+ response = self.tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):-1])
|
|
|
resp = ChatCompletionStreamResponse()
|
|
|
resp.choices[0].index = i
|
|
|
resp.choices[0].delta.content = response[length:]
|