tools.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import json
  2. import os
  3. import zipfile
  4. import git
  5. import urllib.parse
  6. import re
  7. def is_valid_json(json_string):
  8. try:
  9. match = re.search(r'\{.*\}', json_string, re.DOTALL)
  10. if match:
  11. dict_str = match.group()
  12. json.loads(dict_str)
  13. else:
  14. json.loads(json_string)
  15. return True
  16. except ValueError:
  17. return False
  18. def clone_repo(repo_url, clone_to):
  19. """
  20. 克隆一个GitHub仓库。
  21. 参数:
  22. repo_url (str): 原始仓库的URL。
  23. clone_to (str): 克隆到的本地目录。
  24. 返回:
  25. str: 成功时返回克隆到的本地目录(包含子目录),不成功时返回空字符串。
  26. """
  27. try:
  28. if not os.path.exists(clone_to):
  29. os.makedirs(clone_to)
  30. # 从URL中提取仓库名称
  31. repo_name = urllib.parse.urlparse(repo_url).path.split('/')[-1]
  32. # 在clone_to目录下创建新的目录
  33. cloned_path = os.path.join(clone_to, repo_name)
  34. if os.path.exists(cloned_path):
  35. return cloned_path
  36. # 克隆仓库
  37. repo = git.Repo.clone_from(repo_url, cloned_path)
  38. print(f"Repository cloned to {cloned_path}")
  39. return cloned_path
  40. except Exception as e:
  41. print(f"Failed to clone repository: {e}")
  42. return None
  43. def unzip_file(zip_path, extract_dir):
  44. """
  45. 解压zip文件到指定目录,并在指定目录下创建一个新的目录存放解压后的文件
  46. 参数:
  47. zip_path (str): zip压缩包的地址
  48. extract_dir (str): 指定解压的目录
  49. 返回:
  50. str: 解压后的路径
  51. """
  52. if not os.path.exists(extract_dir):
  53. os.makedirs(extract_dir)
  54. base_name = os.path.basename(zip_path)
  55. dir_name = os.path.splitext(base_name)[0]
  56. new_extract_dir = os.path.join(extract_dir, dir_name)
  57. if not os.path.exists(new_extract_dir):
  58. os.makedirs(new_extract_dir)
  59. with zipfile.ZipFile(zip_path, "r") as zip_ref:
  60. zip_ref.extractall(new_extract_dir)
  61. return new_extract_dir
  62. def get_project_files_with_content(project_dir):
  63. """
  64. 获取项目目录下所有文件的相对路径和内容
  65. 参数:
  66. project_dir (str): 项目目录地址
  67. 返回:
  68. list: 包含字典的列表,每个字典包含文件的相对路径和内容
  69. """
  70. files_list = []
  71. for root, dirs, files in os.walk(project_dir):
  72. for file in files:
  73. if filter_data(file):
  74. file_path = os.path.join(root, file)
  75. relative_path = os.path.relpath(file_path, project_dir)
  76. if "__MACOSX" in relative_path:
  77. continue
  78. with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
  79. content = f.read()
  80. files_list.append({"path": relative_path, "content": content})
  81. else:
  82. continue
  83. return files_list
  84. def filter_data(obj):
  85. LANGUAGE_TAG = {
  86. "c++": "// C++",
  87. "cpp": "// C++",
  88. "c": "// C",
  89. "c#": "// C#",
  90. "c-sharp": "// C#",
  91. "css": "/* CSS */",
  92. "cuda": "// Cuda",
  93. "fortran": "! Fortran",
  94. "go": "// Go",
  95. "html": "<!-- HTML -->",
  96. "java": "// Java",
  97. "js": "// JavaScript",
  98. "javascript": "// JavaScript",
  99. "kotlin": "// Kotlin",
  100. "lean": "-- Lean",
  101. "lua": "-- Lua",
  102. "objectivec": "// Objective-C",
  103. "objective-c": "// Objective-C",
  104. "objective-c++": "// Objective-C++",
  105. "pascal": "// Pascal",
  106. "php": "// PHP",
  107. "python": "# Python",
  108. "r": "# R",
  109. "rust": "// Rust",
  110. "ruby": "# Ruby",
  111. "scala": "// Scala",
  112. "shell": "# Shell",
  113. "sql": "-- SQL",
  114. "tex": f"% TeX",
  115. "typescript": "// TypeScript",
  116. "vue": "<!-- Vue -->",
  117. "assembly": "; Assembly",
  118. "dart": "// Dart",
  119. "perl": "# Perl",
  120. "prolog": f"% Prolog",
  121. "swift": "// swift",
  122. "lisp": "; Lisp",
  123. "vb": "' Visual Basic",
  124. "visual basic": "' Visual Basic",
  125. "matlab": f"% Matlab",
  126. "delphi": "{ Delphi }",
  127. "scheme": "; Scheme",
  128. "basic": "' Basic",
  129. "groovy": "// Groovy",
  130. "abap": "* Abap",
  131. "gdscript": "# GDScript",
  132. "haskell": "-- Haskell",
  133. "julia": "# Julia",
  134. "elixir": "# Elixir",
  135. "excel": "' Excel",
  136. "clojure": "; Clojure",
  137. "actionscript": "// ActionScript",
  138. "solidity": "// Solidity",
  139. "powershell": "# PowerShell",
  140. "erlang": f"% Erlang",
  141. "cobol": "// Cobol",
  142. "batchfile": ":: Batch file",
  143. "makefile": "# Makefile",
  144. "dockerfile": "# Dockerfile",
  145. "markdown": "<!-- Markdown -->",
  146. "cmake": "# CMake",
  147. }
  148. programming_languages_to_file_extensions = json.load(
  149. open("utils/programming-languages-to-file-extensions.json")
  150. )
  151. need2del = []
  152. for key in programming_languages_to_file_extensions.keys():
  153. if key.lower() not in LANGUAGE_TAG:
  154. need2del.append(key)
  155. for key in need2del:
  156. del programming_languages_to_file_extensions[key]
  157. ext_to_programming_languages = {}
  158. want_languages = []
  159. for key in programming_languages_to_file_extensions:
  160. for item in programming_languages_to_file_extensions[key]:
  161. ext_to_programming_languages[item] = key
  162. want_languages.append(item)
  163. ext = "." + obj.split(".")[-1]
  164. with open("utils/keep.txt", "r") as f:
  165. keep_files = f.readlines()
  166. keep_files = [l.strip() for l in keep_files]
  167. # print(ext)
  168. if ext not in want_languages:
  169. if obj in keep_files:
  170. return True
  171. return False
  172. else:
  173. return True