run_local.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import chainlit as cl
  2. from chainlit.input_widget import Slider
  3. from llm.local.codegeex4 import CodegeexChatModel
  4. from prompts.base_prompt import (
  5. judge_task_prompt,
  6. get_cur_base_user_prompt,
  7. web_judge_task_prompt,
  8. )
  9. from utils.bingsearch import bing_search_prompt
  10. from utils.tools import unzip_file, get_project_files_with_content
  11. local_model_path = "<your_local_model_path>"
  12. llm = CodegeexChatModel(local_model_path)
  13. class StreamProcessor:
  14. def __init__(self):
  15. self.previous_str = ""
  16. def get_new_part(self, new_str):
  17. new_part = new_str[len(self.previous_str):]
  18. self.previous_str = new_str
  19. return new_part
  20. @cl.set_chat_profiles
  21. async def chat_profile():
  22. return [
  23. cl.ChatProfile(
  24. name="chat聊天",
  25. markdown_description="聊天demo:支持多轮对话。",
  26. starters=[
  27. cl.Starter(
  28. label="请你用python写一个快速排序。",
  29. message="请你用python写一个快速排序。",
  30. ),
  31. cl.Starter(
  32. label="请你介绍一下自己。",
  33. message="请你介绍一下自己。",
  34. ),
  35. cl.Starter(
  36. label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  37. message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  38. ),
  39. cl.Starter(
  40. label="我是一个python初学者,请你告诉我怎么才能学好python。",
  41. message="我是一个python初学者,请你告诉我怎么才能学好python。",
  42. ),
  43. ],
  44. ),
  45. cl.ChatProfile(
  46. name="联网问答",
  47. markdown_description="联网能力demo:支持联网回答用户问题。",
  48. ),
  49. cl.ChatProfile(
  50. name="上传本地项目",
  51. markdown_description="项目级能力demo:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
  52. ),
  53. ]
  54. @cl.on_chat_start
  55. async def start():
  56. settings = await cl.ChatSettings(
  57. [
  58. Slider(
  59. id="temperature",
  60. label="CodeGeeX4 - Temperature",
  61. initial=0.2,
  62. min=0,
  63. max=1,
  64. step=0.1,
  65. ),
  66. Slider(
  67. id="top_p",
  68. label="CodeGeeX4 - top_p",
  69. initial=0.95,
  70. min=0,
  71. max=1,
  72. step=0.1,
  73. ),
  74. ]
  75. ).send()
  76. temperature = settings["temperature"]
  77. top_p = settings["top_p"]
  78. cl.user_session.set("temperature", temperature)
  79. cl.user_session.set("top_p", top_p)
  80. cl.user_session.set("message_history", [])
  81. chat_profile = cl.user_session.get("chat_profile")
  82. extract_dir = "repodata"
  83. if chat_profile == "chat聊天":
  84. pass
  85. elif chat_profile == "上传本地项目":
  86. files = None
  87. while files == None:
  88. files = await cl.AskFileMessage(
  89. content="请上传项目zip压缩文件!",
  90. accept={"application/zip": [".zip"]},
  91. max_size_mb=50,
  92. ).send()
  93. text_file = files[0]
  94. extracted_path = unzip_file(text_file.path, extract_dir)
  95. files_list = get_project_files_with_content(extracted_path)
  96. cl.user_session.set("project_index", files_list)
  97. if len(files_list) > 0:
  98. await cl.Message(
  99. content=f"已成功上传,您可以开始对项目进行提问!",
  100. ).send()
  101. @cl.on_message
  102. async def main(message: cl.Message):
  103. chat_profile = cl.user_session.get("chat_profile")
  104. message_history = cl.user_session.get("message_history")
  105. message_history.append({"role": "user", "content": message.content})
  106. if chat_profile == "chat聊天":
  107. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  108. elif chat_profile == "联网问答":
  109. judge_context = llm.chat(
  110. web_judge_task_prompt.format(user_input=message.content), temperature=0.2
  111. )
  112. print(judge_context)
  113. message_history.pop()
  114. if "是" in judge_context:
  115. prompt_tmp = bing_search_prompt(message.content)
  116. message_history.append({"role": "user", "content": prompt_tmp})
  117. else:
  118. message_history.append({"role": "user", "content": message.content})
  119. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  120. elif chat_profile == "上传本地项目":
  121. judge_context = llm.chat(
  122. judge_task_prompt.format(user_input=message.content), temperature=0.2
  123. )
  124. project_index = cl.user_session.get("project_index")
  125. index_prompt = ""
  126. index_tmp = """###PATH:{path}\n{code}\n"""
  127. for index in project_index:
  128. index_prompt += index_tmp.format(path=index["path"], code=index["content"])
  129. print(judge_context)
  130. prompt_content = (
  131. get_cur_base_user_prompt(
  132. message_history=message_history,
  133. index_prompt=index_prompt,
  134. judge_context=judge_context,
  135. )
  136. if "正常" not in judge_context
  137. else get_cur_base_user_prompt(message_history=message_history)
  138. )
  139. msg = cl.Message(content="")
  140. await msg.send()
  141. temperature = cl.user_session.get("temperature")
  142. top_p = cl.user_session.get("top_p")
  143. if len(prompt_content) / 4 < 120000:
  144. stream = llm.stream_chat(prompt_content, temperature=temperature, top_p=top_p)
  145. stream_processor = StreamProcessor()
  146. for part in stream:
  147. if isinstance(part, str):
  148. text = stream_processor.get_new_part(part)
  149. elif isinstance(part, dict):
  150. text = stream_processor.get_new_part(part["name"] + part["content"])
  151. if token := (text or " "):
  152. await msg.stream_token(token)
  153. else:
  154. await msg.stream_token("项目太大了,请换小一点的项目。")
  155. message_history.append({"role": "assistant", "content": msg.content})
  156. await msg.update()