Răsfoiți Sursa

feat: enhance token counting by extracting text from messages using VSCode LM API (#6424)

NaccOll 5 luni în urmă
părinte
comite
181993f639

+ 6 - 10
src/api/providers/vscode-lm.ts

@@ -7,7 +7,7 @@ import type { ApiHandlerOptions } from "../../shared/api"
 import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"
 
 import { ApiStream } from "../transform/stream"
-import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
+import { convertToVsCodeLmMessages, extractTextCountFromMessage } from "../transform/vscode-lm-format"
 
 import { BaseProvider } from "./base-provider"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -231,7 +231,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 					console.debug("Roo Code <Language Model API>: Empty chat message content")
 					return 0
 				}
-				tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
+				const countMessage = extractTextCountFromMessage(text)
+				tokenCount = await this.client.countTokens(countMessage, this.currentRequestCancellation.token)
 			} else {
 				console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
 				return 0
@@ -268,15 +269,10 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 		}
 	}
 
-	private async calculateTotalInputTokens(
-		systemPrompt: string,
-		vsCodeLmMessages: vscode.LanguageModelChatMessage[],
-	): Promise<number> {
-		const systemTokens: number = await this.internalCountTokens(systemPrompt)
-
+	private async calculateTotalInputTokens(vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
 		const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.internalCountTokens(msg)))
 
-		return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
+		return messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
 	}
 
 	private ensureCleanState(): void {
@@ -359,7 +355,7 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 		this.currentRequestCancellation = new vscode.CancellationTokenSource()
 
 		// Calculate input tokens before starting the stream
-		const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
+		const totalInputTokens: number = await this.calculateTotalInputTokens(vsCodeLmMessages)
 
 		// Accumulate the text and count at the end of the stream to reduce token counting overhead.
 		let accumulatedText: string = ""

+ 160 - 5
src/api/transform/__tests__/vscode-lm-format.spec.ts

@@ -1,8 +1,9 @@
 // npx vitest run src/api/transform/__tests__/vscode-lm-format.spec.ts
 
 import { Anthropic } from "@anthropic-ai/sdk"
+import * as vscode from "vscode"
 
-import { convertToVsCodeLmMessages, convertToAnthropicRole } from "../vscode-lm-format"
+import { convertToVsCodeLmMessages, convertToAnthropicRole, extractTextCountFromMessage } from "../vscode-lm-format"
 
 // Mock crypto using Vitest
 vitest.stubGlobal("crypto", {
@@ -24,8 +25,8 @@ interface MockLanguageModelToolCallPart {
 
 interface MockLanguageModelToolResultPart {
 	type: "tool_result"
-	toolUseId: string
-	parts: MockLanguageModelTextPart[]
+	callId: string
+	content: MockLanguageModelTextPart[]
 }
 
 // Mock vscode namespace
@@ -52,8 +53,8 @@ vitest.mock("vscode", () => {
 	class MockLanguageModelToolResultPart {
 		type = "tool_result"
 		constructor(
-			public toolUseId: string,
-			public parts: MockLanguageModelTextPart[],
+			public callId: string,
+			public content: MockLanguageModelTextPart[],
 		) {}
 	}
 
@@ -189,3 +190,157 @@ describe("convertToAnthropicRole", () => {
 		expect(result).toBeNull()
 	})
 })
+
+describe("extractTextCountFromMessage", () => {
+	it("should extract text from simple string content", () => {
+		const message = {
+			role: "user",
+			content: "Hello world",
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("Hello world")
+	})
+
+	it("should extract text from LanguageModelTextPart", () => {
+		const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content")
+		const message = {
+			role: "user",
+			content: [mockTextPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("Text content")
+	})
+
+	it("should extract text from multiple LanguageModelTextParts", () => {
+		const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("First part")
+		const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Second part")
+		const message = {
+			role: "user",
+			content: [mockTextPart1, mockTextPart2],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("First partSecond part")
+	})
+
+	it("should extract text from LanguageModelToolResultPart", () => {
+		const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result content")
+		const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("tool-result-id", [
+			mockTextPart,
+		])
+		const message = {
+			role: "user",
+			content: [mockToolResultPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("tool-result-idTool result content")
+	})
+
+	it("should extract text from LanguageModelToolCallPart without input", () => {
+		const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {})
+		const message = {
+			role: "assistant",
+			content: [mockToolCallPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("tool-namecall-id")
+	})
+
+	it("should extract text from LanguageModelToolCallPart with input", () => {
+		const mockInput = { operation: "add", numbers: [1, 2, 3] }
+		const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)(
+			"call-id",
+			"calculator",
+			mockInput,
+		)
+		const message = {
+			role: "assistant",
+			content: [mockToolCallPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe(`calculatorcall-id${JSON.stringify(mockInput)}`)
+	})
+
+	it("should extract text from LanguageModelToolCallPart with empty input", () => {
+		const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {})
+		const message = {
+			role: "assistant",
+			content: [mockToolCallPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("tool-namecall-id")
+	})
+
+	it("should extract text from mixed content types", () => {
+		const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content")
+		const mockToolResultTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result")
+		const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [
+			mockToolResultTextPart,
+		])
+		const mockInput = { param: "value" }
+		const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool", mockInput)
+
+		const message = {
+			role: "assistant",
+			content: [mockTextPart, mockToolResultPart, mockToolCallPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe(`Text contentresult-idTool resulttoolcall-id${JSON.stringify(mockInput)}`)
+	})
+
+	it("should handle empty array content", () => {
+		const message = {
+			role: "user",
+			content: [],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("")
+	})
+
+	it("should handle undefined content", () => {
+		const message = {
+			role: "user",
+			content: undefined,
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("")
+	})
+
+	it("should handle ToolResultPart with multiple text parts", () => {
+		const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 1")
+		const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 2")
+		const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [
+			mockTextPart1,
+			mockTextPart2,
+		])
+
+		const message = {
+			role: "user",
+			content: [mockToolResultPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("result-idPart 1Part 2")
+	})
+
+	it("should handle ToolResultPart with empty parts array", () => {
+		const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [])
+
+		const message = {
+			role: "user",
+			content: [mockToolResultPart],
+		} as any
+
+		const result = extractTextCountFromMessage(message)
+		expect(result).toBe("result-id")
+	})
+})

+ 38 - 0
src/api/transform/vscode-lm-format.ts

@@ -155,3 +155,41 @@ export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModel
 			return null
 	}
 }
+
+/**
+ * Extracts the text content from a VS Code Language Model chat message.
+ * @param message A VS Code Language Model chat message.
+ * @returns The extracted text content.
+ */
+export function extractTextCountFromMessage(message: vscode.LanguageModelChatMessage): string {
+	let text = ""
+	if (Array.isArray(message.content)) {
+		for (const item of message.content) {
+			if (item instanceof vscode.LanguageModelTextPart) {
+				text += item.value
+			}
+			if (item instanceof vscode.LanguageModelToolResultPart) {
+				text += item.callId
+				for (const part of item.content) {
+					if (part instanceof vscode.LanguageModelTextPart) {
+						text += part.value
+					}
+				}
+			}
+			if (item instanceof vscode.LanguageModelToolCallPart) {
+				text += item.name
+				text += item.callId
+				if (item.input && Object.keys(item.input).length > 0) {
+					try {
+						text += JSON.stringify(item.input)
+					} catch (error) {
+						console.error("Roo Code <Language Model API>: Failed to stringify tool call input:", error)
+					}
+				}
+			}
+		}
+	} else if (typeof message.content === "string") {
+		text += message.content
+	}
+	return text
+}