run.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import chainlit as cl
  2. from chainlit.input_widget import Slider
  3. from llm.api.codegeex4 import codegeex4
  4. from prompts.base_prompt import (
  5. judge_task_prompt,
  6. get_cur_base_user_prompt,
  7. web_judge_task_prompt,
  8. )
  9. from utils.bingsearch import bing_search_prompt
  10. from utils.tools import unzip_file, get_project_files_with_content
  11. @cl.set_chat_profiles
  12. async def chat_profile():
  13. return [
  14. cl.ChatProfile(
  15. name="chat聊天",
  16. markdown_description="聊天demo:支持多轮对话。",
  17. starters=[
  18. cl.Starter(
  19. label="请你用python写一个快速排序。",
  20. message="请你用python写一个快速排序。",
  21. ),
  22. cl.Starter(
  23. label="请你介绍一下自己。",
  24. message="请你介绍一下自己。",
  25. ),
  26. cl.Starter(
  27. label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  28. message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  29. ),
  30. cl.Starter(
  31. label="我是一个python初学者,请你告诉我怎么才能学好python。",
  32. message="我是一个python初学者,请你告诉我怎么才能学好python。",
  33. ),
  34. ],
  35. ),
  36. cl.ChatProfile(
  37. name="联网问答",
  38. markdown_description="联网能力demo:支持联网回答用户问题。",
  39. ),
  40. cl.ChatProfile(
  41. name="上传本地项目",
  42. markdown_description="项目级能力demo:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
  43. ),
  44. ]
  45. @cl.on_chat_start
  46. async def start():
  47. settings = await cl.ChatSettings(
  48. [
  49. Slider(
  50. id="temperature",
  51. label="CodeGeeX4 - Temperature",
  52. initial=0.2,
  53. min=0,
  54. max=1,
  55. step=0.1,
  56. ),
  57. Slider(
  58. id="top_p",
  59. label="CodeGeeX4 - top_p",
  60. initial=0.95,
  61. min=0,
  62. max=1,
  63. step=0.1,
  64. ),
  65. ]
  66. ).send()
  67. temperature = settings["temperature"]
  68. top_p = settings["top_p"]
  69. cl.user_session.set("temperature", temperature)
  70. cl.user_session.set("top_p", top_p)
  71. cl.user_session.set("message_history", [])
  72. chat_profile = cl.user_session.get("chat_profile")
  73. extract_dir = "repodata"
  74. if chat_profile == "chat聊天":
  75. pass
  76. elif chat_profile == "上传本地项目":
  77. files = None
  78. while files == None:
  79. files = await cl.AskFileMessage(
  80. content="请上传项目zip压缩文件!",
  81. accept={"application/zip": [".zip"]},
  82. max_size_mb=50,
  83. ).send()
  84. text_file = files[0]
  85. extracted_path = unzip_file(text_file.path, extract_dir)
  86. files_list = get_project_files_with_content(extracted_path)
  87. cl.user_session.set("project_index", files_list)
  88. if len(files_list) > 0:
  89. await cl.Message(
  90. content=f"已成功上传,您可以开始对项目进行提问!",
  91. ).send()
  92. @cl.on_message
  93. async def main(message: cl.Message):
  94. chat_profile = cl.user_session.get("chat_profile")
  95. message_history = cl.user_session.get("message_history")
  96. message_history.append({"role": "user", "content": message.content})
  97. if chat_profile == "chat聊天":
  98. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  99. elif chat_profile == "联网问答":
  100. judge_tmp = codegeex4(
  101. web_judge_task_prompt.format(user_input=message.content),
  102. temperature=0.2,
  103. top_p=0.95,
  104. )
  105. judge_context = "\n".join(judge_tmp)
  106. print(judge_context)
  107. message_history.pop()
  108. if "是" in judge_context:
  109. prompt_tmp = bing_search_prompt(message.content)
  110. message_history.append({"role": "user", "content": prompt_tmp})
  111. else:
  112. message_history.append({"role": "user", "content": message.content})
  113. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  114. elif chat_profile == "上传本地项目":
  115. judge_tmp = codegeex4(
  116. judge_task_prompt.format(user_input=message.content),
  117. temperature=0.2,
  118. top_p=0.95,
  119. )
  120. judge_context = ""
  121. for part in judge_tmp:
  122. judge_context += part
  123. project_index = cl.user_session.get("project_index")
  124. index_prompt = ""
  125. index_tmp = """###PATH:{path}\n{code}\n"""
  126. for index in project_index:
  127. index_prompt += index_tmp.format(path=index["path"], code=index["content"])
  128. print(judge_context)
  129. prompt_content = (
  130. get_cur_base_user_prompt(
  131. message_history=message_history,
  132. index_prompt=index_prompt,
  133. judge_context=judge_context,
  134. )
  135. if "正常" not in judge_context
  136. else get_cur_base_user_prompt(message_history=message_history)
  137. )
  138. msg = cl.Message(content="")
  139. await msg.send()
  140. temperature = cl.user_session.get("temperature")
  141. top_p = cl.user_session.get("top_p")
  142. if len(prompt_content) / 4 < 120000:
  143. stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
  144. for part in stream:
  145. if token := (part or " "):
  146. await msg.stream_token(token)
  147. else:
  148. await msg.stream_token("项目太大了,请换小一点的项目。")
  149. message_history.append({"role": "assistant", "content": msg.content})
  150. await msg.update()