2
0
Эх сурвалжийг харах

Merge pull request #49 from XingYu-Zhong/main

add_feature:Project-level architecture diagram
Qinkai 1 жил өмнө
parent
commit
639ec71f07

+ 9 - 4
repodemo/llm/api/codegeex4.py

@@ -2,21 +2,26 @@ import json
 from openai import OpenAI
 import requests
 import os
+from dotenv import load_dotenv
+# 加载 .env 文件
+load_dotenv()
 def codegeex4(messages_list, temperature=0.2, top_p=0.95):
     openai_api_key = os.getenv("openai_api_key")
     openai_api_base = os.getenv("openai_api_base")
     model_name = os.getenv("model_name")
 
-
     client = OpenAI(
         api_key=openai_api_key,
         base_url=openai_api_base,
     )
-
+ 
     chat_response = client.chat.completions.create(
         model=model_name,
         messages=messages_list,
-        temperature=temperature,
-        top_p=top_p
+        temperature =temperature,
+        top_p=top_p,
+        max_tokens=8192
     )
     return chat_response.choices[0].message.content
+
+

+ 11 - 34
repodemo/prompts/base_prompt.py

@@ -1,6 +1,8 @@
 base_system_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"""
 
-repo_system_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码。请根据用户给出的项目仓库中的代码,以及用户提出的需求,生成新的代码或者更改已有代码。输出格式:\n\n###PATH:{PATH}\n{CODE}"""
+project_mermaid_prompt = """请你根据项目目录为这个项目生成一个架构图。请使用mermaid语言生成这个项目的核心架构图,请你确保mermaid的逻辑是正确的且能被解析的,只需要输出mermaid,需要graph LR形式,尽量精简节点,过滤掉对理解项目无关紧要的节点(ex:二进制文件,git文件,日志文件等等),只保留核心节点。"""
+
+file_summary_prompt = """请你为每个文件提供一句话的总结,描述这个文件的作用、内容、格式等等。输出格式:-filename: 文件名\n-summary: 文件总结 """
 
 web_search_prompy = """你将接收到一个用户提出的问题,并请撰写清晰、简洁且准确的答案。
 
@@ -37,8 +39,7 @@ Here is some examples about tools choosing:
 Input:
 tools_summary: {
     "online_query": "Questions need to be queried on the Internet to ensure accurate answers",
-    "project_qa": "Questions need to be answered specific to the project",
-    "project_modify": "The problem is that we need to modify the project"
+    "project_qa": "Questions need to be answered specific to the project"
 }
 input_text: "今天星期几"
 
@@ -56,33 +57,12 @@ Output:
 }
 
 
-Input:
-tools_summary: {
-    "online_query": "Questions need to be queried on the Internet to ensure accurate answers",
-    "project_qa": "Questions need to be answered specific to the project",
-    "project_modify": "The problem is that we need to modify the project"
-}
-input_text: "请你帮我把项目的readme改成韩文"
-
-Output:
-{
-    "thoughts": {
-        "text": "用户需要将项目的readme文件翻译成韩文。",
-        "reasoning": "根据工具概要,project_modify专用于项目修改,这与用户的需求相符。",
-        "criticism": "需要确保用户对翻译后的韩文内容满意,因为翻译质量可能影响项目的整体感受。",
-        "speak": "我们将使用project_modify来修改项目的readme文件。请确认您希望使用的韩文翻译内容。"
-    },
-    "tool": {
-        "name": ["project_modify"]
-    }
-}
 
 
 Input:
 tools_summary: {
     "online_query": "Questions need to be queried on the Internet to ensure accurate answers",
-    "project_qa": "Questions need to be answered specific to the project",
-    "project_modify": "The problem is that we need to modify the project"
+    "project_qa": "Questions need to be answered specific to the project"
 }
 input_text: "你是谁"
 
@@ -102,8 +82,7 @@ Output:
 Input:
 tools_summary: {
     "online_query": "Questions need to be queried on the Internet to ensure accurate answers",
-    "project_qa": "Questions need to be answered specific to the project",
-    "project_modify": "The problem is that we need to modify the project"
+    "project_qa": "Questions need to be answered specific to the project"
 }
 input_text: "解释一下项目"
 
@@ -148,8 +127,7 @@ Don't output in markdown format, something like ```json or ```,just output in th
 Input:
 tools_summary: {
     "online_query": "Questions need to be queried on the Internet to ensure accurate answers",
-    "project_qa": "Questions need to be answered specific to the project",
-    "project_modify": "The problem is that we need to modify the project"
+    "project_qa": "Questions need to be answered specific to the project"
 }
 """
 tools_input_prompt = """
@@ -172,7 +150,7 @@ def build_message_list(result):
 
     return message_list
 
-def get_cur_base_user_prompt(message_history, index_prompt=None, judge_context=""):
+def get_cur_base_user_prompt(message_history, index_prompt=None):
     user_prompt_tmp = """<|user|>\n{user_input}"""
     assistant_prompt_tmp = """<|assistant|>\n{assistant_input}"""
     history_prompt = ""
@@ -188,10 +166,9 @@ def get_cur_base_user_prompt(message_history, index_prompt=None, judge_context="
             )
 
 
-    if "project_modify" not in judge_context:
-        result = base_system_prompt + history_prompt + """<|assistant|>\n"""
-    else:
-        result = repo_system_prompt + history_prompt + """<|assistant|>\n"""
+
+    result = base_system_prompt + history_prompt + """<|assistant|>\n"""
+   
     
     message_list = build_message_list(result)
     # print(message_list)

+ 3 - 1
repodemo/requirements.txt

@@ -2,4 +2,6 @@ chainlit==1.1.305
 beautifulsoup4
 python-dotenv
 gitpython
-openai==1.35.4
+openai==1.35.4
+
+mermaid-py

+ 79 - 22
repodemo/run.py

@@ -1,3 +1,4 @@
+import os
 import chainlit as cl
 from chainlit.input_widget import Slider,Switch
 import json
@@ -8,10 +9,12 @@ from prompts.base_prompt import (
     get_cur_base_user_prompt,
     build_message_list,
     tools_choose_prompt,
-    tools_input_prompt
+    tools_input_prompt,
+    file_summary_prompt
 )
 from utils.bingsearch import bing_search_prompt
-from utils.tools import unzip_file, get_project_files_with_content,clone_repo,is_valid_json
+from utils.tools import extract_code_text, get_directory_structure, get_mermaid_png, unzip_file, get_project_files_with_content,clone_repo,is_valid_json
+from prompts.base_prompt import project_mermaid_prompt
 
 def tools_choose_agent(input_text):
     tools_prompt = tools_choose_prompt+tools_input_prompt.format(input_text=input_text)
@@ -131,20 +134,14 @@ async def start():
         if res.get("value") == "url":
             repo_path =None
             while repo_path == None:
-                res = await cl.AskUserMessage(content="请你在下面消息框中提供GitHub仓库URL? ex:https://github.com/THUDM/CodeGeeX4", timeout=3600).send()
+                res = await cl.AskUserMessage(content="请你在下面消息框中提供GitHub仓库URL? ex:https://github.com/THUDM/CodeGeeX2", timeout=3600).send()
                 if res:
+                    await cl.Message(content="").send()
                     repo_path = clone_repo(res['output'],extract_dir)
                     if repo_path is None:
                         await cl.Message(
                                 content=f"您的github链接无法正常下载,请检查项目链接或github网络连通情况。",
                             ).send()
-            
-            files_list = get_project_files_with_content(repo_path)
-            cl.user_session.set("project_index", files_list)
-            if len(files_list) > 0:
-                await cl.Message(
-                    content=f"已成功上传,您可以开始对项目进行提问!",
-                ).send()
         elif res.get("value") == "zip":
             files = None
             while files == None:
@@ -153,15 +150,74 @@ async def start():
                     accept={"application/zip": [".zip"]},
                     max_size_mb=50,
                 ).send()
-
+            await cl.Message(content="").send()
             text_file = files[0]
-            extracted_path = unzip_file(text_file.path, extract_dir)
-            files_list = get_project_files_with_content(extracted_path)
-            cl.user_session.set("project_index", files_list)
-            if len(files_list) > 0:
-                await cl.Message(
-                    content=f"已成功上传,您可以开始对项目进行提问!",
+            repo_path = unzip_file(text_file.path, extract_dir)
+        files_list = get_project_files_with_content(repo_path)
+        cl.user_session.set("project_index", files_list)
+        
+        if len(files_list) > 0:
+            
+            structure_str = await directory_structure(repo_path)
+            #TODO: 项目文件太多,需要分批处理,这里暂时只取前5个文件
+            top_files_list = sorted(files_list, key=lambda x: len(x["path"]))[:5]
+            index_prompt = ""
+            index_tmp = """###PATH:{path}\n{code}\n"""
+            for index in top_files_list:
+                index_prompt += index_tmp.format(path=index["path"], code=index["content"])
+            file_summary_message_history=[{"role": "user", "content": index_prompt+'\n'+file_summary_prompt}]
+            file_summary_prompt_content = get_cur_base_user_prompt(message_history=file_summary_message_history)
+            file_summary = codegeex4(file_summary_prompt_content, temperature=temperature, top_p=top_p)
+            # index_prompt = ""
+            # index_tmp = """###PATH:{path}\n{code}\n"""
+            # for index in files_list:
+            #     index_prompt += index_tmp.format(path=index["path"], code=index["content"])
+            message_history=[{"role": "user", "content": structure_str+'\n'+project_mermaid_prompt}]
+            prompt_content = get_cur_base_user_prompt(message_history=message_history)
+
+            if len(prompt_content) / 4 < 120000:
+                mermaid_structure = repo_path+"/structure.png"
+                is_bad = False
+                if not os.path.exists(mermaid_structure):
+                    result_mermaid = ""
+                    retry_attempts = 2
+                    for _ in range(retry_attempts):
+                        stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
+                        print(stream)
+                        for part in stream:
+                            if token := (part or " "):
+                                result_mermaid+=token
+
+                        result_mermaid = extract_code_text(result_mermaid, "mermaid")
+                        if result_mermaid!="":
+                            break
+                    
+                    if result_mermaid:
+                        get_mermaid_png(result_mermaid,mermaid_structure)
+                    else:
+                        is_bad = True
+                if is_bad:
+                    await cl.Message(f"架构图生成失败。但不影响后续项目问答。\n{file_summary}\n- ......").send()
+                else:
+                    img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
+                    await cl.Message(
+                    content=f"已成功上传,这是项目的目录架构图,您可以开始对项目进行提问!\n{file_summary}\n- ......",
+                    elements=[img_mermard_structure],
                 ).send()
