瀏覽代碼

Fix tree-sitter (#4857)

Chris Estreich 6 月之前
父節點
當前提交
72cb248ef8
共有 2 個文件被更改,包括 135 次插入197 次删除
  1. 43 123
      src/services/tree-sitter/__tests__/languageParser.spec.ts
  2. 92 74
      src/services/tree-sitter/languageParser.ts

+ 43 - 123
src/services/tree-sitter/__tests__/languageParser.spec.ts

@@ -1,136 +1,56 @@
 // npx vitest services/tree-sitter/__tests__/languageParser.spec.ts
 
+import * as path from "path"
 import { loadRequiredLanguageParsers } from "../languageParser"
 
-vi.mock("web-tree-sitter", () => {
-	const mockParserInit = vi.fn().mockResolvedValue(undefined)
-	const mockLanguageLoad = vi.fn().mockResolvedValue({
-		query: vi.fn().mockReturnValue({ id: "mock-query" }),
-	})
-	const mockSetLanguage = vi.fn()
-
-	// Create a constructor function that also has static methods
-	function MockParser() {
-		return {
-			setLanguage: mockSetLanguage,
-		}
-	}
-	MockParser.init = mockParserInit
-
-	return {
-		Parser: MockParser,
-		Language: {
-			load: mockLanguageLoad,
-		},
-		// Export the mocks so tests can access them
-		__mocks: {
-			mockParserInit,
-			mockLanguageLoad,
-			mockSetLanguage,
-		},
-	}
-})
-
-// Import the mocked module to get access to the mock functions
-const { __mocks } = (await import("web-tree-sitter")) as any
-const { mockParserInit, mockLanguageLoad, mockSetLanguage } = __mocks
+// Path to the directory containing the WASM files.
+const WASM_DIR = path.join(__dirname, "../../../node_modules/tree-sitter-wasms/out")
 
-describe("Language Parser", () => {
-	beforeEach(() => {
-		vi.clearAllMocks()
+describe("loadRequiredLanguageParsers", () => {
+	it("should load Python parser for .py files", async () => {
+		const files = ["test.py"]
+		const parsers = await loadRequiredLanguageParsers(files, WASM_DIR)
+		expect(parsers.py).toBeDefined()
 	})
 
-	describe("loadRequiredLanguageParsers", () => {
-		it("should initialize parser only once", async () => {
-			const files = ["test.js", "test2.js"]
-			await loadRequiredLanguageParsers(files)
-			await loadRequiredLanguageParsers(files)
-
-			expect(mockParserInit).toHaveBeenCalledTimes(1)
-		})
-
-		it("should load JavaScript parser for .js and .jsx files", async () => {
-			const files = ["test.js", "test.jsx"]
-			const parsers = await loadRequiredLanguageParsers(files)
-
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-javascript.wasm"))
-			expect(parsers.js).toBeDefined()
-			expect(parsers.jsx).toBeDefined()
-			expect(parsers.js.query).toBeDefined()
-			expect(parsers.jsx.query).toBeDefined()
-		})
-
-		it("should load TypeScript parser for .ts and .tsx files", async () => {
-			const files = ["test.ts", "test.tsx"]
-			const parsers = await loadRequiredLanguageParsers(files)
-
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-typescript.wasm"))
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-tsx.wasm"))
-			expect(parsers.ts).toBeDefined()
-			expect(parsers.tsx).toBeDefined()
-		})
-
-		it("should load Python parser for .py files", async () => {
-			const files = ["test.py"]
-			const parsers = await loadRequiredLanguageParsers(files)
-
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-python.wasm"))
-			expect(parsers.py).toBeDefined()
-		})
-
-		it("should load multiple language parsers as needed", async () => {
-			const files = ["test.js", "test.py", "test.rs", "test.go"]
-			const parsers = await loadRequiredLanguageParsers(files)
-
-			expect(mockLanguageLoad).toHaveBeenCalledTimes(4)
-			expect(parsers.js).toBeDefined()
-			expect(parsers.py).toBeDefined()
-			expect(parsers.rs).toBeDefined()
-			expect(parsers.go).toBeDefined()
-		})
-
-		it("should handle C/C++ files correctly", async () => {
-			const files = ["test.c", "test.h", "test.cpp", "test.hpp"]
-			const parsers = await loadRequiredLanguageParsers(files)
-
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-c.wasm"))
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-cpp.wasm"))
-			expect(parsers.c).toBeDefined()
-			expect(parsers.h).toBeDefined()
-			expect(parsers.cpp).toBeDefined()
-			expect(parsers.hpp).toBeDefined()
-		})
-
-		it("should handle Kotlin files correctly", async () => {
-			const files = ["test.kt", "test.kts"]
-			const parsers = await loadRequiredLanguageParsers(files)
-
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-kotlin.wasm"))
-			expect(parsers.kt).toBeDefined()
-			expect(parsers.kts).toBeDefined()
-			expect(parsers.kt.query).toBeDefined()
-			expect(parsers.kts.query).toBeDefined()
-		})
-
-		it("should throw error for unsupported file extensions", async () => {
-			const files = ["test.unsupported"]
-
-			await expect(loadRequiredLanguageParsers(files)).rejects.toThrow("Unsupported language: unsupported")
-		})
+	it("should load JavaScript parser for .js and .jsx files", async () => {
+		const files = ["test.js", "test.jsx"]
+		const parsers = await loadRequiredLanguageParsers(files, WASM_DIR)
+		expect(parsers.js).toBeDefined()
+		expect(parsers.jsx).toBeDefined()
+		expect(parsers.js.query).toBeDefined()
+		expect(parsers.jsx.query).toBeDefined()
+	})
 
-		it("should load each language only once for multiple files", async () => {
-			const files = ["test1.js", "test2.js", "test3.js"]
-			await loadRequiredLanguageParsers(files)
+	it("should load multiple language parsers as needed", async () => {
+		const files = ["test.js", "test.py", "test.rs", "test.go"]
+		const parsers = await loadRequiredLanguageParsers(files, WASM_DIR)
+		expect(parsers.js).toBeDefined()
+		expect(parsers.py).toBeDefined()
+		expect(parsers.rs).toBeDefined()
+		expect(parsers.go).toBeDefined()
+	})
 
-			expect(mockLanguageLoad).toHaveBeenCalledTimes(1)
-			expect(mockLanguageLoad).toHaveBeenCalledWith(expect.stringContaining("tree-sitter-javascript.wasm"))
-		})
+	it("should handle C/C++ files correctly", async () => {
+		const files = ["test.c", "test.h", "test.cpp", "test.hpp"]
+		const parsers = await loadRequiredLanguageParsers(files, WASM_DIR)
+		expect(parsers.c).toBeDefined()
+		expect(parsers.h).toBeDefined()
+		expect(parsers.cpp).toBeDefined()
+		expect(parsers.hpp).toBeDefined()
+	})
 
-		it("should set language for each parser instance", async () => {
-			const files = ["test.js", "test.py"]
-			await loadRequiredLanguageParsers(files)
+	it("should handle Kotlin files correctly", async () => {
+		const files = ["test.kt", "test.kts"]
+		const parsers = await loadRequiredLanguageParsers(files, WASM_DIR)
+		expect(parsers.kt).toBeDefined()
+		expect(parsers.kts).toBeDefined()
+		expect(parsers.kt.query).toBeDefined()
+		expect(parsers.kts.query).toBeDefined()
+	})
 
-			expect(mockSetLanguage).toHaveBeenCalledTimes(2)
-		})
+	it("should throw error for unsupported file extensions", async () => {
+		const files = ["test.unsupported"]
+		await expect(loadRequiredLanguageParsers(files, WASM_DIR)).rejects.toThrow("Unsupported language: unsupported")
 	})
 })

+ 92 - 74
src/services/tree-sitter/languageParser.ts

@@ -1,5 +1,5 @@
 import * as path from "path"
-import { Parser, Query, Language } from "web-tree-sitter"
+import { Parser as ParserT, Language as LanguageT, Query as QueryT } from "web-tree-sitter"
 import {
 	javascriptQuery,
 	typescriptQuery,
@@ -32,30 +32,33 @@ import {
 
 export interface LanguageParser {
 	[key: string]: {
-		parser: Parser
-		query: Query
+		parser: ParserT
+		query: QueryT
 	}
 }
 
-async function loadLanguage(langName: string) {
-	return await Language.load(path.join(__dirname, `tree-sitter-${langName}.wasm`))
-}
-
-let isParserInitialized = false
+async function loadLanguage(langName: string, sourceDirectory?: string) {
+	const baseDir = sourceDirectory || __dirname
+	const wasmPath = path.join(baseDir, `tree-sitter-${langName}.wasm`)
 
-async function initializeParser() {
-	if (!isParserInitialized) {
-		await Parser.init()
-		isParserInitialized = true
+	try {
+		const { Language } = require("web-tree-sitter")
+		return await Language.load(wasmPath)
+	} catch (error) {
+		console.error(`Error loading language: ${wasmPath}: ${error instanceof Error ? error.message : error}`)
+		throw error
 	}
 }
 
+let isParserInitialized = false
+
 /*
 Using node bindings for tree-sitter is problematic in vscode extensions 
 because of incompatibility with electron. Going the .wasm route has the 
 advantage of not having to build for multiple architectures.
 
-We use web-tree-sitter and tree-sitter-wasms which provides auto-updating prebuilt WASM binaries for tree-sitter's language parsers.
+We use web-tree-sitter and tree-sitter-wasms which provides auto-updating
+prebuilt WASM binaries for tree-sitter's language parsers.
 
 This function loads WASM modules for relevant language parsers based on input files:
 1. Extracts unique file extensions
@@ -72,142 +75,157 @@ Sources:
 - https://github.com/tree-sitter/tree-sitter/blob/master/lib/binding_web/README.md
 - https://github.com/tree-sitter/tree-sitter/blob/master/lib/binding_web/test/query-test.js
 */
-export async function loadRequiredLanguageParsers(filesToParse: string[]): Promise<LanguageParser> {
-	await initializeParser()
+export async function loadRequiredLanguageParsers(filesToParse: string[], sourceDirectory?: string) {
+	const { Parser, Query } = require("web-tree-sitter")
+
+	if (!isParserInitialized) {
+		try {
+			await Parser.init()
+			isParserInitialized = true
+		} catch (error) {
+			console.error(`Error initializing parser: ${error instanceof Error ? error.message : error}`)
+			throw error
+		}
+	}
+
 	const extensionsToLoad = new Set(filesToParse.map((file) => path.extname(file).toLowerCase().slice(1)))
 	const parsers: LanguageParser = {}
+
 	for (const ext of extensionsToLoad) {
-		let language: Language
-		let query: Query
+		let language: LanguageT
+		let query: QueryT
 		let parserKey = ext // Default to using extension as key
+
 		switch (ext) {
 			case "js":
 			case "jsx":
 			case "json":
-				language = await loadLanguage("javascript")
-				query = language.query(javascriptQuery)
+				language = await loadLanguage("javascript", sourceDirectory)
+				query = new Query(language, javascriptQuery)
 				break
 			case "ts":
-				language = await loadLanguage("typescript")
-				query = language.query(typescriptQuery)
+				language = await loadLanguage("typescript", sourceDirectory)
+				query = new Query(language, typescriptQuery)
 				break
 			case "tsx":
-				language = await loadLanguage("tsx")
-				query = language.query(tsxQuery)
+				language = await loadLanguage("tsx", sourceDirectory)
+				query = new Query(language, tsxQuery)
 				break
 			case "py":
-				language = await loadLanguage("python")
-				query = language.query(pythonQuery)
+				language = await loadLanguage("python", sourceDirectory)
+				query = new Query(language, pythonQuery)
 				break
 			case "rs":
-				language = await loadLanguage("rust")
-				query = language.query(rustQuery)
+				language = await loadLanguage("rust", sourceDirectory)
+				query = new Query(language, rustQuery)
 				break
 			case "go":
-				language = await loadLanguage("go")
-				query = language.query(goQuery)
+				language = await loadLanguage("go", sourceDirectory)
+				query = new Query(language, goQuery)
 				break
 			case "cpp":
 			case "hpp":
-				language = await loadLanguage("cpp")
-				query = language.query(cppQuery)
+				language = await loadLanguage("cpp", sourceDirectory)
+				query = new Query(language, cppQuery)
 				break
 			case "c":
 			case "h":
-				language = await loadLanguage("c")
-				query = language.query(cQuery)
+				language = await loadLanguage("c", sourceDirectory)
+				query = new Query(language, cQuery)
 				break
 			case "cs":
-				language = await loadLanguage("c_sharp")
-				query = language.query(csharpQuery)
+				language = await loadLanguage("c_sharp", sourceDirectory)
+				query = new Query(language, csharpQuery)
 				break
 			case "rb":
-				language = await loadLanguage("ruby")
-				query = language.query(rubyQuery)
+				language = await loadLanguage("ruby", sourceDirectory)
+				query = new Query(language, rubyQuery)
 				break
 			case "java":
-				language = await loadLanguage("java")
-				query = language.query(javaQuery)
+				language = await loadLanguage("java", sourceDirectory)
+				query = new Query(language, javaQuery)
 				break
 			case "php":
-				language = await loadLanguage("php")
-				query = language.query(phpQuery)
+				language = await loadLanguage("php", sourceDirectory)
+				query = new Query(language, phpQuery)
 				break
 			case "swift":
-				language = await loadLanguage("swift")
-				query = language.query(swiftQuery)
+				language = await loadLanguage("swift", sourceDirectory)
+				query = new Query(language, swiftQuery)
 				break
 			case "kt":
 			case "kts":
-				language = await loadLanguage("kotlin")
-				query = language.query(kotlinQuery)
+				language = await loadLanguage("kotlin", sourceDirectory)
+				query = new Query(language, kotlinQuery)
 				break
 			case "css":
-				language = await loadLanguage("css")
-				query = language.query(cssQuery)
+				language = await loadLanguage("css", sourceDirectory)
+				query = new Query(language, cssQuery)
 				break
 			case "html":
-				language = await loadLanguage("html")
-				query = language.query(htmlQuery)
+				language = await loadLanguage("html", sourceDirectory)
+				query = new Query(language, htmlQuery)
 				break
 			case "ml":
 			case "mli":
-				language = await loadLanguage("ocaml")
-				query = language.query(ocamlQuery)
+				language = await loadLanguage("ocaml", sourceDirectory)
+				query = new Query(language, ocamlQuery)
 				break
 			case "scala":
-				language = await loadLanguage("scala")
-				query = language.query(luaQuery) // Temporarily use Lua query until Scala is implemented
+				language = await loadLanguage("scala", sourceDirectory)
+				query = new Query(language, luaQuery) // Temporarily use Lua query until Scala is implemented
 				break
 			case "sol":
-				language = await loadLanguage("solidity")
-				query = language.query(solidityQuery)
+				language = await loadLanguage("solidity", sourceDirectory)
+				query = new Query(language, solidityQuery)
 				break
 			case "toml":
-				language = await loadLanguage("toml")
-				query = language.query(tomlQuery)
+				language = await loadLanguage("toml", sourceDirectory)
+				query = new Query(language, tomlQuery)
 				break
 			case "vue":
-				language = await loadLanguage("vue")
-				query = language.query(vueQuery)
+				language = await loadLanguage("vue", sourceDirectory)
+				query = new Query(language, vueQuery)
 				break
 			case "lua":
-				language = await loadLanguage("lua")
-				query = language.query(luaQuery)
+				language = await loadLanguage("lua", sourceDirectory)
+				query = new Query(language, luaQuery)
 				break
 			case "rdl":
-				language = await loadLanguage("systemrdl")
-				query = language.query(systemrdlQuery)
+				language = await loadLanguage("systemrdl", sourceDirectory)
+				query = new Query(language, systemrdlQuery)
 				break
 			case "tla":
-				language = await loadLanguage("tlaplus")
-				query = language.query(tlaPlusQuery)
+				language = await loadLanguage("tlaplus", sourceDirectory)
+				query = new Query(language, tlaPlusQuery)
 				break
 			case "zig":
-				language = await loadLanguage("zig")
-				query = language.query(zigQuery)
+				language = await loadLanguage("zig", sourceDirectory)
+				query = new Query(language, zigQuery)
 				break
 			case "ejs":
 			case "erb":
-				language = await loadLanguage("embedded_template")
-				parserKey = "embedded_template" // Use same key for both extensions
-				query = language.query(embeddedTemplateQuery)
+				parserKey = "embedded_template" // Use same key for both extensions.
+				language = await loadLanguage("embedded_template", sourceDirectory)
+				query = new Query(language, embeddedTemplateQuery)
 				break
 			case "el":
-				language = await loadLanguage("elisp")
-				query = language.query(elispQuery)
+				language = await loadLanguage("elisp", sourceDirectory)
+				query = new Query(language, elispQuery)
 				break
 			case "ex":
 			case "exs":
-				language = await loadLanguage("elixir")
-				query = language.query(elixirQuery)
+				language = await loadLanguage("elixir", sourceDirectory)
+				query = new Query(language, elixirQuery)
 				break
 			default:
 				throw new Error(`Unsupported language: ${ext}`)
 		}
+
 		const parser = new Parser()
 		parser.setLanguage(language)
 		parsers[parserKey] = { parser, query }
 	}
+
 	return parsers
 }