run.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import chainlit as cl
  2. from chainlit.input_widget import Slider,Switch
  3. import json
  4. import re
  5. from llm.api.codegeex4 import codegeex4
  6. from prompts.base_prompt import (
  7. get_cur_base_user_prompt,
  8. build_message_list,
  9. tools_choose_prompt,
  10. tools_input_prompt
  11. )
  12. from utils.bingsearch import bing_search_prompt
  13. from utils.tools import unzip_file, get_project_files_with_content,clone_repo,is_valid_json
  14. def tools_choose_agent(input_text):
  15. tools_prompt = tools_choose_prompt+tools_input_prompt.format(input_text=input_text)
  16. message_list = build_message_list(tools_prompt)
  17. judge_tmp = codegeex4(
  18. messages_list=message_list,
  19. temperature=0.2,
  20. top_p=0.95,
  21. )
  22. judge_context = ""
  23. for part in judge_tmp:
  24. judge_context += part
  25. attempt = 1
  26. max_attempts = 10
  27. while not is_valid_json(judge_context) and attempt <= max_attempts:
  28. judge_tmp = codegeex4(
  29. messages_list=message_list,
  30. temperature=0.2,
  31. top_p=0.95,
  32. )
  33. judge_context = ""
  34. for part in judge_tmp:
  35. judge_context += part
  36. attempt += 1
  37. match = re.search(r'\{.*\}', judge_context, re.DOTALL)
  38. if match:
  39. dict_str = match.group()
  40. response = json.loads(dict_str)
  41. else:
  42. response = json.loads(judge_context)
  43. tool_name = response["tool"]["name"]
  44. return tool_name
  45. @cl.set_chat_profiles
  46. async def chat_profile():
  47. return [
  48. cl.ChatProfile(
  49. name="联网聊天",
  50. markdown_description="聊天demo:支持多轮对话。支持联网回答用户问题。默认联网,如不联网在输入框左边关闭联网功能。",
  51. starters=[
  52. cl.Starter(
  53. label="请你用python写一个快速排序。",
  54. message="请你用python写一个快速排序。",
  55. ),
  56. cl.Starter(
  57. label="请你介绍一下自己。",
  58. message="请你介绍一下自己。",
  59. ),
  60. cl.Starter(
  61. label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  62. message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  63. ),
  64. cl.Starter(
  65. label="我是一个python初学者,请你告诉我怎么才能学好python。",
  66. message="我是一个python初学者,请你告诉我怎么才能学好python。",
  67. ),
  68. ],
  69. ),
  70. cl.ChatProfile(
  71. name="项目问答",
  72. markdown_description="项目级能力demo:支持上传本地zip压缩包项目,支持输入GitHub链接项目,可以进行项目问答和对项目进行修改。",
  73. ),
  74. ]
  75. @cl.on_settings_update
  76. async def setup_agent(settings):
  77. temperature = settings["temperature"]
  78. top_p = settings["top_p"]
  79. is_online = settings["is_online"]
  80. cl.user_session.set("temperature", temperature)
  81. cl.user_session.set("top_p", top_p)
  82. cl.user_session.set("is_online", is_online)
  83. @cl.on_chat_start
  84. async def start():
  85. settings = await cl.ChatSettings(
  86. [
  87. Slider(
  88. id="temperature",
  89. label="CodeGeeX4 - Temperature",
  90. initial=0.2,
  91. min=0,
  92. max=1,
  93. step=0.1,
  94. ),
  95. Slider(
  96. id="top_p",
  97. label="CodeGeeX4 - top_p",
  98. initial=0.95,
  99. min=0,
  100. max=1,
  101. step=0.1,
  102. ),
  103. Switch(
  104. id="is_online",
  105. label="CodeGeeX4 - is_online",
  106. initial=True
  107. ),
  108. ]
  109. ).send()
  110. temperature = settings["temperature"]
  111. top_p = settings["top_p"]
  112. is_online = settings["is_online"]
  113. cl.user_session.set("temperature", temperature)
  114. cl.user_session.set("top_p", top_p)
  115. cl.user_session.set("is_online", is_online)
  116. cl.user_session.set("message_history", [])
  117. chat_profile = cl.user_session.get("chat_profile")
  118. extract_dir = "repodata"
  119. if chat_profile == "项目问答":
  120. res = await cl.AskActionMessage(
  121. content="请选择项目上传方式",
  122. actions=[
  123. cl.Action(name="zip", value="zip", label="本地上传zip文件"),
  124. cl.Action(name="url", value="url", label="上传GitHub链接"),
  125. ],
  126. ).send()
  127. if res.get("value") == "url":
  128. repo_path =None
  129. while repo_path == None:
  130. res = await cl.AskUserMessage(content="请你在下面消息框中提供GitHub仓库URL? ex:https://github.com/THUDM/CodeGeeX4", timeout=3600).send()
  131. if res:
  132. repo_path = clone_repo(res['output'],extract_dir)
  133. if repo_path is None:
  134. await cl.Message(
  135. content=f"您的github链接无法正常下载,请检查项目链接或github网络连通情况。",
  136. ).send()
  137. files_list = get_project_files_with_content(repo_path)
  138. cl.user_session.set("project_index", files_list)
  139. if len(files_list) > 0:
  140. await cl.Message(
  141. content=f"已成功上传,您可以开始对项目进行提问!",
  142. ).send()
  143. elif res.get("value") == "zip":
  144. files = None
  145. while files == None:
  146. files = await cl.AskFileMessage(
  147. content="请上传项目zip压缩文件!",
  148. accept={"application/zip": [".zip"]},
  149. max_size_mb=50,
  150. ).send()
  151. text_file = files[0]
  152. extracted_path = unzip_file(text_file.path, extract_dir)
  153. files_list = get_project_files_with_content(extracted_path)
  154. cl.user_session.set("project_index", files_list)
  155. if len(files_list) > 0:
  156. await cl.Message(
  157. content=f"已成功上传,您可以开始对项目进行提问!",
  158. ).send()
  159. @cl.step(type="tool")
  160. async def bing_search_tool(search_text):
  161. current_step = cl.context.current_step
  162. # Simulate a running task
  163. current_step.input = search_text
  164. prompt_tmp = bing_search_prompt(search_text)
  165. current_step.output = prompt_tmp
  166. return prompt_tmp
  167. @cl.on_message
  168. async def main(message: cl.Message):
  169. chat_profile = cl.user_session.get("chat_profile")
  170. message_history = cl.user_session.get("message_history")
  171. tool_name = tools_choose_agent(message.content)
  172. is_online = cl.user_session.get("is_online")
  173. if chat_profile == "联网聊天":
  174. if "online_query" in tool_name and is_online:
  175. prompt_tmp = await bing_search_tool(message.content)
  176. message_history.append({"role": "tool", "content": prompt_tmp})
  177. message_history.append({"role": "user", "content": message.content})
  178. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  179. elif chat_profile == "项目问答":
  180. message_history.append({"role": "user", "content": message.content})
  181. project_index = cl.user_session.get("project_index")
  182. index_prompt = ""
  183. index_tmp = """###PATH:{path}\n{code}\n"""
  184. for index in project_index:
  185. index_prompt += index_tmp.format(path=index["path"], code=index["content"])
  186. if len(tool_name)>0:
  187. prompt_content = get_cur_base_user_prompt(
  188. message_history=message_history,
  189. index_prompt=index_prompt,
  190. judge_context=tool_name[0],
  191. )
  192. else:
  193. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  194. msg = cl.Message(content="")
  195. await msg.send()
  196. temperature = cl.user_session.get("temperature")
  197. top_p = cl.user_session.get("top_p")
  198. if len(prompt_content) / 4 < 120000:
  199. stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
  200. for part in stream:
  201. if token := (part or " "):
  202. await msg.stream_token(token)
  203. else:
  204. await msg.stream_token("项目太大了,请换小一点的项目。")
  205. message_history.append({"role": "assistant", "content": msg.content})
  206. await msg.update()