Kaynağa Gözat

Merge pull request #1194 from RooVetGit/fix_context_window_truncation_math

Fix context window truncation math
Matt Rubens 10 ay önce
ebeveyn
işleme
46576e00aa

+ 2 - 2
src/api/providers/glama.ts

@@ -69,7 +69,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
 		let maxTokens: number | undefined
 
 		if (this.getModel().id.startsWith("anthropic/")) {
-			maxTokens = 8_192
+			maxTokens = this.getModel().info.maxTokens
 		}
 
 		const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = {
@@ -177,7 +177,7 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
 			}
 
 			if (this.getModel().id.startsWith("anthropic/")) {
-				requestOptions.max_tokens = 8192
+				requestOptions.max_tokens = this.getModel().info.maxTokens
 			}
 
 			const response = await this.client.chat.completions.create(requestOptions)

+ 3 - 32
src/api/providers/openrouter.ts

@@ -54,20 +54,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 
 		// prompt caching: https://openrouter.ai/docs/prompt-caching
 		// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
-		switch (this.getModel().id) {
-			case "anthropic/claude-3.7-sonnet":
-			case "anthropic/claude-3.5-sonnet":
-			case "anthropic/claude-3.5-sonnet:beta":
-			case "anthropic/claude-3.5-sonnet-20240620":
-			case "anthropic/claude-3.5-sonnet-20240620:beta":
-			case "anthropic/claude-3-5-haiku":
-			case "anthropic/claude-3-5-haiku:beta":
-			case "anthropic/claude-3-5-haiku-20241022":
-			case "anthropic/claude-3-5-haiku-20241022:beta":
-			case "anthropic/claude-3-haiku":
-			case "anthropic/claude-3-haiku:beta":
-			case "anthropic/claude-3-opus":
-			case "anthropic/claude-3-opus:beta":
+		switch (true) {
+			case this.getModel().id.startsWith("anthropic/"):
 				openAiMessages[0] = {
 					role: "system",
 					content: [
@@ -103,23 +91,6 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 				break
 		}
 
-		// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
-		// (models usually default to max tokens allowed)
-		let maxTokens: number | undefined
-		switch (this.getModel().id) {
-			case "anthropic/claude-3.7-sonnet":
-			case "anthropic/claude-3.5-sonnet":
-			case "anthropic/claude-3.5-sonnet:beta":
-			case "anthropic/claude-3.5-sonnet-20240620":
-			case "anthropic/claude-3.5-sonnet-20240620:beta":
-			case "anthropic/claude-3-5-haiku":
-			case "anthropic/claude-3-5-haiku:beta":
-			case "anthropic/claude-3-5-haiku-20241022":
-			case "anthropic/claude-3-5-haiku-20241022:beta":
-				maxTokens = 8_192
-				break
-		}
-
 		let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
 		let topP: number | undefined = undefined
 
@@ -140,7 +111,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 		let fullResponseText = ""
 		const stream = await this.client.chat.completions.create({
 			model: this.getModel().id,
-			max_tokens: maxTokens,
+			max_tokens: this.getModel().info.maxTokens,
 			temperature: this.options.modelTemperature ?? defaultTemperature,
 			top_p: topP,
 			messages: openAiMessages,

+ 2 - 2
src/api/providers/unbound.ts

@@ -71,7 +71,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
 		let maxTokens: number | undefined
 
 		if (this.getModel().id.startsWith("anthropic/")) {
-			maxTokens = 8_192
+			maxTokens = this.getModel().info.maxTokens
 		}
 
 		const { data: completion, response } = await this.client.chat.completions
@@ -150,7 +150,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
 			}
 
 			if (this.getModel().id.startsWith("anthropic/")) {
-				requestOptions.max_tokens = 8192
+				requestOptions.max_tokens = this.getModel().info.maxTokens
 			}
 
 			const response = await this.client.chat.completions.create(requestOptions)

+ 115 - 14
src/core/sliding-window/__tests__/sliding-window.test.ts

@@ -5,6 +5,9 @@ import { Anthropic } from "@anthropic-ai/sdk"
 import { ModelInfo } from "../../../shared/api"
 import { truncateConversation, truncateConversationIfNeeded } from "../index"
 
+/**
+ * Tests for the truncateConversation function
+ */
 describe("truncateConversation", () => {
 	it("should retain the first message", () => {
 		const messages: Anthropic.Messages.MessageParam[] = [
@@ -91,6 +94,86 @@ describe("truncateConversation", () => {
 	})
 })
 
+/**
+ * Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
+ */
+describe("getMaxTokens", () => {
+	// We'll test this indirectly through truncateConversationIfNeeded
+	const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
+		contextWindow,
+		supportsPromptCache: true, // Not relevant for getMaxTokens
+		maxTokens,
+	})
+
+	// Reuse across tests for consistency
+	const messages: Anthropic.Messages.MessageParam[] = [
+		{ role: "user", content: "First message" },
+		{ role: "assistant", content: "Second message" },
+		{ role: "user", content: "Third message" },
+		{ role: "assistant", content: "Fourth message" },
+		{ role: "user", content: "Fifth message" },
+	]
+
+	it("should use maxTokens as buffer when specified", () => {
+		const modelInfo = createModelInfo(100000, 50000)
+		// Max tokens = 100000 - 50000 = 50000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+
+	it("should use 20% of context window as buffer when maxTokens is undefined", () => {
+		const modelInfo = createModelInfo(100000, undefined)
+		// Max tokens = 100000 - (100000 * 0.2) = 80000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+
+	it("should handle small context windows appropriately", () => {
+		const modelInfo = createModelInfo(50000, 10000)
+		// Max tokens = 50000 - 10000 = 40000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+
+	it("should handle large context windows appropriately", () => {
+		const modelInfo = createModelInfo(200000, 30000)
+		// Max tokens = 200000 - 30000 = 170000
+
+		// Below max tokens - no truncation
+		const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
+		expect(result1).toEqual(messages)
+
+		// Above max tokens - truncate
+		const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
+		expect(result2).not.toEqual(messages)
+		expect(result2.length).toBe(3) // Truncated with 0.5 fraction
+	})
+})
+
+/**
+ * Tests for the truncateConversationIfNeeded function
+ */
 describe("truncateConversationIfNeeded", () => {
 	const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
 		contextWindow,
@@ -106,25 +189,43 @@ describe("truncateConversationIfNeeded", () => {
 		{ role: "user", content: "Fifth message" },
 	]
 
-	it("should not truncate if tokens are below threshold for prompt caching models", () => {
-		const modelInfo = createModelInfo(200000, true, 50000)
-		const totalTokens = 100000 // Below threshold
+	it("should not truncate if tokens are below max tokens threshold", () => {
+		const modelInfo = createModelInfo(100000, true, 30000)
+		const maxTokens = 100000 - 30000 // 70000
+		const totalTokens = 69999 // Below threshold
+
 		const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
-		expect(result).toEqual(messages)
+		expect(result).toEqual(messages) // No truncation occurs
 	})
 
-	it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
-		const modelInfo = createModelInfo(200000, false)
-		const totalTokens = 100000 // Below threshold
+	it("should truncate if tokens are above max tokens threshold", () => {
+		const modelInfo = createModelInfo(100000, true, 30000)
+		const maxTokens = 100000 - 30000 // 70000
+		const totalTokens = 70001 // Above threshold
+
+		// When truncating, always uses 0.5 fraction
+		// With 4 messages after the first, 0.5 fraction means remove 2 messages
+		const expectedResult = [messages[0], messages[3], messages[4]]
+
 		const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
-		expect(result).toEqual(messages)
+		expect(result).toEqual(expectedResult)
 	})
 
-	it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
-		const modelInfo = createModelInfo(50000, true) // Small context window
-		const totalTokens = 40001 // Above 80% threshold (40000)
-		const mockResult = [messages[0], messages[3], messages[4]]
-		const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
-		expect(result).toEqual(mockResult)
+	it("should work with non-prompt caching models the same as prompt caching models", () => {
+		// The implementation no longer differentiates between prompt caching and non-prompt caching models
+		const modelInfo1 = createModelInfo(100000, true, 30000)
+		const modelInfo2 = createModelInfo(100000, false, 30000)
+
+		// Test below threshold
+		const belowThreshold = 69999
+		expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
+			truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
+		)
+
+		// Test above threshold
+		const aboveThreshold = 70001
+		expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
+			truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
+		)
 	})
 })

+ 7 - 56
src/core/sliding-window/index.ts

@@ -28,13 +28,9 @@ export function truncateConversation(
 /**
  * Conditionally truncates the conversation messages if the total token count exceeds the model's limit.
  *
- * Depending on whether the model supports prompt caching, different maximum token thresholds
- * and truncation fractions are used. If the current total tokens exceed the threshold,
- * the conversation is truncated using the appropriate fraction.
- *
  * @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages.
  * @param {number} totalTokens - The total number of tokens in the conversation.
- * @param {ModelInfo} modelInfo - Model metadata including context window size and prompt cache support.
+ * @param {ModelInfo} modelInfo - Model metadata including context window size.
  * @returns {Anthropic.Messages.MessageParam[]} The original or truncated conversation messages.
  */
 export function truncateConversationIfNeeded(
@@ -42,61 +38,16 @@ export function truncateConversationIfNeeded(
 	totalTokens: number,
 	modelInfo: ModelInfo,
 ): Anthropic.Messages.MessageParam[] {
-	if (modelInfo.supportsPromptCache) {
-		return totalTokens < getMaxTokensForPromptCachingModels(modelInfo)
-			? messages
-			: truncateConversation(messages, getTruncFractionForPromptCachingModels(modelInfo))
-	} else {
-		return totalTokens < getMaxTokensForNonPromptCachingModels(modelInfo)
-			? messages
-			: truncateConversation(messages, getTruncFractionForNonPromptCachingModels(modelInfo))
-	}
+	return totalTokens < getMaxTokens(modelInfo) ? messages : truncateConversation(messages, 0.5)
 }
 
 /**
- * Calculates the maximum allowed tokens for models that support prompt caching.
- *
- * The maximum is computed as the greater of (contextWindow - buffer) and 80% of the contextWindow.
+ * Calculates the maximum allowed tokens
  *
  * @param {ModelInfo} modelInfo - The model information containing the context window size.
- * @returns {number} The maximum number of tokens allowed for prompt caching models.
- */
-function getMaxTokensForPromptCachingModels(modelInfo: ModelInfo): number {
-	// The buffer needs to be at least as large as `modelInfo.maxTokens`.
-	const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
-	return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
-}
-
-/**
- * Provides the fraction of messages to remove for models that support prompt caching.
- *
- * @param {ModelInfo} modelInfo - The model information (unused in current implementation).
- * @returns {number} The truncation fraction for prompt caching models (fixed at 0.5).
- */
-function getTruncFractionForPromptCachingModels(modelInfo: ModelInfo): number {
-	return 0.5
-}
-
-/**
- * Calculates the maximum allowed tokens for models that do not support prompt caching.
- *
- * The maximum is computed as the greater of (contextWindow - 40000) and 80% of the contextWindow.
- *
- * @param {ModelInfo} modelInfo - The model information containing the context window size.
- * @returns {number} The maximum number of tokens allowed for non-prompt caching models.
- */
-function getMaxTokensForNonPromptCachingModels(modelInfo: ModelInfo): number {
-	// The buffer needs to be at least as large as `modelInfo.maxTokens`.
-	const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
-	return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
-}
-
-/**
- * Provides the fraction of messages to remove for models that do not support prompt caching.
- *
- * @param {ModelInfo} modelInfo - The model information.
- * @returns {number} The truncation fraction for non-prompt caching models (fixed at 0.1).
+ * @returns {number} The maximum number of tokens allowed
  */
-function getTruncFractionForNonPromptCachingModels(modelInfo: ModelInfo): number {
-	return Math.min(40_000 / modelInfo.contextWindow, 0.2)
+function getMaxTokens(modelInfo: ModelInfo): number {
+	// The buffer needs to be at least as large as `modelInfo.maxTokens`, or 20% of the context window if for some reason it's not set.
+	return modelInfo.contextWindow - Math.max(modelInfo.maxTokens || modelInfo.contextWindow * 0.2)
 }

+ 56 - 21
src/core/webview/ClineProvider.ts

@@ -1926,6 +1926,17 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 						cacheReadsPrice: parsePrice(rawModel.cached_price),
 					}
 
+					switch (rawModel.id) {
+						case rawModel.id.startsWith("anthropic/claude-3-7-sonnet"):
+							modelInfo.maxTokens = 16384
+							break
+						case rawModel.id.startsWith("anthropic/"):
+							modelInfo.maxTokens = 8192
+							break
+						default:
+							break
+					}
+
 					models[rawModel.id] = modelInfo
 				}
 			} else {
@@ -2076,6 +2087,17 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 						cacheReadsPrice: parsePrice(rawModel.pricePerToken?.cacheRead),
 					}
 
+					switch (rawModel.id) {
+						case rawModel.id.startsWith("anthropic/claude-3-7-sonnet"):
+							modelInfo.maxTokens = 16384
+							break
+						case rawModel.id.startsWith("anthropic/"):
+							modelInfo.maxTokens = 8192
+							break
+						default:
+							break
+					}
+
 					models[rawModel.id] = modelInfo
 				}
 			} else {
@@ -2127,46 +2149,46 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 						description: rawModel.description,
 					}
 
-					switch (rawModel.id) {
-						case "anthropic/claude-3.7-sonnet":
-						case "anthropic/claude-3.7-sonnet:beta":
-						case "anthropic/claude-3.5-sonnet":
-						case "anthropic/claude-3.5-sonnet:beta":
-							// NOTE: this needs to be synced with api.ts/openrouter default model info.
+					// NOTE: this needs to be synced with api.ts/openrouter default model info.
+					switch (true) {
+						case rawModel.id.startsWith("anthropic/claude-3.7-sonnet"):
 							modelInfo.supportsComputerUse = true
 							modelInfo.supportsPromptCache = true
 							modelInfo.cacheWritesPrice = 3.75
 							modelInfo.cacheReadsPrice = 0.3
+							modelInfo.maxTokens = 16384
 							break
-						case "anthropic/claude-3.5-sonnet-20240620":
-						case "anthropic/claude-3.5-sonnet-20240620:beta":
+						case rawModel.id.startsWith("anthropic/claude-3.5-sonnet-20240620"):
 							modelInfo.supportsPromptCache = true
 							modelInfo.cacheWritesPrice = 3.75
 							modelInfo.cacheReadsPrice = 0.3
+							modelInfo.maxTokens = 8192
 							break
-						case "anthropic/claude-3-5-haiku":
-						case "anthropic/claude-3-5-haiku:beta":
-						case "anthropic/claude-3-5-haiku-20241022":
-						case "anthropic/claude-3-5-haiku-20241022:beta":
-						case "anthropic/claude-3.5-haiku":
-						case "anthropic/claude-3.5-haiku:beta":
-						case "anthropic/claude-3.5-haiku-20241022":
-						case "anthropic/claude-3.5-haiku-20241022:beta":
+						case rawModel.id.startsWith("anthropic/claude-3.5-sonnet"):
+							modelInfo.supportsComputerUse = true
+							modelInfo.supportsPromptCache = true
+							modelInfo.cacheWritesPrice = 3.75
+							modelInfo.cacheReadsPrice = 0.3
+							modelInfo.maxTokens = 8192
+							break
+						case rawModel.id.startsWith("anthropic/claude-3-5-haiku"):
 							modelInfo.supportsPromptCache = true
 							modelInfo.cacheWritesPrice = 1.25
 							modelInfo.cacheReadsPrice = 0.1
+							modelInfo.maxTokens = 8192
 							break
-						case "anthropic/claude-3-opus":
-						case "anthropic/claude-3-opus:beta":
+						case rawModel.id.startsWith("anthropic/claude-3-opus"):
 							modelInfo.supportsPromptCache = true
 							modelInfo.cacheWritesPrice = 18.75
 							modelInfo.cacheReadsPrice = 1.5
+							modelInfo.maxTokens = 8192
 							break
-						case "anthropic/claude-3-haiku":
-						case "anthropic/claude-3-haiku:beta":
+						case rawModel.id.startsWith("anthropic/claude-3-haiku"):
+						default:
 							modelInfo.supportsPromptCache = true
 							modelInfo.cacheWritesPrice = 0.3
 							modelInfo.cacheReadsPrice = 0.03
+							modelInfo.maxTokens = 8192
 							break
 					}
 
@@ -2200,7 +2222,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			if (response.data) {
 				const rawModels: Record<string, any> = response.data
 				for (const [modelId, model] of Object.entries(rawModels)) {
-					models[modelId] = {
+					const modelInfo: ModelInfo = {
 						maxTokens: model?.maxTokens ? parseInt(model.maxTokens) : undefined,
 						contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0,
 						supportsImages: model?.supportsImages ?? false,
@@ -2211,6 +2233,19 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 						cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined,
 						cacheReadsPrice: model?.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined,
 					}
+
+					switch (true) {
+						case modelId.startsWith("anthropic/claude-3-7-sonnet"):
+							modelInfo.maxTokens = 16384
+							break
+						case modelId.startsWith("anthropic/"):
+							modelInfo.maxTokens = 8192
+							break
+						default:
+							break
+					}
+
+					models[modelId] = modelInfo
 				}
 			}
 			await fs.writeFile(unboundModelsFilePath, JSON.stringify(models))

+ 1 - 1
src/shared/api.ts

@@ -97,7 +97,7 @@ export type AnthropicModelId = keyof typeof anthropicModels
 export const anthropicDefaultModelId: AnthropicModelId = "claude-3-7-sonnet-20250219"
 export const anthropicModels = {
 	"claude-3-7-sonnet-20250219": {
-		maxTokens: 64_000,
+		maxTokens: 16384,
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsComputerUse: true,