Pārlūkot izejas kodu

fix: correct token counting for context truncation display (#9961)

Hannes Rudolph 2 nedēļas atpakaļ
vecāks
revīzija
f414ba41a6

+ 93 - 0
src/core/context-management/__tests__/context-management.spec.ts

@@ -1407,4 +1407,97 @@ describe("Context Management", () => {
 			expect(resultWithLastMessage).toBe(true)
 		})
 	})
+
+	/**
+	 * Tests for newContextTokensAfterTruncation including system prompt
+	 */
+	describe("newContextTokensAfterTruncation", () => {
+		const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
+			contextWindow,
+			supportsPromptCache: true,
+			maxTokens,
+		})
+
+		it("should include system prompt tokens in newContextTokensAfterTruncation", async () => {
+			const modelInfo = createModelInfo(100000, 30000)
+			const totalTokens = 70001 // Above threshold to trigger truncation
+
+			const messages: ApiMessage[] = [
+				{ role: "user", content: "First message" },
+				{ role: "assistant", content: "Second message" },
+				{ role: "user", content: "Third message" },
+				{ role: "assistant", content: "Fourth message" },
+				{ role: "user", content: "" }, // Small content in last message
+			]
+
+			const systemPrompt = "You are a helpful assistant. Follow these rules carefully."
+
+			const result = await manageContext({
+				messages,
+				totalTokens,
+				contextWindow: modelInfo.contextWindow,
+				maxTokens: modelInfo.maxTokens,
+				apiHandler: mockApiHandler,
+				autoCondenseContext: false,
+				autoCondenseContextPercent: 100,
+				systemPrompt,
+				taskId,
+				profileThresholds: {},
+				currentProfileId: "default",
+			})
+
+			// Should have truncation
+			expect(result.truncationId).toBeDefined()
+			expect(result.newContextTokensAfterTruncation).toBeDefined()
+
+			// The newContextTokensAfterTruncation should include system prompt tokens
+			// Count system prompt tokens to verify
+			const systemPromptTokens = await estimateTokenCount([{ type: "text", text: systemPrompt }], mockApiHandler)
+			expect(systemPromptTokens).toBeGreaterThan(0)
+
+			// newContextTokensAfterTruncation should be >= system prompt tokens
+			// (since it includes system prompt + remaining message tokens)
+			expect(result.newContextTokensAfterTruncation).toBeGreaterThanOrEqual(systemPromptTokens)
+		})
+
+		it("should produce consistent prev vs new token comparison (both including system prompt)", async () => {
+			const modelInfo = createModelInfo(100000, 30000)
+			const totalTokens = 70001 // Above threshold to trigger truncation
+
+			const messages: ApiMessage[] = [
+				{ role: "user", content: "First message" },
+				{ role: "assistant", content: "Second message" },
+				{ role: "user", content: "Third message" },
+				{ role: "assistant", content: "Fourth message" },
+				{ role: "user", content: "" }, // Small content in last message
+			]
+
+			const systemPrompt = "System prompt for testing"
+
+			const result = await manageContext({
+				messages,
+				totalTokens,
+				contextWindow: modelInfo.contextWindow,
+				maxTokens: modelInfo.maxTokens,
+				apiHandler: mockApiHandler,
+				autoCondenseContext: false,
+				autoCondenseContextPercent: 100,
+				systemPrompt,
+				taskId,
+				profileThresholds: {},
+				currentProfileId: "default",
+			})
+
+			// After truncation, newContextTokensAfterTruncation should be less than prevContextTokens
+			// because we removed some messages
+			expect(result.newContextTokensAfterTruncation).toBeDefined()
+			expect(result.newContextTokensAfterTruncation).toBeLessThan(result.prevContextTokens)
+
+			// But newContextTokensAfterTruncation should still be a reasonable value
+			// (not near-zero like the bug showed) - it should be at least
+			// a significant fraction of prevContextTokens after 50% truncation
+			// With system prompt included, we expect roughly 50% of the messages remaining
+			expect(result.newContextTokensAfterTruncation).toBeGreaterThan(0)
+		})
+	})
 })

