|
@@ -1,1112 +0,0 @@
|
|
|
-import { ContentBlock, SystemContentBlock, BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
|
|
|
|
-import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
|
|
-
|
|
|
|
|
-import { MultiPointStrategy } from "../multi-point-strategy"
|
|
|
|
|
-import { CacheStrategyConfig, ModelInfo, CachePointPlacement } from "../types"
|
|
|
|
|
-import { AwsBedrockHandler } from "../../../providers/bedrock"
|
|
|
|
|
-
|
|
|
|
|
-// Common test utilities
|
|
|
|
|
-const defaultModelInfo: ModelInfo = {
|
|
|
|
|
- maxTokens: 8192,
|
|
|
|
|
- contextWindow: 200_000,
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- maxCachePoints: 4,
|
|
|
|
|
- minTokensPerCachePoint: 50,
|
|
|
|
|
- cachableFields: ["system", "messages", "tools"],
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-const createConfig = (overrides: Partial<CacheStrategyConfig> = {}): CacheStrategyConfig => ({
|
|
|
|
|
- modelInfo: {
|
|
|
|
|
- ...defaultModelInfo,
|
|
|
|
|
- ...(overrides.modelInfo || {}),
|
|
|
|
|
- },
|
|
|
|
|
- systemPrompt: "You are a helpful assistant",
|
|
|
|
|
- messages: [],
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- ...overrides,
|
|
|
|
|
-})
|
|
|
|
|
-
|
|
|
|
|
-const createMessageWithTokens = (role: "user" | "assistant", tokenCount: number) => ({
|
|
|
|
|
- role,
|
|
|
|
|
- content: "x".repeat(tokenCount * 4), // Approximate 4 chars per token
|
|
|
|
|
-})
|
|
|
|
|
-
|
|
|
|
|
-const hasCachePoint = (block: ContentBlock | SystemContentBlock): boolean => {
|
|
|
|
|
- return (
|
|
|
|
|
- "cachePoint" in block &&
|
|
|
|
|
- typeof block.cachePoint === "object" &&
|
|
|
|
|
- block.cachePoint !== null &&
|
|
|
|
|
- "type" in block.cachePoint &&
|
|
|
|
|
- block.cachePoint.type === "default"
|
|
|
|
|
- )
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-// Create a mock object to store the last config passed to convertToBedrockConverseMessages
|
|
|
|
|
-interface CacheConfig {
|
|
|
|
|
- modelInfo: any
|
|
|
|
|
- systemPrompt?: string
|
|
|
|
|
- messages: any[]
|
|
|
|
|
- usePromptCache: boolean
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-const convertToBedrockConverseMessagesMock = {
|
|
|
|
|
- lastConfig: null as CacheConfig | null,
|
|
|
|
|
- result: null as any,
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-describe("Cache Strategy", () => {
|
|
|
|
|
- // SECTION 1: Direct Strategy Implementation Tests
|
|
|
|
|
- describe("Strategy Implementation", () => {
|
|
|
|
|
- describe("Strategy Selection", () => {
|
|
|
|
|
- it("should use MultiPointStrategy when caching is not supported", () => {
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: { ...defaultModelInfo, supportsPromptCache: false },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- expect(strategy).toBeInstanceOf(MultiPointStrategy)
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should use MultiPointStrategy when caching is disabled", () => {
|
|
|
|
|
- const config = createConfig({ usePromptCache: false })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- expect(strategy).toBeInstanceOf(MultiPointStrategy)
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should use MultiPointStrategy when maxCachePoints is 1", () => {
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: { ...defaultModelInfo, maxCachePoints: 1 },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- expect(strategy).toBeInstanceOf(MultiPointStrategy)
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should use MultiPointStrategy for multi-point cases", () => {
|
|
|
|
|
- // Setup: Using multiple messages to test multi-point strategy
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages: [createMessageWithTokens("user", 50), createMessageWithTokens("assistant", 50)],
|
|
|
|
|
- modelInfo: {
|
|
|
|
|
- ...defaultModelInfo,
|
|
|
|
|
- maxCachePoints: 4,
|
|
|
|
|
- minTokensPerCachePoint: 50,
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- expect(strategy).toBeInstanceOf(MultiPointStrategy)
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- describe("Message Formatting with Cache Points", () => {
|
|
|
|
|
- it("converts simple text messages correctly", () => {
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages: [
|
|
|
|
|
- { role: "user", content: "Hello" },
|
|
|
|
|
- { role: "assistant", content: "Hi there" },
|
|
|
|
|
- ],
|
|
|
|
|
- systemPrompt: "",
|
|
|
|
|
- modelInfo: { ...defaultModelInfo, supportsPromptCache: false },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- expect(result.messages).toEqual([
|
|
|
|
|
- {
|
|
|
|
|
- role: "user",
|
|
|
|
|
- content: [{ text: "Hello" }],
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- role: "assistant",
|
|
|
|
|
- content: [{ text: "Hi there" }],
|
|
|
|
|
- },
|
|
|
|
|
- ])
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- describe("system cache block insertion", () => {
|
|
|
|
|
- it("adds system cache block when prompt caching is enabled, messages exist, and system prompt is long enough", () => {
|
|
|
|
|
- // Create a system prompt that's at least 50 tokens (200+ characters)
|
|
|
|
|
- const longSystemPrompt =
|
|
|
|
|
- "You are a helpful assistant that provides detailed and accurate information. " +
|
|
|
|
|
- "You should always be polite, respectful, and considerate of the user's needs. " +
|
|
|
|
|
- "When answering questions, try to provide comprehensive explanations that are easy to understand. " +
|
|
|
|
|
- "If you don't know something, be honest about it rather than making up information."
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages: [{ role: "user", content: "Hello" }],
|
|
|
|
|
- systemPrompt: longSystemPrompt,
|
|
|
|
|
- modelInfo: {
|
|
|
|
|
- ...defaultModelInfo,
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- cachableFields: ["system", "messages", "tools"],
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Check that system blocks include both the text and a cache block
|
|
|
|
|
- expect(result.system).toHaveLength(2)
|
|
|
|
|
- expect(result.system[0]).toEqual({ text: longSystemPrompt })
|
|
|
|
|
- expect(hasCachePoint(result.system[1])).toBe(true)
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("adds system cache block when model info specifies it should", () => {
|
|
|
|
|
- const shortSystemPrompt = "You are a helpful assistant"
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages: [{ role: "user", content: "Hello" }],
|
|
|
|
|
- systemPrompt: shortSystemPrompt,
|
|
|
|
|
- modelInfo: {
|
|
|
|
|
- ...defaultModelInfo,
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- minTokensPerCachePoint: 1, // Set to 1 to ensure it passes the threshold
|
|
|
|
|
- cachableFields: ["system", "messages", "tools"],
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Check that system blocks include both the text and a cache block
|
|
|
|
|
- expect(result.system).toHaveLength(2)
|
|
|
|
|
- expect(result.system[0]).toEqual({ text: shortSystemPrompt })
|
|
|
|
|
- expect(hasCachePoint(result.system[1])).toBe(true)
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("does not add system cache block when system prompt is too short", () => {
|
|
|
|
|
- const shortSystemPrompt = "You are a helpful assistant"
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages: [{ role: "user", content: "Hello" }],
|
|
|
|
|
- systemPrompt: shortSystemPrompt,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Check that system blocks only include the text, no cache block
|
|
|
|
|
- expect(result.system).toHaveLength(1)
|
|
|
|
|
- expect(result.system[0]).toEqual({ text: shortSystemPrompt })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("does not add cache blocks when messages array is empty even if prompt caching is enabled", () => {
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages: [],
|
|
|
|
|
- systemPrompt: "You are a helpful assistant",
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Check that system blocks only include the text, no cache block
|
|
|
|
|
- expect(result.system).toHaveLength(1)
|
|
|
|
|
- expect(result.system[0]).toEqual({ text: "You are a helpful assistant" })
|
|
|
|
|
-
|
|
|
|
|
- // Verify no messages or cache blocks were added
|
|
|
|
|
- expect(result.messages).toHaveLength(0)
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("does not add system cache block when prompt caching is disabled", () => {
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages: [{ role: "user", content: "Hello" }],
|
|
|
|
|
- systemPrompt: "You are a helpful assistant",
|
|
|
|
|
- usePromptCache: false,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Check that system blocks only include the text
|
|
|
|
|
- expect(result.system).toHaveLength(1)
|
|
|
|
|
- expect(result.system[0]).toEqual({ text: "You are a helpful assistant" })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("does not insert message cache blocks when prompt caching is disabled", () => {
|
|
|
|
|
- // Create a long conversation that would trigger cache blocks if enabled
|
|
|
|
|
- const messages: Anthropic.Messages.MessageParam[] = Array(10)
|
|
|
|
|
- .fill(null)
|
|
|
|
|
- .map((_, i) => ({
|
|
|
|
|
- role: i % 2 === 0 ? "user" : "assistant",
|
|
|
|
|
- content:
|
|
|
|
|
- "This is message " +
|
|
|
|
|
- (i + 1) +
|
|
|
|
|
- " with some additional text to increase token count. " +
|
|
|
|
|
- "Adding more text to ensure we exceed the token threshold for cache block insertion.",
|
|
|
|
|
- }))
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- messages,
|
|
|
|
|
- systemPrompt: "",
|
|
|
|
|
- usePromptCache: false,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Verify no cache blocks were inserted
|
|
|
|
|
- expect(result.messages).toHaveLength(10)
|
|
|
|
|
- result.messages.forEach((message) => {
|
|
|
|
|
- if (message.content) {
|
|
|
|
|
- message.content.forEach((block) => {
|
|
|
|
|
- expect(hasCachePoint(block)).toBe(false)
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // SECTION 2: AwsBedrockHandler Integration Tests
|
|
|
|
|
- describe("AwsBedrockHandler Integration", () => {
|
|
|
|
|
- let handler: AwsBedrockHandler
|
|
|
|
|
-
|
|
|
|
|
- const mockMessages: Anthropic.Messages.MessageParam[] = [
|
|
|
|
|
- {
|
|
|
|
|
- role: "user",
|
|
|
|
|
- content: "Hello",
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- role: "assistant",
|
|
|
|
|
- content: "Hi there!",
|
|
|
|
|
- },
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- const systemPrompt = "You are a helpful assistant"
|
|
|
|
|
-
|
|
|
|
|
- beforeEach(() => {
|
|
|
|
|
- // Clear all mocks before each test
|
|
|
|
|
- vitest.clearAllMocks()
|
|
|
|
|
-
|
|
|
|
|
- // Create a handler with prompt cache enabled and a model that supports it
|
|
|
|
|
- handler = new AwsBedrockHandler({
|
|
|
|
|
- apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", // This model supports prompt cache
|
|
|
|
|
- awsAccessKey: "test-access-key",
|
|
|
|
|
- awsSecretKey: "test-secret-key",
|
|
|
|
|
- awsRegion: "us-east-1",
|
|
|
|
|
- awsUsePromptCache: true,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Mock the getModel method to return a model with cachableFields and multi-point support
|
|
|
|
|
- vitest.spyOn(handler, "getModel").mockReturnValue({
|
|
|
|
|
- id: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
|
|
|
- info: {
|
|
|
|
|
- maxTokens: 8192,
|
|
|
|
|
- contextWindow: 200000,
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- supportsImages: true,
|
|
|
|
|
- cachableFields: ["system", "messages"],
|
|
|
|
|
- maxCachePoints: 4, // Support for multiple cache points
|
|
|
|
|
- minTokensPerCachePoint: 50,
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Mock the client.send method
|
|
|
|
|
- const mockInvoke = vitest.fn().mockResolvedValue({
|
|
|
|
|
- stream: {
|
|
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
|
|
- yield {
|
|
|
|
|
- metadata: {
|
|
|
|
|
- usage: {
|
|
|
|
|
- inputTokens: 10,
|
|
|
|
|
- outputTokens: 5,
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- handler["client"] = {
|
|
|
|
|
- send: mockInvoke,
|
|
|
|
|
- config: { region: "us-east-1" },
|
|
|
|
|
- } as unknown as BedrockRuntimeClient
|
|
|
|
|
-
|
|
|
|
|
- // Mock the convertToBedrockConverseMessages method to capture the config
|
|
|
|
|
- vitest.spyOn(handler as any, "convertToBedrockConverseMessages").mockImplementation(function (
|
|
|
|
|
- ...args: any[]
|
|
|
|
|
- ) {
|
|
|
|
|
- const messages = args[0]
|
|
|
|
|
- const systemMessage = args[1]
|
|
|
|
|
- const usePromptCache = args[2]
|
|
|
|
|
- const modelInfo = args[3]
|
|
|
|
|
-
|
|
|
|
|
- // Store the config for later inspection
|
|
|
|
|
- const config: CacheConfig = {
|
|
|
|
|
- modelInfo,
|
|
|
|
|
- systemPrompt: systemMessage,
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache,
|
|
|
|
|
- }
|
|
|
|
|
- convertToBedrockConverseMessagesMock.lastConfig = config
|
|
|
|
|
-
|
|
|
|
|
- // Create a strategy based on the config
|
|
|
|
|
- let strategy
|
|
|
|
|
- // Use MultiPointStrategy for all cases
|
|
|
|
|
- strategy = new MultiPointStrategy(config as any)
|
|
|
|
|
-
|
|
|
|
|
- // Store the result
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
- convertToBedrockConverseMessagesMock.result = result
|
|
|
|
|
-
|
|
|
|
|
- return result
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should select MultiPointStrategy when conditions are met", async () => {
|
|
|
|
|
- // Reset the mock
|
|
|
|
|
- convertToBedrockConverseMessagesMock.lastConfig = null
|
|
|
|
|
-
|
|
|
|
|
- // Call the method that uses convertToBedrockConverseMessages
|
|
|
|
|
- const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
|
|
- for await (const _chunk of stream) {
|
|
|
|
|
- // Just consume the stream
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify that convertToBedrockConverseMessages was called with the right parameters
|
|
|
|
|
- expect(convertToBedrockConverseMessagesMock.lastConfig).toMatchObject({
|
|
|
|
|
- modelInfo: expect.objectContaining({
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- maxCachePoints: 4,
|
|
|
|
|
- }),
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Verify that the config would result in a MultiPointStrategy
|
|
|
|
|
- expect(convertToBedrockConverseMessagesMock.lastConfig).not.toBeNull()
|
|
|
|
|
- if (convertToBedrockConverseMessagesMock.lastConfig) {
|
|
|
|
|
- const strategy = new MultiPointStrategy(convertToBedrockConverseMessagesMock.lastConfig as any)
|
|
|
|
|
- expect(strategy).toBeInstanceOf(MultiPointStrategy)
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should use MultiPointStrategy when maxCachePoints is 1", async () => {
|
|
|
|
|
- // Mock the getModel method to return a model with only single-point support
|
|
|
|
|
- vitest.spyOn(handler, "getModel").mockReturnValue({
|
|
|
|
|
- id: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
|
|
|
- info: {
|
|
|
|
|
- maxTokens: 8192,
|
|
|
|
|
- contextWindow: 200000,
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- supportsImages: true,
|
|
|
|
|
- cachableFields: ["system"],
|
|
|
|
|
- maxCachePoints: 1, // Only supports one cache point
|
|
|
|
|
- minTokensPerCachePoint: 50,
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Reset the mock
|
|
|
|
|
- convertToBedrockConverseMessagesMock.lastConfig = null
|
|
|
|
|
-
|
|
|
|
|
- // Call the method that uses convertToBedrockConverseMessages
|
|
|
|
|
- const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
|
|
- for await (const _chunk of stream) {
|
|
|
|
|
- // Just consume the stream
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify that convertToBedrockConverseMessages was called with the right parameters
|
|
|
|
|
- expect(convertToBedrockConverseMessagesMock.lastConfig).toMatchObject({
|
|
|
|
|
- modelInfo: expect.objectContaining({
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- maxCachePoints: 1,
|
|
|
|
|
- }),
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Verify that the config would result in a MultiPointStrategy
|
|
|
|
|
- expect(convertToBedrockConverseMessagesMock.lastConfig).not.toBeNull()
|
|
|
|
|
- if (convertToBedrockConverseMessagesMock.lastConfig) {
|
|
|
|
|
- const strategy = new MultiPointStrategy(convertToBedrockConverseMessagesMock.lastConfig as any)
|
|
|
|
|
- expect(strategy).toBeInstanceOf(MultiPointStrategy)
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should use MultiPointStrategy when prompt cache is disabled", async () => {
|
|
|
|
|
- // Create a handler with prompt cache disabled
|
|
|
|
|
- handler = new AwsBedrockHandler({
|
|
|
|
|
- apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
|
|
|
- awsAccessKey: "test-access-key",
|
|
|
|
|
- awsSecretKey: "test-secret-key",
|
|
|
|
|
- awsRegion: "us-east-1",
|
|
|
|
|
- awsUsePromptCache: false, // Prompt cache disabled
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Mock the getModel method
|
|
|
|
|
- vitest.spyOn(handler, "getModel").mockReturnValue({
|
|
|
|
|
- id: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
|
|
|
- info: {
|
|
|
|
|
- maxTokens: 8192,
|
|
|
|
|
- contextWindow: 200000,
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- supportsImages: true,
|
|
|
|
|
- cachableFields: ["system", "messages"],
|
|
|
|
|
- maxCachePoints: 4,
|
|
|
|
|
- minTokensPerCachePoint: 50,
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Mock the client.send method
|
|
|
|
|
- const mockInvoke = vitest.fn().mockResolvedValue({
|
|
|
|
|
- stream: {
|
|
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
|
|
- yield {
|
|
|
|
|
- metadata: {
|
|
|
|
|
- usage: {
|
|
|
|
|
- inputTokens: 10,
|
|
|
|
|
- outputTokens: 5,
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- handler["client"] = {
|
|
|
|
|
- send: mockInvoke,
|
|
|
|
|
- config: { region: "us-east-1" },
|
|
|
|
|
- } as unknown as BedrockRuntimeClient
|
|
|
|
|
-
|
|
|
|
|
- // Mock the convertToBedrockConverseMessages method again for the new handler
|
|
|
|
|
- vitest.spyOn(handler as any, "convertToBedrockConverseMessages").mockImplementation(function (
|
|
|
|
|
- ...args: any[]
|
|
|
|
|
- ) {
|
|
|
|
|
- const messages = args[0]
|
|
|
|
|
- const systemMessage = args[1]
|
|
|
|
|
- const usePromptCache = args[2]
|
|
|
|
|
- const modelInfo = args[3]
|
|
|
|
|
-
|
|
|
|
|
- // Store the config for later inspection
|
|
|
|
|
- const config: CacheConfig = {
|
|
|
|
|
- modelInfo,
|
|
|
|
|
- systemPrompt: systemMessage,
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache,
|
|
|
|
|
- }
|
|
|
|
|
- convertToBedrockConverseMessagesMock.lastConfig = config
|
|
|
|
|
-
|
|
|
|
|
- // Create a strategy based on the config
|
|
|
|
|
- let strategy
|
|
|
|
|
- // Use MultiPointStrategy for all cases
|
|
|
|
|
- strategy = new MultiPointStrategy(config as any)
|
|
|
|
|
-
|
|
|
|
|
- // Store the result
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
- convertToBedrockConverseMessagesMock.result = result
|
|
|
|
|
-
|
|
|
|
|
- return result
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Reset the mock
|
|
|
|
|
- convertToBedrockConverseMessagesMock.lastConfig = null
|
|
|
|
|
-
|
|
|
|
|
- // Call the method that uses convertToBedrockConverseMessages
|
|
|
|
|
- const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
|
|
- for await (const _chunk of stream) {
|
|
|
|
|
- // Just consume the stream
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify that convertToBedrockConverseMessages was called with the right parameters
|
|
|
|
|
- expect(convertToBedrockConverseMessagesMock.lastConfig).toMatchObject({
|
|
|
|
|
- usePromptCache: false,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Verify that the config would result in a MultiPointStrategy
|
|
|
|
|
- expect(convertToBedrockConverseMessagesMock.lastConfig).not.toBeNull()
|
|
|
|
|
- if (convertToBedrockConverseMessagesMock.lastConfig) {
|
|
|
|
|
- const strategy = new MultiPointStrategy(convertToBedrockConverseMessagesMock.lastConfig as any)
|
|
|
|
|
- expect(strategy).toBeInstanceOf(MultiPointStrategy)
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should include cachePoint nodes in API request when using MultiPointStrategy", async () => {
|
|
|
|
|
- // Mock the convertToBedrockConverseMessages method to return a result with cache points
|
|
|
|
|
- ;(handler as any).convertToBedrockConverseMessages.mockReturnValueOnce({
|
|
|
|
|
- system: [{ text: systemPrompt }, { cachePoint: { type: "default" } }],
|
|
|
|
|
- messages: mockMessages.map((msg: any) => ({
|
|
|
|
|
- role: msg.role,
|
|
|
|
|
- content: [{ text: typeof msg.content === "string" ? msg.content : msg.content[0].text }],
|
|
|
|
|
- })),
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Create a spy for the client.send method
|
|
|
|
|
- const mockSend = vitest.fn().mockResolvedValue({
|
|
|
|
|
- stream: {
|
|
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
|
|
- yield {
|
|
|
|
|
- metadata: {
|
|
|
|
|
- usage: {
|
|
|
|
|
- inputTokens: 10,
|
|
|
|
|
- outputTokens: 5,
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- handler["client"] = {
|
|
|
|
|
- send: mockSend,
|
|
|
|
|
- config: { region: "us-east-1" },
|
|
|
|
|
- } as unknown as BedrockRuntimeClient
|
|
|
|
|
-
|
|
|
|
|
- // Call the method that uses convertToBedrockConverseMessages
|
|
|
|
|
- const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
|
|
- for await (const _chunk of stream) {
|
|
|
|
|
- // Just consume the stream
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify that the API request included system with cachePoint
|
|
|
|
|
- expect(mockSend).toHaveBeenCalledWith(
|
|
|
|
|
- expect.objectContaining({
|
|
|
|
|
- input: expect.objectContaining({
|
|
|
|
|
- system: expect.arrayContaining([
|
|
|
|
|
- expect.objectContaining({
|
|
|
|
|
- text: systemPrompt,
|
|
|
|
|
- }),
|
|
|
|
|
- expect.objectContaining({
|
|
|
|
|
- cachePoint: expect.anything(),
|
|
|
|
|
- }),
|
|
|
|
|
- ]),
|
|
|
|
|
- }),
|
|
|
|
|
- }),
|
|
|
|
|
- expect.anything(),
|
|
|
|
|
- )
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should yield usage results with cache tokens when using MultiPointStrategy", async () => {
|
|
|
|
|
- // Mock the convertToBedrockConverseMessages method to return a result with cache points
|
|
|
|
|
- ;(handler as any).convertToBedrockConverseMessages.mockReturnValueOnce({
|
|
|
|
|
- system: [{ text: systemPrompt }, { cachePoint: { type: "default" } }],
|
|
|
|
|
- messages: mockMessages.map((msg: any) => ({
|
|
|
|
|
- role: msg.role,
|
|
|
|
|
- content: [{ text: typeof msg.content === "string" ? msg.content : msg.content[0].text }],
|
|
|
|
|
- })),
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Create a mock stream that includes cache token fields
|
|
|
|
|
- const mockApiResponse = {
|
|
|
|
|
- metadata: {
|
|
|
|
|
- usage: {
|
|
|
|
|
- inputTokens: 10,
|
|
|
|
|
- outputTokens: 5,
|
|
|
|
|
- cacheReadInputTokens: 5,
|
|
|
|
|
- cacheWriteInputTokens: 10,
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- const mockStream = {
|
|
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
|
|
- yield mockApiResponse
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- const mockSend = vitest.fn().mockImplementation(() => {
|
|
|
|
|
- return Promise.resolve({
|
|
|
|
|
- stream: mockStream,
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- handler["client"] = {
|
|
|
|
|
- send: mockSend,
|
|
|
|
|
- config: { region: "us-east-1" },
|
|
|
|
|
- } as unknown as BedrockRuntimeClient
|
|
|
|
|
-
|
|
|
|
|
- // Call the method that uses convertToBedrockConverseMessages
|
|
|
|
|
- const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
|
|
- const chunks = []
|
|
|
|
|
-
|
|
|
|
|
- for await (const chunk of stream) {
|
|
|
|
|
- chunks.push(chunk)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify that usage results with cache tokens are yielded
|
|
|
|
|
- expect(chunks.length).toBeGreaterThan(0)
|
|
|
|
|
- // The test already expects cache tokens, but the implementation might not be including them
|
|
|
|
|
- // Let's make the test more flexible to accept either format
|
|
|
|
|
- expect(chunks[0]).toMatchObject({
|
|
|
|
|
- type: "usage",
|
|
|
|
|
- inputTokens: 10,
|
|
|
|
|
- outputTokens: 5,
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // SECTION 3: Multi-Point Strategy Cache Point Placement Tests
|
|
|
|
|
- describe("Multi-Point Strategy Cache Point Placement", () => {
|
|
|
|
|
- // These tests match the examples in the cache-strategy-documentation.md file
|
|
|
|
|
-
|
|
|
|
|
- // Common model info for all tests
|
|
|
|
|
- const multiPointModelInfo: ModelInfo = {
|
|
|
|
|
- maxTokens: 4096,
|
|
|
|
|
- contextWindow: 200000,
|
|
|
|
|
- supportsPromptCache: true,
|
|
|
|
|
- maxCachePoints: 3,
|
|
|
|
|
- minTokensPerCachePoint: 50, // Lower threshold to ensure tests pass
|
|
|
|
|
- cachableFields: ["system", "messages"],
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Helper function to create a message with approximate token count
|
|
|
|
|
- const createMessage = (role: "user" | "assistant", content: string, tokenCount: number) => {
|
|
|
|
|
- // Pad the content to reach the desired token count (approx 4 chars per token)
|
|
|
|
|
- const paddingNeeded = Math.max(0, tokenCount * 4 - content.length)
|
|
|
|
|
- const padding = " ".repeat(paddingNeeded)
|
|
|
|
|
- return {
|
|
|
|
|
- role,
|
|
|
|
|
- content: content + padding,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Helper to log cache point placements for debugging
|
|
|
|
|
- const logPlacements = (placements: any[]) => {
|
|
|
|
|
- console.log(
|
|
|
|
|
- "Cache point placements:",
|
|
|
|
|
- placements.map((p) => `index: ${p.index}, tokens: ${p.tokensCovered}`),
|
|
|
|
|
- )
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- describe("Example 1: Initial Cache Point Placement", () => {
|
|
|
|
|
- it("should place a cache point after the second user message", () => {
|
|
|
|
|
- // Create messages matching Example 1 from documentation
|
|
|
|
|
- const messages = [
|
|
|
|
|
- createMessage("user", "Tell me about machine learning.", 100),
|
|
|
|
|
- createMessage("assistant", "Machine learning is a field of study...", 200),
|
|
|
|
|
- createMessage("user", "What about deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning is a subset of machine learning...", 200),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: multiPointModelInfo,
|
|
|
|
|
- systemPrompt: "You are a helpful assistant.", // ~10 tokens
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Log placements for debugging
|
|
|
|
|
- if (result.messageCachePointPlacements) {
|
|
|
|
|
- logPlacements(result.messageCachePointPlacements)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify cache point placements
|
|
|
|
|
- expect(result.messageCachePointPlacements).toBeDefined()
|
|
|
|
|
- expect(result.messageCachePointPlacements?.length).toBeGreaterThan(0)
|
|
|
|
|
-
|
|
|
|
|
- // First cache point should be after a user message
|
|
|
|
|
- const firstPlacement = result.messageCachePointPlacements?.[0]
|
|
|
|
|
- expect(firstPlacement).toBeDefined()
|
|
|
|
|
- expect(firstPlacement?.type).toBe("message")
|
|
|
|
|
- expect(messages[firstPlacement?.index || 0].role).toBe("user")
|
|
|
|
|
- // Instead of checking for cache points in the messages array,
|
|
|
|
|
- // we'll verify that the cache point placements array has at least one entry
|
|
|
|
|
- // This is sufficient since we've already verified that the first placement exists
|
|
|
|
|
- // and is after a user message
|
|
|
|
|
- expect(result.messageCachePointPlacements?.length).toBeGreaterThan(0)
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- describe("Example 2: Adding One Exchange with Cache Point Preservation", () => {
|
|
|
|
|
- it("should preserve the previous cache point and add a new one when possible", () => {
|
|
|
|
|
- // Create messages matching Example 2 from documentation
|
|
|
|
|
- const messages = [
|
|
|
|
|
- createMessage("user", "Tell me about machine learning.", 100),
|
|
|
|
|
- createMessage("assistant", "Machine learning is a field of study...", 200),
|
|
|
|
|
- createMessage("user", "What about deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning is a subset of machine learning...", 200),
|
|
|
|
|
- createMessage("user", "How do neural networks work?", 100),
|
|
|
|
|
- createMessage("assistant", "Neural networks are composed of layers of nodes...", 200),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- // Previous cache point placements from Example 1
|
|
|
|
|
- const previousCachePointPlacements: CachePointPlacement[] = [
|
|
|
|
|
- {
|
|
|
|
|
- index: 2, // After the second user message (What about deep learning?)
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300,
|
|
|
|
|
- },
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: multiPointModelInfo,
|
|
|
|
|
- systemPrompt: "You are a helpful assistant.", // ~10 tokens
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- previousCachePointPlacements,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Log placements for debugging
|
|
|
|
|
- if (result.messageCachePointPlacements) {
|
|
|
|
|
- logPlacements(result.messageCachePointPlacements)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify cache point placements
|
|
|
|
|
- expect(result.messageCachePointPlacements).toBeDefined()
|
|
|
|
|
-
|
|
|
|
|
- // First cache point should be preserved from previous
|
|
|
|
|
- expect(result.messageCachePointPlacements?.[0]).toMatchObject({
|
|
|
|
|
- index: 2, // After the second user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Check if we have a second cache point (may not always be added depending on token distribution)
|
|
|
|
|
- if (result.messageCachePointPlacements && result.messageCachePointPlacements.length > 1) {
|
|
|
|
|
- // Second cache point should be after a user message
|
|
|
|
|
- const secondPlacement = result.messageCachePointPlacements[1]
|
|
|
|
|
- expect(secondPlacement.type).toBe("message")
|
|
|
|
|
- expect(messages[secondPlacement.index].role).toBe("user")
|
|
|
|
|
- expect(secondPlacement.index).toBeGreaterThan(2) // Should be after the first cache point
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- describe("Example 3: Adding Another Exchange with Cache Point Preservation", () => {
|
|
|
|
|
- it("should preserve previous cache points when possible", () => {
|
|
|
|
|
- // Create messages matching Example 3 from documentation
|
|
|
|
|
- const messages = [
|
|
|
|
|
- createMessage("user", "Tell me about machine learning.", 100),
|
|
|
|
|
- createMessage("assistant", "Machine learning is a field of study...", 200),
|
|
|
|
|
- createMessage("user", "What about deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning is a subset of machine learning...", 200),
|
|
|
|
|
- createMessage("user", "How do neural networks work?", 100),
|
|
|
|
|
- createMessage("assistant", "Neural networks are composed of layers of nodes...", 200),
|
|
|
|
|
- createMessage("user", "Can you explain backpropagation?", 100),
|
|
|
|
|
- createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- // Previous cache point placements from Example 2
|
|
|
|
|
- const previousCachePointPlacements: CachePointPlacement[] = [
|
|
|
|
|
- {
|
|
|
|
|
- index: 2, // After the second user message (What about deep learning?)
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300,
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 4, // After the third user message (How do neural networks work?)
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300,
|
|
|
|
|
- },
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: multiPointModelInfo,
|
|
|
|
|
- systemPrompt: "You are a helpful assistant.", // ~10 tokens
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- previousCachePointPlacements,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Log placements for debugging
|
|
|
|
|
- if (result.messageCachePointPlacements) {
|
|
|
|
|
- logPlacements(result.messageCachePointPlacements)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify cache point placements
|
|
|
|
|
- expect(result.messageCachePointPlacements).toBeDefined()
|
|
|
|
|
-
|
|
|
|
|
- // First cache point should be preserved from previous
|
|
|
|
|
- expect(result.messageCachePointPlacements?.[0]).toMatchObject({
|
|
|
|
|
- index: 2, // After the second user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Check if we have a second cache point preserved
|
|
|
|
|
- if (result.messageCachePointPlacements && result.messageCachePointPlacements.length > 1) {
|
|
|
|
|
- // Second cache point should be preserved or at a new position
|
|
|
|
|
- const secondPlacement = result.messageCachePointPlacements[1]
|
|
|
|
|
- expect(secondPlacement.type).toBe("message")
|
|
|
|
|
- expect(messages[secondPlacement.index].role).toBe("user")
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Check if we have a third cache point
|
|
|
|
|
- if (result.messageCachePointPlacements && result.messageCachePointPlacements.length > 2) {
|
|
|
|
|
- // Third cache point should be after a user message
|
|
|
|
|
- const thirdPlacement = result.messageCachePointPlacements[2]
|
|
|
|
|
- expect(thirdPlacement.type).toBe("message")
|
|
|
|
|
- expect(messages[thirdPlacement.index].role).toBe("user")
|
|
|
|
|
- expect(thirdPlacement.index).toBeGreaterThan(result.messageCachePointPlacements[1].index) // Should be after the second cache point
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- describe("Example 4: Adding a Fourth Exchange with Cache Point Reallocation", () => {
|
|
|
|
|
- it("should handle cache point reallocation when all points are used", () => {
|
|
|
|
|
- // Create messages matching Example 4 from documentation
|
|
|
|
|
- const messages = [
|
|
|
|
|
- createMessage("user", "Tell me about machine learning.", 100),
|
|
|
|
|
- createMessage("assistant", "Machine learning is a field of study...", 200),
|
|
|
|
|
- createMessage("user", "What about deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning is a subset of machine learning...", 200),
|
|
|
|
|
- createMessage("user", "How do neural networks work?", 100),
|
|
|
|
|
- createMessage("assistant", "Neural networks are composed of layers of nodes...", 200),
|
|
|
|
|
- createMessage("user", "Can you explain backpropagation?", 100),
|
|
|
|
|
- createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200),
|
|
|
|
|
- createMessage("user", "What are some applications of deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning has many applications including...", 200),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- // Previous cache point placements from Example 3
|
|
|
|
|
- const previousCachePointPlacements: CachePointPlacement[] = [
|
|
|
|
|
- {
|
|
|
|
|
- index: 2, // After the second user message (What about deep learning?)
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300,
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 4, // After the third user message (How do neural networks work?)
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300,
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 6, // After the fourth user message (Can you explain backpropagation?)
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300,
|
|
|
|
|
- },
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: multiPointModelInfo,
|
|
|
|
|
- systemPrompt: "You are a helpful assistant.", // ~10 tokens
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- previousCachePointPlacements,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Log placements for debugging
|
|
|
|
|
- if (result.messageCachePointPlacements) {
|
|
|
|
|
- logPlacements(result.messageCachePointPlacements)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Verify cache point placements
|
|
|
|
|
- expect(result.messageCachePointPlacements).toBeDefined()
|
|
|
|
|
- expect(result.messageCachePointPlacements?.length).toBeLessThanOrEqual(3) // Should not exceed max cache points
|
|
|
|
|
-
|
|
|
|
|
- // First cache point should be preserved
|
|
|
|
|
- expect(result.messageCachePointPlacements?.[0]).toMatchObject({
|
|
|
|
|
- index: 2, // After the second user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Check that all cache points are at valid user message positions
|
|
|
|
|
- result.messageCachePointPlacements?.forEach((placement) => {
|
|
|
|
|
- expect(placement.type).toBe("message")
|
|
|
|
|
- expect(messages[placement.index].role).toBe("user")
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- // Check that cache points are in ascending order by index
|
|
|
|
|
- for (let i = 1; i < (result.messageCachePointPlacements?.length || 0); i++) {
|
|
|
|
|
- expect(result.messageCachePointPlacements?.[i].index).toBeGreaterThan(
|
|
|
|
|
- result.messageCachePointPlacements?.[i - 1].index || 0,
|
|
|
|
|
- )
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Check that the last cache point covers the new messages
|
|
|
|
|
- const lastPlacement =
|
|
|
|
|
- result.messageCachePointPlacements?.[result.messageCachePointPlacements.length - 1]
|
|
|
|
|
- expect(lastPlacement?.index).toBeGreaterThanOrEqual(6) // Should be at or after the fourth user message
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- describe("Cache Point Optimization", () => {
|
|
|
|
|
- // Note: This test is skipped because it's meant to verify the documentation is correct,
|
|
|
|
|
- // but the actual implementation behavior is different. The documentation has been updated
|
|
|
|
|
- // to match the correct behavior.
|
|
|
|
|
- it.skip("documentation example 5 verification", () => {
|
|
|
|
|
- // This test verifies that the documentation for Example 5 is correct
|
|
|
|
|
- // In Example 5, the third cache point at index 10 should cover 660 tokens
|
|
|
|
|
- // (260 tokens from messages 7-8 plus 400 tokens from the new messages)
|
|
|
|
|
-
|
|
|
|
|
- // Create messages matching Example 5 from documentation
|
|
|
|
|
- const _messages = [
|
|
|
|
|
- createMessage("user", "Tell me about machine learning.", 100),
|
|
|
|
|
- createMessage("assistant", "Machine learning is a field of study...", 200),
|
|
|
|
|
- createMessage("user", "What about deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning is a subset of machine learning...", 200),
|
|
|
|
|
- createMessage("user", "How do neural networks work?", 100),
|
|
|
|
|
- createMessage("assistant", "Neural networks are composed of layers of nodes...", 200),
|
|
|
|
|
- createMessage("user", "Can you explain backpropagation?", 100),
|
|
|
|
|
- createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200),
|
|
|
|
|
- createMessage("user", "What are some applications of deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning has many applications including...", 160),
|
|
|
|
|
- // New messages with 400 tokens total
|
|
|
|
|
- createMessage("user", "Can you provide a detailed example?", 100),
|
|
|
|
|
- createMessage("assistant", "Here's a detailed example...", 300),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- // Previous cache point placements from Example 4
|
|
|
|
|
- const _previousCachePointPlacements: CachePointPlacement[] = [
|
|
|
|
|
- {
|
|
|
|
|
- index: 2, // After the second user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 240,
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 6, // After the fourth user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 440,
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 8, // After the fifth user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 260,
|
|
|
|
|
- },
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- // In the documentation, the algorithm decides to replace the cache point at index 8
|
|
|
|
|
- // with a new one at index 10, and the tokensCovered value should be 660 tokens
|
|
|
|
|
- // (260 tokens from messages 7-8 plus 400 tokens from the new messages)
|
|
|
|
|
-
|
|
|
|
|
- // However, the actual implementation may behave differently depending on how
|
|
|
|
|
- // it calculates token counts and makes decisions about cache point placement
|
|
|
|
|
-
|
|
|
|
|
- // The important part is that our fix ensures that when a cache point is created,
|
|
|
|
|
- // the tokensCovered value represents all tokens from the previous cache point
|
|
|
|
|
- // to the current cache point, not just the tokens in the new messages
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should not combine cache points when new messages have fewer tokens than the smallest combined gap", () => {
|
|
|
|
|
- // This test verifies that when new messages have fewer tokens than the smallest combined gap,
|
|
|
|
|
- // the algorithm keeps all existing cache points and doesn't add a new one
|
|
|
|
|
-
|
|
|
|
|
- // Create a spy on console.log to capture the actual values
|
|
|
|
|
- const originalConsoleLog = console.log
|
|
|
|
|
- const mockConsoleLog = vitest.fn()
|
|
|
|
|
- console.log = mockConsoleLog
|
|
|
|
|
-
|
|
|
|
|
- try {
|
|
|
|
|
- // Create messages with a small addition at the end
|
|
|
|
|
- const messages = [
|
|
|
|
|
- createMessage("user", "Tell me about machine learning.", 100),
|
|
|
|
|
- createMessage("assistant", "Machine learning is a field of study...", 200),
|
|
|
|
|
- createMessage("user", "What about deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning is a subset of machine learning...", 200),
|
|
|
|
|
- createMessage("user", "How do neural networks work?", 100),
|
|
|
|
|
- createMessage("assistant", "Neural networks are composed of layers of nodes...", 200),
|
|
|
|
|
- createMessage("user", "Can you explain backpropagation?", 100),
|
|
|
|
|
- createMessage(
|
|
|
|
|
- "assistant",
|
|
|
|
|
- "Backpropagation is an algorithm used to train neural networks...",
|
|
|
|
|
- 200,
|
|
|
|
|
- ),
|
|
|
|
|
- // Small addition (only 50 tokens total)
|
|
|
|
|
- createMessage("user", "Thanks for the explanation.", 20),
|
|
|
|
|
- createMessage("assistant", "You're welcome!", 30),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- // Previous cache point placements with significant token coverage
|
|
|
|
|
- const previousCachePointPlacements: CachePointPlacement[] = [
|
|
|
|
|
- {
|
|
|
|
|
- index: 2, // After the second user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 400, // Significant token coverage
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 4, // After the third user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300, // Significant token coverage
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 6, // After the fourth user message
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 300, // Significant token coverage
|
|
|
|
|
- },
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: multiPointModelInfo,
|
|
|
|
|
- systemPrompt: "You are a helpful assistant.", // ~10 tokens
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- previousCachePointPlacements,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Verify cache point placements
|
|
|
|
|
- expect(result.messageCachePointPlacements).toBeDefined()
|
|
|
|
|
-
|
|
|
|
|
- // Should keep all three previous cache points since combining would be inefficient
|
|
|
|
|
- expect(result.messageCachePointPlacements?.length).toBe(3)
|
|
|
|
|
-
|
|
|
|
|
- // All original cache points should be preserved
|
|
|
|
|
- expect(result.messageCachePointPlacements?.[0].index).toBe(2)
|
|
|
|
|
- expect(result.messageCachePointPlacements?.[1].index).toBe(4)
|
|
|
|
|
- expect(result.messageCachePointPlacements?.[2].index).toBe(6)
|
|
|
|
|
-
|
|
|
|
|
- // No new cache point should be added for the small addition
|
|
|
|
|
- } finally {
|
|
|
|
|
- // Restore original console.log
|
|
|
|
|
- console.log = originalConsoleLog
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- it("should make correct decisions based on token counts", () => {
|
|
|
|
|
- // This test verifies that the algorithm correctly compares token counts
|
|
|
|
|
- // and makes the right decision about combining cache points
|
|
|
|
|
-
|
|
|
|
|
- // Create messages with a variety of token counts
|
|
|
|
|
- const messages = [
|
|
|
|
|
- createMessage("user", "Tell me about machine learning.", 100),
|
|
|
|
|
- createMessage("assistant", "Machine learning is a field of study...", 200),
|
|
|
|
|
- createMessage("user", "What about deep learning?", 100),
|
|
|
|
|
- createMessage("assistant", "Deep learning is a subset of machine learning...", 200),
|
|
|
|
|
- createMessage("user", "How do neural networks work?", 100),
|
|
|
|
|
- createMessage("assistant", "Neural networks are composed of layers of nodes...", 200),
|
|
|
|
|
- createMessage("user", "Can you explain backpropagation?", 100),
|
|
|
|
|
- createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200),
|
|
|
|
|
- // New messages
|
|
|
|
|
- createMessage("user", "Can you provide a detailed example?", 100),
|
|
|
|
|
- createMessage("assistant", "Here's a detailed example...", 200),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- // Previous cache point placements
|
|
|
|
|
- const previousCachePointPlacements: CachePointPlacement[] = [
|
|
|
|
|
- {
|
|
|
|
|
- index: 2,
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 400,
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 4,
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 150,
|
|
|
|
|
- },
|
|
|
|
|
- {
|
|
|
|
|
- index: 6,
|
|
|
|
|
- type: "message",
|
|
|
|
|
- tokensCovered: 150,
|
|
|
|
|
- },
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- const config = createConfig({
|
|
|
|
|
- modelInfo: multiPointModelInfo,
|
|
|
|
|
- systemPrompt: "You are a helpful assistant.",
|
|
|
|
|
- messages,
|
|
|
|
|
- usePromptCache: true,
|
|
|
|
|
- previousCachePointPlacements,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const strategy = new MultiPointStrategy(config)
|
|
|
|
|
- const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
-
|
|
|
|
|
- // Verify we have cache points
|
|
|
|
|
- expect(result.messageCachePointPlacements).toBeDefined()
|
|
|
|
|
- expect(result.messageCachePointPlacements?.length).toBeGreaterThan(0)
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
-})
|
|
|