Przeglądaj źródła

Gemini caching improvements (#2925)

Chris Estreich 8 miesięcy temu
rodzic
commit
a3f1a3f3ad

+ 5 - 0
.changeset/bright-singers-drop.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": patch
+---
+
+Enable prompt caching for Gemini (with some improvements)

+ 85 - 28
src/api/providers/gemini.ts

@@ -4,27 +4,42 @@ import {
 	type GenerateContentResponseUsageMetadata,
 	type GenerateContentParameters,
 	type Content,
+	CreateCachedContentConfig,
 } from "@google/genai"
+import NodeCache from "node-cache"
 
 import { SingleCompletionHandler } from "../"
 import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
 import { geminiDefaultModelId, geminiModels } from "../../shared/api"
-import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
+import {
+	convertAnthropicContentToGemini,
+	convertAnthropicMessageToGemini,
+	getMessagesLength,
+} from "../transform/gemini-format"
 import type { ApiStream } from "../transform/stream"
 import { BaseProvider } from "./base-provider"
 
 const CACHE_TTL = 5
 
+const CONTEXT_CACHE_TOKEN_MINIMUM = 4096
+
+type CacheEntry = {
+	key: string
+	count: number
+}
+
 export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: ApiHandlerOptions
+
 	private client: GoogleGenAI
-	private contentCaches: Map<string, { key: string; count: number }>
+	private contentCaches: NodeCache
+	private isCacheBusy = false
 
 	constructor(options: ApiHandlerOptions) {
 		super()
 		this.options = options
 		this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
-		this.contentCaches = new Map()
+		this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
 	}
 
 	async *createMessage(
@@ -35,36 +50,76 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 		const { id: model, thinkingConfig, maxOutputTokens, info } = this.getModel()
 
 		const contents = messages.map(convertAnthropicMessageToGemini)
+		const contentsLength = systemInstruction.length + getMessagesLength(contents)
+
 		let uncachedContent: Content[] | undefined = undefined
 		let cachedContent: string | undefined = undefined
 		let cacheWriteTokens: number | undefined = undefined
 
+		// The minimum input token count for context caching is 4,096.
+		// For a basic approximation we assume 4 characters per token.
+		// We can use tiktoken eventually to get a more accurat token count.
 		// https://ai.google.dev/gemini-api/docs/caching?lang=node
-		// if (info.supportsPromptCache && cacheKey) {
-		// 	const cacheEntry = this.contentCaches.get(cacheKey)
-
-		// 	if (cacheEntry) {
-		// 		uncachedContent = contents.slice(cacheEntry.count, contents.length)
-		// 		cachedContent = cacheEntry.key
-		// 	}
+		// https://ai.google.dev/gemini-api/docs/tokens?lang=node
+		const isCacheAvailable =
+			info.supportsPromptCache &&
+			this.options.promptCachingEnabled &&
+			cacheKey &&
+			contentsLength > 4 * CONTEXT_CACHE_TOKEN_MINIMUM
+
+		if (isCacheAvailable) {
+			const cacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)
+
+			if (cacheEntry) {
+				uncachedContent = contents.slice(cacheEntry.count, contents.length)
+				cachedContent = cacheEntry.key
+				console.log(
+					`[GeminiHandler] using ${cacheEntry.count} cached messages (${cacheEntry.key}) and ${uncachedContent.length} uncached messages`,
+				)
+			}
 
-		// 	const newCacheEntry = await this.client.caches.create({
-		// 		model,
-		// 		config: { contents, systemInstruction, ttl: `${CACHE_TTL * 60}s` },
-		// 	})
+			if (!this.isCacheBusy) {
+				this.isCacheBusy = true
+				const timestamp = Date.now()
+
+				this.client.caches
+					.create({
+						model,
+						config: {
+							contents,
+							systemInstruction,
+							ttl: `${CACHE_TTL * 60}s`,
+							httpOptions: { timeout: 120_000 },
+						},
+					})
+					.then((result) => {
+						const { name, usageMetadata } = result
+
+						if (name) {
+							this.contentCaches.set<CacheEntry>(cacheKey, { key: name, count: contents.length })
+							cacheWriteTokens = usageMetadata?.totalTokenCount ?? 0
+							console.log(
+								`[GeminiHandler] cached ${contents.length} messages (${cacheWriteTokens} tokens) in ${Date.now() - timestamp}ms`,
+							)
+						}
+					})
+					.catch((error) => {
+						console.error(`[GeminiHandler] caches.create error`, error)
+					})
+					.finally(() => {
+						this.isCacheBusy = false
+					})
+			}
+		}
 
-		// 	if (newCacheEntry.name) {
-		// 		this.contentCaches.set(cacheKey, { key: newCacheEntry.name, count: contents.length })
-		// 		cacheWriteTokens = newCacheEntry.usageMetadata?.totalTokenCount ?? 0
-		// 	}
-		// }
+		const isCacheUsed = !!cachedContent
 
 		const params: GenerateContentParameters = {
 			model,
 			contents: uncachedContent ?? contents,
 			config: {
 				cachedContent,
-				systemInstruction: cachedContent ? undefined : systemInstruction,
+				systemInstruction: isCacheUsed ? undefined : systemInstruction,
 				httpOptions: this.options.googleGeminiBaseUrl
 					? { baseUrl: this.options.googleGeminiBaseUrl }
 					: undefined,
@@ -94,13 +149,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 			const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
 			const reasoningTokens = lastUsageMetadata.thoughtsTokenCount
 
-			// const totalCost = this.calculateCost({
-			// 	info,
-			// 	inputTokens,
-			// 	outputTokens,
-			// 	cacheWriteTokens,
-			// 	cacheReadTokens,
-			// })
+			const totalCost = isCacheUsed
+				? this.calculateCost({
+						info,
+						inputTokens,
+						outputTokens,
+						cacheWriteTokens,
+						cacheReadTokens,
+					})
+				: undefined
 
 			yield {
 				type: "usage",
@@ -109,7 +166,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 				cacheWriteTokens,
 				cacheReadTokens,
 				reasoningTokens,
-				// totalCost,
+				totalCost,
 			}
 		}
 	}

+ 6 - 0
src/api/transform/gemini-format.ts

@@ -76,3 +76,9 @@ export function convertAnthropicMessageToGemini(message: Anthropic.Messages.Mess
 		parts: convertAnthropicContentToGemini(message.content),
 	}
 }
