소스 검색

Merge pull request #14 from XingYu-Zhong/main

Eliminate PEP 8 warnings
Shaobo 1 년 전
부모
커밋
73c3720049
7개의 변경된 파일289개의 추가작업 그리고 262개의 파일을 삭제
  1. 19 21
      repodemo/llm/api/codegeex4.py
  2. 20 16
      repodemo/llm/local/codegeex4.py
  3. 16 13
      repodemo/prompts/base_prompt.py
  4. 69 61
      repodemo/run.py
  5. 68 62
      repodemo/run_local.py
  6. 16 10
      repodemo/utils/bingsearch.py
  7. 81 79
      repodemo/utils/tools.py

+ 19 - 21
repodemo/llm/api/codegeex4.py

@@ -1,41 +1,39 @@
 import requests
 import json
 
-URL = "" #the url you deploy codegeex service
+URL = ""  # the url you deploy codegeex service
+
+
 def codegeex4(prompt, temperature=0.8, top_p=0.8):
     url = URL
-    headers = {
-        'Content-Type': 'application/json'
-    }
+    headers = {"Content-Type": "application/json"}
     data = {
-        'inputs': prompt,
-        'parameters': {
-            'best_of':1,
-            'do_sample': True,
-            'max_new_tokens': 4012,
-            'temperature': temperature,
-            'top_p': top_p,
-            'stop': ["<|endoftext|>", "<|user|>", "<|observation|>", "<|assistant|>"],
-        }
+        "inputs": prompt,
+        "parameters": {
+            "best_of": 1,
+            "do_sample": True,
+            "max_new_tokens": 4012,
+            "temperature": temperature,
+            "top_p": top_p,
+            "stop": ["<|endoftext|>", "<|user|>", "<|observation|>", "<|assistant|>"],
+        },
     }
     response = requests.post(url, json=data, headers=headers, verify=False, stream=True)
 
     if response.status_code == 200:
         for line in response.iter_lines():
             if line:
-                decoded_line = line.decode('utf-8').replace('data:', '').strip()
+                decoded_line = line.decode("utf-8").replace("data:", "").strip()
                 if decoded_line:
                     try:
 
                         content = json.loads(decoded_line)
-                    
-                        token_text = content.get('token', {}).get('text', '')
-                        if '<|endoftext|>' in token_text:
-                            break  
+
+                        token_text = content.get("token", {}).get("text", "")
+                        if "<|endoftext|>" in token_text:
+                            break
                         yield token_text
                     except json.JSONDecodeError:
                         continue
     else:
-        print('请求失败:', response.status_code)
-
-
+        print("请求失败:", response.status_code)

+ 20 - 16
repodemo/llm/local/codegeex4.py

@@ -3,45 +3,49 @@ from transformers import AutoModel, AutoTokenizer
 from typing import Iterator
 import torch
 
-class CodegeexChatModel():
+
+class CodegeexChatModel:
     device: str = Field(description="device to load the model")
     tokenizer = Field(description="model's tokenizer")
     model = Field(description="Codegeex model")
     temperature: float = Field(description="temperature to use for the model.")
 
-    def __init__(self,model_name_or_path):
+    def __init__(self, model_name_or_path):
         super().__init__()
         self.device = "cuda" if torch.cuda.is_available() else "cpu"
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
-        self.model = AutoModel.from_pretrained(
-            model_name_or_path,
-            trust_remote_code=True
-        ).to(self.device).eval()
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            model_name_or_path, trust_remote_code=True
+        )
+        self.model = (
+            AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
+            .to(self.device)
+            .eval()
+        )
         print("Model has been initialized.")
 
-    def chat(self, prompt,temperature=0.2,top_p=0.95):
+    def chat(self, prompt, temperature=0.2, top_p=0.95):
         try:
             response, _ = self.model.chat(
                 self.tokenizer,
                 query=prompt,
                 max_length=120000,
                 temperature=temperature,
-                top_p=top_p
+                top_p=top_p,
             )
             return response
         except Exception as e:
             return f"error:{e}"
 
