Sfoglia il codice sorgente

Improve sliding window algorithm to not break prompt cache so often

Saoud Rizwan 1 anno fa
parent
commit
a160e8d67b
2 ha cambiato i file con 49 aggiunte e 61 eliminazioni
  1. 11 3
      src/ClaudeDev.ts
  2. 38 58
      src/utils/context-management.ts

+ 11 - 3
src/ClaudeDev.ts

@@ -25,7 +25,7 @@ import { HistoryItem } from "./shared/HistoryItem"
 import { combineApiRequests } from "./shared/combineApiRequests"
 import { combineCommandSequences } from "./shared/combineCommandSequences"
 import { findLastIndex } from "./utils"
-import { slidingWindowContextManagement } from "./utils/context-management"
+import { isWithinContextWindow, truncateHalfConversation } from "./utils/context-management"
 
 const SYSTEM_PROMPT =
 	() => `You are Claude Dev, a highly skilled software developer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
@@ -1253,13 +1253,21 @@ The following additional instructions are provided by the user. They should be f
 ${this.customInstructions.trim()}
 `
 			}
-			const adjustedMessages = slidingWindowContextManagement(
+			const isPromptWithinContextWindow = isWithinContextWindow(
 				this.api.getModel().info.contextWindow,
+				systemPrompt,
+				tools,
+				this.apiConversationHistory
+			)
+			if (!isPromptWithinContextWindow) {
+				const truncatedMessages = truncateHalfConversation(this.apiConversationHistory)
+				await this.overwriteApiConversationHistory(truncatedMessages)
+			}
+			const { message, userCredits } = await this.api.createMessage(
 				systemPrompt,
 				this.apiConversationHistory,
 				tools
 			)
-			const { message, userCredits } = await this.api.createMessage(systemPrompt, adjustedMessages, tools)
 			if (userCredits !== undefined) {
 				console.log("Updating credits", userCredits)
 				// TODO: update credits

+ 38 - 58
src/utils/context-management.ts

@@ -2,74 +2,54 @@ import { Anthropic } from "@anthropic-ai/sdk"
 import { countTokens } from "@anthropic-ai/tokenizer"
 import { Buffer } from "buffer"
 import sizeOf from "image-size"
-import cloneDeep from "clone-deep"
 
-export function slidingWindowContextManagement(
+export function isWithinContextWindow(
 	contextWindow: number,
 	systemPrompt: string,
-	messages: Anthropic.Messages.MessageParam[],
-	tools: Anthropic.Messages.Tool[]
-): Anthropic.Messages.MessageParam[] {
+	tools: Anthropic.Messages.Tool[],
+	messages: Anthropic.Messages.MessageParam[]
+): boolean {
 	const adjustedContextWindow = contextWindow - 10_000 // Buffer to account for tokenizer differences
+	// counting tokens is expensive, so we first try to estimate before doing a more accurate calculation
+	const estimatedTotalMessageTokens = countTokens(systemPrompt + JSON.stringify(tools) + JSON.stringify(messages))
+	if (estimatedTotalMessageTokens <= adjustedContextWindow) {
+		return true
+	}
 	const systemPromptTokens = countTokens(systemPrompt)
 	const toolsTokens = countTokens(JSON.stringify(tools))
 	let availableTokens = adjustedContextWindow - systemPromptTokens - toolsTokens
-	let totalMessageTokens = messages.reduce((sum, message) => sum + countMessageTokens(message), 0)
+	let accurateTotalMessageTokens = messages.reduce((sum, message) => sum + countMessageTokens(message), 0)
+	return accurateTotalMessageTokens <= availableTokens
+}
 
-	if (totalMessageTokens <= availableTokens) {
-		return messages
-	}
+/*
+We can't implement a dynamically updating sliding window as it would break prompt cache
+every time. To maintain the benefits of caching, we need to keep conversation history
+static. This operation should be performed as infrequently as possible. If a user reaches
+a 200k context, we can assume that the first half is likely irrelevant to their current task.
+Therefore, this function should only be called when absolutely necessary to fit within
+context limits, not as a continuous process.
+*/
+export function truncateHalfConversation(
+	messages: Anthropic.Messages.MessageParam[]
+): Anthropic.Messages.MessageParam[] {
+	// Anthropic expects messages to be in user-assistant order, and tool use messages must be followed by tool results. We need to maintain this structure while truncating.
 
-	// If over limit, remove messages starting from the third message onwards (task and claude's step-by-step thought process are important to keep in context)
-	const newMessages = cloneDeep(messages) // since we're manipulating nested objects and arrays, need to deep clone to prevent mutating original history
-	let index = 2
-	while (totalMessageTokens > availableTokens && index < newMessages.length) {
-		const messageToEmpty = newMessages[index]
-		const originalTokens = countMessageTokens(messageToEmpty)
-		// Empty the content of the message (messages must be in a specific order so we can't just remove)
-		if (typeof messageToEmpty.content === "string") {
-			messageToEmpty.content = "(truncated due to context limits)"
-		} else if (Array.isArray(messageToEmpty.content)) {
-			messageToEmpty.content = messageToEmpty.content.map((item) => {
-				if (typeof item === "string") {
-					return {
-						type: "text",
-						text: "(truncated due to context limits)",
-					} as Anthropic.Messages.TextBlockParam
-				} else if (item.type === "text") {
-					return {
-						type: "text",
-						text: "(truncated due to context limits)",
-					} as Anthropic.Messages.TextBlockParam
-				} else if (item.type === "image") {
-					return {
-						type: "text",
-						text: "(image removed due to context limits)",
-					} as Anthropic.Messages.TextBlockParam
-				} else if (item.type === "tool_use") {
-					return { ...item, input: {} } as Anthropic.Messages.ToolUseBlockParam
-				} else if (item.type === "tool_result") {
-					return {
-						...item,
-						content: Array.isArray(item.content)
-							? item.content.map((contentItem) =>
-									contentItem.type === "text"
-										? { type: "text", text: "(truncated due to context limits)" }
-										: contentItem.type === "image"
-										? { type: "text", text: "(image removed due to context limits)" }
-										: contentItem
-							  )
-							: "(truncated due to context limits)",
-					} as Anthropic.Messages.ToolResultBlockParam
-				}
-				return item
-			})
-		}
-		const newTokens = countMessageTokens(messageToEmpty)
-		totalMessageTokens -= originalTokens - newTokens
-		index++
+	// Keep the first Task message (likely the most important)
+	const truncatedMessages = [messages[0]]
+
+	// Remove half of user-assistant pairs
+	const messagesToRemove = Math.floor(messages.length / 4) * 2 // has to be even number
+	const summaryMessage: Anthropic.Messages.MessageParam = {
+		role: "assistant",
+		content: `(${messagesToRemove} messages were truncated to fit within context limits)`,
 	}
-	return newMessages
+	truncatedMessages.push(summaryMessage)
+
+	const remainingMessages = messages.slice(messagesToRemove)
+	truncatedMessages.push(...remainingMessages)
+
+	return truncatedMessages
 }
 
 function countMessageTokens(message: Anthropic.Messages.MessageParam): number {