+
+const getContentLength = ({ parts }: Content): number =>
+	parts?.reduce((length, { text }) => length + (text?.length ?? 0), 0) ?? 0
+
+export const getMessagesLength = (contents: Content[]): number =>
+	contents.reduce((length, content) => length + getContentLength(content), 0)

+ 3 - 3
src/shared/api.ts

@@ -679,7 +679,7 @@ export const geminiModels = {
 		maxTokens: 65_535,
 		contextWindow: 1_048_576,
 		supportsImages: true,
-		supportsPromptCache: false,
+		supportsPromptCache: true,
 		isPromptCacheOptional: true,
 		inputPrice: 2.5, // This is the pricing for prompts above 200k tokens.
 		outputPrice: 15,
@@ -704,7 +704,7 @@ export const geminiModels = {
 		maxTokens: 8192,
 		contextWindow: 1_048_576,
 		supportsImages: true,
-		supportsPromptCache: false,
+		supportsPromptCache: true,
 		isPromptCacheOptional: true,
 		inputPrice: 0.1,
 		outputPrice: 0.4,
@@ -755,7 +755,7 @@ export const geminiModels = {
 		maxTokens: 8192,
 		contextWindow: 1_048_576,
 		supportsImages: true,
-		supportsPromptCache: false,
+		supportsPromptCache: true,
 		isPromptCacheOptional: true,
 		inputPrice: 0.15, // This is the pricing for prompts above 128k tokens.
 		outputPrice: 0.6,