+            else:
+                await cl.Message("项目太大了,请换小一点的项目。").send()
+            cl.user_session.set("project_message_history", [])
+            
+
+            
+
[email protected](type="tool")
+async def directory_structure(repo_path):
+    current_step = cl.context.current_step
+    current_step.input = "加载项目目录"
+    structure_str = get_directory_structure(repo_path)
+    current_step.output = "```shell\n"+structure_str+"\n```"
+    return structure_str
 
 
 @cl.step(type="tool")
@@ -181,15 +237,15 @@ async def main(message: cl.Message):
     message_history = cl.user_session.get("message_history")
     
     tool_name = tools_choose_agent(message.content)
+    print(tool_name)
     is_online = cl.user_session.get("is_online")
-
-    if chat_profile == "联网聊天":
+    
+    if chat_profile == "chat聊天":
         if "online_query" in tool_name and is_online:
             prompt_tmp = await bing_search_tool(message.content)
             message_history.append({"role": "tool", "content": prompt_tmp})
         message_history.append({"role": "user", "content": message.content})
         prompt_content = get_cur_base_user_prompt(message_history=message_history)
-
     elif chat_profile == "项目问答":
         message_history.append({"role": "user", "content": message.content})
         project_index = cl.user_session.get("project_index")
@@ -201,12 +257,13 @@ async def main(message: cl.Message):
             prompt_content = get_cur_base_user_prompt(
                     message_history=message_history,
                     index_prompt=index_prompt,
-                    judge_context=tool_name[0],
                 )
             
         else:
             prompt_content = get_cur_base_user_prompt(message_history=message_history)
