|
|
@@ -5,6 +5,9 @@ import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
import { ModelInfo } from "../../../shared/api"
|
|
|
import { truncateConversation, truncateConversationIfNeeded } from "../index"
|
|
|
|
|
|
+/**
|
|
|
+ * Tests for the truncateConversation function
|
|
|
+ */
|
|
|
describe("truncateConversation", () => {
|
|
|
it("should retain the first message", () => {
|
|
|
const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
@@ -91,6 +94,86 @@ describe("truncateConversation", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
+/**
|
|
|
+ * Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded)
|
|
|
+ */
|
|
|
+describe("getMaxTokens", () => {
|
|
|
+ // We'll test this indirectly through truncateConversationIfNeeded
|
|
|
+ const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({
|
|
|
+ contextWindow,
|
|
|
+ supportsPromptCache: true, // Not relevant for getMaxTokens
|
|
|
+ maxTokens,
|
|
|
+ })
|
|
|
+
|
|
|
+ // Reuse across tests for consistency
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ { role: "user", content: "First message" },
|
|
|
+ { role: "assistant", content: "Second message" },
|
|
|
+ { role: "user", content: "Third message" },
|
|
|
+ { role: "assistant", content: "Fourth message" },
|
|
|
+ { role: "user", content: "Fifth message" },
|
|
|
+ ]
|
|
|
+
|
|
|
+ it("should use maxTokens as buffer when specified", () => {
|
|
|
+ const modelInfo = createModelInfo(100000, 50000)
|
|
|
+ // Max tokens = 100000 - 50000 = 50000
|
|
|
+
|
|
|
+ // Below max tokens - no truncation
|
|
|
+ const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
|
|
|
+ expect(result1).toEqual(messages)
|
|
|
+
|
|
|
+ // Above max tokens - truncate
|
|
|
+ const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
|
|
|
+ expect(result2).not.toEqual(messages)
|
|
|
+ expect(result2.length).toBe(3) // Truncated with 0.5 fraction
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should use 20% of context window as buffer when maxTokens is undefined", () => {
|
|
|
+ const modelInfo = createModelInfo(100000, undefined)
|
|
|
+ // Max tokens = 100000 - (100000 * 0.2) = 80000
|
|
|
+
|
|
|
+ // Below max tokens - no truncation
|
|
|
+ const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
|
|
|
+ expect(result1).toEqual(messages)
|
|
|
+
|
|
|
+ // Above max tokens - truncate
|
|
|
+ const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
|
|
|
+ expect(result2).not.toEqual(messages)
|
|
|
+ expect(result2.length).toBe(3) // Truncated with 0.5 fraction
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle small context windows appropriately", () => {
|
|
|
+ const modelInfo = createModelInfo(50000, 10000)
|
|
|
+ // Max tokens = 50000 - 10000 = 40000
|
|
|
+
|
|
|
+ // Below max tokens - no truncation
|
|
|
+ const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
|
|
|
+ expect(result1).toEqual(messages)
|
|
|
+
|
|
|
+ // Above max tokens - truncate
|
|
|
+ const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
|
|
|
+ expect(result2).not.toEqual(messages)
|
|
|
+ expect(result2.length).toBe(3) // Truncated with 0.5 fraction
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle large context windows appropriately", () => {
|
|
|
+ const modelInfo = createModelInfo(200000, 30000)
|
|
|
+ // Max tokens = 200000 - 30000 = 170000
|
|
|
+
|
|
|
+ // Below max tokens - no truncation
|
|
|
+ const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
|
|
|
+ expect(result1).toEqual(messages)
|
|
|
+
|
|
|
+ // Above max tokens - truncate
|
|
|
+ const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
|
|
|
+ expect(result2).not.toEqual(messages)
|
|
|
+ expect(result2.length).toBe(3) // Truncated with 0.5 fraction
|
|
|
+ })
|
|
|
+})
|
|
|
+
|
|
|
+/**
|
|
|
+ * Tests for the truncateConversationIfNeeded function
|
|
|
+ */
|
|
|
describe("truncateConversationIfNeeded", () => {
|
|
|
const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
|
|
|
contextWindow,
|
|
|
@@ -106,25 +189,43 @@ describe("truncateConversationIfNeeded", () => {
|
|
|
{ role: "user", content: "Fifth message" },
|
|
|
]
|
|
|
|
|
|
- it("should not truncate if tokens are below threshold for prompt caching models", () => {
|
|
|
- const modelInfo = createModelInfo(200000, true, 50000)
|
|
|
- const totalTokens = 100000 // Below threshold
|
|
|
+ it("should not truncate if tokens are below max tokens threshold", () => {
|
|
|
+ const modelInfo = createModelInfo(100000, true, 30000)
|
|
|
+ const maxTokens = 100000 - 30000 // 70000
|
|
|
+ const totalTokens = 69999 // Below threshold
|
|
|
+
|
|
|
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
|
|
|
- expect(result).toEqual(messages)
|
|
|
+ expect(result).toEqual(messages) // No truncation occurs
|
|
|
})
|
|
|
|
|
|
- it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
|
|
|
- const modelInfo = createModelInfo(200000, false)
|
|
|
- const totalTokens = 100000 // Below threshold
|
|
|
+ it("should truncate if tokens are above max tokens threshold", () => {
|
|
|
+ const modelInfo = createModelInfo(100000, true, 30000)
|
|
|
+ const maxTokens = 100000 - 30000 // 70000
|
|
|
+ const totalTokens = 70001 // Above threshold
|
|
|
+
|
|
|
+ // When truncating, always uses 0.5 fraction
|
|
|
+ // With 4 messages after the first, 0.5 fraction means remove 2 messages
|
|
|
+ const expectedResult = [messages[0], messages[3], messages[4]]
|
|
|
+
|
|
|
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
|
|
|
- expect(result).toEqual(messages)
|
|
|
+ expect(result).toEqual(expectedResult)
|
|
|
})
|
|
|
|
|
|
- it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
|
|
|
- const modelInfo = createModelInfo(50000, true) // Small context window
|
|
|
- const totalTokens = 40001 // Above 80% threshold (40000)
|
|
|
- const mockResult = [messages[0], messages[3], messages[4]]
|
|
|
- const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
|
|
|
- expect(result).toEqual(mockResult)
|
|
|
+ it("should work with non-prompt caching models the same as prompt caching models", () => {
|
|
|
+ // The implementation no longer differentiates between prompt caching and non-prompt caching models
|
|
|
+ const modelInfo1 = createModelInfo(100000, true, 30000)
|
|
|
+ const modelInfo2 = createModelInfo(100000, false, 30000)
|
|
|
+
|
|
|
+ // Test below threshold
|
|
|
+ const belowThreshold = 69999
|
|
|
+ expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
|
|
|
+ truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Test above threshold
|
|
|
+ const aboveThreshold = 70001
|
|
|
+ expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
|
|
|
+ truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
|
|
|
+ )
|
|
|
})
|
|
|
})
|