-    def stream_chat(self,prompt,temperature=0.2,top_p=0.95):
+    def stream_chat(self, prompt, temperature=0.2, top_p=0.95):
 
         try:
             for response, _ in self.model.stream_chat(
-                    self.tokenizer,
-                    query=prompt,
-                    max_length=120000,
-                    temperature=temperature,
-                    top_p=top_p
+                self.tokenizer,
+                query=prompt,
+                max_length=120000,
+                temperature=temperature,
+                top_p=top_p,
             ):
                 yield response
         except Exception as e:
-            yield f'error: {e}'
+            yield f"error: {e}"

+ 16 - 13
repodemo/prompts/base_prompt.py

@@ -4,7 +4,7 @@ repo_system_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeG
 
 judge_task_prompt = """<|system|>\n你是一位任务分类专家,请你对用户的输入进行分类(问答/修改/正常),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字,如果用户输入的是一个与项目无关的问题则只需要输出正常两个字。<|user|>\n{user_input}<|assistant|>\n"""
 
-web_judge_task_prompt ="""<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。<|user|>\n{user_input}\n这个问题需要进行联网来回答吗?仅回答“是”或者“否”。<|assistant|>\n"""
+web_judge_task_prompt = """<|system|>\n你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。<|user|>\n{user_input}\n这个问题需要进行联网来回答吗?仅回答“是”或者“否”。<|assistant|>\n"""
 
 # judge_task_prompt = """<|system|>\n你是一位任务分类专家,请你对用户的输入进行分类(问答/修改),如果用户的输入是对项目进行提问则只需要输出问答两个字,如果用户的输入是对项目进行修改或增加则只需要输出修改两个字。<|user|>\n{user_input}<|assistant|>\n"""
 web_search_prompy = """
@@ -19,24 +19,27 @@ web_search_prompy = """
 - 除了代码和特定的名称和引用外,您的答案必须使用与问题相同的语言来撰写。
 """.lstrip()
 
-def get_cur_base_user_prompt(message_history,index_prompt = None,judge_context = ""):
+
+def get_cur_base_user_prompt(message_history, index_prompt=None, judge_context=""):
     user_prompt_tmp = """<|user|>\n{user_input}"""
     assistant_prompt_tmp = """<|assistant|>\n{assistant_input}"""
     history_prompt = ""
-    for i,message in enumerate(message_history):
-        if message['role'] == 'user':
-            if i==0 and index_prompt is not None:
-                history_prompt += "<|user|>\n"+index_prompt+message['content']
+    for i, message in enumerate(message_history):
+        if message["role"] == "user":
+            if i == 0 and index_prompt is not None:
+                history_prompt += "<|user|>\n" + index_prompt + message["content"]
             else:
-                history_prompt += user_prompt_tmp.format(user_input=message['content'])
-        elif message['role'] ==  'assistant':
-            history_prompt += assistant_prompt_tmp.format(assistant_input=message['content'])
-   
+                history_prompt += user_prompt_tmp.format(user_input=message["content"])
+        elif message["role"] == "assistant":
+            history_prompt += assistant_prompt_tmp.format(
+                assistant_input=message["content"]
+            )
+
     # print("修改" not in judge_context)
     # print(judge_context)
     if "修改" not in judge_context:
-        result = base_system_prompt+history_prompt+"""<|assistant|>\n"""
+        result = base_system_prompt + history_prompt + """<|assistant|>\n"""
     else:
-        result = repo_system_prompt+history_prompt+"""<|assistant|>\n"""
+        result = repo_system_prompt + history_prompt + """<|assistant|>\n"""
     print(result)
-    return result
+    return result

+ 69 - 61
repodemo/run.py

@@ -1,8 +1,12 @@
 import chainlit as cl
 from chainlit.input_widget import Slider
 from llm.api.codegeex4 import codegeex4
-from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_judge_task_prompt
-from utils.tools import unzip_file,get_project_files_with_content
+from prompts.base_prompt import (
+    judge_task_prompt,
+    get_cur_base_user_prompt,
+    web_judge_task_prompt,
+)
+from utils.tools import unzip_file, get_project_files_with_content
 from utils.bingsearch import bing_search_prompt
 
 
@@ -12,44 +16,36 @@ async def chat_profile():
         cl.ChatProfile(
             name="chat聊天",
             markdown_description="聊天demo:支持多轮对话。",
-            starters = [
+            starters=[
                 cl.Starter(
-                label="请你用python写一个快速排序。",
-                message="请你用python写一个快速排序。",
-               
+                    label="请你用python写一个快速排序。",
+                    message="请你用python写一个快速排序。",
                 ),
-
-            cl.Starter(
-                label="请你介绍一下自己。",
-                message="请你介绍一下自己。",
-               
+                cl.Starter(
+                    label="请你介绍一下自己。",
+                    message="请你介绍一下自己。",
+                ),
+                cl.Starter(
+                    label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
+                    message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
                 ),
-            cl.Starter(
-                label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
-                message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
-                
+                cl.Starter(
+                    label="我是一个python初学者,请你告诉我怎么才能学好python。",
+                    message="我是一个python初学者,请你告诉我怎么才能学好python。",
                 ),
-            cl.Starter(
-                label="我是一个python初学者,请你告诉我怎么才能学好python。",
-                message="我是一个python初学者,请你告诉我怎么才能学好python。",
-                
-                )
-            ]
-      
+            ],
         ),
         cl.ChatProfile(
             name="联网问答",
-            markdown_description="联网能力dome:支持联网回答用户问题。",
-            
+            markdown_description="联网能力demo:支持联网回答用户问题。",
         ),
         cl.ChatProfile(
             name="上传本地项目",
-            markdown_description="项目级能力dome:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
-            
-        )
+            markdown_description="项目级能力demo:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
+        ),
     ]
 
-    
+
 @cl.on_chat_start
 async def start():
     settings = await cl.ChatSettings(
@@ -74,33 +70,31 @@ async def start():
     ).send()
     temperature = settings["temperature"]
     top_p = settings["top_p"]
-    cl.user_session.set('temperature',temperature)
-    cl.user_session.set('top_p',top_p)
-    cl.user_session.set(
-        "message_history",
-        []
-    )
+    cl.user_session.set("temperature", temperature)
+    cl.user_session.set("top_p", top_p)
+    cl.user_session.set("message_history", [])
     chat_profile = cl.user_session.get("chat_profile")
-    extract_dir = 'repodata'
+    extract_dir = "repodata"
     if chat_profile == "chat聊天":
         pass
-    elif chat_profile =="上传本地项目":
+    elif chat_profile == "上传本地项目":
         files = None
         while files == None:
             files = await cl.AskFileMessage(
-                content="请上传项目zip压缩文件!", accept={"application/zip": [".zip"]},max_size_mb=50
+                content="请上传项目zip压缩文件!",
+                accept={"application/zip": [".zip"]},
+                max_size_mb=50,
             ).send()
 
         text_file = files[0]
-        extracted_path = unzip_file(text_file.path,extract_dir)
+        extracted_path = unzip_file(text_file.path, extract_dir)
         files_list = get_project_files_with_content(extracted_path)
-        cl.user_session.set("project_index",files_list)
-        if len(files_list)>0:
+        cl.user_session.set("project_index", files_list)
+        if len(files_list) > 0:
             await cl.Message(
                 content=f"已成功上传,您可以开始对项目进行提问!",
             ).send()
-    
-    
+
 
 @cl.on_message
 async def main(message: cl.Message):
@@ -109,43 +103,57 @@ async def main(message: cl.Message):
     message_history.append({"role": "user", "content": message.content})
     if chat_profile == "chat聊天":
         prompt_content = get_cur_base_user_prompt(message_history=message_history)
-        
-    elif chat_profile=="联网问答":
-        judge_tmp = codegeex4(web_judge_task_prompt.format(user_input=message.content),temperature=0.2,top_p = 0.95)
-        judge_context = '\n'.join(judge_tmp)
+
+    elif chat_profile == "联网问答":
+        judge_tmp = codegeex4(
+            web_judge_task_prompt.format(user_input=message.content),
+            temperature=0.2,
+            top_p=0.95,
+        )
+        judge_context = "\n".join(judge_tmp)
         print(judge_context)
         message_history.pop()
 
-        if '是' in judge_context:
+        if "是" in judge_context:
             prompt_tmp = bing_search_prompt(message.content)
             message_history.append({"role": "user", "content": prompt_tmp})
         else:
             message_history.append({"role": "user", "content": message.content})
         prompt_content = get_cur_base_user_prompt(message_history=message_history)
 
-    elif chat_profile =="上传本地项目" :
-        judge_tmp = codegeex4(judge_task_prompt.format(user_input=message.content),temperature=0.2,top_p = 0.95)
-        judge_context = ''
+    elif chat_profile == "上传本地项目":
+        judge_tmp = codegeex4(
+            judge_task_prompt.format(user_input=message.content),
+            temperature=0.2,
+            top_p=0.95,
+        )
+        judge_context = ""
         for part in judge_tmp:
-            judge_context+=part
-      
+            judge_context += part
+
         project_index = cl.user_session.get("project_index")
         index_prompt = ""
         index_tmp = """###PATH:{path}\n{code}\n"""
         for index in project_index:
