data.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. from pathlib import Path
  3. from llama_index.core.node_parser import CodeSplitter
  4. from llama_index.core.schema import BaseNode
  5. from llama_index.readers.file import FlatReader
  6. Languages = {
  7. 'c': "c",
  8. 'cpp': "cpp",
  9. 'go': "go",
  10. 'java': "java",
  11. 'js': "javascript",
  12. 'md': "markdown",
  13. 'py': "python",
  14. 'ts': "typescript",
  15. }
  16. def traverse(repo_path: str) -> list[str]:
  17. """
  18. Traverse the directory, fetch all files
  19. - skip hidden directories
  20. - only keep the supported files
  21. :param repo_path: path to this repo
  22. """
  23. def helper(root):
  24. for entry in os.scandir(root):
  25. if entry.name.startswith('.'):
  26. continue
  27. if entry.is_file():
  28. ext = entry.name.split('.')[-1].lower()
  29. if ext not in Languages.keys():
  30. continue
  31. file_paths.append(entry.path)
  32. elif entry.is_dir():
  33. helper(entry.path)
  34. file_paths = []
  35. helper(repo_path)
  36. return sorted(file_paths)
  37. def split_into_chunks(file_path, lines_per_chunk, lines_overlap, max_chars) -> list[BaseNode]:
  38. """
  39. Split file into chunks
  40. :param file_path: path to the file
  41. :param lines_per_chunk: lines for each chunk
  42. :param lines_overlap: overlap lines between 2 chunks
  43. :param max_chars: max characters for each chunk
  44. """
  45. ext = file_path.split('.')[-1].lower()
  46. lang = Languages.get(ext, None)
  47. if not lang:
  48. return []
  49. try:
  50. documents = FlatReader().load_data(Path(file_path))
  51. splitter = CodeSplitter(
  52. language=lang,
  53. chunk_lines=lines_per_chunk,
  54. chunk_lines_overlap=lines_overlap,
  55. max_chars=max_chars,
  56. )
  57. return splitter.get_nodes_from_documents(documents)
  58. except Exception as e:
  59. print(f'`{file_path}`切分失败: {e}')
  60. return []