|
|
@@ -4,27 +4,42 @@ import {
|
|
|
type GenerateContentResponseUsageMetadata,
|
|
|
type GenerateContentParameters,
|
|
|
type Content,
|
|
|
+ CreateCachedContentConfig,
|
|
|
} from "@google/genai"
|
|
|
+import NodeCache from "node-cache"
|
|
|
|
|
|
import { SingleCompletionHandler } from "../"
|
|
|
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
|
|
|
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
|
|
|
-import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
|
|
|
+import {
|
|
|
+ convertAnthropicContentToGemini,
|
|
|
+ convertAnthropicMessageToGemini,
|
|
|
+ getMessagesLength,
|
|
|
+} from "../transform/gemini-format"
|
|
|
import type { ApiStream } from "../transform/stream"
|
|
|
import { BaseProvider } from "./base-provider"
|
|
|
|
|
|
const CACHE_TTL = 5
|
|
|
|
|
|
+const CONTEXT_CACHE_TOKEN_MINIMUM = 4096
|
|
|
+
|
|
|
+type CacheEntry = {
|
|
|
+ key: string
|
|
|
+ count: number
|
|
|
+}
|
|
|
+
|
|
|
export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
|
|
|
protected options: ApiHandlerOptions
|
|
|
+
|
|
|
private client: GoogleGenAI
|
|
|
- private contentCaches: Map<string, { key: string; count: number }>
|
|
|
+ private contentCaches: NodeCache
|
|
|
+ private isCacheBusy = false
|
|
|
|
|
|
constructor(options: ApiHandlerOptions) {
|
|
|
super()
|
|
|
this.options = options
|
|
|
this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
|
|
|
- this.contentCaches = new Map()
|
|
|
+ this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
|
|
|
}
|
|
|
|
|
|
async *createMessage(
|
|
|
@@ -35,36 +50,76 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
const { id: model, thinkingConfig, maxOutputTokens, info } = this.getModel()
|
|
|
|
|
|
const contents = messages.map(convertAnthropicMessageToGemini)
|
|
|
+ const contentsLength = systemInstruction.length + getMessagesLength(contents)
|
|
|
+
|
|
|
let uncachedContent: Content[] | undefined = undefined
|
|
|
let cachedContent: string | undefined = undefined
|
|
|
let cacheWriteTokens: number | undefined = undefined
|
|
|
|
|
|
+ // The minimum input token count for context caching is 4,096.
|
|
|
+ // For a basic approximation we assume 4 characters per token.
|
|
|
+ // We can use tiktoken eventually to get a more accurat token count.
|
|
|
// https://ai.google.dev/gemini-api/docs/caching?lang=node
|
|
|
- // if (info.supportsPromptCache && cacheKey) {
|
|
|
- // const cacheEntry = this.contentCaches.get(cacheKey)
|
|
|
-
|
|
|
- // if (cacheEntry) {
|
|
|
- // uncachedContent = contents.slice(cacheEntry.count, contents.length)
|
|
|
- // cachedContent = cacheEntry.key
|
|
|
- // }
|
|
|
+ // https://ai.google.dev/gemini-api/docs/tokens?lang=node
|
|
|
+ const isCacheAvailable =
|
|
|
+ info.supportsPromptCache &&
|
|
|
+ this.options.promptCachingEnabled &&
|
|
|
+ cacheKey &&
|
|
|
+ contentsLength > 4 * CONTEXT_CACHE_TOKEN_MINIMUM
|
|
|
+
|
|
|
+ if (isCacheAvailable) {
|
|
|
+ const cacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)
|
|
|
+
|
|
|
+ if (cacheEntry) {
|
|
|
+ uncachedContent = contents.slice(cacheEntry.count, contents.length)
|
|
|
+ cachedContent = cacheEntry.key
|
|
|
+ console.log(
|
|
|
+ `[GeminiHandler] using ${cacheEntry.count} cached messages (${cacheEntry.key}) and ${uncachedContent.length} uncached messages`,
|
|
|
+ )
|
|
|
+ }
|
|
|
|
|
|
- // const newCacheEntry = await this.client.caches.create({
|
|
|
- // model,
|
|
|
- // config: { contents, systemInstruction, ttl: `${CACHE_TTL * 60}s` },
|
|
|
- // })
|
|
|
+ if (!this.isCacheBusy) {
|
|
|
+ this.isCacheBusy = true
|
|
|
+ const timestamp = Date.now()
|
|
|
+
|
|
|
+ this.client.caches
|
|
|
+ .create({
|
|
|
+ model,
|
|
|
+ config: {
|
|
|
+ contents,
|
|
|
+ systemInstruction,
|
|
|
+ ttl: `${CACHE_TTL * 60}s`,
|
|
|
+ httpOptions: { timeout: 120_000 },
|
|
|
+ },
|
|
|
+ })
|
|
|
+ .then((result) => {
|
|
|
+ const { name, usageMetadata } = result
|
|
|
+
|
|
|
+ if (name) {
|
|
|
+ this.contentCaches.set<CacheEntry>(cacheKey, { key: name, count: contents.length })
|
|
|
+ cacheWriteTokens = usageMetadata?.totalTokenCount ?? 0
|
|
|
+ console.log(
|
|
|
+ `[GeminiHandler] cached ${contents.length} messages (${cacheWriteTokens} tokens) in ${Date.now() - timestamp}ms`,
|
|
|
+ )
|
|
|
+ }
|
|
|
+ })
|
|
|
+ .catch((error) => {
|
|
|
+ console.error(`[GeminiHandler] caches.create error`, error)
|
|
|
+ })
|
|
|
+ .finally(() => {
|
|
|
+ this.isCacheBusy = false
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- // if (newCacheEntry.name) {
|
|
|
- // this.contentCaches.set(cacheKey, { key: newCacheEntry.name, count: contents.length })
|
|
|
- // cacheWriteTokens = newCacheEntry.usageMetadata?.totalTokenCount ?? 0
|
|
|
- // }
|
|
|
- // }
|
|
|
+ const isCacheUsed = !!cachedContent
|
|
|
|
|
|
const params: GenerateContentParameters = {
|
|
|
model,
|
|
|
contents: uncachedContent ?? contents,
|
|
|
config: {
|
|
|
cachedContent,
|
|
|
- systemInstruction: cachedContent ? undefined : systemInstruction,
|
|
|
+ systemInstruction: isCacheUsed ? undefined : systemInstruction,
|
|
|
httpOptions: this.options.googleGeminiBaseUrl
|
|
|
? { baseUrl: this.options.googleGeminiBaseUrl }
|
|
|
: undefined,
|
|
|
@@ -94,13 +149,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
|
|
|
const reasoningTokens = lastUsageMetadata.thoughtsTokenCount
|
|
|
|
|
|
- // const totalCost = this.calculateCost({
|
|
|
- // info,
|
|
|
- // inputTokens,
|
|
|
- // outputTokens,
|
|
|
- // cacheWriteTokens,
|
|
|
- // cacheReadTokens,
|
|
|
- // })
|
|
|
+ const totalCost = isCacheUsed
|
|
|
+ ? this.calculateCost({
|
|
|
+ info,
|
|
|
+ inputTokens,
|
|
|
+ outputTokens,
|
|
|
+ cacheWriteTokens,
|
|
|
+ cacheReadTokens,
|
|
|
+ })
|
|
|
+ : undefined
|
|
|
|
|
|
yield {
|
|
|
type: "usage",
|
|
|
@@ -109,7 +166,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
cacheWriteTokens,
|
|
|
cacheReadTokens,
|
|
|
reasoningTokens,
|
|
|
- // totalCost,
|
|
|
+ totalCost,
|
|
|
}
|
|
|
}
|
|
|
}
|