-            index_prompt+=index_tmp.format(path=index['path'],code=index['content'])
+            index_prompt += index_tmp.format(path=index["path"], code=index["content"])
         print(judge_context)
-        prompt_content = get_cur_base_user_prompt(message_history=message_history,index_prompt=index_prompt,judge_context=judge_context) if '正常' not in judge_context else get_cur_base_user_prompt(message_history=message_history)
-    
-    
+        prompt_content = (
+            get_cur_base_user_prompt(
+                message_history=message_history,
+                index_prompt=index_prompt,
+                judge_context=judge_context,
+            )
+            if "正常" not in judge_context
+            else get_cur_base_user_prompt(message_history=message_history)
+        )
 
     msg = cl.Message(content="")
     await msg.send()
     temperature = cl.user_session.get("temperature")
-    top_p = cl.user_session.get('top_p')
-    
-    if len(prompt_content)/4<120000:
-        stream =  codegeex4(prompt_content,temperature=temperature,top_p = top_p)
+    top_p = cl.user_session.get("top_p")
+
+    if len(prompt_content) / 4 < 120000:
+        stream = codegeex4(prompt_content, temperature=temperature, top_p=top_p)
 
         for part in stream:
             if token := (part or " "):
@@ -154,4 +162,4 @@ async def main(message: cl.Message):
         await msg.stream_token("项目太大了,请换小一点的项目。")
 
     message_history.append({"role": "assistant", "content": msg.content})