+ 8 - 1
src/core/context-management/index.ts

@@ -323,7 +323,14 @@ export async function manageContext({
 		const effectiveMessages = truncationResult.messages.filter(
 			(msg) => !msg.truncationParent && !msg.isTruncationMarker,
 		)
-		let newContextTokensAfterTruncation = 0
+
+		// Include system prompt tokens so this value matches what we send to the API.
+		// Note: `prevContextTokens` is computed locally here (totalTokens + lastMessageTokens).
+		let newContextTokensAfterTruncation = await estimateTokenCount(
+			[{ type: "text", text: systemPrompt }],
+			apiHandler,
+		)
+
 		for (const msg of effectiveMessages) {
 			const content = msg.content
 			if (Array.isArray(content)) {

+ 159 - 0
src/utils/__tests__/tiktoken.spec.ts

@@ -134,4 +134,163 @@ describe("tiktoken", () => {
 		// Both calls should return the same token count
 		expect(result1).toBe(result2)
 	})
+
+	describe("tool_use blocks", () => {
+		it("should count tokens for tool_use blocks with simple arguments", async () => {
+			const content = [
+				{
+					type: "tool_use",
+					id: "tool_123",
+					name: "read_file",
+					input: { path: "/src/main.ts" },
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should return a positive token count for the serialized tool call
+			expect(result).toBeGreaterThan(0)
+		})
+
+		it("should count tokens for tool_use blocks with complex arguments", async () => {
+			const content = [
+				{
+					type: "tool_use",
+					id: "tool_456",
+					name: "write_to_file",
+					input: {
+						path: "/src/components/Button.tsx",
+						content:
+							"import React from 'react';\n\nexport const Button = ({ children, onClick }) => {\n  return <button onClick={onClick}>{children}</button>;\n};",
+					},
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should return a token count reflecting the larger content
+			expect(result).toBeGreaterThan(10)
+		})
+
+		it("should handle tool_use blocks with empty input", async () => {
+			const content = [
+				{
+					type: "tool_use",
+					id: "tool_789",
+					name: "list_files",
+					input: {},
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should still count the tool name (and empty args)
+			expect(result).toBeGreaterThan(0)
+		})
+	})
+
+	describe("tool_result blocks", () => {
+		it("should count tokens for tool_result blocks with string content", async () => {
+			const content = [
+				{
+					type: "tool_result",
+					tool_use_id: "tool_123",
+					content: "File content: export const foo = 'bar';",
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should return a positive token count
+			expect(result).toBeGreaterThan(0)
+		})
+
+		it("should count tokens for tool_result blocks with array content", async () => {
+			const content = [
+				{
+					type: "tool_result",
+					tool_use_id: "tool_456",
+					content: [
+						{ type: "text", text: "First part of the result" },
+						{ type: "text", text: "Second part of the result" },
+					],
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should count tokens from all text parts
+			expect(result).toBeGreaterThan(0)
+		})
+
+		it("should count tokens for tool_result blocks with error flag", async () => {
+			const content = [
+				{
+					type: "tool_result",
+					tool_use_id: "tool_789",
+					is_error: true,
+					content: "Error: File not found",
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should include the error indicator and content
+			expect(result).toBeGreaterThan(0)
+		})
+
+		it("should handle tool_result blocks with image content in array", async () => {
+			const content = [
+				{
+					type: "tool_result",
+					tool_use_id: "tool_abc",
+					content: [
+						{ type: "text", text: "Screenshot captured" },
+						{ type: "image", source: { type: "base64", media_type: "image/png", data: "abc123" } },
+					],
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should count text and include placeholder for images
+			expect(result).toBeGreaterThan(0)
+		})
+	})
+
+	describe("mixed content with tools", () => {
+		it("should count tokens for conversation with tool_use and tool_result", async () => {
+			const content = [
+				{ type: "text", text: "Let me read that file for you." },
+				{
+					type: "tool_use",
+					id: "tool_123",
+					name: "read_file",
+					input: { path: "/src/index.ts" },
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const result = await tiktoken(content)
+			// Should count both text and tool_use tokens
+			expect(result).toBeGreaterThan(5)
+		})
+
+		it("should produce larger count for tool_result with large content vs small content", async () => {
+			const smallContent = [
+				{
+					type: "tool_result",
+					tool_use_id: "tool_1",
+					content: "OK",
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const largeContent = [
+				{
+					type: "tool_result",
+					tool_use_id: "tool_2",
+					content:
+						"This is a much longer result that contains a lot more text and should therefore have a significantly higher token count than the small content.",
+				},
+			] as Anthropic.Messages.ContentBlockParam[]
+
+			const smallResult = await tiktoken(smallContent)
+			const largeResult = await tiktoken(largeContent)
+
+			// Large content should have more tokens
+			expect(largeResult).toBeGreaterThan(smallResult)
+		})
+	})
 })

+ 60 - 0
src/utils/tiktoken.ts

@@ -6,6 +6,52 @@ const TOKEN_FUDGE_FACTOR = 1.5
 
 let encoder: Tiktoken | null = null
 
+/**
+ * Serializes a tool_use block to text for token counting.
+ * Approximates how the API sees the tool call.
+ */
+function serializeToolUse(block: Anthropic.Messages.ToolUseBlockParam): string {
+	const parts = [`Tool: ${block.name}`]
+	if (block.input !== undefined) {
+		try {
+			parts.push(`Arguments: ${JSON.stringify(block.input)}`)
+		} catch {
+			parts.push(`Arguments: [serialization error]`)
+		}
+	}
+	return parts.join("\n")
+}
+
+/**
+ * Serializes a tool_result block to text for token counting.
+ * Handles both string content and array content.
+ */
+function serializeToolResult(block: Anthropic.Messages.ToolResultBlockParam): string {
+	const parts = [`Tool Result (${block.tool_use_id})`]
+
+	if (block.is_error) {
+		parts.push(`[Error]`)
+	}
+
+	const content = block.content
+	if (typeof content === "string") {
+		parts.push(content)
+	} else if (Array.isArray(content)) {
+		// Handle array of content blocks recursively
+		for (const item of content) {
+			if (item.type === "text") {
+				parts.push(item.text || "")
+			} else if (item.type === "image") {
+				parts.push("[Image content]")
+			} else {
+				parts.push(`[Unsupported content block: ${String((item as { type?: unknown }).type)}]`)
+			}
+		}
+	}
+
+	return parts.join("\n")
+}
+
 export async function tiktoken(content: Anthropic.Messages.ContentBlockParam[]): Promise<number> {
 	if (content.length === 0) {
 		return 0
@@ -37,6 +83,20 @@ export async function tiktoken(content: Anthropic.Messages.ContentBlockParam[]):
 			} else {
 				totalTokens += 300 // Conservative estimate for unknown images
 			}
+		} else if (block.type === "tool_use") {
+			// Serialize tool_use block to text and count tokens
+			const serialized = serializeToolUse(block as Anthropic.Messages.ToolUseBlockParam)
+			if (serialized.length > 0) {
+				const tokens = encoder.encode(serialized, undefined, [])
+				totalTokens += tokens.length
+			}
+		} else if (block.type === "tool_result") {
+			// Serialize tool_result block to text and count tokens
+			const serialized = serializeToolResult(block as Anthropic.Messages.ToolResultBlockParam)
+			if (serialized.length > 0) {
+				const tokens = encoder.encode(serialized, undefined, [])
+				totalTokens += tokens.length
+			}
 		}
 	}