Просмотр исходного кода

修复了代码中的错误,优化了流聊天功能,提高了最大长度限制并添加了错误处理。

XingYu-Zhong 1 год назад
Родитель
Сommit
9ccf7b7581
2 измененных файлов с 9 добавлено и 13 удалено
  1. 3 11
      repodemo/llm/local/codegeex4.py
  2. 6 2
      repodemo/run_local.py

+ 3 - 11
repodemo/llm/local/codegeex4.py

@@ -2,14 +2,7 @@ from pydantic import Field
 from transformers import AutoModel, AutoTokenizer
 from typing import Iterator
 import torch
-class StreamProcessor:
-    def __init__(self):
-        self.previous_str = ""
 
-    def get_new_part(self, new_str):
-        new_part = new_str[len(self.previous_str):]
-        self.previous_str = new_str
-        return new_part
 class CodegeexChatModel():
     device: str = Field(description="device to load the model")
     tokenizer = Field(description="model's tokenizer")
@@ -31,7 +24,7 @@ class CodegeexChatModel():
             response, _ = self.model.chat(
                 self.tokenizer,
                 query=prompt,
-                max_length=4012,
+                max_length=120000,
                 temperature=temperature,
                 top_p=top_p
             )
@@ -42,14 +35,13 @@ class CodegeexChatModel():
     def stream_chat(self,prompt,temperature=0.2,top_p=0.95):
 
         try:
-            stream_processor = StreamProcessor()
             for response, _ in self.model.stream_chat(
                     self.tokenizer,
                     query=prompt,
-                    max_length=4012,
+                    max_length=120000,
                     temperature=temperature,
                     top_p=top_p
             ):
-                yield stream_processor.get_new_part(response)
+                yield response
         except Exception as e:
             yield f'error: {e}'

+ 6 - 2
repodemo/run_local.py

@@ -5,7 +5,6 @@ from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_j
 from utils.tools import unzip_file,get_project_files_with_content
 from utils.bingsearch import bing_search_prompt
 from llm.local.codegeex4 import CodegeexChatModel
-
 local_model_path = '<your_local_model_path>'
 llm = CodegeexChatModel(local_model_path)
 
@@ -155,8 +154,13 @@ async def main(message: cl.Message):
     
     if len(prompt_content)/4<120000:
         stream =  llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p)
+        stream_processor = StreamProcessor()
         for part in stream:
-            if token := (part or " "):
+            if isinstance(part, str):
+                text = stream_processor.get_new_part(part)
+            elif isinstance(part, dict):
+                text = stream_processor.get_new_part(part['name']+part['content'])
+            if token := (text or " "):
                 await msg.stream_token(token)
     else:
         await msg.stream_token("项目太大了,请换小一点的项目。")