-    await msg.update()
+    await msg.update()

+ 68 - 62
repodemo/run_local.py

@@ -1,66 +1,65 @@
 import chainlit as cl
 from chainlit.input_widget import Slider
 from llm.api.codegeex4 import codegeex4
-from prompts.base_prompt import judge_task_prompt,get_cur_base_user_prompt,web_judge_task_prompt
-from utils.tools import unzip_file,get_project_files_with_content
+from prompts.base_prompt import (
+    judge_task_prompt,
+    get_cur_base_user_prompt,
+    web_judge_task_prompt,
+)
+from utils.tools import unzip_file, get_project_files_with_content
 from utils.bingsearch import bing_search_prompt
 from llm.local.codegeex4 import CodegeexChatModel
-local_model_path = '<your_local_model_path>'
+
+local_model_path = "<your_local_model_path>"
 llm = CodegeexChatModel(local_model_path)
 
+
 class StreamProcessor:
     def __init__(self):
         self.previous_str = ""
 
     def get_new_part(self, new_str):
-        new_part = new_str[len(self.previous_str):]
+        new_part = new_str[len(self.previous_str) :]
         self.previous_str = new_str
         return new_part
 
+
 @cl.set_chat_profiles
 async def chat_profile():
     return [
         cl.ChatProfile(
             name="chat聊天",
             markdown_description="聊天demo:支持多轮对话。",
-            starters = [
+            starters=[
                 cl.Starter(
-                label="请你用python写一个快速排序。",
-                message="请你用python写一个快速排序。",
-               
+                    label="请你用python写一个快速排序。",
+                    message="请你用python写一个快速排序。",
                 ),
-
-            cl.Starter(
-                label="请你介绍一下自己。",
-                message="请你介绍一下自己。",
-               
+                cl.Starter(
+                    label="请你介绍一下自己。",
+                    message="请你介绍一下自己。",
                 ),
-            cl.Starter(
-                label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
-                message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
-                
+                cl.Starter(
+                    label="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
+                    message="用 Python 编写一个脚本来自动发送每日电子邮件报告,并指导我如何进行设置。",
+                ),
+                cl.Starter(
+                    label="我是一个python初学者,请你告诉我怎么才能学好python。",
+                    message="我是一个python初学者,请你告诉我怎么才能学好python。",
                 ),
-            cl.Starter(
-                label="我是一个python初学者,请你告诉我怎么才能学好python。",
-                message="我是一个python初学者,请你告诉我怎么才能学好python。",
-                
-                )
-            ]
-      
+            ],
         ),
         cl.ChatProfile(
             name="联网问答",
-            markdown_description="联网能力dome:支持联网回答用户问题。",
-            
+            markdown_description="联网能力demo:支持联网回答用户问题。",
         ),
         cl.ChatProfile(
             name="上传本地项目",
-            markdown_description="项目级能力dome:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
-            
-        )
+            markdown_description="项目级能力demo:支持上传本地zip压缩包项目,可以进行项目问答和对项目进行修改。",
+        ),
     ]
 
-    
+
 @cl.on_chat_start
 async def start():
     settings = await cl.ChatSettings(
@@ -85,33 +84,31 @@ async def start():
     ).send()
     temperature = settings["temperature"]
     top_p = settings["top_p"]
-    cl.user_session.set('temperature',temperature)
-    cl.user_session.set('top_p',top_p)
-    cl.user_session.set(
-        "message_history",
-        []
-    )
+    cl.user_session.set("temperature", temperature)
+    cl.user_session.set("top_p", top_p)
+    cl.user_session.set("message_history", [])
     chat_profile = cl.user_session.get("chat_profile")
-    extract_dir = 'repodata'
+    extract_dir = "repodata"
     if chat_profile == "chat聊天":
         pass
-    elif chat_profile =="上传本地项目":
+    elif chat_profile == "上传本地项目":
         files = None
         while files == None:
             files = await cl.AskFileMessage(
-                content="请上传项目zip压缩文件!", accept={"application/zip": [".zip"]},max_size_mb=50
+                content="请上传项目zip压缩文件!",
+                accept={"application/zip": [".zip"]},
+                max_size_mb=50,
             ).send()
 
         text_file = files[0]
-        extracted_path = unzip_file(text_file.path,extract_dir)
+        extracted_path = unzip_file(text_file.path, extract_dir)
         files_list = get_project_files_with_content(extracted_path)
-        cl.user_session.set("project_index",files_list)
-        if len(files_list)>0:
+        cl.user_session.set("project_index", files_list)
+        if len(files_list) > 0:
             await cl.Message(
                 content=f"已成功上传,您可以开始对项目进行提问!",
             ).send()
-    
-    
+
 
 @cl.on_message
 async def main(message: cl.Message):
@@ -120,50 +117,59 @@ async def main(message: cl.Message):
     message_history.append({"role": "user", "content": message.content})
     if chat_profile == "chat聊天":
         prompt_content = get_cur_base_user_prompt(message_history=message_history)
-        
-    elif chat_profile=="联网问答":
-        judge_context = llm.chat(web_judge_task_prompt.format(user_input=message.content),temperature=0.2)
+
+    elif chat_profile == "联网问答":
+        judge_context = llm.chat(
+            web_judge_task_prompt.format(user_input=message.content), temperature=0.2
+        )
         print(judge_context)
         message_history.pop()
 
-        if '是' in judge_context:
+        if "是" in judge_context:
             prompt_tmp = bing_search_prompt(message.content)
             message_history.append({"role": "user", "content": prompt_tmp})
         else:
             message_history.append({"role": "user", "content": message.content})
         prompt_content = get_cur_base_user_prompt(message_history=message_history)
 
-    elif chat_profile =="上传本地项目" :
-        judge_context = llm.chat(judge_task_prompt.format(user_input=message.content),temperature=0.2)
-        
-      
+    elif chat_profile == "上传本地项目":
+        judge_context = llm.chat(
+            judge_task_prompt.format(user_input=message.content), temperature=0.2
+        )
+
         project_index = cl.user_session.get("project_index")
         index_prompt = ""
         index_tmp = """###PATH:{path}\n{code}\n"""
         for index in project_index:
-            index_prompt+=index_tmp.format(path=index['path'],code=index['content'])
+            index_prompt += index_tmp.format(path=index["path"], code=index["content"])
         print(judge_context)
-        prompt_content = get_cur_base_user_prompt(message_history=message_history,index_prompt=index_prompt,judge_context=judge_context) if '正常' not in judge_context else get_cur_base_user_prompt(message_history=message_history)
-    
-    
+        prompt_content = (
+            get_cur_base_user_prompt(
+                message_history=message_history,
+                index_prompt=index_prompt,
+                judge_context=judge_context,
+            )
+            if "正常" not in judge_context
+            else get_cur_base_user_prompt(message_history=message_history)
+        )
 
     msg = cl.Message(content="")
     await msg.send()
     temperature = cl.user_session.get("temperature")
-    top_p = cl.user_session.get('top_p')
-    
-    if len(prompt_content)/4<120000:
-        stream =  llm.stream_chat(prompt_content,temperature=temperature,top_p = top_p)
+    top_p = cl.user_session.get("top_p")
+
+    if len(prompt_content) / 4 < 120000:
+        stream = llm.stream_chat(prompt_content, temperature=temperature, top_p=top_p)
         stream_processor = StreamProcessor()
         for part in stream:
             if isinstance(part, str):
                 text = stream_processor.get_new_part(part)
             elif isinstance(part, dict):
-                text = stream_processor.get_new_part(part['name']+part['content'])
+                text = stream_processor.get_new_part(part["name"] + part["content"])
             if token := (text or " "):
                 await msg.stream_token(token)
     else:
         await msg.stream_token("项目太大了,请换小一点的项目。")
 
     message_history.append({"role": "assistant", "content": msg.content})
-    await msg.update()
+    await msg.update()

+ 16 - 10
repodemo/utils/bingsearch.py

@@ -2,7 +2,9 @@ import requests
 from bs4 import BeautifulSoup as BS4
 import requests
 
-BING_API_KEY = '<your_bing_api_key>'
+BING_API_KEY = "<your_bing_api_key>"
+
+
 def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]:
     """
     Search with bing and return the contexts.
@@ -13,9 +15,9 @@ def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]:
         headers={"Ocp-Apim-Subscription-Key": BING_API_KEY},
         params={
             "q": query,
-            "responseFilter": ['webpages'],
-            "freshness": 'month',
-            "mkt": 'zh-CN'
+            "responseFilter": ["webpages"],
+            "freshness": "month",
+            "mkt": "zh-CN",
         },
         timeout=search_timeout,
     )
@@ -23,25 +25,29 @@ def search_with_bing(query: str, search_timeout=30, top_k=6) -> list[dict]:
         json_content = response.json()
         # print(json_content)
         contexts = json_content["webPages"]["value"][:top_k]
-        #logger.info("Web搜索完成")
+        # logger.info("Web搜索完成")
         return contexts
     except Exception as e:
-        #logger.error(f"搜索失败,错误原因: {e}")
+        # logger.error(f"搜索失败,错误原因: {e}")
         print(f"搜索失败,错误原因: {e}")
         return []
 
+
 def fetch_url(url):
     response = requests.get(url)
-    #use beautifulsoup4 to parse html
-    soup = BS4(response.text, 'html.parser')
+    # use beautifulsoup4 to parse html
+    soup = BS4(response.text, "html.parser")
     plain_text = soup.get_text()
     return plain_text
 
+
 def bing_search_prompt(input):
     contents = search_with_bing(input, search_timeout=5, top_k=6)
     citations = "\n\n".join(
-        [f"[[citation:{i + 1}]]\n```markdown\n{item['snippet']}\n```" for i, item in enumerate(contents)]
+        [
+            f"[[citation:{i + 1}]]\n```markdown\n{item['snippet']}\n```"
+            for i, item in enumerate(contents)
+        ]
     )
     prompt = f"[引用]\n{citations}\n问:{input}\n"
     return prompt
-

+ 81 - 79
repodemo/utils/tools.py

@@ -2,14 +2,15 @@ import zipfile
 import os
 import json
 
+
 def unzip_file(zip_path, extract_dir):
     """
     解压zip文件到指定目录,并在指定目录下创建一个新的目录存放解压后的文件
