run.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import os
  2. import chainlit as cl
  3. from chainlit.input_widget import Slider,Switch
  4. import json
  5. import re
  6. from llm.api.codegeex4 import codegeex4
  7. from prompts.base_prompt import (
  8. get_cur_base_user_prompt,
  9. build_message_list,
  10. tools_choose_prompt,
  11. tools_input_prompt,
  12. file_summary_prompt
  13. )
  14. from utils.bingsearch import bing_search_prompt
  15. from utils.tools import extract_code_text, get_directory_structure, get_mermaid_png, unzip_file, get_project_files_with_content,clone_repo,is_valid_json
  16. from prompts.base_prompt import project_mermaid_prompt
  17. def tools_choose_agent(input_text):
  18. tools_prompt = tools_choose_prompt+tools_input_prompt.format(input_text=input_text)
  19. message_list = build_message_list(tools_prompt)
  20. judge_tmp = codegeex4(
  21. messages_list=message_list,
  22. temperature=0.2,
  23. top_p=0.95,
  24. )
  25. judge_context = ""
  26. for part in judge_tmp:
  27. judge_context += part
  28. attempt = 1
  29. max_attempts = 10
  30. while not is_valid_json(judge_context) and attempt <= max_attempts:
  31. judge_tmp = codegeex4(
  32. messages_list=message_list,
  33. temperature=0.2,
  34. top_p=0.95,
  35. )
  36. judge_context = ""
  37. for part in judge_tmp:
  38. judge_context += part
  39. attempt += 1
  40. match = re.search(r'\{.*\}', judge_context, re.DOTALL)
  41. if match:
  42. dict_str = match.group()
  43. response = json.loads(dict_str)
  44. else:
  45. response = json.loads(judge_context)
  46. tool_name = response["tool"]["name"]
  47. return tool_name
  48. @cl.set_chat_profiles
  49. async def chat_profile():
  50. return [
  51. cl.ChatProfile(
  52. name="chat聊天",
  53. markdown_description="聊天demo:支持多轮对话。支持联网回答用户问题(需要在输入框左边打开联网开关)。默认联网,如不联网在输入框左边关闭联网功能。",
  54. starters=[
  55. cl.Starter(
  56. label="请你用python写一个快速排序。",
  57. message="请你用python写一个快速排序。",
  58. ),
  59. cl.Starter(
  60. label="请你介绍一下自己。",
  61. message="请你介绍一下自己。",
  62. ),
  63. cl.Starter(
  64. label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  65. message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  66. ),
  67. cl.Starter(
  68. label="我是一个python初学者,请你告诉我怎么才能学好python。",
  69. message="我是一个python初学者,请你告诉我怎么才能学好python。",
  70. ),
  71. ],
  72. ),
  73. cl.ChatProfile(
  74. name="项目问答",
  75. markdown_description="项目级能力demo:支持上传本地zip压缩包项目,支持输入GitHub链接项目,可以进行项目问答和对项目进行修改。",
  76. ),
  77. ]
  78. @cl.on_settings_update
  79. async def setup_agent(settings):
  80. temperature = settings["temperature"]
  81. top_p = settings["top_p"]
  82. is_online = settings["is_online"]
  83. cl.user_session.set("temperature", temperature)
  84. cl.user_session.set("top_p", top_p)
  85. cl.user_session.set("is_online", is_online)
  86. @cl.on_chat_start
  87. async def start():
  88. settings = await cl.ChatSettings(
  89. [
  90. Slider(
  91. id="temperature",
  92. label="CodeGeeX4 - Temperature",
  93. initial=0.2,
  94. min=0,
  95. max=1,
  96. step=0.1,
  97. ),
  98. Slider(
  99. id="top_p",
  100. label="CodeGeeX4 - top_p",
  101. initial=0.95,
  102. min=0,
  103. max=1,
  104. step=0.1,
  105. ),
  106. Switch(
  107. id="is_online",
  108. label="CodeGeeX4 - is_online",
  109. initial=False
  110. ),
  111. ]
  112. ).send()
  113. temperature = settings["temperature"]
  114. top_p = settings["top_p"]
  115. is_online = settings["is_online"]
  116. cl.user_session.set("temperature", temperature)
  117. cl.user_session.set("top_p", top_p)
  118. cl.user_session.set("is_online", is_online)
  119. cl.user_session.set("message_history", [])
  120. chat_profile = cl.user_session.get("chat_profile")
  121. extract_dir = "repodata"
  122. if chat_profile == "项目问答":
  123. res = await cl.AskActionMessage(
  124. content="请选择项目上传方式",
  125. actions=[
  126. cl.Action(name="zip", value="zip", label="本地上传zip文件"),
  127. cl.Action(name="url", value="url", label="上传GitHub链接"),
  128. ],
  129. ).send()
  130. if res.get("value") == "url":
  131. repo_path =None
  132. while repo_path == None:
  133. res = await cl.AskUserMessage(content="请你在下面消息框中提供GitHub仓库URL? ex:https://github.com/THUDM/CodeGeeX2", timeout=3600).send()
  134. if res:
  135. await cl.Message(content="").send()
  136. repo_path = clone_repo(res['output'],extract_dir)
  137. if repo_path is None:
  138. await cl.Message(
  139. content=f"您的github链接无法正常下载,请检查项目链接或github网络连通情况。",
  140. ).send()
  141. elif res.get("value") == "zip":
  142. files = None
  143. while files == None:
  144. files = await cl.AskFileMessage(
  145. content="请上传项目zip压缩文件!",
  146. accept={"application/zip": [".zip"]},
  147. max_size_mb=50,
  148. ).send()
  149. await cl.Message(content="").send()
  150. text_file = files[0]
  151. repo_path = unzip_file(text_file.path, extract_dir)
  152. files_list = get_project_files_with_content(repo_path)
  153. cl.user_session.set("project_index", files_list)
  154. if len(files_list) > 0:
  155. structure_str = await directory_structure(repo_path)
  156. #TODO: 项目文件太多,需要分批处理,这里暂时只取前5个文件
  157. top_files_list = sorted(files_list, key=lambda x: len(x["path"]))[:5]
  158. index_prompt = ""
  159. index_tmp = """###PATH:{path}\n{code}\n"""
  160. for index in top_files_list:
  161. index_prompt += index_tmp.format(path=index["path"], code=index["content"])
  162. file_summary_message_history=[{"role": "user", "content": index_prompt+'\n'+file_summary_prompt}]
  163. file_summary_prompt_content = get_cur_base_user_prompt(message_history=file_summary_message_history)
  164. file_summary = codegeex4(file_summary_prompt_content, temperature=temperature, top_p=top_p)
  165. # index_prompt = ""
  166. # index_tmp = """###PATH:{path}\n{code}\n"""
  167. # for index in files_list:
  168. # index_prompt += index_tmp.format(path=index["path"], code=index["content"])
  169. message_history=[{"role": "user", "content": structure_str+'\n'+project_mermaid_prompt}]
  170. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  171. if len(prompt_content) / 4 < 120000:
  172. mermaid_structure = repo_path+"/structure.png"
  173. is_bad = False
  174. if not os.path.exists(mermaid_structure):
  175. result_mermaid = ""
  176. retry_attempts = 2
  177. for _ in range(retry_attempts):
  178. stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
  179. print(stream)
  180. for part in stream:
  181. if token := (part or " "):
  182. result_mermaid+=token
  183. result_mermaid = extract_code_text(result_mermaid, "mermaid")
  184. if result_mermaid!="":
  185. break
  186. if result_mermaid:
  187. get_mermaid_png(result_mermaid,mermaid_structure)
  188. else:
  189. is_bad = True
  190. if is_bad:
  191. await cl.Message(f"架构图生成失败。但不影响后续项目问答。\n{file_summary}\n- ......").send()
  192. else:
  193. img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
  194. await cl.Message(
  195. content=f"已成功上传,这是项目的目录架构图,您可以开始对项目进行提问!\n{file_summary}\n- ......",
  196. elements=[img_mermard_structure],
  197. ).send()
  198. else:
  199. await cl.Message("项目太大了,请换小一点的项目。").send()
  200. cl.user_session.set("project_message_history", [])
  201. @cl.step(type="tool")
  202. async def directory_structure(repo_path):
  203. current_step = cl.context.current_step
  204. current_step.input = "加载项目目录"
  205. structure_str = get_directory_structure(repo_path)
  206. current_step.output = "```shell\n"+structure_str+"\n```"
  207. return structure_str
  208. @cl.step(type="tool")
  209. async def bing_search_tool(search_text):
  210. current_step = cl.context.current_step
  211. # Simulate a running task
  212. current_step.input = search_text
  213. prompt_tmp = bing_search_prompt(search_text)
  214. current_step.output = prompt_tmp
  215. return prompt_tmp
  216. @cl.on_message
  217. async def main(message: cl.Message):
  218. chat_profile = cl.user_session.get("chat_profile")
  219. message_history = cl.user_session.get("message_history")
  220. tool_name = tools_choose_agent(message.content)
  221. print(tool_name)
  222. is_online = cl.user_session.get("is_online")
  223. if chat_profile == "chat聊天":
  224. if "online_query" in tool_name and is_online:
  225. prompt_tmp = await bing_search_tool(message.content)
  226. message_history.append({"role": "tool", "content": prompt_tmp})
  227. message_history.append({"role": "user", "content": message.content})
  228. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  229. elif chat_profile == "项目问答":
  230. message_history.append({"role": "user", "content": message.content})
  231. project_index = cl.user_session.get("project_index")
  232. index_prompt = ""
  233. index_tmp = """###PATH:{path}\n{code}\n"""
  234. for index in project_index:
  235. index_prompt += index_tmp.format(path=index["path"], code=index["content"])
  236. if len(tool_name)>0:
  237. prompt_content = get_cur_base_user_prompt(
  238. message_history=message_history,
  239. index_prompt=index_prompt,
  240. )
  241. else:
  242. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  243. else:
  244. message_history.append({"role": "user", "content": message.content})
  245. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  246. msg = cl.Message(content="")
  247. await msg.send()
  248. temperature = cl.user_session.get("temperature")
  249. top_p = cl.user_session.get("top_p")
  250. if len(prompt_content) / 4 < 120000:
  251. stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
  252. for part in stream:
  253. if token := (part or " "):
  254. await msg.stream_token(token)
  255. else:
  256. await msg.stream_token("项目太大了,请换小一点的项目。")
  257. message_history.append({"role": "assistant", "content": msg.content})
  258. await msg.update()