-        
+    else:
+        message_history.append({"role": "user", "content": message.content})
+        prompt_content = get_cur_base_user_prompt(message_history=message_history)
         
     
     msg = cl.Message(content="")

+ 44 - 0
repodemo/utils/tools.py

@@ -4,6 +4,15 @@ import zipfile
 import git
 import urllib.parse
 import re
+import mermaid as md
+from mermaid.graph import Graph
+
+
+def get_mermaid_png(mermaid_str:str,png_file_path:str):
+    size_comment = "%%{init: {'theme': 'base', 'themeVariables': { 'primaryColor': '#ffcc00', 'edgeLabelBackground':'#ffffff', 'fontSize': 18, 'background': 'transparent', 'lineColor': '#000000' }}}%%\n"
+    graph = Graph('example-flowchart', size_comment+mermaid_str)
+    graphe = md.Mermaid(graph)
+    graphe.to_png(png_file_path)
 
 def is_valid_json(json_string):
     try:
@@ -101,7 +110,42 @@ def get_project_files_with_content(project_dir):
                 continue
 
     return files_list
+def extract_code_text(text,language):
+    
+    start_marker = f"```{language}"
+    end_marker = "```"
+    start_index = text.find(start_marker)
+    end_index = text.find(end_marker, start_index + len(start_marker))
+    # print(start_index, end_index,language)
+    if start_index == -1 or end_index == -1 or start_index >= end_index:
+        return ""
+
+    start_index += len(start_marker)
+    return text[start_index:end_index].strip()
+
+def get_directory_structure(directory_path):
+    """
+    获取指定目录下的文件结构并返回为字符串格式。
+
+    :param directory_path: str, 目录路径
+    :return: str, 文件结构
+    """
+    structure = []
 
+    for root, dirs, files in os.walk(directory_path):
+        # Check if the current directory is .git and skip if it is
+        if '.git' in root.split(os.sep):
+            continue
+        
+        level = root.replace(directory_path, '').count(os.sep)
+        indent = ' ' * 4 * level
+        structure.append(f"{indent}{os.path.basename(root)}/")
+        
+        sub_indent = ' ' * 4 * (level + 1)
+        for file in files:
+            structure.append(f"{sub_indent}{file}")
+    
+    return "\n".join(structure)
 
 def filter_data(obj):
     LANGUAGE_TAG = {