|
|
@@ -1,3 +1,4 @@
|
|
|
+import os
|
|
|
import chainlit as cl
|
|
|
from chainlit.input_widget import Slider,Switch
|
|
|
import json
|
|
|
@@ -11,7 +12,8 @@ from prompts.base_prompt import (
|
|
|
tools_input_prompt
|
|
|
)
|
|
|
from utils.bingsearch import bing_search_prompt
|
|
|
-from utils.tools import unzip_file, get_project_files_with_content,clone_repo,is_valid_json
|
|
|
+from utils.tools import extract_code_text, get_directory_structure, get_mermaid_png, unzip_file, get_project_files_with_content,clone_repo,is_valid_json
|
|
|
+from prompts.base_prompt import project_mermaid_prompt
|
|
|
|
|
|
def tools_choose_agent(input_text):
|
|
|
tools_prompt = tools_choose_prompt+tools_input_prompt.format(input_text=input_text)
|
|
|
@@ -139,13 +141,6 @@ async def start():
|
|
|
await cl.Message(
|
|
|
content=f"您的github链接无法正常下载,请检查项目链接或github网络连通情况。",
|
|
|
).send()
|
|
|
-
|
|
|
- files_list = get_project_files_with_content(repo_path)
|
|
|
- cl.user_session.set("project_index", files_list)
|
|
|
- if len(files_list) > 0:
|
|
|
- await cl.Message(
|
|
|
- content=f"已成功上传,您可以开始对项目进行提问!",
|
|
|
- ).send()
|
|
|
elif res.get("value") == "zip":
|
|
|
files = None
|
|
|
while files == None:
|
|
|
@@ -156,13 +151,52 @@ async def start():
|
|
|
).send()
|
|
|
|
|
|
text_file = files[0]
|
|
|
- extracted_path = unzip_file(text_file.path, extract_dir)
|
|
|
- files_list = get_project_files_with_content(extracted_path)
|
|
|
- cl.user_session.set("project_index", files_list)
|
|
|
- if len(files_list) > 0:
|
|
|
+ repo_path = unzip_file(text_file.path, extract_dir)
|
|
|
+ files_list = get_project_files_with_content(repo_path)
|
|
|
+ cl.user_session.set("project_index", files_list)
|
|
|
+ if len(files_list) > 0:
|
|
|
+ await directory_structure(repo_path)
|
|
|
+
|
|
|
+ index_prompt = ""
|
|
|
+ index_tmp = """###PATH:{path}\n{code}\n"""
|
|
|
+ for index in files_list:
|
|
|
+ index_prompt += index_tmp.format(path=index["path"], code=index["content"])
|
|
|
+ message_history=[{"role": "user", "content": project_mermaid_prompt}]
|
|
|
+ prompt_content = get_cur_base_user_prompt(message_history=message_history, index_prompt=index_prompt)
|
|
|
+ print(prompt_content)
|
|
|
+ if len(prompt_content) / 4 < 120000:
|
|
|
+ mermaid_structure = repo_path+"/structure.png"
|
|
|
+ if not os.path.exists(mermaid_structure):
|
|
|
+ result_mermaid = ""
|
|
|
+ while result_mermaid == "":
|
|
|
+ stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
|
|
|
+
|
|
|
+ for part in stream:
|
|
|
+ if token := (part or " "):
|
|
|
+ result_mermaid+=token
|
|
|
+
|
|
|
+ result_mermaid = extract_code_text(result_mermaid, "mermaid")
|
|
|
+
|
|
|
+ mermaid_structure = repo_path+"/structure.png"
|
|
|
+
|
|
|
+ get_mermaid_png(result_mermaid,mermaid_structure)
|
|
|
+ img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
|
|
|
await cl.Message(
|
|
|
- content=f"已成功上传,您可以开始对项目进行提问!",
|
|
|
+ content=f"已成功上传,这是项目的架构图,您可以开始对项目进行提问!",
|
|
|
+ elements=[img_mermard_structure],
|
|
|
).send()
|
|
|
+ else:
|
|
|
+ await cl.Message("项目太大了,请换小一点的项目。").send()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
[email protected](type="tool")
|
|
|
+async def directory_structure(repo_path):
|
|
|
+ current_step = cl.context.current_step
|
|
|
+ current_step.input = "加载项目目录"
|
|
|
+ structure_str = get_directory_structure(repo_path)
|
|
|
+ current_step.output = "```shell\n"+structure_str+"\n```"
|
|
|
|
|
|
|
|
|
@cl.step(type="tool")
|