Explorar o código

Count tokens worker (#3037)

* Count tokens worker

* Appease knip
Chris Estreich hai 9 meses
pai
achega
65cb9244ee

+ 20 - 11
esbuild.js

@@ -167,7 +167,7 @@ const extensionConfig = {
 		{
 			name: "alias-plugin",
 			setup(build) {
-				build.onResolve({ filter: /^pkce-challenge$/ }, (args) => {
+				build.onResolve({ filter: /^pkce-challenge$/ }, (_args) => {
 					return { path: require.resolve("pkce-challenge/dist/index.browser.js") }
 				})
 			},
@@ -181,22 +181,31 @@ const extensionConfig = {
 	external: ["vscode"],
 }
 
+const workerConfig = {
+	bundle: true,
+	minify: production,
+	sourcemap: !production,
+	logLevel: "silent",
+	entryPoints: ["src/workers/countTokens.ts"],
+	format: "cjs",
+	sourcesContent: false,
+	platform: "node",
+	outdir: "dist/workers",
+}
+
 async function main() {
-	const extensionCtx = await esbuild.context(extensionConfig)
+	const [extensionCtx, workerCtx] = await Promise.all([
+		esbuild.context(extensionConfig),
+		esbuild.context(workerConfig),
+	])
 
 	if (watch) {
-		// Start the esbuild watcher
-		await extensionCtx.watch()
-
-		// Copy and watch locale files
-		console.log("Copying locale files initially...")
+		await Promise.all([extensionCtx.watch(), workerCtx.watch()])
 		copyLocaleFiles()
-
-		// Set up the watcher for locale files
 		setupLocaleWatcher()
 	} else {
-		await extensionCtx.rebuild()
-		await extensionCtx.dispose()
+		await Promise.all([extensionCtx.rebuild(), workerCtx.rebuild()])
+		await Promise.all([extensionCtx.dispose(), workerCtx.dispose()])
 	}
 }
 

+ 1 - 0
knip.json

@@ -17,6 +17,7 @@
 		"evals/**",
 		"src/activate/**",
 		"src/exports/**",
+		"src/workers/**",
 		"src/schemas/ipc.ts",
 		"src/extension.ts",
 		"scripts/**"

+ 7 - 0
package-lock.json

@@ -63,6 +63,7 @@
 				"turndown": "^7.2.0",
 				"vscode-material-icons": "^0.1.1",
 				"web-tree-sitter": "^0.22.6",
+				"workerpool": "^9.2.0",
 				"zod": "^3.23.8"
 			},
 			"devDependencies": {
@@ -22021,6 +22022,12 @@
 				"node": ">=0.10.0"
 			}
 		},
+		"node_modules/workerpool": {
+			"version": "9.2.0",
+			"resolved": "https://registry.npmjs.org/workerpool/-/workerpool-9.2.0.tgz",
+			"integrity": "sha512-PKZqBOCo6CYkVOwAxWxQaSF2Fvb5Iv2fCeTP7buyWI2GiynWr46NcXSgK/idoV6e60dgCBfgYc+Un3HMvmqP8w==",
+			"license": "Apache-2.0"
+		},
 		"node_modules/wrap-ansi": {
 			"version": "8.1.0",
 			"resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz",

+ 1 - 0
package.json

@@ -433,6 +433,7 @@
 		"turndown": "^7.2.0",
 		"vscode-material-icons": "^0.1.1",
 		"web-tree-sitter": "^0.22.6",
+		"workerpool": "^9.2.0",
 		"zod": "^3.23.8"
 	},
 	"devDependencies": {

+ 11 - 47
src/api/providers/base-provider.ts

@@ -1,66 +1,30 @@
 import { Anthropic } from "@anthropic-ai/sdk"
-import { ApiHandler } from ".."
+
 import { ModelInfo } from "../../shared/api"
-import { ApiStream } from "../transform/stream"
-import { Tiktoken } from "tiktoken/lite"
-import o200kBase from "tiktoken/encoders/o200k_base"
 
-// Reuse the fudge factor used in the original code
-const TOKEN_FUDGE_FACTOR = 1.5
+import { ApiHandler } from "../index"
+import { ApiStream } from "../transform/stream"
+import { countTokens } from "../../utils/countTokens"
 
 /**
- * Base class for API providers that implements common functionality
+ * Base class for API providers that implements common functionality.
  */
 export abstract class BaseProvider implements ApiHandler {
-	// Cache the Tiktoken encoder instance since it's stateless
-	private encoder: Tiktoken | null = null
 	abstract createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
 	abstract getModel(): { id: string; info: ModelInfo }
 
 	/**
-	 * Default token counting implementation using tiktoken
-	 * Providers can override this to use their native token counting endpoints
-	 *
-	 * Uses a cached Tiktoken encoder instance for performance since it's stateless.
-	 * The encoder is created lazily on first use and reused for subsequent calls.
+	 * Default token counting implementation using tiktoken.
+	 * Providers can override this to use their native token counting endpoints.
 	 *
 	 * @param content The content to count tokens for
 	 * @returns A promise resolving to the token count
 	 */
-	async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
-		if (!content || content.length === 0) return 0
-
-		let totalTokens = 0
-
-		// Lazily create and cache the encoder if it doesn't exist
-		if (!this.encoder) {
-			this.encoder = new Tiktoken(o200kBase.bpe_ranks, o200kBase.special_tokens, o200kBase.pat_str)
-		}
-
-		// Process each content block using the cached encoder
-		for (const block of content) {
-			if (block.type === "text") {
-				// Use tiktoken for text token counting
-				const text = block.text || ""
-
-				if (text.length > 0) {
-					const tokens = this.encoder.encode(text)
-					totalTokens += tokens.length
-				}
-			} else if (block.type === "image") {
-				// For images, calculate based on data size
-				const imageSource = block.source
-
-				if (imageSource && typeof imageSource === "object" && "data" in imageSource) {
-					const base64Data = imageSource.data as string
-					totalTokens += Math.ceil(Math.sqrt(base64Data.length))
-				} else {
-					totalTokens += 300 // Conservative estimate for unknown images
-				}
-			}
+	async countTokens(content: Anthropic.Messages.ContentBlockParam[]): Promise<number> {
+		if (content.length === 0) {
+			return 0
 		}
 
-		// Add a fudge factor to account for the fact that tiktoken is not always accurate
-		return Math.ceil(totalTokens * TOKEN_FUDGE_FACTOR)
+		return countTokens(content, { useWorker: true })
 	}
 }

+ 142 - 0
src/utils/__tests__/tiktoken.test.ts

@@ -0,0 +1,142 @@
+// npx jest src/utils/__tests__/tiktoken.test.ts
+
+import { tiktoken } from "../tiktoken"
+import { Anthropic } from "@anthropic-ai/sdk"
+
+describe("tiktoken", () => {
+	it("should return 0 for empty content array", async () => {
+		const result = await tiktoken([])
+		expect(result).toBe(0)
+	})
+
+	it("should correctly count tokens for text content", async () => {
+		const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello world" }]
+
+		const result = await tiktoken(content)
+		// We can't predict the exact token count without mocking,
+		// but we can verify it's a positive number
+		expect(result).toEqual(3)
+	})
+
+	it("should handle empty text content", async () => {
+		const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "" }]
+
+		const result = await tiktoken(content)
+		expect(result).toBe(0)
+	})
+
+	it("should handle missing text content", async () => {
+		// Using 'as any' to bypass TypeScript's type checking for this test case
+		// since we're specifically testing how the function handles undefined text
+		const content = [{ type: "text" }] as any as Anthropic.Messages.ContentBlockParam[]
+
+		const result = await tiktoken(content)
+		expect(result).toBe(0)
+	})
+
+	it("should correctly count tokens for image content with data", async () => {
+		const base64Data =
+			"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
+		const content: Anthropic.Messages.ContentBlockParam[] = [
+			{
+				type: "image",
+				source: {
+					type: "base64",
+					media_type: "image/png",
+					data: base64Data,
+				},
+			},
+		]
+
+		const result = await tiktoken(content)
+		// For images, we expect a token count based on the square root of the data length
+		// plus the fudge factor
+		const expectedMinTokens = Math.ceil(Math.sqrt(base64Data.length))
+		expect(result).toBeGreaterThanOrEqual(expectedMinTokens)
+	})
+
+	it("should use conservative estimate for image content without data", async () => {
+		// Using 'as any' to bypass TypeScript's type checking for this test case
+		// since we're specifically testing the fallback behavior
+		const content = [
+			{
+				type: "image",
+				source: {
+					type: "base64",
+					media_type: "image/png",
+					// data is intentionally missing to test fallback
+				},
+			},
+		] as any as Anthropic.Messages.ContentBlockParam[]
+
+		const result = await tiktoken(content)
+		// Conservative estimate is 300 tokens, plus the fudge factor
+		const expectedMinTokens = 300
+		expect(result).toBeGreaterThanOrEqual(expectedMinTokens)
+	})
+
+	it("should correctly count tokens for mixed content", async () => {
+		const base64Data =
+			"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
+		const content: Anthropic.Messages.ContentBlockParam[] = [
+			{ type: "text", text: "Hello world" },
+			{
+				type: "image",
+				source: {
+					type: "base64",
+					media_type: "image/png",
+					data: base64Data,
+				},
+			},
+			{ type: "text", text: "Goodbye world" },
+		]
+
+		const result = await tiktoken(content)
+		// We expect a positive token count for mixed content
+		expect(result).toBeGreaterThan(0)
+	})
+
+	it("should apply a fudge factor to the token count", async () => {
+		// We can test the fudge factor by comparing the token count with a rough estimate
+		const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test" }]
+
+		const result = await tiktoken(content)
+
+		// Run the function again with the same content to get a consistent result
+		const result2 = await tiktoken(content)
+
+		// Both calls should return the same token count
+		expect(result).toBe(result2)
+
+		// The result should be greater than 0
+		expect(result).toBeGreaterThan(0)
+	})
+
+	it("should reuse the encoder for multiple calls", async () => {
+		// We can't directly test the caching behavior without mocking,
+		// but we can test that multiple calls with the same content return the same result
+		// which indirectly verifies the encoder is working consistently
+
+		const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello world" }]
+
+		// Time the first call which should create the encoder
+		const startTime1 = performance.now()
+		const result1 = await tiktoken(content)
+		const endTime1 = performance.now()
+		const duration1 = endTime1 - startTime1
+
+		// Time the second call which should reuse the encoder
+		const startTime2 = performance.now()
+		const result2 = await tiktoken(content)
+		const endTime2 = performance.now()
+		const duration2 = endTime2 - startTime2
+
+		// Both calls should return the same token count
+		expect(result1).toBe(result2)
+
+		// This is a loose test and might occasionally fail due to system load,
+		// but generally the second call should be faster or similar in speed
+		// since it reuses the encoder
+		expect(duration2).toBeLessThanOrEqual(duration1 * 1.5)
+	})
+})

+ 45 - 0
src/utils/countTokens.ts

@@ -0,0 +1,45 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import workerpool from "workerpool"
+
+import { countTokensResultSchema } from "../workers/types"
+import { tiktoken } from "./tiktoken"
+
+let pool: workerpool.Pool | null | undefined = undefined
+
+export type CountTokensOptions = {
+	useWorker?: boolean
+}
+
+export async function countTokens(
+	content: Anthropic.Messages.ContentBlockParam[],
+	{ useWorker = true }: CountTokensOptions = {},
+): Promise<number> {
+	// Lazily create the worker pool if it doesn't exist.
+	if (useWorker && typeof pool === "undefined") {
+		pool = workerpool.pool(__dirname + "/workers/countTokens.js", {
+			maxWorkers: 1,
+			maxQueueSize: 10,
+		})
+	}
+
+	// If the worker pool doesn't exist or the caller doesn't want to use it
+	// then, use the non-worker implementation.
+	if (!useWorker || !pool) {
+		return tiktoken(content)
+	}
+
+	try {
+		const data = await pool.exec("countTokens", [content])
+		const result = countTokensResultSchema.parse(data)
+
+		if (!result.success) {
+			throw new Error(result.error)
+		}
+
+		return result.count
+	} catch (error) {
+		pool = null
+		console.error(error)
+		return tiktoken(content)
+	}
+}

+ 46 - 0
src/utils/tiktoken.ts

@@ -0,0 +1,46 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { Tiktoken } from "tiktoken/lite"
+import o200kBase from "tiktoken/encoders/o200k_base"
+
+const TOKEN_FUDGE_FACTOR = 1.5
+
+let encoder: Tiktoken | null = null
+
+export async function tiktoken(content: Anthropic.Messages.ContentBlockParam[]): Promise<number> {
+	if (content.length === 0) {
+		return 0
+	}
+
+	let totalTokens = 0
+
+	// Lazily create and cache the encoder if it doesn't exist.
+	if (!encoder) {
+		encoder = new Tiktoken(o200kBase.bpe_ranks, o200kBase.special_tokens, o200kBase.pat_str)
+	}
+
+	// Process each content block using the cached encoder.
+	for (const block of content) {
+		if (block.type === "text") {
+			const text = block.text || ""
+
+			if (text.length > 0) {
+				const tokens = encoder.encode(text)
+				totalTokens += tokens.length
+			}
+		} else if (block.type === "image") {
+			// For images, calculate based on data size.
+			const imageSource = block.source
+
+			if (imageSource && typeof imageSource === "object" && "data" in imageSource) {
+				const base64Data = imageSource.data as string
+				totalTokens += Math.ceil(Math.sqrt(base64Data.length))
+			} else {
+				totalTokens += 300 // Conservative estimate for unknown images
+			}
+		}
+	}
+
+	// Add a fudge factor to account for the fact that tiktoken is not always
+	// accurate.
+	return Math.ceil(totalTokens * TOKEN_FUDGE_FACTOR)
+}

+ 21 - 0
src/workers/countTokens.ts

@@ -0,0 +1,21 @@
+import workerpool from "workerpool"
+
+import { Anthropic } from "@anthropic-ai/sdk"
+
+import { tiktoken } from "../utils/tiktoken"
+
+import { type CountTokensResult } from "./types"
+
+async function countTokens(content: Anthropic.Messages.ContentBlockParam[]): Promise<CountTokensResult> {
+	try {
+		const count = await tiktoken(content)
+		return { success: true, count }
+	} catch (error) {
+		return {
+			success: false,
+			error: error instanceof Error ? error.message : "Unknown error",
+		}
+	}
+}
+
+workerpool.worker({ countTokens })

+ 11 - 0
src/workers/types.ts

@@ -0,0 +1,11 @@
+import { z } from "zod"
+
+export const countTokensResultSchema = z.discriminatedUnion("success", [
+	z.object({
+		success: z.literal(true),
+		count: z.number(),
+	}),
+	z.object({ success: z.literal(false), error: z.string() }),
+])
+
+export type CountTokensResult = z.infer<typeof countTokensResultSchema>