tools.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import json
  2. import os
  3. import zipfile
  4. import git
  5. import urllib.parse
  6. import re
  7. import mermaid as md
  8. from mermaid.graph import Graph
  9. def get_mermaid_png(mermaid_str:str,png_file_path:str):
  10. size_comment = "%%{init: {'theme': 'base', 'themeVariables': { 'primaryColor': '#ffcc00', 'edgeLabelBackground':'#ffffff', 'fontSize': 18, 'background': 'transparent', 'lineColor': '#000000' }}}%%\n"
  11. graph = Graph('example-flowchart', size_comment+mermaid_str)
  12. graphe = md.Mermaid(graph)
  13. graphe.to_png(png_file_path)
  14. def is_valid_json(json_string):
  15. try:
  16. match = re.search(r'\{.*\}', json_string, re.DOTALL)
  17. if match:
  18. dict_str = match.group()
  19. json.loads(dict_str)
  20. else:
  21. json.loads(json_string)
  22. return True
  23. except ValueError:
  24. return False
  25. def clone_repo(repo_url, clone_to):
  26. """
  27. 克隆一个GitHub仓库。
  28. 参数:
  29. repo_url (str): 原始仓库的URL。
  30. clone_to (str): 克隆到的本地目录。
  31. 返回:
  32. str: 成功时返回克隆到的本地目录(包含子目录),不成功时返回空字符串。
  33. """
  34. try:
  35. if not os.path.exists(clone_to):
  36. os.makedirs(clone_to)
  37. # 从URL中提取仓库名称
  38. repo_name = urllib.parse.urlparse(repo_url).path.split('/')[-1]
  39. # 在clone_to目录下创建新的目录
  40. cloned_path = os.path.join(clone_to, repo_name)
  41. if os.path.exists(cloned_path):
  42. return cloned_path
  43. # 克隆仓库
  44. repo = git.Repo.clone_from(repo_url, cloned_path)
  45. print(f"Repository cloned to {cloned_path}")
  46. return cloned_path
  47. except Exception as e:
  48. print(f"Failed to clone repository: {e}")
  49. return None
  50. def unzip_file(zip_path, extract_dir):
  51. """
  52. 解压zip文件到指定目录,并在指定目录下创建一个新的目录存放解压后的文件
  53. 参数:
  54. zip_path (str): zip压缩包的地址
  55. extract_dir (str): 指定解压的目录
  56. 返回:
  57. str: 解压后的路径
  58. """
  59. if not os.path.exists(extract_dir):
  60. os.makedirs(extract_dir)
  61. base_name = os.path.basename(zip_path)
  62. dir_name = os.path.splitext(base_name)[0]
  63. new_extract_dir = os.path.join(extract_dir, dir_name)
  64. if not os.path.exists(new_extract_dir):
  65. os.makedirs(new_extract_dir)
  66. with zipfile.ZipFile(zip_path, "r") as zip_ref:
  67. zip_ref.extractall(new_extract_dir)
  68. return new_extract_dir
  69. def get_project_files_with_content(project_dir):
  70. """
  71. 获取项目目录下所有文件的相对路径和内容
  72. 参数:
  73. project_dir (str): 项目目录地址
  74. 返回:
  75. list: 包含字典的列表,每个字典包含文件的相对路径和内容
  76. """
  77. files_list = []
  78. for root, dirs, files in os.walk(project_dir):
  79. for file in files:
  80. if filter_data(file):
  81. file_path = os.path.join(root, file)
  82. relative_path = os.path.relpath(file_path, project_dir)
  83. if "__MACOSX" in relative_path:
  84. continue
  85. with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
  86. content = f.read()
  87. files_list.append({"path": relative_path, "content": content})
  88. else:
  89. continue
  90. return files_list
  91. def extract_code_text(text,language):
  92. start_marker = f"```{language}"
  93. end_marker = "```"
  94. start_index = text.find(start_marker)
  95. end_index = text.find(end_marker, start_index + len(start_marker))
  96. # print(start_index, end_index,language)
  97. if start_index == -1 or end_index == -1 or start_index >= end_index:
  98. return ""
  99. start_index += len(start_marker)
  100. return text[start_index:end_index].strip()
  101. def get_directory_structure(directory_path):
  102. """
  103. 获取指定目录下的文件结构并返回为字符串格式。
  104. :param directory_path: str, 目录路径
  105. :return: str, 文件结构
  106. """
  107. structure = []
  108. for root, dirs, files in os.walk(directory_path):
  109. # Check if the current directory is .git and skip if it is
  110. if '.git' in root.split(os.sep):
  111. continue
  112. level = root.replace(directory_path, '').count(os.sep)
  113. indent = ' ' * 4 * level
  114. structure.append(f"{indent}{os.path.basename(root)}/")
  115. sub_indent = ' ' * 4 * (level + 1)
  116. for file in files:
  117. structure.append(f"{sub_indent}{file}")
  118. return "\n".join(structure)
  119. def filter_data(obj):
  120. LANGUAGE_TAG = {
  121. "c++": "// C++",
  122. "cpp": "// C++",
  123. "c": "// C",
  124. "c#": "// C#",
  125. "c-sharp": "// C#",
  126. "css": "/* CSS */",
  127. "cuda": "// Cuda",
  128. "fortran": "! Fortran",
  129. "go": "// Go",
  130. "html": "<!-- HTML -->",
  131. "java": "// Java",
  132. "js": "// JavaScript",
  133. "javascript": "// JavaScript",
  134. "kotlin": "// Kotlin",
  135. "lean": "-- Lean",
  136. "lua": "-- Lua",
  137. "objectivec": "// Objective-C",
  138. "objective-c": "// Objective-C",
  139. "objective-c++": "// Objective-C++",
  140. "pascal": "// Pascal",
  141. "php": "// PHP",
  142. "python": "# Python",
  143. "r": "# R",
  144. "rust": "// Rust",
  145. "ruby": "# Ruby",
  146. "scala": "// Scala",
  147. "shell": "# Shell",
  148. "sql": "-- SQL",
  149. "tex": f"% TeX",
  150. "typescript": "// TypeScript",
  151. "vue": "<!-- Vue -->",
  152. "assembly": "; Assembly",
  153. "dart": "// Dart",
  154. "perl": "# Perl",
  155. "prolog": f"% Prolog",
  156. "swift": "// swift",
  157. "lisp": "; Lisp",
  158. "vb": "' Visual Basic",
  159. "visual basic": "' Visual Basic",
  160. "matlab": f"% Matlab",
  161. "delphi": "{ Delphi }",
  162. "scheme": "; Scheme",
  163. "basic": "' Basic",
  164. "groovy": "// Groovy",
  165. "abap": "* Abap",
  166. "gdscript": "# GDScript",
  167. "haskell": "-- Haskell",
  168. "julia": "# Julia",
  169. "elixir": "# Elixir",
  170. "excel": "' Excel",
  171. "clojure": "; Clojure",
  172. "actionscript": "// ActionScript",
  173. "solidity": "// Solidity",
  174. "powershell": "# PowerShell",
  175. "erlang": f"% Erlang",
  176. "cobol": "// Cobol",
  177. "batchfile": ":: Batch file",
  178. "makefile": "# Makefile",
  179. "dockerfile": "# Dockerfile",
  180. "markdown": "<!-- Markdown -->",
  181. "cmake": "# CMake",
  182. }
  183. programming_languages_to_file_extensions = json.load(
  184. open("utils/programming-languages-to-file-extensions.json")
  185. )
  186. need2del = []
  187. for key in programming_languages_to_file_extensions.keys():
  188. if key.lower() not in LANGUAGE_TAG:
  189. need2del.append(key)
  190. for key in need2del:
  191. del programming_languages_to_file_extensions[key]
  192. ext_to_programming_languages = {}
  193. want_languages = []
  194. for key in programming_languages_to_file_extensions:
  195. for item in programming_languages_to_file_extensions[key]:
  196. ext_to_programming_languages[item] = key
  197. want_languages.append(item)
  198. ext = "." + obj.split(".")[-1]
  199. with open("utils/keep.txt", "r") as f:
  200. keep_files = f.readlines()
  201. keep_files = [l.strip() for l in keep_files]
  202. # print(ext)
  203. if ext not in want_languages:
  204. if obj in keep_files:
  205. return True
  206. return False
  207. else:
  208. return True