Przeglądaj źródła

Estimate the number of tokens in the last message for sliding window math

Chris Estreich 10 miesięcy temu
rodzic
commit
481d613bc8

+ 1 - 1
e2e/tsconfig.json

@@ -11,6 +11,6 @@
 		"useUnknownInCatchVariables": false,
 		"outDir": "out"
 	},
-	"include": ["src", "../src/exports"],
+	"include": ["src", "../src/exports/cline.d.ts"],
 	"exclude": [".vscode-test", "**/node_modules/**", "out"]
 }

+ 9 - 0
package-lock.json

@@ -34,6 +34,7 @@
 				"get-folder-size": "^5.0.0",
 				"globby": "^14.0.2",
 				"isbinaryfile": "^5.0.2",
+				"js-tiktoken": "^1.0.19",
 				"mammoth": "^1.8.0",
 				"monaco-vscode-textmate-theme-converter": "^0.1.7",
 				"openai": "^4.78.1",
@@ -10909,6 +10910,14 @@
 				"jiti": "lib/jiti-cli.mjs"
 			}
 		},
+		"node_modules/js-tiktoken": {
+			"version": "1.0.19",
+			"resolved": "https://registry.npmjs.org/js-tiktoken/-/js-tiktoken-1.0.19.tgz",
+			"integrity": "sha512-XC63YQeEcS47Y53gg950xiZ4IWmkfMe4p2V9OSaBt26q+p47WHn18izuXzSclCI73B7yGqtfRsT6jcZQI0y08g==",
+			"dependencies": {
+				"base64-js": "^1.5.1"
+			}
+		},
 		"node_modules/js-tokens": {
 			"version": "4.0.0",
 			"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",

+ 1 - 0
package.json

@@ -348,6 +348,7 @@
 		"sound-play": "^1.1.0",
 		"string-similarity": "^4.0.4",
 		"strip-ansi": "^7.1.0",
+		"js-tiktoken": "^1.0.19",
 		"tmp": "^0.2.3",
 		"tree-sitter-wasms": "^0.1.11",
 		"turndown": "^7.2.0",

+ 199 - 25
src/core/sliding-window/__tests__/sliding-window.test.ts

@@ -3,7 +3,7 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 
 import { ModelInfo } from "../../../shared/api"
-import { truncateConversation, truncateConversationIfNeeded } from "../index"
+import { estimateTokenCount, truncateConversation, truncateConversationIfNeeded } from "../index"
 
 /**
  * Tests for the truncateConversation function
@@ -118,23 +118,26 @@ describe("getMaxTokens", () => {
 		const modelInfo = createModelInfo(100000, 50000)
 		// Max tokens = 100000 - 50000 = 50000
 
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
 		// Below max tokens - no truncation
 		const result1 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 49999,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result1).toEqual(messages)
+		expect(result1).toEqual(messagesWithSmallContent)
 
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 50001,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result2).not.toEqual(messages)
+		expect(result2).not.toEqual(messagesWithSmallContent)
 		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
 	})
 
@@ -142,23 +145,26 @@ describe("getMaxTokens", () => {
 		const modelInfo = createModelInfo(100000, undefined)
 		// Max tokens = 100000 - (100000 * 0.2) = 80000
 
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
 		// Below max tokens - no truncation
 		const result1 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 79999,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result1).toEqual(messages)
+		expect(result1).toEqual(messagesWithSmallContent)
 
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 80001,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result2).not.toEqual(messages)
+		expect(result2).not.toEqual(messagesWithSmallContent)
 		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
 	})
 
@@ -166,23 +172,26 @@ describe("getMaxTokens", () => {
 		const modelInfo = createModelInfo(50000, 10000)
 		// Max tokens = 50000 - 10000 = 40000
 
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
 		// Below max tokens - no truncation
 		const result1 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 39999,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result1).toEqual(messages)
+		expect(result1).toEqual(messagesWithSmallContent)
 
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 40001,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result2).not.toEqual(messages)
+		expect(result2).not.toEqual(messagesWithSmallContent)
 		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
 	})
 
@@ -190,23 +199,26 @@ describe("getMaxTokens", () => {
 		const modelInfo = createModelInfo(200000, 30000)
 		// Max tokens = 200000 - 30000 = 170000
 
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
 		// Below max tokens - no truncation
 		const result1 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 169999,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result1).toEqual(messages)
+		expect(result1).toEqual(messagesWithSmallContent)
 
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens: 170001,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result2).not.toEqual(messages)
+		expect(result2).not.toEqual(messagesWithSmallContent)
 		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
 	})
 })
@@ -234,13 +246,16 @@ describe("truncateConversationIfNeeded", () => {
 		const maxTokens = 100000 - 30000 // 70000
 		const totalTokens = 69999 // Below threshold
 
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
 		const result = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
-		expect(result).toEqual(messages) // No truncation occurs
+		expect(result).toEqual(messagesWithSmallContent) // No truncation occurs
 	})
 
 	it("should truncate if tokens are above max tokens threshold", () => {
@@ -248,12 +263,15 @@ describe("truncateConversationIfNeeded", () => {
 		const maxTokens = 100000 - 30000 // 70000
 		const totalTokens = 70001 // Above threshold
 
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
 		// When truncating, always uses 0.5 fraction
 		// With 4 messages after the first, 0.5 fraction means remove 2 messages
-		const expectedResult = [messages[0], messages[3], messages[4]]
+		const expectedResult = [messagesWithSmallContent[0], messagesWithSmallContent[3], messagesWithSmallContent[4]]
 
 		const result = truncateConversationIfNeeded({
-			messages,
+			messages: messagesWithSmallContent,
 			totalTokens,
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
@@ -266,18 +284,21 @@ describe("truncateConversationIfNeeded", () => {
 		const modelInfo1 = createModelInfo(100000, true, 30000)
 		const modelInfo2 = createModelInfo(100000, false, 30000)
 
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
 		// Test below threshold
 		const belowThreshold = 69999
 		expect(
 			truncateConversationIfNeeded({
-				messages,
+				messages: messagesWithSmallContent,
 				totalTokens: belowThreshold,
 				contextWindow: modelInfo1.contextWindow,
 				maxTokens: modelInfo1.maxTokens,
 			}),
 		).toEqual(
 			truncateConversationIfNeeded({
-				messages,
+				messages: messagesWithSmallContent,
 				totalTokens: belowThreshold,
 				contextWindow: modelInfo2.contextWindow,
 				maxTokens: modelInfo2.maxTokens,
@@ -288,18 +309,171 @@ describe("truncateConversationIfNeeded", () => {
 		const aboveThreshold = 70001
 		expect(
 			truncateConversationIfNeeded({
-				messages,
+				messages: messagesWithSmallContent,
 				totalTokens: aboveThreshold,
 				contextWindow: modelInfo1.contextWindow,
 				maxTokens: modelInfo1.maxTokens,
 			}),
 		).toEqual(
 			truncateConversationIfNeeded({
-				messages,
+				messages: messagesWithSmallContent,
 				totalTokens: aboveThreshold,
 				contextWindow: modelInfo2.contextWindow,
 				maxTokens: modelInfo2.maxTokens,
 			}),
 		)
 	})
+
+	it("should consider incoming content when deciding to truncate", () => {
+		const modelInfo = createModelInfo(100000, true, 30000)
+		const maxTokens = 30000
+		const availableTokens = modelInfo.contextWindow - maxTokens
+
+		// Test case 1: Small content that won't push us over the threshold
+		const smallContent = [{ type: "text" as const, text: "Small content" }]
+		const smallContentTokens = estimateTokenCount(smallContent)
+		const messagesWithSmallContent: Anthropic.Messages.MessageParam[] = [
+			...messages.slice(0, -1),
+			{ role: messages[messages.length - 1].role, content: smallContent },
+		]
+
+		// Set base tokens so total is below threshold even with small content added
+		const baseTokensForSmall = availableTokens - smallContentTokens - 10
+		const resultWithSmall = truncateConversationIfNeeded({
+			messages: messagesWithSmallContent,
+			totalTokens: baseTokensForSmall,
+			contextWindow: modelInfo.contextWindow,
+			maxTokens,
+		})
+		expect(resultWithSmall).toEqual(messagesWithSmallContent) // No truncation
+
+		// Test case 2: Large content that will push us over the threshold
+		const largeContent = [
+			{
+				type: "text" as const,
+				text: "A very large incoming message that would consume a significant number of tokens and push us over the threshold",
+			},
+		]
+		const largeContentTokens = estimateTokenCount(largeContent)
+		const messagesWithLargeContent: Anthropic.Messages.MessageParam[] = [
+			...messages.slice(0, -1),
+			{ role: messages[messages.length - 1].role, content: largeContent },
+		]
+
+		// Set base tokens so we're just below threshold without content, but over with content
+		const baseTokensForLarge = availableTokens - Math.floor(largeContentTokens / 2)
+		const resultWithLarge = truncateConversationIfNeeded({
+			messages: messagesWithLargeContent,
+			totalTokens: baseTokensForLarge,
+			contextWindow: modelInfo.contextWindow,
+			maxTokens,
+		})
+		expect(resultWithLarge).not.toEqual(messagesWithLargeContent) // Should truncate
+
+		// Test case 3: Very large content that will definitely exceed threshold
+		const veryLargeContent = [{ type: "text" as const, text: "X".repeat(1000) }]
+		const veryLargeContentTokens = estimateTokenCount(veryLargeContent)
+		const messagesWithVeryLargeContent: Anthropic.Messages.MessageParam[] = [
+			...messages.slice(0, -1),
+			{ role: messages[messages.length - 1].role, content: veryLargeContent },
+		]
+
+		// Set base tokens so we're just below threshold without content
+		const baseTokensForVeryLarge = availableTokens - Math.floor(veryLargeContentTokens / 2)
+		const resultWithVeryLarge = truncateConversationIfNeeded({
+			messages: messagesWithVeryLargeContent,
+			totalTokens: baseTokensForVeryLarge,
+			contextWindow: modelInfo.contextWindow,
+			maxTokens,
+		})
+		expect(resultWithVeryLarge).not.toEqual(messagesWithVeryLargeContent) // Should truncate
+	})
+})
+/**
+ * Tests for the estimateTokenCount function
+ */
+describe("estimateTokenCount", () => {
+	it("should return 0 for empty or undefined content", () => {
+		expect(estimateTokenCount([])).toBe(0)
+		// @ts-ignore - Testing with undefined
+		expect(estimateTokenCount(undefined)).toBe(0)
+	})
+
+	it("should estimate tokens for text blocks", () => {
+		const content: Array<Anthropic.Messages.ContentBlockParam> = [
+			{ type: "text", text: "This is a text block with 36 characters" },
+		]
+
+		// With tiktoken, the exact token count may differ from character-based estimation
+		// Instead of expecting an exact number, we verify it's a reasonable positive number
+		const result = estimateTokenCount(content)
+		expect(result).toBeGreaterThan(0)
+
+		// We can also verify that longer text results in more tokens
+		const longerContent: Array<Anthropic.Messages.ContentBlockParam> = [
+			{
+				type: "text",
+				text: "This is a longer text block with significantly more characters to encode into tokens",
+			},
+		]
+		const longerResult = estimateTokenCount(longerContent)
+		expect(longerResult).toBeGreaterThan(result)
+	})
+
+	it("should estimate tokens for image blocks based on data size", () => {
+		// Small image
+		const smallImage: Array<Anthropic.Messages.ContentBlockParam> = [
+			{ type: "image", source: { type: "base64", media_type: "image/jpeg", data: "small_dummy_data" } },
+		]
+		// Larger image with more data
+		const largerImage: Array<Anthropic.Messages.ContentBlockParam> = [
+			{ type: "image", source: { type: "base64", media_type: "image/png", data: "X".repeat(1000) } },
+		]
+
+		// Verify the token count scales with the size of the image data
+		const smallImageTokens = estimateTokenCount(smallImage)
+		const largerImageTokens = estimateTokenCount(largerImage)
+
+		// Small image should have some tokens
+		expect(smallImageTokens).toBeGreaterThan(0)
+
+		// Larger image should have proportionally more tokens
+		expect(largerImageTokens).toBeGreaterThan(smallImageTokens)
+
+		// Verify the larger image calculation matches our formula including the 50% fudge factor
+		expect(largerImageTokens).toBe(48)
+	})
+
+	it("should estimate tokens for mixed content blocks", () => {
+		const content: Array<Anthropic.Messages.ContentBlockParam> = [
+			{ type: "text", text: "A text block with 30 characters" },
+			{ type: "image", source: { type: "base64", media_type: "image/jpeg", data: "dummy_data" } },
+			{ type: "text", text: "Another text with 24 chars" },
+		]
+
+		// We know image tokens calculation should be consistent
+		const imageTokens = Math.ceil(Math.sqrt("dummy_data".length)) * 1.5
+
+		// With tiktoken, we can't predict exact text token counts,
+		// but we can verify the total is greater than just the image tokens
+		const result = estimateTokenCount(content)
+		expect(result).toBeGreaterThan(imageTokens)
+
+		// Also test against a version with only the image to verify text adds tokens
+		const imageOnlyContent: Array<Anthropic.Messages.ContentBlockParam> = [
+			{ type: "image", source: { type: "base64", media_type: "image/jpeg", data: "dummy_data" } },
+		]
+		const imageOnlyResult = estimateTokenCount(imageOnlyContent)
+		expect(result).toBeGreaterThan(imageOnlyResult)
+	})
+
+	it("should handle empty text blocks", () => {
+		const content: Array<Anthropic.Messages.ContentBlockParam> = [{ type: "text", text: "" }]
+		expect(estimateTokenCount(content)).toBe(0)
+	})
+
+	it("should handle plain string messages", () => {
+		const content = "This is a plain text message"
+		expect(estimateTokenCount([{ type: "text", text: content }])).toBeGreaterThan(0)
+	})
 })

+ 66 - 4
src/core/sliding-window/index.ts

@@ -1,5 +1,51 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 
+import { Tiktoken } from "js-tiktoken/lite"
+import o200kBase from "js-tiktoken/ranks/o200k_base"
+
+const TOKEN_FUDGE_FACTOR = 1.5
+
+/**
+ * Counts tokens for user content using tiktoken for text
+ * and a size-based calculation for images.
+ *
+ * @param {Array<Anthropic.Messages.ContentBlockParam>} content - The content to count tokens for
+ * @returns {number} The token count
+ */
+export function estimateTokenCount(content: Array<Anthropic.Messages.ContentBlockParam>): number {
+	if (!content || content.length === 0) return 0
+
+	let totalTokens = 0
+	let encoder = null
+
+	// Create encoder
+	encoder = new Tiktoken(o200kBase)
+
+	// Process each content block
+	for (const block of content) {
+		if (block.type === "text") {
+			// Use tiktoken for text token counting
+			const text = block.text || ""
+			if (text.length > 0) {
+				const tokens = encoder.encode(text)
+				totalTokens += tokens.length
+			}
+		} else if (block.type === "image") {
+			// For images, calculate based on data size
+			const imageSource = block.source
+			if (imageSource && typeof imageSource === "object" && "data" in imageSource) {
+				const base64Data = imageSource.data as string
+				totalTokens += Math.ceil(Math.sqrt(base64Data.length))
+			} else {
+				totalTokens += 300 // Conservative estimate for unknown images
+			}
+		}
+	}
+
+	// Add a fudge factor to account for the fact that tiktoken is not always accurate
+	return Math.ceil(totalTokens * TOKEN_FUDGE_FACTOR)
+}
+
 /**
  * Truncates a conversation by removing a fraction of the messages.
  *
@@ -25,10 +71,10 @@ export function truncateConversation(
 
 /**
  * Conditionally truncates the conversation messages if the total token count
- * exceeds the model's limit.
+ * exceeds the model's limit, considering the size of incoming content.
  *
  * @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages.
- * @param {number} totalTokens - The total number of tokens in the conversation.
+ * @param {number} totalTokens - The total number of tokens in the conversation (excluding the last user message).
  * @param {number} contextWindow - The context window size.
  * @param {number} maxTokens - The maximum number of tokens allowed.
  * @returns {Anthropic.Messages.MessageParam[]} The original or truncated conversation messages.
@@ -47,6 +93,22 @@ export function truncateConversationIfNeeded({
 	contextWindow,
 	maxTokens,
 }: TruncateOptions): Anthropic.Messages.MessageParam[] {
-	const allowedTokens = contextWindow - (maxTokens || contextWindow * 0.2)
-	return totalTokens < allowedTokens ? messages : truncateConversation(messages, 0.5)
+	// Calculate the maximum tokens reserved for response
+	const reservedTokens = maxTokens || contextWindow * 0.2
+
+	// Estimate tokens for the last message (which is always a user message)
+	const lastMessage = messages[messages.length - 1]
+	const lastMessageContent = lastMessage.content
+	const lastMessageTokens = Array.isArray(lastMessageContent)
+		? estimateTokenCount(lastMessageContent)
+		: estimateTokenCount([{ type: "text", text: lastMessageContent as string }])
+
+	// Calculate total effective tokens (totalTokens never includes the last message)
+	const effectiveTokens = totalTokens + lastMessageTokens
+
+	// Calculate available tokens for conversation history
+	const allowedTokens = contextWindow - reservedTokens
+
+	// Determine if truncation is needed and apply if necessary
+	return effectiveTokens < allowedTokens ? messages : truncateConversation(messages, 0.5)
 }