Sfoglia il codice sorgente

feat: Update project structure generation logic

XingYu-Zhong 1 anno fa
parent
commit
91145f55ab

+ 9 - 4
repodemo/llm/api/codegeex4.py

@@ -2,21 +2,26 @@ import json
 from openai import OpenAI
 import requests
 import os
+from dotenv import load_dotenv
+# 加载 .env 文件
+load_dotenv()
 def codegeex4(messages_list, temperature=0.2, top_p=0.95):
     openai_api_key = os.getenv("openai_api_key")
     openai_api_base = os.getenv("openai_api_base")
     model_name = os.getenv("model_name")
 
-
     client = OpenAI(
         api_key=openai_api_key,
         base_url=openai_api_base,
     )
-
+ 
     chat_response = client.chat.completions.create(
         model=model_name,
         messages=messages_list,
-        temperature=temperature,
-        top_p=top_p
+        temperature =temperature,
+        top_p=top_p,
+        max_tokens=8192
     )
     return chat_response.choices[0].message.content
+
+

+ 1 - 1
repodemo/prompts/base_prompt.py

@@ -1,6 +1,6 @@
 base_system_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"""
 
-project_mermaid_prompt = """请你分析项目内容后为这个项目生成一个架构图。请使用mermaid语言生成这个项目的架构图,请你确保mermaid的逻辑是正确的且能被解析的,只需要输出mermaid。"""
+project_mermaid_prompt = """请你根据项目目录为这个项目生成一个架构图。请使用mermaid语言生成这个项目的核心架构图,请你确保mermaid的逻辑是正确的且能被解析的,只需要输出mermaid,需要graph LR形式,尽量精简节点。"""
 
 web_search_prompy = """你将接收到一个用户提出的问题,并请撰写清晰、简洁且准确的答案。
 

+ 22 - 11
repodemo/run.py

@@ -155,38 +155,48 @@ async def start():
         files_list = get_project_files_with_content(repo_path)
         cl.user_session.set("project_index", files_list)
         if len(files_list) > 0:
-            await directory_structure(repo_path)
+            structure_str = await directory_structure(repo_path)
 
             index_prompt = ""
             index_tmp = """###PATH:{path}\n{code}\n"""
             for index in files_list:
                 index_prompt += index_tmp.format(path=index["path"], code=index["content"])
-            message_history=[{"role": "user", "content": project_mermaid_prompt}]
-            prompt_content = get_cur_base_user_prompt(message_history=message_history, index_prompt=index_prompt)
+            message_history=[{"role": "user", "content": structure_str+project_mermaid_prompt}]
+            prompt_content = get_cur_base_user_prompt(message_history=message_history)
+            # prompt_content.insert(-1,{"role": "user", "content": project_mermaid_prompt})
             print(prompt_content)
             if len(prompt_content) / 4 < 120000:
                 mermaid_structure = repo_path+"/structure.png"
+                is_bad = False
                 if not os.path.exists(mermaid_structure):
                     result_mermaid = ""
-                    while result_mermaid == "":
+                    retry_attempts = 2
+                    for _ in range(retry_attempts):
                         stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
-                        
+                        print(stream)
                         for part in stream:
                             if token := (part or " "):
                                 result_mermaid+=token
 
                         result_mermaid = extract_code_text(result_mermaid, "mermaid")
-        
-                    mermaid_structure = repo_path+"/structure.png"
-
-                    get_mermaid_png(result_mermaid,mermaid_structure)
-                img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
-                await cl.Message(
+                        if result_mermaid!="":
+                            break
+                    
+                    if result_mermaid:
+                        get_mermaid_png(result_mermaid,mermaid_structure)
+                    else:
+                        is_bad = True
+                if is_bad:
+                    await cl.Message("架构图生成失败。但不影响后续项目问答。").send()
+                else:
+                    img_mermard_structure = cl.Image(path=mermaid_structure, name="structure", display="inline",size="large")
+                    await cl.Message(
                     content=f"已成功上传,这是项目的架构图,您可以开始对项目进行提问!",
                     elements=[img_mermard_structure],
                 ).send()
             else:
                 await cl.Message("项目太大了,请换小一点的项目。").send()
+            cl.user_session.set("project_message_history", [])
             
 
             
@@ -197,6 +207,7 @@ async def directory_structure(repo_path):
     current_step.input = "加载项目目录"
     structure_str = get_directory_structure(repo_path)
     current_step.output = "```shell\n"+structure_str+"\n```"
+    return structure_str
 
 
 @cl.step(type="tool")

+ 3 - 2
repodemo/utils/tools.py

@@ -9,10 +9,11 @@ from mermaid.graph import Graph
 
 
 def get_mermaid_png(mermaid_str:str,png_file_path:str):
-    graph = Graph('example-flowchart', mermaid_str)
+    size_comment = "%%{init: {'theme': 'base', 'themeVariables': { 'primaryColor': '#ffcc00', 'edgeLabelBackground':'#ffffff', 'fontSize': 18, 'background': 'transparent', 'lineColor': '#000000' }}}%%\n"
+    graph = Graph('example-flowchart', size_comment+mermaid_str)
     graphe = md.Mermaid(graph)
     graphe.to_png(png_file_path)
-    
+
 def is_valid_json(json_string):
     try:
         match = re.search(r'\{.*\}', json_string, re.DOTALL)