run_local.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import chainlit as cl
  2. from chainlit.input_widget import Slider
  3. from llm.api.codegeex4 import codegeex4
  4. from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_judge_task_prompt
  5. from utils.tools import unzip_file,get_project_files_with_content
  6. from utils.bingsearch import bing_search_prompt
  7. from llm.local.codegeex4 import CodegeexChatModel
  8. local_model_path = '<your_local_model_path>'
  9. llm = CodegeexChatModel(local_model_path)
  10. class StreamProcessor:
  11. def __init__(self):
  12. self.previous_str = ""
  13. def get_new_part(self, new_str):
  14. new_part = new_str[len(self.previous_str):]
  15. self.previous_str = new_str
  16. return new_part
  17. @cl.set_chat_profiles
  18. async def chat_profile():
  19. return [
  20. cl.ChatProfile(
  21. name="chat聊天",
  22. markdown_description="聊天demo:支持多轮对话。",
  23. starters = [
  24. cl.Starter(
  25. label="请你用python写一个快速排序。",
  26. message="请你用python写一个快速排序。",
  27. ),
  28. cl.Starter(
  29. label="请你介绍一下自己。",
  30. message="请你介绍一下自己。",
  31. ),
  32. cl.Starter(
  33. label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  34. message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
  35. ),
  36. cl.Starter(
  37. label="我是一个python初学者,请你告诉我怎么才能学好python。",
  38. message="我是一个python初学者,请你告诉我怎么才能学好python。",
  39. )
  40. ]
  41. ),
  42. cl.ChatProfile(
  43. name="联网问答",
  44. markdown_description="联网能力dome:支持联网回答用户问题。",
  45. ),
  46. cl.ChatProfile(
  47. name="上传本地项目",
  48. markdown_description="项目级能力dome:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
  49. )
  50. ]
  51. @cl.on_chat_start
  52. async def start():
  53. settings = await cl.ChatSettings(
  54. [
  55. Slider(
  56. id="temperature",
  57. label="CodeGeeX4 - Temperature",
  58. initial=0.2,
  59. min=0,
  60. max=1,
  61. step=0.1,
  62. ),
  63. Slider(
  64. id="top_p",
  65. label="CodeGeeX4 - top_p",
  66. initial=0.95,
  67. min=0,
  68. max=1,
  69. step=0.1,
  70. ),
  71. ]
  72. ).send()
  73. temperature = settings["temperature"]
  74. top_p = settings["top_p"]
  75. cl.user_session.set('temperature',temperature)
  76. cl.user_session.set('top_p',top_p)
  77. cl.user_session.set(
  78. "message_history",
  79. []
  80. )
  81. chat_profile = cl.user_session.get("chat_profile")
  82. extract_dir = 'repodata'
  83. if chat_profile == "chat聊天":
  84. pass
  85. elif chat_profile =="上传本地项目":
  86. files = None
  87. while files == None:
  88. files = await cl.AskFileMessage(
  89. content="请上传项目zip压缩文件!", accept={"application/zip": [".zip"]},max_size_mb=50
  90. ).send()
  91. text_file = files[0]
  92. extracted_path = unzip_file(text_file.path,extract_dir)
  93. files_list = get_project_files_with_content(extracted_path)
  94. cl.user_session.set("project_index",files_list)
  95. if len(files_list)>0:
  96. await cl.Message(
  97. content=f"已成功上传,您可以开始对项目进行提问!",
  98. ).send()
  99. @cl.on_message
  100. async def main(message: cl.Message):
  101. chat_profile = cl.user_session.get("chat_profile")
  102. message_history = cl.user_session.get("message_history")
  103. message_history.append({"role": "user", "content": message.content})
  104. if chat_profile == "chat聊天":
  105. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  106. elif chat_profile=="联网问答":
  107. judge_context = llm.chat(web_judge_task_prompt.format(user_input=message.content),temperature=0.2)
  108. print(judge_context)
  109. message_history.pop()
  110. if '是' in judge_context:
  111. prompt_tmp = bing_search_prompt(message.content)
  112. message_history.append({"role": "user", "content": prompt_tmp})
  113. else:
  114. message_history.append({"role": "user", "content": message.content})
  115. prompt_content = get_cur_base_user_prompt(message_history=message_history)
  116. elif chat_profile =="上传本地项目" :
  117. judge_context = llm.chat(judge_task_prompt.format(user_input=message.content),temperature=0.2)
  118. project_index = cl.user_session.get("project_index")
  119. index_prompt = ""
  120. index_tmp = """###PATH:{path}\n{code}\n"""
  121. for index in project_index:
  122. index_prompt+=index_tmp.format(path=index['path'],code=index['content'])
  123. print(judge_context)
  124. prompt_content = get_cur_base_user_prompt(message_history=message_history,index_prompt=index_prompt,judge_context=judge_context) if '正常' not in judge_context else get_cur_base_user_prompt(message_history=message_history)
  125. msg = cl.Message(content="")
  126. await msg.send()
  127. temperature = cl.user_session.get("temperature")
  128. top_p = cl.user_session.get('top_p')
  129. if len(prompt_content)/4<120000:
  130. stream = llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p)
  131. for part in stream:
  132. if token := (part or " "):
  133. await msg.stream_token(token)
  134. else:
  135. await msg.stream_token("项目太大了,请换小一点的项目。")
  136. message_history.append({"role": "assistant", "content": msg.content})
  137. await msg.update()