Kaynağa Gözat

Simplify the context truncation math

Matt Rubens 10 ay önce
ebeveyn
işleme
d8cafbc67e

+ 115 - 14
src/core/sliding-window/__tests__/sliding-window.test.ts

@@ -5,6 +5,9 @@ import { Anthropic } from "@anthropic-ai/sdk"
 import { ModelInfo } from "../../../shared/api"
 import { truncateConversation, truncateConversationIfNeeded } from "../index"
 
+/**
+ * Tests for the truncateConversation function
+ */
 describe("truncateConversation", () => {
 	it("should retain the first message", () => {
 		const messages: Anthropic.Messages.MessageParam[] = [
@@ -91,6 +94,86 @@ describe("truncateConversation", () => {
 	})
 })
 
+/**
+ * Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
+ */
+describe("getMaxTokens", () => {
+	// We'll test this indirectly through truncateConversationIfNeeded
+	const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
+		contextWindow,
+		supportsPromptCache: true, // Not relevant for getMaxTokens
+		maxTokens,
+	})
+
+	// Reuse across tests for consistency
+	const messages: Anthropic.Messages.MessageParam[] = [
+		{ role: "user", content: "First message" },
+		{ role: "assistant", content: "Second message" },
+		{ role: "user", content: "Third message" },
+		{ role: "assistant", content: "Fourth message" },
+		{ role: "user", content: "Fifth message" },
+	]
+
+	it("should use maxTokens as buffer when specified", () => {
+		const modelInfo = createModelInfo(100000, 50000)
+		// Max tokens = 100000 - 50000 = 50000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+
+	it("should use 20% of context window as buffer when maxTokens is undefined", () => {
+		const modelInfo = createModelInfo(100000, undefined)
+		// Max tokens = 100000 - (100000 * 0.2) = 80000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+
+	it("should handle small context windows appropriately", () => {
+		const modelInfo = createModelInfo(50000, 10000)
+		// Max tokens = 50000 - 10000 = 40000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+
+	it("should handle large context windows appropriately", () => {
+		const modelInfo = createModelInfo(200000, 30000)
+		// Max tokens = 200000 - 30000 = 170000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+})
+
+/**
+ * Tests for the truncateConversationIfNeeded function
+ */
 describe("truncateConversationIfNeeded", () => {
 	const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
 		contextWindow,
@@ -106,25 +189,43 @@ describe("truncateConversationIfNeeded", () => {
 		{ role: "user", content: "Fifth message" },
 	]
 
-	it("should not truncate if tokens are below threshold for prompt caching models", () => {
-		const modelInfo = createModelInfo(200000, true, 50000)
-		const totalTokens = 100000 // Below threshold
+	it("should not truncate if tokens are below max tokens threshold", () => {
+		const modelInfo = createModelInfo(100000, true, 30000)
+		const maxTokens = 100000 - 30000 // 70000
+		const totalTokens = 69999 // Below threshold
+
 		const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
-		expect(result).toEqual(messages)
+		expect(result).toEqual(messages) // No truncation occurs
 	})
 
-	it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
-		const modelInfo = createModelInfo(200000, false)
-		const totalTokens = 100000 // Below threshold
+	it("should truncate if tokens are above max tokens threshold", () => {
+		const modelInfo = createModelInfo(100000, true, 30000)
+		const maxTokens = 100000 - 30000 // 70000
+		const totalTokens = 70001 // Above threshold
+
+		// When truncating, always uses 0.5 fraction
+		// With 4 messages after the first, 0.5 fraction means remove 2 messages
+		const expectedResult = [messages[0], messages[3], messages[4]]
+
 		const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
-		expect(result).toEqual(messages)
+		expect(result).toEqual(expectedResult)
 	})
 
-	it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
-		const modelInfo = createModelInfo(50000, true) // Small context window
-		const totalTokens = 40001 // Above 80% threshold (40000)
-		const mockResult = [messages[0], messages[3], messages[4]]
-		const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
-		expect(result).toEqual(mockResult)
+	it("should work with non-prompt caching models the same as prompt caching models", () => {
+		// The implementation no longer differentiates between prompt caching and non-prompt caching models
+		const modelInfo1 = createModelInfo(100000, true, 30000)
+		const modelInfo2 = createModelInfo(100000, false, 30000)
+
+		// Test below threshold
+		const belowThreshold = 69999
+		expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
+			truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
+		)
+
+		// Test above threshold
+		const aboveThreshold = 70001
+		expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
+			truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
+		)
 	})
 })

+ 7 - 56
src/core/sliding-window/index.ts

@@ -28,13 +28,9 @@ export function truncateConversation(
 /**
  * Conditionally truncates the conversation messages if the total token count exceeds the model's limit.
  *
- * Depending on whether the model supports prompt caching, different maximum token thresholds
- * and truncation fractions are used. If the current total tokens exceed the threshold,
- * the conversation is truncated using the appropriate fraction.
- *
  * @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages.
  * @param {number} totalTokens - The total number of tokens in the conversation.
- * @param {ModelInfo} modelInfo - Model metadata including context window size and prompt cache support.
+ * @param {ModelInfo} modelInfo - Model metadata including context window size.
  * @returns {Anthropic.Messages.MessageParam[]} The original or truncated conversation messages.
  */
 export function truncateConversationIfNeeded(
@@ -42,61 +38,16 @@ export function truncateConversationIfNeeded(
 	totalTokens: number,
 	modelInfo: ModelInfo,
 ): Anthropic.Messages.MessageParam[] {
-	if (modelInfo.supportsPromptCache) {
-		return totalTokens < getMaxTokensForPromptCachingModels(modelInfo)
-			? messages
-			: truncateConversation(messages, getTruncFractionForPromptCachingModels(modelInfo))
-	} else {
-		return totalTokens < getMaxTokensForNonPromptCachingModels(modelInfo)
-			? messages
-			: truncateConversation(messages, getTruncFractionForNonPromptCachingModels(modelInfo))
-	}
+	return totalTokens < getMaxTokens(modelInfo) ? messages : truncateConversation(messages, 0.5)
 }
 
 /**
- * Calculates the maximum allowed tokens for models that support prompt caching.
- *
- * The maximum is computed as the greater of (contextWindow - buffer) and 80% of the contextWindow.
+ * Calculates the maximum allowed tokens
  *
  * @param {ModelInfo} modelInfo - The model information containing the context window size.
- * @returns {number} The maximum number of tokens allowed for prompt caching models.
- */
-function getMaxTokensForPromptCachingModels(modelInfo: ModelInfo): number {
-	// The buffer needs to be at least as large as `modelInfo.maxTokens`.
-	const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
-	return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
-}
-
-/**
- * Provides the fraction of messages to remove for models that support prompt caching.
- *
- * @param {ModelInfo} modelInfo - The model information (unused in current implementation).
- * @returns {number} The truncation fraction for prompt caching models (fixed at 0.5).
- */
-function getTruncFractionForPromptCachingModels(modelInfo: ModelInfo): number {
-	return 0.5
-}
-
-/**
- * Calculates the maximum allowed tokens for models that do not support prompt caching.
- *
- * The maximum is computed as the greater of (contextWindow - 40000) and 80% of the contextWindow.
- *
- * @param {ModelInfo} modelInfo - The model information containing the context window size.
- * @returns {number} The maximum number of tokens allowed for non-prompt caching models.
- */
-function getMaxTokensForNonPromptCachingModels(modelInfo: ModelInfo): number {
-	// The buffer needs to be at least as large as `modelInfo.maxTokens`.
-	const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
-	return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
-}
-
-/**
- * Provides the fraction of messages to remove for models that do not support prompt caching.
- *
- * @param {ModelInfo} modelInfo - The model information.
- * @returns {number} The truncation fraction for non-prompt caching models (fixed at 0.1).
+ * @returns {number} The maximum number of tokens allowed
  */
-function getTruncFractionForNonPromptCachingModels(modelInfo: ModelInfo): number {
-	return Math.min(40_000 / modelInfo.contextWindow, 0.2)
+function getMaxTokens(modelInfo: ModelInfo): number {
+	// The buffer needs to be at least as large as `modelInfo.maxTokens`, or 20% of the context window if for some reason it's not set.
+	return modelInfo.contextWindow - Math.max(modelInfo.maxTokens || modelInfo.contextWindow * 0.2)
 }