Przeglądaj źródła

update project structure png

XingYu-Zhong 1 rok temu
rodzic
commit
bb106517c8
3 zmienionych plików z 22 dodań i 17 usunięć
  1. 1 1
      repodemo/prompts/base_prompt.py
  2. 15 15
      repodemo/run.py
  3. 6 1
      repodemo/utils/tools.py

+ 1 - 1
repodemo/prompts/base_prompt.py

@@ -1,6 +1,6 @@
 base_system_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"""
 
-project_mermaid_prompt = """请你根据项目目录为这个项目生成一个架构图。请使用mermaid语言生成这个项目的核心架构图,请你确保mermaid的逻辑是正确的且能被解析的,只需要输出mermaid,需要graph LR形式,尽量精简节点。"""
+project_mermaid_prompt = """请你根据项目目录为这个项目生成一个架构图。请使用mermaid语言生成这个项目的核心架构图,请你确保mermaid的逻辑是正确的且能被解析的,只需要输出mermaid,需要graph LR形式,尽量精简节点,过滤掉对理解项目无关紧要的节点(ex:二进制文件,git文件,日志文件等等),只保留核心节点。"""
 
 web_search_prompy = """你将接收到一个用户提出的问题,并请撰写清晰、简洁且准确的答案。
 

+ 15 - 15
repodemo/run.py

@@ -130,12 +130,12 @@ async def start():
                 cl.Action(name="url", value="url", label="上传GitHub链接"),
             ],
         ).send()
-  
         if res.get("value") == "url":
             repo_path =None
             while repo_path == None:
                 res = await cl.AskUserMessage(content="请你在下面消息框中提供GitHub仓库URL? ex:https://github.com/THUDM/CodeGeeX2", timeout=3600).send()
                 if res:
+                    await cl.Message(content="").send()
                     repo_path = clone_repo(res['output'],extract_dir)
                     if repo_path is None:
                         await cl.Message(
@@ -149,22 +149,20 @@ async def start():
                     accept={"application/zip": [".zip"]},
                     max_size_mb=50,
                 ).send()
-
+            await cl.Message(content="").send()
             text_file = files[0]
             repo_path = unzip_file(text_file.path, extract_dir)
         files_list = get_project_files_with_content(repo_path)
         cl.user_session.set("project_index", files_list)
         if len(files_list) > 0:
             structure_str = await directory_structure(repo_path)
-
-            index_prompt = ""
-            index_tmp = """###PATH:{path}\n{code}\n"""
-            for index in files_list:
-                index_prompt += index_tmp.format(path=index["path"], code=index["content"])
-            message_history=[{"role": "user", "content": structure_str+project_mermaid_prompt}]
+            # index_prompt = ""
+            # index_tmp = """###PATH:{path}\n{code}\n"""
+            # for index in files_list:
+            #     index_prompt += index_tmp.format(path=index["path"], code=index["content"])
+            message_history=[{"role": "user", "content": structure_str+'\n'+project_mermaid_prompt}]
             prompt_content = get_cur_base_user_prompt(message_history=message_history)
-            # prompt_content.insert(-1,{"role": "user", "content": project_mermaid_prompt})
-            print(prompt_content)
+
             if len(prompt_content) / 4 < 120000:
                 mermaid_structure = repo_path+"/structure.png"
                 is_bad = False
@@ -191,7 +189,7 @@ async def start():
                 else:
                     img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
                     await cl.Message(
-                    content=f"已成功上传,这是项目的架构图,您可以开始对项目进行提问!",
+                    content=f"已成功上传,这是项目的目录架构图,您可以开始对项目进行提问!",
                     elements=[img_mermard_structure],
                 ).send()
             else:
@@ -227,15 +225,15 @@ async def main(message: cl.Message):
     message_history = cl.user_session.get("message_history")
     
     tool_name = tools_choose_agent(message.content)
+    print(tool_name)
     is_online = cl.user_session.get("is_online")
-
-    if chat_profile == "联网聊天":
+    
+    if chat_profile == "chat聊天":
         if "online_query" in tool_name and is_online:
             prompt_tmp = await bing_search_tool(message.content)
             message_history.append({"role": "tool", "content": prompt_tmp})
         message_history.append({"role": "user", "content": message.content})
         prompt_content = get_cur_base_user_prompt(message_history=message_history)
-
     elif chat_profile == "项目问答":
         message_history.append({"role": "user", "content": message.content})
         project_index = cl.user_session.get("project_index")
@@ -251,7 +249,9 @@ async def main(message: cl.Message):
             
         else:
             prompt_content = get_cur_base_user_prompt(message_history=message_history)
-        
+    else:
+        message_history.append({"role": "user", "content": message.content})
+        prompt_content = get_cur_base_user_prompt(message_history=message_history)
         
     
     msg = cl.Message(content="")

+ 6 - 1
repodemo/utils/tools.py

@@ -133,13 +133,18 @@ def get_directory_structure(directory_path):
     structure = []
 
     for root, dirs, files in os.walk(directory_path):
+        # Check if the current directory is .git and skip if it is
+        if '.git' in root.split(os.sep):
+            continue
+        
         level = root.replace(directory_path, '').count(os.sep)
         indent = ' ' * 4 * level
         structure.append(f"{indent}{os.path.basename(root)}/")
+        
         sub_indent = ' ' * 4 * (level + 1)
         for file in files:
             structure.append(f"{sub_indent}{file}")
-
+    
     return "\n".join(structure)
 
 def filter_data(obj):