Browse Source

Gemini caching fixes (#3096)

Chris Estreich 8 months ago
parent
commit
06d8dd2bcf

+ 5 - 19
src/api/providers/anthropic-vertex.ts

@@ -3,13 +3,14 @@ import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
 import { GoogleAuth, JWTInput } from "google-auth-library"
 
 import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
-import { ApiStream } from "../transform/stream"
 import { safeJsonParse } from "../../shared/safeJsonParse"
 
+import { ApiStream } from "../transform/stream"
+import { addCacheBreakpoints } from "../transform/caching/vertex"
+
 import { getModelParams, SingleCompletionHandler } from "../index"
-import { BaseProvider } from "./base-provider"
 import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants"
-import { formatMessageForCache } from "../transform/vertex-caching"
+import { BaseProvider } from "./base-provider"
 
 // https://docs.anthropic.com/en/api/claude-on-vertex-ai
 export class AnthropicVertexHandler extends BaseProvider implements SingleCompletionHandler {
@@ -57,16 +58,6 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
 			thinking,
 		} = this.getModel()
 
-		// Find indices of user messages that we want to cache
-		// We only cache the last two user messages to stay within the 4-block limit
-		// (1 block for system + 1 block each for last two user messages = 3 total)
-		const userMsgIndices = supportsPromptCache
-			? messages.reduce((acc, msg, i) => (msg.role === "user" ? [...acc, i] : acc), [] as number[])
-			: []
-
-		const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
-		const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
-
 		/**
 		 * Vertex API has specific limitations for prompt caching:
 		 * 1. Maximum of 4 blocks can have cache_control
@@ -89,12 +80,7 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
 			system: supportsPromptCache
 				? [{ text: systemPrompt, type: "text" as const, cache_control: { type: "ephemeral" } }]
 				: systemPrompt,
-			messages: messages.map((message, index) => {
-				// Only cache the last two user messages.
-				const shouldCache =
-					supportsPromptCache && (index === lastUserMsgIndex || index === secondLastMsgUserIndex)
-				return formatMessageForCache(message, shouldCache)
-			}),
+			messages: supportsPromptCache ? addCacheBreakpoints(messages) : messages,
 			stream: true,
 		}
 

+ 4 - 2
src/api/providers/glama.ts

@@ -3,9 +3,11 @@ import axios from "axios"
 import OpenAI from "openai"
 
 import { ApiHandlerOptions, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
+
 import { ApiStream } from "../transform/stream"
 import { convertToOpenAiMessages } from "../transform/openai-format"
-import { addCacheControlDirectives } from "../transform/caching"
+import { addCacheBreakpoints } from "../transform/caching/anthropic"
+
 import { SingleCompletionHandler } from "../index"
 import { RouterProvider } from "./router-provider"
 
@@ -37,7 +39,7 @@ export class GlamaHandler extends RouterProvider implements SingleCompletionHand
 		]
 
 		if (modelId.startsWith("anthropic/claude-3")) {
-			addCacheControlDirectives(systemPrompt, openAiMessages)
+			addCacheBreakpoints(systemPrompt, openAiMessages)
 		}
 
 		// Required by Anthropic; other providers default to max tokens allowed.

+ 7 - 35
src/api/providers/openrouter.ts

@@ -11,9 +11,12 @@ import {
 	OPTIONAL_PROMPT_CACHING_MODELS,
 	REASONING_MODELS,
 } from "../../shared/api"
+
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStreamChunk } from "../transform/stream"
 import { convertToR1Format } from "../transform/r1-format"
+import { addCacheBreakpoints as addAnthropicCacheBreakpoints } from "../transform/caching/anthropic"
+import { addCacheBreakpoints as addGeminiCacheBreakpoints } from "../transform/caching/gemini"
 
 import { getModelParams, SingleCompletionHandler } from "../index"
 import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
@@ -93,42 +96,11 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
 
 		const isCacheAvailable = promptCache.supported && (!promptCache.optional || this.options.promptCachingEnabled)
 
-		// Prompt caching: https://openrouter.ai/docs/prompt-caching
-		// Now with Gemini support: https://openrouter.ai/docs/features/prompt-caching
-		// Note that we don't check the `ModelInfo` object because it is cached
-		// in the settings for OpenRouter and the value could be stale.
+		// https://openrouter.ai/docs/features/prompt-caching
 		if (isCacheAvailable) {
-			openAiMessages[0] = {
-				role: "system",
-				// @ts-ignore-next-line
-				content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }],
-			}
-
-			// Add cache_control to the last two user messages
-			// (note: this works because we only ever add one user message at a time, but if we added multiple we'd need to mark the user message before the last assistant message)
-			const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2)
-
-			lastTwoUserMessages.forEach((msg) => {
-				if (typeof msg.content === "string") {
-					msg.content = [{ type: "text", text: msg.content }]
-				}
-
-				if (Array.isArray(msg.content)) {
-					// NOTE: This is fine since env details will always be added
-					// at the end. But if it wasn't there, and the user added a
-					// image_url type message, it would pop a text part before
-					// it and then move it after to the end.
-					let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
-
-					if (!lastTextPart) {
-						lastTextPart = { type: "text", text: "..." }
-						msg.content.push(lastTextPart)
-					}
-
-					// @ts-ignore-next-line
-					lastTextPart["cache_control"] = { type: "ephemeral" }
-				}
-			})
+			modelId.startsWith("google")
+				? addGeminiCacheBreakpoints(systemPrompt, openAiMessages)
+				: addAnthropicCacheBreakpoints(systemPrompt, openAiMessages)
 		}
 
 		// https://openrouter.ai/docs/transforms

+ 4 - 2
src/api/providers/unbound.ts

@@ -2,9 +2,11 @@ import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
 
 import { ApiHandlerOptions, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api"
+
 import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
 import { convertToOpenAiMessages } from "../transform/openai-format"
-import { addCacheControlDirectives } from "../transform/caching"
+import { addCacheBreakpoints } from "../transform/caching/anthropic"
+
 import { SingleCompletionHandler } from "../index"
 import { RouterProvider } from "./router-provider"
 
@@ -39,7 +41,7 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
 		]
 
 		if (modelId.startsWith("anthropic/claude-3")) {
-			addCacheControlDirectives(systemPrompt, openAiMessages)
+			addCacheBreakpoints(systemPrompt, openAiMessages)
 		}
 
 		// Required by Anthropic; other providers default to max tokens allowed.

+ 0 - 36
src/api/transform/caching.ts

@@ -1,36 +0,0 @@
-import OpenAI from "openai"
-
-export const addCacheControlDirectives = (systemPrompt: string, messages: OpenAI.Chat.ChatCompletionMessageParam[]) => {
-	messages[0] = {
-		role: "system",
-		content: [
-			{
-				type: "text",
-				text: systemPrompt,
-				// @ts-ignore-next-line
-				cache_control: { type: "ephemeral" },
-			},
-		],
-	}
-
-	messages
-		.filter((msg) => msg.role === "user")
-		.slice(-2)
-		.forEach((msg) => {
-			if (typeof msg.content === "string") {
-				msg.content = [{ type: "text", text: msg.content }]
-			}
-
-			if (Array.isArray(msg.content)) {
-				let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
-
-				if (!lastTextPart) {
-					lastTextPart = { type: "text", text: "..." }
-					msg.content.push(lastTextPart)
-				}
-
-				// @ts-ignore-next-line
-				lastTextPart["cache_control"] = { type: "ephemeral" }
-			}
-		})
-}

+ 181 - 0
src/api/transform/caching/__tests__/anthropic.test.ts

@@ -0,0 +1,181 @@
+// npx jest src/api/transform/caching/__tests__/anthropic.test.ts
+
+import OpenAI from "openai"
+
+import { addCacheBreakpoints } from "../anthropic"
+
+describe("addCacheBreakpoints (Anthropic)", () => {
+	const systemPrompt = "You are a helpful assistant."
+
+	it("should always add a cache breakpoint to the system prompt", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "Hello" },
+		]
+
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[0].content).toEqual([
+			{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should not add breakpoints to user messages if there are none", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: systemPrompt }]
+		const originalMessages = JSON.parse(JSON.stringify(messages))
+
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[0].content).toEqual([
+			{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } },
+		])
+
+		expect(messages.length).toBe(originalMessages.length)
+	})
+
+	it("should add a breakpoint to the only user message if only one exists", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "User message 1" },
+		]
+
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[1].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should add breakpoints to both user messages if only two exist", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "User message 1" },
+			{ role: "user", content: "User message 2" },
+		]
+
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[1].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(messages[2].content).toEqual([
+			{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should add breakpoints to the last two user messages when more than two exist", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "User message 1" }, // Should not get breakpoint.
+			{ role: "user", content: "User message 2" }, // Should get breakpoint.
+			{ role: "user", content: "User message 3" }, // Should get breakpoint.
+		]
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[1].content).toEqual([{ type: "text", text: "User message 1" }])
+
+		expect(messages[2].content).toEqual([
+			{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(messages[3].content).toEqual([
+			{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should handle assistant messages correctly when finding last two user messages", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "User message 1" }, // Should not get breakpoint.
+			{ role: "assistant", content: "Assistant response 1" },
+			{ role: "user", content: "User message 2" }, // Should get breakpoint (second to last user).
+			{ role: "assistant", content: "Assistant response 2" },
+			{ role: "user", content: "User message 3" }, // Should get breakpoint (last user).
+			{ role: "assistant", content: "Assistant response 3" },
+		]
+		addCacheBreakpoints(systemPrompt, messages)
+
+		const userMessages = messages.filter((m) => m.role === "user")
+
+		expect(userMessages[0].content).toEqual([{ type: "text", text: "User message 1" }])
+
+		expect(userMessages[1].content).toEqual([
+			{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(userMessages[2].content).toEqual([
+			{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should add breakpoint to the last text part if content is an array", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "User message 1" },
+			{
+				role: "user",
+				content: [
+					{ type: "text", text: "This is the last user message." },
+					{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
+					{ type: "text", text: "This part should get the breakpoint." },
+				],
+			},
+		]
+
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[1].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(messages[2].content).toEqual([
+			{ type: "text", text: "This is the last user message." },
+			{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
+			{ type: "text", text: "This part should get the breakpoint.", cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should add a placeholder text part if the target message has no text parts", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "User message 1" },
+			{
+				role: "user",
+				content: [{ type: "image_url", image_url: { url: "data:image/png;base64,..." } }],
+			},
+		]
+
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[1].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(messages[2].content).toEqual([
+			{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
+			{ type: "text", text: "...", cache_control: { type: "ephemeral" } }, // Placeholder added.
+		])
+	})
+
+	it("should ensure content is array format even if no breakpoint added", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "User message 1" }, // String content, no breakpoint.
+			{ role: "user", content: "User message 2" }, // Gets breakpoint.
+			{ role: "user", content: "User message 3" }, // Gets breakpoint.
+		]
+
+		addCacheBreakpoints(systemPrompt, messages)
+
+		expect(messages[1].content).toEqual([{ type: "text", text: "User message 1" }])
+
+		expect(messages[2].content).toEqual([
+			{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(messages[3].content).toEqual([
+			{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
+		])
+	})
+})

+ 266 - 0
src/api/transform/caching/__tests__/gemini.test.ts

@@ -0,0 +1,266 @@
+// npx jest src/api/transform/caching/__tests__/gemini.test.ts
+
+import OpenAI from "openai"
+
+import { addCacheBreakpoints } from "../gemini"
+
+describe("addCacheBreakpoints", () => {
+	const systemPrompt = "You are a helpful assistant."
+
+	it("should always add a cache breakpoint to the system prompt", () => {
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			{ role: "user", content: "Hello" },
+		]
+		addCacheBreakpoints(systemPrompt, messages, 10) // Pass frequency
+		expect(messages[0].content).toEqual([
+			{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should not add breakpoints for fewer than N user messages", () => {
+		const frequency = 5
+
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...Array.from({ length: frequency - 1 }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+		]
+
+		const originalMessages = JSON.parse(JSON.stringify(messages))
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		expect(messages[0].content).toEqual([
+			{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } },
+		])
+
+		for (let i = 1; i < messages.length; i++) {
+			const originalContent = originalMessages[i].content
+
+			const expectedContent =
+				typeof originalContent === "string" ? [{ type: "text", text: originalContent }] : originalContent
+
+			expect(messages[i].content).toEqual(expectedContent)
+		}
+	})
+
+	it("should add a breakpoint to the Nth user message", () => {
+		const frequency = 5
+
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...Array.from({ length: frequency }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+		]
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		// Check Nth user message (index 'frequency' in the full array).
+		expect(messages[frequency].content).toEqual([
+			{ type: "text", text: `User message ${frequency}`, cache_control: { type: "ephemeral" } },
+		])
+
+		// Check (N-1)th user message (index frequency-1) - should be unchanged.
+		expect(messages[frequency - 1].content).toEqual([{ type: "text", text: `User message ${frequency - 1}` }])
+	})
+
+	it("should add breakpoints to the Nth and 2*Nth user messages", () => {
+		const frequency = 5
+
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...Array.from({ length: frequency * 2 }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+		]
+
+		expect(messages.length).toEqual(frequency * 2 + 1)
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		const indices = []
+
+		for (let i = 0; i < messages.length; i++) {
+			const content = messages[i].content?.[0]
+
+			if (typeof content === "object" && "cache_control" in content) {
+				indices.push(i)
+			}
+		}
+
+		expect(indices).toEqual([0, 5, 10])
+
+		// Check Nth user message (index frequency)
+		expect(messages[frequency].content).toEqual([
+			{ type: "text", text: `User message ${frequency}`, cache_control: { type: "ephemeral" } },
+		])
+
+		// Check (2*N-1)th user message (index 2*frequency-1) - unchanged
+		expect(messages[frequency * 2 - 1].content).toEqual([
+			{ type: "text", text: `User message ${frequency * 2 - 1}` },
+		])
+
+		// Check 2*Nth user message (index 2*frequency)
+		expect(messages[frequency * 2].content).toEqual([
+			{ type: "text", text: `User message ${frequency * 2}`, cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should handle assistant messages correctly when counting user messages", () => {
+		const frequency = 5
+
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			// N-1 user messages
+			...Array.from({ length: frequency - 1 }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+			{ role: "assistant", content: "Assistant response" },
+			{ role: "user", content: `User message ${frequency}` }, // This is the Nth user message.
+			{ role: "assistant", content: "Another response" },
+			{ role: "user", content: `User message ${frequency + 1}` },
+		]
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		// Find the Nth user message.
+		const nthUserMessage = messages.filter((m) => m.role === "user")[frequency - 1]
+		expect(nthUserMessage.content).toEqual([
+			{ type: "text", text: `User message ${frequency}`, cache_control: { type: "ephemeral" } },
+		])
+
+		// Check the (N+1)th user message is unchanged.
+		const nPlusOneUserMessage = messages.filter((m) => m.role === "user")[frequency]
+		expect(nPlusOneUserMessage.content).toEqual([{ type: "text", text: `User message ${frequency + 1}` }])
+	})
+
+	it("should add breakpoint to the last text part if content is an array", () => {
+		const frequency = 5
+
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...Array.from({ length: frequency - 1 }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+			{
+				role: "user", // Nth user message
+				content: [
+					{ type: "text", text: `This is the ${frequency}th user message.` },
+					{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
+					{ type: "text", text: "This part should get the breakpoint." },
+				],
+			},
+		]
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		expect(messages[frequency].content).toEqual([
+			{ type: "text", text: `This is the ${frequency}th user message.` },
+			{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
+			{ type: "text", text: "This part should get the breakpoint.", cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should add a placeholder text part if the target message has no text parts", () => {
+		const frequency = 5
+
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...Array.from({ length: frequency - 1 }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+			{
+				role: "user", // Nth user message.
+				content: [{ type: "image_url", image_url: { url: "data:image/png;base64,..." } }],
+			},
+		]
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		expect(messages[frequency].content).toEqual([
+			{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
+			{ type: "text", text: "...", cache_control: { type: "ephemeral" } },
+		])
+	})
+
+	it("should add breakpoints correctly with frequency 5", () => {
+		const frequency = 5
+
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...Array.from({ length: 12 }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+		]
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		// Check 5th user message (index 5).
+		expect(messages[5].content).toEqual([
+			{ type: "text", text: "User message 5", cache_control: { type: "ephemeral" } },
+		])
+
+		// Check 9th user message (index 9) - unchanged
+		expect(messages[9].content).toEqual([{ type: "text", text: "User message 9" }])
+
+		// Check 10th user message (index 10).
+		expect(messages[10].content).toEqual([
+			{ type: "text", text: "User message 10", cache_control: { type: "ephemeral" } },
+		])
+
+		// Check 11th user message (index 11) - unchanged
+		expect(messages[11].content).toEqual([{ type: "text", text: "User message 11" }])
+	})
+
+	it("should not add breakpoints (except system) if frequency is 0", () => {
+		const frequency = 0
+		const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...Array.from({ length: 15 }, (_, i) => ({
+				role: "user" as const,
+				content: `User message ${i + 1}`,
+			})),
+		]
+		const originalMessages = JSON.parse(JSON.stringify(messages))
+
+		addCacheBreakpoints(systemPrompt, messages, frequency)
+
+		// Check system prompt.
+		expect(messages[0].content).toEqual([
+			{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } },
+		])
+
+		// Check all user messages - none should have cache_control
+		for (let i = 1; i < messages.length; i++) {
+			const originalContent = originalMessages[i].content
+
+			const expectedContent =
+				typeof originalContent === "string" ? [{ type: "text", text: originalContent }] : originalContent
+
+			expect(messages[i].content).toEqual(expectedContent) // Should match original (after string->array conversion).
+
+			// Ensure no cache_control was added to user messages.
+			const content = messages[i].content
+
+			if (Array.isArray(content)) {
+				// Assign to new variable after type check.
+				const contentParts = content
+
+				contentParts.forEach((part: any) => {
+					// Iterate over the correctly typed variable.
+					expect(part).not.toHaveProperty("cache_control")
+				})
+			}
+		}
+	})
+})

+ 178 - 0
src/api/transform/caching/__tests__/vertex.test.ts

@@ -0,0 +1,178 @@
+// npx jest src/api/transform/caching/__tests__/vertex.test.ts
+
+import { Anthropic } from "@anthropic-ai/sdk"
+
+import { addCacheBreakpoints } from "../vertex"
+
+describe("addCacheBreakpoints (Vertex)", () => {
+	it("should return an empty array if input is empty", () => {
+		const messages: Anthropic.Messages.MessageParam[] = []
+		const result = addCacheBreakpoints(messages)
+		expect(result).toEqual([])
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should not add breakpoints if there are no user messages", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [{ role: "assistant", content: "Hello" }]
+		const originalMessages = JSON.parse(JSON.stringify(messages))
+		const result = addCacheBreakpoints(messages)
+		expect(result).toEqual(originalMessages) // Should be unchanged.
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should add a breakpoint to the only user message if only one exists", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "User message 1" }]
+		const result = addCacheBreakpoints(messages)
+
+		expect(result).toHaveLength(1)
+
+		expect(result[0].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should add breakpoints to both user messages if only two exist", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{ role: "user", content: "User message 1" },
+			{ role: "user", content: "User message 2" },
+		]
+
+		const result = addCacheBreakpoints(messages)
+		expect(result).toHaveLength(2)
+
+		expect(result[0].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result[1].content).toEqual([
+			{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should add breakpoints only to the last two user messages when more than two exist", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{ role: "user", content: "User message 1" }, // Should not get breakpoint.
+			{ role: "user", content: "User message 2" }, // Should get breakpoint.
+			{ role: "user", content: "User message 3" }, // Should get breakpoint.
+		]
+
+		const originalMessage1 = JSON.parse(JSON.stringify(messages[0]))
+		const result = addCacheBreakpoints(messages)
+
+		expect(result).toHaveLength(3)
+		expect(result[0]).toEqual(originalMessage1)
+
+		expect(result[1].content).toEqual([
+			{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result[2].content).toEqual([
+			{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should handle assistant messages correctly when finding last two user messages", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{ role: "user", content: "User message 1" }, // Should not get breakpoint.
+			{ role: "assistant", content: "Assistant response 1" }, // Should be unchanged.
+			{ role: "user", content: "User message 2" }, // Should get breakpoint (second to last user).
+			{ role: "assistant", content: "Assistant response 2" }, // Should be unchanged.
+			{ role: "user", content: "User message 3" }, // Should get breakpoint (last user).
+			{ role: "assistant", content: "Assistant response 3" }, // Should be unchanged.
+		]
+		const originalMessage1 = JSON.parse(JSON.stringify(messages[0]))
+		const originalAssistant1 = JSON.parse(JSON.stringify(messages[1]))
+		const originalAssistant2 = JSON.parse(JSON.stringify(messages[3]))
+		const originalAssistant3 = JSON.parse(JSON.stringify(messages[5]))
+
+		const result = addCacheBreakpoints(messages)
+		expect(result).toHaveLength(6)
+
+		expect(result[0]).toEqual(originalMessage1)
+		expect(result[1]).toEqual(originalAssistant1)
+
+		expect(result[2].content).toEqual([
+			{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result[3]).toEqual(originalAssistant2)
+
+		expect(result[4].content).toEqual([
+			{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result[5]).toEqual(originalAssistant3)
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should add breakpoint only to the last text part if content is an array", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{ role: "user", content: "User message 1" }, // Gets breakpoint.
+			{
+				role: "user", // Gets breakpoint.
+				content: [
+					{ type: "text", text: "First text part." }, // No breakpoint.
+					{ type: "image", source: { type: "base64", media_type: "image/png", data: "..." } },
+					{ type: "text", text: "Last text part." }, // Gets breakpoint.
+				],
+			},
+		]
+
+		const result = addCacheBreakpoints(messages)
+		expect(result).toHaveLength(2)
+
+		expect(result[0].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+
+		expect(result[1].content).toEqual([
+			{ type: "text", text: "First text part." }, // Unchanged.
+			{ type: "image", source: { type: "base64", media_type: "image/png", data: "..." } }, // Unchanged.
+			{ type: "text", text: "Last text part.", cache_control: { type: "ephemeral" } }, // Breakpoint added.
+		])
+
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should handle array content with no text parts gracefully", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{ role: "user", content: "User message 1" }, // Gets breakpoint.
+			{
+				role: "user", // Gets breakpoint, but has no text part to add it to.
+				content: [{ type: "image", source: { type: "base64", media_type: "image/png", data: "..." } }],
+			},
+		]
+
+		const originalMessage2 = JSON.parse(JSON.stringify(messages[1]))
+
+		const result = addCacheBreakpoints(messages)
+		expect(result).toHaveLength(2)
+
+		expect(result[0].content).toEqual([
+			{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
+		])
+
+		// Check second user message - should be unchanged as no text part found.
+		expect(result[1]).toEqual(originalMessage2)
+		expect(result).not.toBe(messages) // Ensure new array.
+	})
+
+	it("should not modify the original messages array", () => {
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{ role: "user", content: "User message 1" },
+			{ role: "user", content: "User message 2" },
+		]
+		const originalMessagesCopy = JSON.parse(JSON.stringify(messages))
+
+		addCacheBreakpoints(messages)
+
+		// Verify original array is untouched.
+		expect(messages).toEqual(originalMessagesCopy)
+	})
+})

+ 41 - 0
src/api/transform/caching/anthropic.ts

@@ -0,0 +1,41 @@
+import OpenAI from "openai"
+
+export function addCacheBreakpoints(systemPrompt: string, messages: OpenAI.Chat.ChatCompletionMessageParam[]) {
+	messages[0] = {
+		role: "system",
+		// @ts-ignore-next-line
+		content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }],
+	}
+
+	// Ensure all user messages have content in array format first
+	for (const msg of messages) {
+		if (msg.role === "user" && typeof msg.content === "string") {
+			msg.content = [{ type: "text", text: msg.content }]
+		}
+	}
+
+	// Add `cache_control: ephemeral` to the last two user messages.
+	// (Note: this works because we only ever add one user message at a
+	// time, but if we added multiple we'd need to mark the user message
+	// before the last assistant message.)
+	messages
+		.filter((msg) => msg.role === "user")
+		.slice(-2)
+		.forEach((msg) => {
+			if (Array.isArray(msg.content)) {
+				// NOTE: This is fine since env details will always be added
+				// at the end. But if it wasn't there, and the user added a
+				// image_url type message, it would pop a text part before
+				// it and then move it after to the end.
+				let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
+
+				if (!lastTextPart) {
+					lastTextPart = { type: "text", text: "..." }
+					msg.content.push(lastTextPart)
+				}
+
+				// @ts-ignore-next-line
+				lastTextPart["cache_control"] = { type: "ephemeral" }
+			}
+		})
+}

+ 47 - 0
src/api/transform/caching/gemini.ts

@@ -0,0 +1,47 @@
+import OpenAI from "openai"
+
+export function addCacheBreakpoints(
+	systemPrompt: string,
+	messages: OpenAI.Chat.ChatCompletionMessageParam[],
+	frequency: number = 10,
+) {
+	// *Always* cache the system prompt.
+	messages[0] = {
+		role: "system",
+		// @ts-ignore-next-line
+		content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }],
+	}
+
+	// Add breakpoints every N user messages based on frequency.
+	let count = 0
+
+	for (const msg of messages) {
+		if (msg.role !== "user") {
+			continue
+		}
+
+		// Ensure content is in array format for potential modification.
+		if (typeof msg.content === "string") {
+			msg.content = [{ type: "text", text: msg.content }]
+		}
+
+		const isNthMessage = count % frequency === frequency - 1
+
+		if (isNthMessage) {
+			if (Array.isArray(msg.content)) {
+				// Find the last text part to add the cache control to.
+				let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
+
+				if (!lastTextPart) {
+					lastTextPart = { type: "text", text: "..." } // Add a placeholder if no text part exists.
+					msg.content.push(lastTextPart)
+				}
+
+				// @ts-ignore-next-line - Add cache control property
+				lastTextPart["cache_control"] = { type: "ephemeral" }
+			}
+		}
+
+		count++
+	}
+}

+ 49 - 0
src/api/transform/caching/vertex.ts

@@ -0,0 +1,49 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+
+export function addCacheBreakpoints(messages: Anthropic.Messages.MessageParam[]) {
+	// Find indices of user messages that we want to cache.
+	// We only cache the last two user messages to stay within the 4-block limit
+	// (1 block for system + 1 block each for last two user messages = 3 total).
+	const indices = messages.reduce((acc, msg, i) => (msg.role === "user" ? [...acc, i] : acc), [] as number[])
+
+	// Only cache the last two user messages.
+	const lastIndex = indices[indices.length - 1] ?? -1
+	const secondLastIndex = indices[indices.length - 2] ?? -1
+
+	return messages.map((message, index) =>
+		message.role !== "assistant" && (index === lastIndex || index === secondLastIndex)
+			? cachedMessage(message)
+			: message,
+	)
+}
+
+function cachedMessage(message: Anthropic.Messages.MessageParam): Anthropic.Messages.MessageParam {
+	// For string content, we convert to array format with optional cache control.
+	if (typeof message.content === "string") {
+		return {
+			...message,
+			// For string content, we only have one block so it's always the last block.
+			content: [{ type: "text" as const, text: message.content, cache_control: { type: "ephemeral" } }],
+		}
+	}
+
+	// For array content, find the last text block index once before mapping.
+	const lastTextBlockIndex = message.content.reduce(
+		(lastIndex, content, index) => (content.type === "text" ? index : lastIndex),
+		-1,
+	)
+
+	// Then use this pre-calculated index in the map function.
+	return {
+		...message,
+		content: message.content.map((content, index) =>
+			content.type === "text"
+				? {
+						...content,
+						// Check if this is the last text block using our pre-calculated index.
+						...(index === lastTextBlockIndex && { cache_control: { type: "ephemeral" } }),
+					}
+				: content,
+		),
+	}
+}

+ 0 - 70
src/api/transform/vertex-caching.ts

@@ -1,70 +0,0 @@
-import { Anthropic } from "@anthropic-ai/sdk"
-
-interface VertexTextBlock {
-	type: "text"
-	text: string
-	cache_control?: { type: "ephemeral" }
-}
-
-interface VertexImageBlock {
-	type: "image"
-	source: {
-		type: "base64"
-		media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp"
-		data: string
-	}
-}
-
-type VertexContentBlock = VertexTextBlock | VertexImageBlock
-
-interface VertexMessage extends Omit<Anthropic.Messages.MessageParam, "content"> {
-	content: string | VertexContentBlock[]
-}
-
-export function formatMessageForCache(message: Anthropic.Messages.MessageParam, shouldCache: boolean): VertexMessage {
-	// Assistant messages are kept as-is since they can't be cached
-	if (message.role === "assistant") {
-		return message as VertexMessage
-	}
-
-	// For string content, we convert to array format with optional cache control
-	if (typeof message.content === "string") {
-		return {
-			...message,
-			content: [
-				{
-					type: "text" as const,
-					text: message.content,
-					// For string content, we only have one block so it's always the last
-					...(shouldCache && { cache_control: { type: "ephemeral" } }),
-				},
-			],
-		}
-	}
-
-	// For array content, find the last text block index once before mapping
-	const lastTextBlockIndex = message.content.reduce(
-		(lastIndex, content, index) => (content.type === "text" ? index : lastIndex),
-		-1,
-	)
-
-	// Then use this pre-calculated index in the map function.
-	return {
-		...message,
-		content: message.content.map((content, contentIndex) => {
-			// Images and other non-text content are passed through unchanged.
-			if (content.type === "image") {
-				return content as VertexImageBlock
-			}
-
-			// Check if this is the last text block using our pre-calculated index.
-			const isLastTextBlock = contentIndex === lastTextBlockIndex
-
-			return {
-				type: "text" as const,
-				text: (content as { text: string }).text,
-				...(shouldCache && isLastTextBlock && { cache_control: { type: "ephemeral" } }),
-			}
-		}),
-	}
-}

+ 1 - 2
webview-ui/src/components/settings/ExperimentalSettings.tsx

@@ -6,13 +6,12 @@ import { EXPERIMENT_IDS, experimentConfigsMap, ExperimentId } from "@roo/shared/
 
 import { cn } from "@/lib/utils"
 
-import { SetCachedStateField, SetExperimentEnabled } from "./types"
+import { SetExperimentEnabled } from "./types"
 import { SectionHeader } from "./SectionHeader"
 import { Section } from "./Section"
 import { ExperimentalFeature } from "./ExperimentalFeature"
 
 type ExperimentalSettingsProps = HTMLAttributes<HTMLDivElement> & {
-	setCachedStateField: SetCachedStateField<"terminalOutputLineLimit" | "maxOpenTabsContext">
 	experiments: Record<ExperimentId, boolean>
 	setExperimentEnabled: SetExperimentEnabled
 }

+ 1 - 5
webview-ui/src/components/settings/SettingsView.tsx

@@ -501,11 +501,7 @@ const SettingsView = forwardRef<SettingsViewRef, SettingsViewProps>(({ onDone, t
 				</div>
 
 				<div ref={experimentalRef}>
-					<ExperimentalSettings
-						setCachedStateField={setCachedStateField}
-						setExperimentEnabled={setExperimentEnabled}
-						experiments={experiments}
-					/>
+					<ExperimentalSettings setExperimentEnabled={setExperimentEnabled} experiments={experiments} />
 				</div>
 
 				<div ref={languageRef}>