Просмотр исходного кода

feat: add prompt caching support for LiteLLM (#5791) (#6074)

* feat: add prompt caching support for LiteLLM (#5791)

- Add litellmUsePromptCache configuration option to provider settings
- Implement cache control headers in LiteLLM handler when enabled
- Add UI checkbox for enabling prompt caching (only shown for supported models)
- Track cache read/write tokens in usage data
- Add comprehensive test for prompt caching functionality
- Reuse existing translation keys for consistency across languages

This allows LiteLLM users to benefit from prompt caching with supported models
like Claude 3.7, reducing costs and improving response times.

* fix: improve LiteLLM prompt caching to work for multi-turn conversations

- Convert system message to structured format with cache_control
- Handle both string and array content types for user messages
- Apply cache_control to content items, not just message level
- Update tests to match new message structure

This ensures prompt caching works correctly for all messages in a conversation,
not just the initial system prompt and first user message.

* fix: resolve TypeScript linter error for cache_control property

Use type assertion to handle cache_control property that's not in OpenAI types
Murilo Pires 5 месяцев назад
Родитель
Сommit
a0018c9d04

+ 1 - 0
packages/types/src/provider-settings.ts

@@ -238,6 +238,7 @@ const litellmSchema = baseProviderSettingsSchema.extend({
 	litellmBaseUrl: z.string().optional(),
 	litellmBaseUrl: z.string().optional(),
 	litellmApiKey: z.string().optional(),
 	litellmApiKey: z.string().optional(),
 	litellmModelId: z.string().optional(),
 	litellmModelId: z.string().optional(),
+	litellmUsePromptCache: z.boolean().optional(),
 })
 })
 
 
 const defaultSchema = z.object({
 const defaultSchema = z.object({

+ 158 - 0
src/api/providers/__tests__/lite-llm.spec.ts

@@ -0,0 +1,158 @@
+import { describe, it, expect, vi, beforeEach } from "vitest"
+import OpenAI from "openai"
+import { Anthropic } from "@anthropic-ai/sdk"
+
+import { LiteLLMHandler } from "../lite-llm"
+import { ApiHandlerOptions } from "../../../shared/api"
+import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types"
+
+// Mock vscode first to avoid import errors
+vi.mock("vscode", () => ({}))
+
+// Mock OpenAI
+vi.mock("openai", () => {
+	const mockStream = {
+		[Symbol.asyncIterator]: vi.fn(),
+	}
+
+	const mockCreate = vi.fn().mockReturnValue({
+		withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
+	})
+
+	return {
+		default: vi.fn().mockImplementation(() => ({
+			chat: {
+				completions: {
+					create: mockCreate,
+				},
+			},
+		})),
+	}
+})
+
+// Mock model fetching
+vi.mock("../fetchers/modelCache", () => ({
+	getModels: vi.fn().mockImplementation(() => {
+		return Promise.resolve({
+			[litellmDefaultModelId]: litellmDefaultModelInfo,
+		})
+	}),
+}))
+
+describe("LiteLLMHandler", () => {
+	let handler: LiteLLMHandler
+	let mockOptions: ApiHandlerOptions
+	let mockOpenAIClient: any
+
+	beforeEach(() => {
+		vi.clearAllMocks()
+		mockOptions = {
+			litellmApiKey: "test-key",
+			litellmBaseUrl: "http://localhost:4000",
+			litellmModelId: litellmDefaultModelId,
+		}
+		handler = new LiteLLMHandler(mockOptions)
+		mockOpenAIClient = new OpenAI()
+	})
+
+	describe("prompt caching", () => {
+		it("should add cache control headers when litellmUsePromptCache is enabled", async () => {
+			const optionsWithCache: ApiHandlerOptions = {
+				...mockOptions,
+				litellmUsePromptCache: true,
+			}
+			handler = new LiteLLMHandler(optionsWithCache)
+
+			const systemPrompt = "You are a helpful assistant"
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{ role: "user", content: "Hello" },
+				{ role: "assistant", content: "Hi there!" },
+				{ role: "user", content: "How are you?" },
+			]
+
+			// Mock the stream response
+			const mockStream = {
+				async *[Symbol.asyncIterator]() {
+					yield {
+						choices: [{ delta: { content: "I'm doing well!" } }],
+						usage: {
+							prompt_tokens: 100,
+							completion_tokens: 50,
+							cache_creation_input_tokens: 20,
+							cache_read_input_tokens: 30,
+						},
+					}
+				},
+			}
+
+			mockOpenAIClient.chat.completions.create.mockReturnValue({
+				withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
+			})
+
+			const generator = handler.createMessage(systemPrompt, messages)
+			const results = []
+			for await (const chunk of generator) {
+				results.push(chunk)
+			}
+
+			// Verify that create was called with cache control headers
+			const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0]
+
+			// Check system message has cache control in the proper format
+			expect(createCall.messages[0]).toMatchObject({
+				role: "system",
+				content: [
+					{
+						type: "text",
+						text: systemPrompt,
+						cache_control: { type: "ephemeral" },
+					},
+				],
+			})
+
+			// Check that the last two user messages have cache control
+			const userMessageIndices = createCall.messages
+				.map((msg: any, idx: number) => (msg.role === "user" ? idx : -1))
+				.filter((idx: number) => idx !== -1)
+
+			const lastUserIdx = userMessageIndices[userMessageIndices.length - 1]
+			const secondLastUserIdx = userMessageIndices[userMessageIndices.length - 2]
+
+			// Check last user message has proper structure with cache control
+			expect(createCall.messages[lastUserIdx]).toMatchObject({
+				role: "user",
+				content: [
+					{
+						type: "text",
+						text: "How are you?",
+						cache_control: { type: "ephemeral" },
+					},
+				],
+			})
+
+			// Check second last user message (first user message in this case)
+			if (secondLastUserIdx !== -1) {
+				expect(createCall.messages[secondLastUserIdx]).toMatchObject({
+					role: "user",
+					content: [
+						{
+							type: "text",
+							text: "Hello",
+							cache_control: { type: "ephemeral" },
+						},
+					],
+				})
+			}
+
+			// Verify usage includes cache tokens
+			const usageChunk = results.find((chunk) => chunk.type === "usage")
+			expect(usageChunk).toMatchObject({
+				type: "usage",
+				inputTokens: 100,
+				outputTokens: 50,
+				cacheWriteTokens: 20,
+				cacheReadTokens: 30,
+			})
+		})
+	})
+})

+ 79 - 9
src/api/providers/lite-llm.ts

@@ -39,10 +39,70 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
 	): ApiStream {
 	): ApiStream {
 		const { id: modelId, info } = await this.fetchModel()
 		const { id: modelId, info } = await this.fetchModel()
 
 
-		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
-			{ role: "system", content: systemPrompt },
-			...convertToOpenAiMessages(messages),
-		]
+		const openAiMessages = convertToOpenAiMessages(messages)
+
+		// Prepare messages with cache control if enabled and supported
+		let systemMessage: OpenAI.Chat.ChatCompletionMessageParam
+		let enhancedMessages: OpenAI.Chat.ChatCompletionMessageParam[]
+
+		if (this.options.litellmUsePromptCache && info.supportsPromptCache) {
+			// Create system message with cache control in the proper format
+			systemMessage = {
+				role: "system",
+				content: [
+					{
+						type: "text",
+						text: systemPrompt,
+						cache_control: { type: "ephemeral" },
+					} as any,
+				],
+			}
+
+			// Find the last two user messages to apply caching
+			const userMsgIndices = openAiMessages.reduce(
+				(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
+				[] as number[],
+			)
+			const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
+			const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
+
+			// Apply cache_control to the last two user messages
+			enhancedMessages = openAiMessages.map((message, index) => {
+				if ((index === lastUserMsgIndex || index === secondLastUserMsgIndex) && message.role === "user") {
+					// Handle both string and array content types
+					if (typeof message.content === "string") {
+						return {
+							...message,
+							content: [
+								{
+									type: "text",
+									text: message.content,
+									cache_control: { type: "ephemeral" },
+								} as any,
+							],
+						}
+					} else if (Array.isArray(message.content)) {
+						// Apply cache control to the last content item in the array
+						return {
+							...message,
+							content: message.content.map((content, contentIndex) =>
+								contentIndex === message.content.length - 1
+									? ({
+											...content,
+											cache_control: { type: "ephemeral" },
+										} as any)
+									: content,
+							),
+						}
+					}
+				}
+				return message
+			})
+		} else {
+			// No cache control - use simple format
+			systemMessage = { role: "system", content: systemPrompt }
+			enhancedMessages = openAiMessages
+		}
 
 
 		// Required by some providers; others default to max tokens allowed
 		// Required by some providers; others default to max tokens allowed
 		let maxTokens: number | undefined = info.maxTokens ?? undefined
 		let maxTokens: number | undefined = info.maxTokens ?? undefined
@@ -50,7 +110,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
 		const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
 		const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
 			model: modelId,
 			model: modelId,
 			max_tokens: maxTokens,
 			max_tokens: maxTokens,
-			messages: openAiMessages,
+			messages: [systemMessage, ...enhancedMessages],
 			stream: true,
 			stream: true,
 			stream_options: {
 			stream_options: {
 				include_usage: true,
 				include_usage: true,
@@ -80,20 +140,30 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
 			}
 			}
 
 
 			if (lastUsage) {
 			if (lastUsage) {
+				// Extract cache-related information if available
+				// LiteLLM may use different field names for cache tokens
+				const cacheWriteTokens =
+					lastUsage.cache_creation_input_tokens || (lastUsage as any).prompt_cache_miss_tokens || 0
+				const cacheReadTokens =
+					lastUsage.prompt_tokens_details?.cached_tokens ||
+					(lastUsage as any).cache_read_input_tokens ||
+					(lastUsage as any).prompt_cache_hit_tokens ||
+					0
+
 				const usageData: ApiStreamUsageChunk = {
 				const usageData: ApiStreamUsageChunk = {
 					type: "usage",
 					type: "usage",
 					inputTokens: lastUsage.prompt_tokens || 0,
 					inputTokens: lastUsage.prompt_tokens || 0,
 					outputTokens: lastUsage.completion_tokens || 0,
 					outputTokens: lastUsage.completion_tokens || 0,
-					cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0,
-					cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0,
+					cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
+					cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
 				}
 				}
 
 
 				usageData.totalCost = calculateApiCostOpenAI(
 				usageData.totalCost = calculateApiCostOpenAI(
 					info,
 					info,
 					usageData.inputTokens,
 					usageData.inputTokens,
 					usageData.outputTokens,
 					usageData.outputTokens,
-					usageData.cacheWriteTokens,
-					usageData.cacheReadTokens,
+					usageData.cacheWriteTokens || 0,
+					usageData.cacheReadTokens || 0,
 				)
 				)
 
 
 				yield usageData
 				yield usageData

+ 24 - 1
webview-ui/src/components/settings/providers/LiteLLM.tsx

@@ -1,5 +1,5 @@
 import { useCallback, useState, useEffect, useRef } from "react"
 import { useCallback, useState, useEffect, useRef } from "react"
-import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
+import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"
 
 
 import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types"
 import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types"
 
 
@@ -151,6 +151,29 @@ export const LiteLLM = ({
 				organizationAllowList={organizationAllowList}
 				organizationAllowList={organizationAllowList}
 				errorMessage={modelValidationError}
 				errorMessage={modelValidationError}
 			/>
 			/>
+
+			{/* Show prompt caching option if the selected model supports it */}
+			{(() => {
+				const selectedModelId = apiConfiguration.litellmModelId || litellmDefaultModelId
+				const selectedModel = routerModels?.litellm?.[selectedModelId]
+				if (selectedModel?.supportsPromptCache) {
+					return (
+						<div className="mt-4">
+							<VSCodeCheckbox
+								checked={apiConfiguration.litellmUsePromptCache || false}
+								onChange={(e: any) => {
+									setApiConfigurationField("litellmUsePromptCache", e.target.checked)
+								}}>
+								<span className="font-medium">{t("settings:providers.enablePromptCaching")}</span>
+							</VSCodeCheckbox>
+							<div className="text-sm text-vscode-descriptionForeground ml-6 mt-1">
+								{t("settings:providers.enablePromptCachingTitle")}
+							</div>
+						</div>
+					)
+				}
+				return null
+			})()}
 		</>
 		</>
 	)
 	)
 }
 }