Browse Source

feature:add mermaid_png function

XingYu-Zhong 1 year ago
parent
commit
5025da65e4
4 changed files with 90 additions and 14 deletions
  1. 2 0
      repodemo/prompts/base_prompt.py
  2. 3 1
      repodemo/requirements.txt
  3. 47 13
      repodemo/run.py
  4. 38 0
      repodemo/utils/tools.py

+ 2 - 0
repodemo/prompts/base_prompt.py

@@ -1,5 +1,7 @@
 base_system_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"""
 
+project_mermaid_prompt = """请你分析项目内容后为这个项目生成一个架构图。请使用mermaid语言生成这个项目的架构图,请你确保mermaid的逻辑是正确的且能被解析的,只需要输出mermaid。"""
+
 web_search_prompy = """你将接收到一个用户提出的问题,并请撰写清晰、简洁且准确的答案。
 
 # Note

+ 3 - 1
repodemo/requirements.txt

@@ -2,4 +2,6 @@ chainlit==1.1.305
 beautifulsoup4
 python-dotenv
 gitpython
-openai==1.35.4
+openai==1.35.4
+
+mermaid-py

+ 47 - 13
repodemo/run.py

@@ -1,3 +1,4 @@
+import os
 import chainlit as cl
 from chainlit.input_widget import Slider,Switch
 import json
@@ -11,7 +12,8 @@ from prompts.base_prompt import (
     tools_input_prompt
 )
 from utils.bingsearch import bing_search_prompt
-from utils.tools import unzip_file, get_project_files_with_content,clone_repo,is_valid_json
+from utils.tools import extract_code_text, get_directory_structure, get_mermaid_png, unzip_file, get_project_files_with_content,clone_repo,is_valid_json
+from prompts.base_prompt import project_mermaid_prompt
 
 def tools_choose_agent(input_text):
     tools_prompt = tools_choose_prompt+tools_input_prompt.format(input_text=input_text)
@@ -139,13 +141,6 @@ async def start():
                         await cl.Message(
                                 content=f"您的github链接无法正常下载,请检查项目链接或github网络连通情况。",
                             ).send()
-            
-            files_list = get_project_files_with_content(repo_path)
-            cl.user_session.set("project_index", files_list)
-            if len(files_list) > 0:
-                await cl.Message(
-                    content=f"已成功上传,您可以开始对项目进行提问!",
-                ).send()
         elif res.get("value") == "zip":
             files = None
             while files == None:
@@ -156,13 +151,52 @@ async def start():
                 ).send()
 
             text_file = files[0]
-            extracted_path = unzip_file(text_file.path, extract_dir)
-            files_list = get_project_files_with_content(extracted_path)
-            cl.user_session.set("project_index", files_list)
-            if len(files_list) > 0:
+            repo_path = unzip_file(text_file.path, extract_dir)
+        files_list = get_project_files_with_content(repo_path)
+        cl.user_session.set("project_index", files_list)
+        if len(files_list) > 0:
+            await directory_structure(repo_path)
+
+            index_prompt = ""
+            index_tmp = """###PATH:{path}\n{code}\n"""
+            for index in files_list:
+                index_prompt += index_tmp.format(path=index["path"], code=index["content"])
+            message_history=[{"role": "user", "content": project_mermaid_prompt}]
+            prompt_content = get_cur_base_user_prompt(message_history=message_history, index_prompt=index_prompt)
+            print(prompt_content)
+            if len(prompt_content) / 4 < 120000:
+                mermaid_structure = repo_path+"/structure.png"
+                if not os.path.exists(mermaid_structure):
+                    result_mermaid = ""
+                    while result_mermaid == "":
+                        stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
+                        
+                        for part in stream:
+                            if token := (part or " "):
+                                result_mermaid+=token
+
+                        result_mermaid = extract_code_text(result_mermaid, "mermaid")
+        
+                    mermaid_structure = repo_path+"/structure.png"
+
+                    get_mermaid_png(result_mermaid,mermaid_structure)
+                img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
                 await cl.Message(
-                    content=f"已成功上传,您可以开始对项目进行提问!",
+                    content=f"已成功上传,这是项目的架构图,您可以开始对项目进行提问!",
+                    elements=[img_mermard_structure],
                 ).send()
+            else:
+                await cl.Message("项目太大了,请换小一点的项目。").send()
+            
+
+            
+
[email protected](type="tool")
+async def directory_structure(repo_path):
+    current_step = cl.context.current_step
+    current_step.input = "加载项目目录"
+    structure_str = get_directory_structure(repo_path)
+    current_step.output = "```shell\n"+structure_str+"\n```"
 
 
 @cl.step(type="tool")

+ 38 - 0
repodemo/utils/tools.py

@@ -4,7 +4,15 @@ import zipfile
 import git
 import urllib.parse
 import re
+import mermaid as md
+from mermaid.graph import Graph
 
+
+def get_mermaid_png(mermaid_str:str,png_file_path:str):
+    graph = Graph('example-flowchart', mermaid_str)
+    graphe = md.Mermaid(graph)
+    graphe.to_png(png_file_path)
+    
 def is_valid_json(json_string):
     try:
         match = re.search(r'\{.*\}', json_string, re.DOTALL)
@@ -101,7 +109,37 @@ def get_project_files_with_content(project_dir):
                 continue
 
     return files_list
+def extract_code_text(text,language):
+    
+    start_marker = f"```{language}"
+    end_marker = "```"
+    start_index = text.find(start_marker)
+    end_index = text.find(end_marker, start_index + len(start_marker))
+    # print(start_index, end_index,language)
+    if start_index == -1 or end_index == -1 or start_index >= end_index:
+        return ""
+
+    start_index += len(start_marker)
+    return text[start_index:end_index].strip()
+
+def get_directory_structure(directory_path):
+    """
+    获取指定目录下的文件结构并返回为字符串格式。
+
+    :param directory_path: str, 目录路径
+    :return: str, 文件结构
+    """
+    structure = []
+
+    for root, dirs, files in os.walk(directory_path):
+        level = root.replace(directory_path, '').count(os.sep)
+        indent = ' ' * 4 * level
+        structure.append(f"{indent}{os.path.basename(root)}/")
+        sub_indent = ' ' * 4 * (level + 1)
+        for file in files:
+            structure.append(f"{sub_indent}{file}")
 
+    return "\n".join(structure)
 
 def filter_data(obj):
     LANGUAGE_TAG = {