-    
+
     参数:
     zip_path (str): zip压缩包的地址
     extract_dir (str): 指定解压的目录
-    
+
     返回:
     str: 解压后的路径
     """
@@ -19,11 +20,11 @@ def unzip_file(zip_path, extract_dir):
     base_name = os.path.basename(zip_path)
     dir_name = os.path.splitext(base_name)[0]
     new_extract_dir = os.path.join(extract_dir, dir_name)
-    
+
     if not os.path.exists(new_extract_dir):
         os.makedirs(new_extract_dir)
 
-    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+    with zipfile.ZipFile(zip_path, "r") as zip_ref:
         zip_ref.extractall(new_extract_dir)
 
     return new_extract_dir
@@ -32,15 +33,15 @@ def unzip_file(zip_path, extract_dir):
 def get_project_files_with_content(project_dir):
     """
     获取项目目录下所有文件的相对路径和内容
-    
+
     参数:
     project_dir (str): 项目目录地址
-    
+
     返回:
     list: 包含字典的列表,每个字典包含文件的相对路径和内容
     """
     files_list = []
-    
+
     for root, dirs, files in os.walk(project_dir):
         for file in files:
             if filter_data(file):
@@ -48,83 +49,85 @@ def get_project_files_with_content(project_dir):
                 relative_path = os.path.relpath(file_path, project_dir)
                 if "__MACOSX" in relative_path:
                     continue
-                with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
+                with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
                     content = f.read()
-                files_list.append({'path': relative_path, 'content': content})
+                files_list.append({"path": relative_path, "content": content})
             else:
                 continue
-    
+
     return files_list
 
+
 def filter_data(obj):
     LANGUAGE_TAG = {
-    "c++"          : "// C++",
-    "cpp"          : "// C++",
-    "c"            : "// C",
-    "c#"           : "// C#",
-    "c-sharp"      : "// C#",
-    "css"          : "/* CSS */",
-    "cuda"         : "// Cuda",
-    "fortran"      : "! Fortran",
-    "go"           : "// Go",
-    "html"         : "<!-- HTML -->",
-    "java"         : "// Java",
-    "js"           : "// JavaScript",
-    "javascript"   : "// JavaScript",
-    "kotlin"       : "// Kotlin",
-    "lean"         : "-- Lean",
-    "lua"          : "-- Lua",
-    "objectivec"  : "// Objective-C",
-    "objective-c"  : "// Objective-C",
-    "objective-c++": "// Objective-C++",
-    "pascal"       : "// Pascal",
-    "php"          : "// PHP",
-    "python"       : "# Python",
-    "r"            : "# R",
-    "rust"         : "// Rust",
-    "ruby"         : "# Ruby",
-    "scala"        : "// Scala",
-    "shell"        : "# Shell",
-    "sql"          : "-- SQL",
-    "tex"          : f"% TeX",
-    "typescript"   : "// TypeScript",
-    "vue"          : "<!-- Vue -->",
-
-    "assembly"     : "; Assembly",
-    "dart"         : "// Dart",
-    "perl"         : "# Perl",
-    "prolog"       : f"% Prolog",
-    "swift"        : "// swift",
-    "lisp"         : "; Lisp",
-    "vb"           : "' Visual Basic",
-    "visual basic" : "' Visual Basic",
-    "matlab"       : f"% Matlab",
-    "delphi"       : "{ Delphi }",
-    "scheme"       : "; Scheme",
-    "basic"        : "' Basic",
-    "assembly"     : "; Assembly",
-    "groovy"       : "// Groovy",
-    "abap"         : "* Abap",
-    "gdscript"     : "# GDScript",
-    "haskell"      : "-- Haskell",
-    "julia"        : "# Julia",
-    "elixir"       : "# Elixir",
-    "excel"        : "' Excel",
-    "clojure"      : "; Clojure",
-    "actionscript" : "// ActionScript",
-    "solidity"     : "// Solidity",
-    "powershell"   : "# PowerShell",
-    "erlang"       : f"% Erlang",
-    "cobol"        : "// Cobol",
-    "batchfile"  : ":: Batch file",
-    "makefile"     : "# Makefile",
-    "dockerfile"   : "# Dockerfile",
-    "markdown"     : "<!-- Markdown -->",
-    "cmake"        : "# CMake",
-    "dockerfile"   : "# Dockerfile",
+        "c++": "// C++",
+        "cpp": "// C++",
+        "c": "// C",
+        "c#": "// C#",
+        "c-sharp": "// C#",
+        "css": "/* CSS */",
+        "cuda": "// Cuda",
+        "fortran": "! Fortran",
+        "go": "// Go",
+        "html": "<!-- HTML -->",
+        "java": "// Java",
+        "js": "// JavaScript",
+        "javascript": "// JavaScript",
+        "kotlin": "// Kotlin",
+        "lean": "-- Lean",
+        "lua": "-- Lua",
+        "objectivec": "// Objective-C",
+        "objective-c": "// Objective-C",
+        "objective-c++": "// Objective-C++",
+        "pascal": "// Pascal",
+        "php": "// PHP",
+        "python": "# Python",
+        "r": "# R",
+        "rust": "// Rust",
+        "ruby": "# Ruby",
+        "scala": "// Scala",
+        "shell": "# Shell",
+        "sql": "-- SQL",
+        "tex": f"% TeX",
+        "typescript": "// TypeScript",
+        "vue": "<!-- Vue -->",
+        "assembly": "; Assembly",
+        "dart": "// Dart",
+        "perl": "# Perl",
+        "prolog": f"% Prolog",
+        "swift": "// swift",
+        "lisp": "; Lisp",
+        "vb": "' Visual Basic",
+        "visual basic": "' Visual Basic",
+        "matlab": f"% Matlab",
+        "delphi": "{ Delphi }",
+        "scheme": "; Scheme",
+        "basic": "' Basic",
+        "assembly": "; Assembly",
+        "groovy": "// Groovy",
+        "abap": "* Abap",
+        "gdscript": "# GDScript",
+        "haskell": "-- Haskell",
+        "julia": "# Julia",
+        "elixir": "# Elixir",
+        "excel": "' Excel",
+        "clojure": "; Clojure",
+        "actionscript": "// ActionScript",
+        "solidity": "// Solidity",
+        "powershell": "# PowerShell",
+        "erlang": f"% Erlang",
+        "cobol": "// Cobol",
+        "batchfile": ":: Batch file",
+        "makefile": "# Makefile",
+        "dockerfile": "# Dockerfile",
+        "markdown": "<!-- Markdown -->",
+        "cmake": "# CMake",
+        "dockerfile": "# Dockerfile",
     }
 
-    programming_languages_to_file_extensions = json.load(open('utils/programming-languages-to-file-extensions.json'))
+    programming_languages_to_file_extensions = json.load(
+        open("utils/programming-languages-to-file-extensions.json")
+    )
     need2del = []
     for key in programming_languages_to_file_extensions.keys():
         if key.lower() not in LANGUAGE_TAG:
@@ -140,15 +143,14 @@ def filter_data(obj):
             ext_to_programming_languages[item] = key
             want_languages.append(item)
 
-    ext = '.'+obj.split('.')[-1]
-    with open('utils/keep.txt', 'r') as f:
+    ext = "." + obj.split(".")[-1]
+    with open("utils/keep.txt", "r") as f:
         keep_files = f.readlines()
         keep_files = [l.strip() for l in keep_files]
-    #print(ext)
+    # print(ext)
     if ext not in want_languages:
         if obj in keep_files:
             return True
         return False
     else:
         return True
-