|
|
@@ -1,3 +1,4 @@
|
|
|
+import os
|
|
|
import chainlit as cl
|
|
|
from chainlit.input_widget import Slider,Switch
|
|
|
import json
|
|
|
@@ -8,10 +9,12 @@ from prompts.base_prompt import (
|
|
|
get_cur_base_user_prompt,
|
|
|
build_message_list,
|
|
|
tools_choose_prompt,
|
|
|
- tools_input_prompt
|
|
|
+ tools_input_prompt,
|
|
|
+ file_summary_prompt
|
|
|
)
|
|
|
from utils.bingsearch import bing_search_prompt
|
|
|
-from utils.tools import unzip_file, get_project_files_with_content,clone_repo,is_valid_json
|
|
|
+from utils.tools import extract_code_text, get_directory_structure, get_mermaid_png, unzip_file, get_project_files_with_content,clone_repo,is_valid_json
|
|
|
+from prompts.base_prompt import project_mermaid_prompt
|
|
|
|
|
|
def tools_choose_agent(input_text):
|
|
|
tools_prompt = tools_choose_prompt+tools_input_prompt.format(input_text=input_text)
|
|
|
@@ -131,20 +134,14 @@ async def start():
|
|
|
if res.get("value") == "url":
|
|
|
repo_path =None
|
|
|
while repo_path == None:
|
|
|
- res = await cl.AskUserMessage(content="请你在下面消息框中提供GitHub仓库URL? ex:https://github.com/THUDM/CodeGeeX4", timeout=3600).send()
|
|
|
+ res = await cl.AskUserMessage(content="请你在下面消息框中提供GitHub仓库URL? ex:https://github.com/THUDM/CodeGeeX2", timeout=3600).send()
|
|
|
if res:
|
|
|
+ await cl.Message(content="").send()
|
|
|
repo_path = clone_repo(res['output'],extract_dir)
|
|
|
if repo_path is None:
|
|
|
await cl.Message(
|
|
|
content=f"您的github链接无法正常下载,请检查项目链接或github网络连通情况。",
|
|
|
).send()
|
|
|
-
|
|
|
- files_list = get_project_files_with_content(repo_path)
|
|
|
- cl.user_session.set("project_index", files_list)
|
|
|
- if len(files_list) > 0:
|
|
|
- await cl.Message(
|
|
|
- content=f"已成功上传,您可以开始对项目进行提问!",
|
|
|
- ).send()
|
|
|
elif res.get("value") == "zip":
|
|
|
files = None
|
|
|
while files == None:
|
|
|
@@ -153,15 +150,74 @@ async def start():
|
|
|
accept={"application/zip": [".zip"]},
|
|
|
max_size_mb=50,
|
|
|
).send()
|
|
|
-
|
|
|
+ await cl.Message(content="").send()
|
|
|
text_file = files[0]
|
|
|
- extracted_path = unzip_file(text_file.path, extract_dir)
|
|
|
- files_list = get_project_files_with_content(extracted_path)
|
|
|
- cl.user_session.set("project_index", files_list)
|
|
|
- if len(files_list) > 0:
|
|
|
- await cl.Message(
|
|
|
- content=f"已成功上传,您可以开始对项目进行提问!",
|
|
|
+ repo_path = unzip_file(text_file.path, extract_dir)
|
|
|
+ files_list = get_project_files_with_content(repo_path)
|
|
|
+ cl.user_session.set("project_index", files_list)
|
|
|
+
|
|
|
+ if len(files_list) > 0:
|
|
|
+
|
|
|
+ structure_str = await directory_structure(repo_path)
|
|
|
+ #TODO: 项目文件太多,需要分批处理,这里暂时只取前5个文件
|
|
|
+ top_files_list = sorted(files_list, key=lambda x: len(x["path"]))[:5]
|
|
|
+ index_prompt = ""
|
|
|
+ index_tmp = """###PATH:{path}\n{code}\n"""
|
|
|
+ for index in top_files_list:
|
|
|
+ index_prompt += index_tmp.format(path=index["path"], code=index["content"])
|
|
|
+ file_summary_message_history=[{"role": "user", "content": index_prompt+'\n'+file_summary_prompt}]
|
|
|
+ file_summary_prompt_content = get_cur_base_user_prompt(message_history=file_summary_message_history)
|
|
|
+ file_summary = codegeex4(file_summary_prompt_content, temperature=temperature, top_p=top_p)
|
|
|
+ # index_prompt = ""
|
|
|
+ # index_tmp = """###PATH:{path}\n{code}\n"""
|
|
|
+ # for index in files_list:
|
|
|
+ # index_prompt += index_tmp.format(path=index["path"], code=index["content"])
|
|
|
+ message_history=[{"role": "user", "content": structure_str+'\n'+project_mermaid_prompt}]
|
|
|
+ prompt_content = get_cur_base_user_prompt(message_history=message_history)
|
|
|
+
|
|
|
+ if len(prompt_content) / 4 < 120000:
|
|
|
+ mermaid_structure = repo_path+"/structure.png"
|
|
|
+ is_bad = False
|
|
|
+ if not os.path.exists(mermaid_structure):
|
|
|
+ result_mermaid = ""
|
|
|
+ retry_attempts = 2
|
|
|
+ for _ in range(retry_attempts):
|
|
|
+ stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
|
|
|
+ print(stream)
|
|
|
+ for part in stream:
|
|
|
+ if token := (part or " "):
|
|
|
+ result_mermaid+=token
|
|
|
+
|
|
|
+ result_mermaid = extract_code_text(result_mermaid, "mermaid")
|
|
|
+ if result_mermaid!="":
|
|
|
+ break
|
|
|
+
|
|
|
+ if result_mermaid:
|
|
|
+ get_mermaid_png(result_mermaid,mermaid_structure)
|
|
|
+ else:
|
|
|
+ is_bad = True
|
|
|
+ if is_bad:
|
|
|
+ await cl.Message(f"架构图生成失败。但不影响后续项目问答。\n{file_summary}\n- ......").send()
|
|
|
+ else:
|
|
|
+ img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
|
|
|
+ await cl.Message(
|
|
|
+ content=f"已成功上传,这是项目的目录架构图,您可以开始对项目进行提问!\n{file_summary}\n- ......",
|
|
|
+ elements=[img_mermard_structure],
|
|
|
).send()
|
|
|
+ else:
|
|
|
+ await cl.Message("项目太大了,请换小一点的项目。").send()
|
|
|
+ cl.user_session.set("project_message_history", [])
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
[email protected](type="tool")
|
|
|
+async def directory_structure(repo_path):
|
|
|
+ current_step = cl.context.current_step
|
|
|
+ current_step.input = "加载项目目录"
|
|
|
+ structure_str = get_directory_structure(repo_path)
|
|
|
+ current_step.output = "```shell\n"+structure_str+"\n```"
|
|
|
+ return structure_str
|
|
|
|
|
|
|
|
|
@cl.step(type="tool")
|
|
|
@@ -181,15 +237,15 @@ async def main(message: cl.Message):
|
|
|
message_history = cl.user_session.get("message_history")
|
|
|
|
|
|
tool_name = tools_choose_agent(message.content)
|
|
|
+ print(tool_name)
|
|
|
is_online = cl.user_session.get("is_online")
|
|
|
-
|
|
|
- if chat_profile == "联网聊天":
|
|
|
+
|
|
|
+ if chat_profile == "chat聊天":
|
|
|
if "online_query" in tool_name and is_online:
|
|
|
prompt_tmp = await bing_search_tool(message.content)
|
|
|
message_history.append({"role": "tool", "content": prompt_tmp})
|
|
|
message_history.append({"role": "user", "content": message.content})
|
|
|
prompt_content = get_cur_base_user_prompt(message_history=message_history)
|
|
|
-
|
|
|
elif chat_profile == "项目问答":
|
|
|
message_history.append({"role": "user", "content": message.content})
|
|
|
project_index = cl.user_session.get("project_index")
|
|
|
@@ -201,12 +257,13 @@ async def main(message: cl.Message):
|
|
|
prompt_content = get_cur_base_user_prompt(
|
|
|
message_history=message_history,
|
|
|
index_prompt=index_prompt,
|
|
|
- judge_context=tool_name[0],
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
prompt_content = get_cur_base_user_prompt(message_history=message_history)
|
|
|
-
|
|
|
+ else:
|
|
|
+ message_history.append({"role": "user", "content": message.content})
|
|
|
+ prompt_content = get_cur_base_user_prompt(message_history=message_history)
|
|
|
|
|
|
|
|
|
msg = cl.Message(content="")
|