Просмотр исходного кода

Merge pull request #1289 from RooVetGit/truncation_token_buffer

Add a 5k token buffer before the end of the context window
Matt Rubens 10 месяцев назад
Родитель
Сommit
8df6bdf0a7

+ 38 - 16
src/core/sliding-window/__tests__/sliding-window.test.ts

@@ -3,7 +3,7 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 
 import { ModelInfo } from "../../../shared/api"
-import { estimateTokenCount, truncateConversation, truncateConversationIfNeeded } from "../index"
+import { TOKEN_BUFFER, estimateTokenCount, truncateConversation, truncateConversationIfNeeded } from "../index"
 
 /**
  * Tests for the truncateConversation function
@@ -121,10 +121,10 @@ describe("getMaxTokens", () => {
 		// Create messages with very small content in the last one to avoid token overflow
 		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
 
-		// Below max tokens - no truncation
+		// Below max tokens and buffer - no truncation
 		const result1 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 49999,
+			totalTokens: 44999, // Well below threshold + buffer
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -133,7 +133,7 @@ describe("getMaxTokens", () => {
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 50001,
+			totalTokens: 50001, // Above threshold
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -148,10 +148,10 @@ describe("getMaxTokens", () => {
 		// Create messages with very small content in the last one to avoid token overflow
 		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
 
-		// Below max tokens - no truncation
+		// Below max tokens and buffer - no truncation
 		const result1 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 79999,
+			totalTokens: 74999, // Well below threshold + buffer
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -160,7 +160,7 @@ describe("getMaxTokens", () => {
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 80001,
+			totalTokens: 80001, // Above threshold
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -175,10 +175,10 @@ describe("getMaxTokens", () => {
 		// Create messages with very small content in the last one to avoid token overflow
 		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
 
-		// Below max tokens - no truncation
+		// Below max tokens and buffer - no truncation
 		const result1 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 39999,
+			totalTokens: 34999, // Well below threshold + buffer
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -187,7 +187,7 @@ describe("getMaxTokens", () => {
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 40001,
+			totalTokens: 40001, // Above threshold
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -202,10 +202,10 @@ describe("getMaxTokens", () => {
 		// Create messages with very small content in the last one to avoid token overflow
 		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
 
-		// Below max tokens - no truncation
+		// Below max tokens and buffer - no truncation
 		const result1 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 169999,
+			totalTokens: 164999, // Well below threshold + buffer
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -214,7 +214,7 @@ describe("getMaxTokens", () => {
 		// Above max tokens - truncate
 		const result2 = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
-			totalTokens: 170001,
+			totalTokens: 170001, // Above threshold
 			contextWindow: modelInfo.contextWindow,
 			maxTokens: modelInfo.maxTokens,
 		})
@@ -244,7 +244,7 @@ describe("truncateConversationIfNeeded", () => {
 	it("should not truncate if tokens are below max tokens threshold", () => {
 		const modelInfo = createModelInfo(100000, true, 30000)
 		const maxTokens = 100000 - 30000 // 70000
-		const totalTokens = 69999 // Below threshold
+		const totalTokens = 64999 // Well below threshold + buffer
 
 		// Create messages with very small content in the last one to avoid token overflow
 		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
@@ -337,8 +337,8 @@ describe("truncateConversationIfNeeded", () => {
 			{ role: messages[messages.length - 1].role, content: smallContent },
 		]
 
-		// Set base tokens so total is below threshold even with small content added
-		const baseTokensForSmall = availableTokens - smallContentTokens - 10
+		// Set base tokens so total is well below threshold + buffer even with small content added
+		const baseTokensForSmall = availableTokens - smallContentTokens - TOKEN_BUFFER - 10
 		const resultWithSmall = truncateConversationIfNeeded({
 			messages: messagesWithSmallContent,
 			totalTokens: baseTokensForSmall,
@@ -388,7 +388,29 @@ describe("truncateConversationIfNeeded", () => {
 		})
 		expect(resultWithVeryLarge).not.toEqual(messagesWithVeryLargeContent) // Should truncate
 	})
+
+	it("should truncate if tokens are within TOKEN_BUFFER of the threshold", () => {
+		const modelInfo = createModelInfo(100000, true, 30000)
+		const maxTokens = 100000 - 30000 // 70000
+		const totalTokens = 66000 // Within 5000 of threshold (70000)
+
+		// Create messages with very small content in the last one to avoid token overflow
+		const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }]
+
+		// When truncating, always uses 0.5 fraction
+		// With 4 messages after the first, 0.5 fraction means remove 2 messages
+		const expectedResult = [messagesWithSmallContent[0], messagesWithSmallContent[3], messagesWithSmallContent[4]]
+
+		const result = truncateConversationIfNeeded({
+			messages: messagesWithSmallContent,
+			totalTokens,
+			contextWindow: modelInfo.contextWindow,
+			maxTokens: modelInfo.maxTokens,
+		})
+		expect(result).toEqual(expectedResult)
+	})
 })
+
 /**
  * Tests for the estimateTokenCount function
  */

+ 4 - 2
src/core/sliding-window/index.ts

@@ -3,7 +3,8 @@ import { Anthropic } from "@anthropic-ai/sdk"
 import { Tiktoken } from "js-tiktoken/lite"
 import o200kBase from "js-tiktoken/ranks/o200k_base"
 
-const TOKEN_FUDGE_FACTOR = 1.5
+export const TOKEN_FUDGE_FACTOR = 1.5
+export const TOKEN_BUFFER = 5000
 
 /**
  * Counts tokens for user content using tiktoken for text
@@ -110,5 +111,6 @@ export function truncateConversationIfNeeded({
 	const allowedTokens = contextWindow - reservedTokens
 
 	// Determine if truncation is needed and apply if necessary
-	return effectiveTokens < allowedTokens ? messages : truncateConversation(messages, 0.5)
+	// Truncate if we're within TOKEN_BUFFER of the limit
+	return effectiveTokens > allowedTokens - TOKEN_BUFFER ? truncateConversation(messages, 0.5) : messages
 }