|
|
@@ -219,7 +219,7 @@ describe("GeminiHandler", () => {
|
|
|
mockInfo.cacheWritesPrice! * (cacheWriteTokens / 1_000_000) * (CACHE_TTL / 60)
|
|
|
const expectedCost = expectedInputCost + expectedOutputCost + expectedCacheWriteCost
|
|
|
|
|
|
- const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens, cacheWriteTokens })
|
|
|
+ const cost = handler.calculateCost({ info: mockInfo, inputTokens, outputTokens })
|
|
|
expect(cost).toBeCloseTo(expectedCost)
|
|
|
})
|
|
|
|
|
|
@@ -247,384 +247,3 @@ describe("GeminiHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
})
|
|
|
-
|
|
|
-describe("Caching Logic", () => {
|
|
|
- const systemPrompt = "System prompt"
|
|
|
- const longContent = "a".repeat(5 * 4096) // Ensure content is long enough for caching
|
|
|
- const mockMessagesLong: Anthropic.Messages.MessageParam[] = [
|
|
|
- { role: "user", content: longContent },
|
|
|
- { role: "assistant", content: "OK" },
|
|
|
- ]
|
|
|
- const cacheKey = "test-cache-key"
|
|
|
- const mockCacheName = "generated/caches/mock-cache-name"
|
|
|
- const mockCacheTokens = 5000
|
|
|
-
|
|
|
- let handlerWithCache: GeminiHandler
|
|
|
- let mockGenerateContentStream: jest.Mock
|
|
|
- let mockCreateCache: jest.Mock
|
|
|
- let mockDeleteCache: jest.Mock
|
|
|
- let mockCacheGet: jest.Mock
|
|
|
- let mockCacheSet: jest.Mock
|
|
|
-
|
|
|
- beforeEach(() => {
|
|
|
- mockGenerateContentStream = jest.fn().mockResolvedValue({
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield { text: "Response" }
|
|
|
- yield {
|
|
|
- usageMetadata: {
|
|
|
- promptTokenCount: 100, // Uncached input
|
|
|
- candidatesTokenCount: 50, // Output
|
|
|
- cachedContentTokenCount: 0, // Default, override in tests
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
- })
|
|
|
- mockCreateCache = jest.fn().mockResolvedValue({
|
|
|
- name: mockCacheName,
|
|
|
- usageMetadata: { totalTokenCount: mockCacheTokens },
|
|
|
- })
|
|
|
- mockDeleteCache = jest.fn().mockResolvedValue({})
|
|
|
- mockCacheGet = jest.fn().mockReturnValue(undefined) // Default: cache miss
|
|
|
- mockCacheSet = jest.fn()
|
|
|
-
|
|
|
- handlerWithCache = new GeminiHandler({
|
|
|
- apiKey: "test-key",
|
|
|
- apiModelId: "gemini-1.5-flash-latest", // Use a model that supports caching
|
|
|
- geminiApiKey: "test-key",
|
|
|
- promptCachingDisabled: false,
|
|
|
- })
|
|
|
-
|
|
|
- handlerWithCache["client"] = {
|
|
|
- models: {
|
|
|
- generateContentStream: mockGenerateContentStream,
|
|
|
- },
|
|
|
- caches: {
|
|
|
- create: mockCreateCache,
|
|
|
- delete: mockDeleteCache,
|
|
|
- },
|
|
|
- } as any
|
|
|
- handlerWithCache["contentCaches"] = {
|
|
|
- get: mockCacheGet,
|
|
|
- set: mockCacheSet,
|
|
|
- } as any
|
|
|
- })
|
|
|
-
|
|
|
- it("should not use cache if promptCachingDisabled is true", async () => {
|
|
|
- handlerWithCache["options"].promptCachingDisabled = true
|
|
|
- const stream = handlerWithCache.createMessage(systemPrompt, mockMessagesLong, cacheKey)
|
|
|
-
|
|
|
- for await (const _ of stream) {
|
|
|
- }
|
|
|
-
|
|
|
- expect(mockCacheGet).not.toHaveBeenCalled()
|
|
|
- expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- config: expect.objectContaining({
|
|
|
- cachedContent: undefined,
|
|
|
- systemInstruction: systemPrompt,
|
|
|
- }),
|
|
|
- }),
|
|
|
- )
|
|
|
- expect(mockCreateCache).not.toHaveBeenCalled()
|
|
|
- })
|
|
|
-
|
|
|
- it("should not use cache if content length is below threshold", async () => {
|
|
|
- const shortMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "short" }]
|
|
|
- const stream = handlerWithCache.createMessage(systemPrompt, shortMessages, cacheKey)
|
|
|
- for await (const _ of stream) {
|
|
|
- /* consume stream */
|
|
|
- }
|
|
|
-
|
|
|
- expect(mockCacheGet).not.toHaveBeenCalled() // Doesn't even check cache if too short
|
|
|
- expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- config: expect.objectContaining({
|
|
|
- cachedContent: undefined,
|
|
|
- systemInstruction: systemPrompt,
|
|
|
- }),
|
|
|
- }),
|
|
|
- )
|
|
|
- expect(mockCreateCache).not.toHaveBeenCalled()
|
|
|
- })
|
|
|
-
|
|
|
- it("should perform cache write on miss when conditions met", async () => {
|
|
|
- const stream = handlerWithCache.createMessage(systemPrompt, mockMessagesLong, cacheKey)
|
|
|
- const chunks = []
|
|
|
-
|
|
|
- for await (const chunk of stream) {
|
|
|
- chunks.push(chunk)
|
|
|
- }
|
|
|
-
|
|
|
- expect(mockCacheGet).toHaveBeenCalledWith(cacheKey)
|
|
|
- expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- config: expect.objectContaining({
|
|
|
- cachedContent: undefined,
|
|
|
- systemInstruction: systemPrompt,
|
|
|
- }),
|
|
|
- }),
|
|
|
- )
|
|
|
-
|
|
|
- await new Promise(process.nextTick) // Allow microtasks (like the async writeCache) to run
|
|
|
-
|
|
|
- expect(mockCreateCache).toHaveBeenCalledTimes(1)
|
|
|
- expect(mockCreateCache).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- model: expect.stringContaining("gemini-2.0-flash-001"), // Adjusted expectation based on test run
|
|
|
- config: expect.objectContaining({
|
|
|
- systemInstruction: systemPrompt,
|
|
|
- contents: expect.any(Array), // Verify contents structure if needed
|
|
|
- ttl: expect.stringContaining("300s"),
|
|
|
- }),
|
|
|
- }),
|
|
|
- )
|
|
|
- expect(mockCacheSet).toHaveBeenCalledWith(
|
|
|
- cacheKey,
|
|
|
- expect.objectContaining({
|
|
|
- key: mockCacheName,
|
|
|
- count: mockMessagesLong.length,
|
|
|
- tokens: mockCacheTokens,
|
|
|
- }),
|
|
|
- )
|
|
|
- expect(mockDeleteCache).not.toHaveBeenCalled() // No previous cache to delete
|
|
|
-
|
|
|
- const usageChunk = chunks.find((c) => c.type === "usage")
|
|
|
-
|
|
|
- expect(usageChunk).toEqual(
|
|
|
- expect.objectContaining({
|
|
|
- cacheWriteTokens: 100, // Should match promptTokenCount when write is queued
|
|
|
- cacheReadTokens: 0,
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
-
|
|
|
- it("should use cache on hit and not send system prompt", async () => {
|
|
|
- const cachedMessagesCount = 1
|
|
|
- const cacheReadTokensCount = 4000
|
|
|
- mockCacheGet.mockReturnValue({ key: mockCacheName, count: cachedMessagesCount, tokens: cacheReadTokensCount })
|
|
|
-
|
|
|
- mockGenerateContentStream.mockResolvedValue({
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield { text: "Response" }
|
|
|
- yield {
|
|
|
- usageMetadata: {
|
|
|
- promptTokenCount: 10, // Uncached input tokens
|
|
|
- candidatesTokenCount: 50,
|
|
|
- cachedContentTokenCount: cacheReadTokensCount, // Simulate cache hit reporting
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
- })
|
|
|
-
|
|
|
- // Only send the second message (index 1) as uncached
|
|
|
- const stream = handlerWithCache.createMessage(systemPrompt, mockMessagesLong, cacheKey)
|
|
|
- const chunks = []
|
|
|
-
|
|
|
- for await (const chunk of stream) {
|
|
|
- chunks.push(chunk)
|
|
|
- }
|
|
|
-
|
|
|
- expect(mockCacheGet).toHaveBeenCalledWith(cacheKey)
|
|
|
- expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- contents: expect.any(Array), // Should contain only the *uncached* messages
|
|
|
- config: expect.objectContaining({
|
|
|
- cachedContent: mockCacheName, // Cache name provided
|
|
|
- systemInstruction: undefined, // System prompt NOT sent on hit
|
|
|
- }),
|
|
|
- }),
|
|
|
- )
|
|
|
-
|
|
|
- // Check that the contents sent are only the *new* messages
|
|
|
- const calledContents = mockGenerateContentStream.mock.calls[0][0].contents
|
|
|
- expect(calledContents.length).toBe(mockMessagesLong.length - cachedMessagesCount) // Only new messages sent
|
|
|
-
|
|
|
- // Wait for potential async cache write (shouldn't happen here)
|
|
|
- await new Promise(process.nextTick)
|
|
|
- expect(mockCreateCache).not.toHaveBeenCalled()
|
|
|
- expect(mockCacheSet).not.toHaveBeenCalled() // No write occurred
|
|
|
-
|
|
|
- // Check usage data for cache read tokens
|
|
|
- const usageChunk = chunks.find((c) => c.type === "usage")
|
|
|
- expect(usageChunk).toEqual(
|
|
|
- expect.objectContaining({
|
|
|
- inputTokens: 10, // Uncached tokens
|
|
|
- outputTokens: 50,
|
|
|
- cacheWriteTokens: undefined, // No write queued
|
|
|
- cacheReadTokens: cacheReadTokensCount, // Read tokens reported
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
-
|
|
|
- it("should trigger cache write and delete old cache on hit with enough new messages", async () => {
|
|
|
- const previousCacheName = "generated/caches/old-cache-name"
|
|
|
- const previousCacheTokens = 3000
|
|
|
- const previousMessageCount = 1
|
|
|
-
|
|
|
- mockCacheGet.mockReturnValue({
|
|
|
- key: previousCacheName,
|
|
|
- count: previousMessageCount,
|
|
|
- tokens: previousCacheTokens,
|
|
|
- })
|
|
|
-
|
|
|
- // Simulate enough new messages to trigger write (>= CACHE_WRITE_FREQUENCY)
|
|
|
- const newMessagesCount = 10
|
|
|
-
|
|
|
- const messagesForCacheWrite = [
|
|
|
- mockMessagesLong[0], // Will be considered cached
|
|
|
- ...Array(newMessagesCount).fill({ role: "user", content: "new message" }),
|
|
|
- ] as Anthropic.Messages.MessageParam[]
|
|
|
-
|
|
|
- // Mock generateContentStream to report some uncached tokens
|
|
|
- mockGenerateContentStream.mockResolvedValue({
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield { text: "Response" }
|
|
|
- yield {
|
|
|
- usageMetadata: {
|
|
|
- promptTokenCount: 500, // Uncached input tokens for the 10 new messages
|
|
|
- candidatesTokenCount: 50,
|
|
|
- cachedContentTokenCount: previousCacheTokens,
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
- })
|
|
|
-
|
|
|
- const stream = handlerWithCache.createMessage(systemPrompt, messagesForCacheWrite, cacheKey)
|
|
|
- const chunks = []
|
|
|
-
|
|
|
- for await (const chunk of stream) {
|
|
|
- chunks.push(chunk)
|
|
|
- }
|
|
|
-
|
|
|
- expect(mockCacheGet).toHaveBeenCalledWith(cacheKey)
|
|
|
-
|
|
|
- expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- contents: expect.any(Array), // Should contain only the *new* messages
|
|
|
- config: expect.objectContaining({
|
|
|
- cachedContent: previousCacheName, // Old cache name used for reading
|
|
|
- systemInstruction: undefined, // System prompt NOT sent
|
|
|
- }),
|
|
|
- }),
|
|
|
- )
|
|
|
- const calledContents = mockGenerateContentStream.mock.calls[0][0].contents
|
|
|
- expect(calledContents.length).toBe(newMessagesCount) // Only new messages sent
|
|
|
-
|
|
|
- // Wait for async cache write and delete
|
|
|
- await new Promise(process.nextTick)
|
|
|
- await new Promise(process.nextTick) // Needs extra tick for delete promise chain?
|
|
|
-
|
|
|
- expect(mockCreateCache).toHaveBeenCalledTimes(1)
|
|
|
- expect(mockCreateCache).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- // New cache uses *all* messages
|
|
|
- config: expect.objectContaining({
|
|
|
- contents: expect.any(Array), // Should contain *all* messagesForCacheWrite
|
|
|
- systemInstruction: systemPrompt, // System prompt included in *new* cache
|
|
|
- }),
|
|
|
- }),
|
|
|
- )
|
|
|
- const createCallContents = mockCreateCache.mock.calls[0][0].config.contents
|
|
|
- expect(createCallContents.length).toBe(messagesForCacheWrite.length) // All messages in new cache
|
|
|
-
|
|
|
- expect(mockCacheSet).toHaveBeenCalledWith(
|
|
|
- cacheKey,
|
|
|
- expect.objectContaining({
|
|
|
- key: mockCacheName, // New cache name
|
|
|
- count: messagesForCacheWrite.length, // New count
|
|
|
- tokens: mockCacheTokens,
|
|
|
- }),
|
|
|
- )
|
|
|
-
|
|
|
- expect(mockDeleteCache).toHaveBeenCalledTimes(1)
|
|
|
- expect(mockDeleteCache).toHaveBeenCalledWith({ name: previousCacheName }) // Old cache deleted
|
|
|
-
|
|
|
- const usageChunk = chunks.find((c) => c.type === "usage")
|
|
|
-
|
|
|
- expect(usageChunk).toEqual(
|
|
|
- expect.objectContaining({
|
|
|
- inputTokens: 500, // Uncached tokens
|
|
|
- outputTokens: 50,
|
|
|
- cacheWriteTokens: 500, // Write tokens match uncached input when write is queued on hit? No, should be total tokens for the *new* cache. Let's adjust mockCreateCache.
|
|
|
- cacheReadTokens: previousCacheTokens,
|
|
|
- }),
|
|
|
- )
|
|
|
-
|
|
|
- // Re-run with adjusted expectation after fixing mockCreateCache if needed
|
|
|
- // Let's assume mockCreateCache returns the *total* tokens for the *new* cache (system + all messages)
|
|
|
- const expectedNewCacheTotalTokens = 6000 // Example total tokens for the new cache
|
|
|
-
|
|
|
- mockCreateCache.mockResolvedValue({
|
|
|
- name: mockCacheName,
|
|
|
- usageMetadata: { totalTokenCount: expectedNewCacheTotalTokens },
|
|
|
- })
|
|
|
-
|
|
|
- // Re-run the stream consumption and checks if necessary, or adjust expectation:
|
|
|
- // The cacheWriteTokens in usage should reflect the *input* tokens that triggered the write,
|
|
|
- // which are the *uncached* tokens in this hit scenario.
|
|
|
- // The cost calculation uses the token count from the *create* response though.
|
|
|
- // Let's stick to the current implementation: cacheWriteTokens = inputTokens when write is queued.
|
|
|
- expect(usageChunk?.cacheWriteTokens).toBe(500) // Matches the uncached promptTokenCount
|
|
|
- })
|
|
|
-
|
|
|
- it("should handle cache create error gracefully", async () => {
|
|
|
- const consoleErrorSpy = jest.spyOn(console, "error").mockImplementation(() => {})
|
|
|
- const createError = new Error("Failed to create cache")
|
|
|
- mockCreateCache.mockRejectedValue(createError)
|
|
|
-
|
|
|
- const stream = handlerWithCache.createMessage(systemPrompt, mockMessagesLong, cacheKey)
|
|
|
-
|
|
|
- for await (const _ of stream) {
|
|
|
- }
|
|
|
-
|
|
|
- // Wait for async cache write attempt
|
|
|
- await new Promise(process.nextTick)
|
|
|
-
|
|
|
- expect(mockCreateCache).toHaveBeenCalledTimes(1)
|
|
|
- expect(mockCacheSet).not.toHaveBeenCalled() // Set should not be called on error
|
|
|
- expect(consoleErrorSpy).toHaveBeenCalledWith(
|
|
|
- expect.stringContaining("[GeminiHandler] caches.create error"),
|
|
|
- createError,
|
|
|
- )
|
|
|
- consoleErrorSpy.mockRestore()
|
|
|
- })
|
|
|
-
|
|
|
- it("should handle cache delete error gracefully", async () => {
|
|
|
- const consoleErrorSpy = jest.spyOn(console, "error").mockImplementation(() => {})
|
|
|
- const deleteError = new Error("Failed to delete cache")
|
|
|
- mockDeleteCache.mockRejectedValue(deleteError)
|
|
|
-
|
|
|
- // Setup for cache hit + write scenario to trigger delete
|
|
|
- const previousCacheName = "generated/caches/old-cache-name"
|
|
|
- mockCacheGet.mockReturnValue({ key: previousCacheName, count: 1, tokens: 3000 })
|
|
|
-
|
|
|
- const newMessagesCount = 10
|
|
|
-
|
|
|
- const messagesForCacheWrite = [
|
|
|
- mockMessagesLong[0],
|
|
|
- ...Array(newMessagesCount).fill({ role: "user", content: "new message" }),
|
|
|
- ] as Anthropic.Messages.MessageParam[]
|
|
|
-
|
|
|
- const stream = handlerWithCache.createMessage(systemPrompt, messagesForCacheWrite, cacheKey)
|
|
|
-
|
|
|
- for await (const _ of stream) {
|
|
|
- }
|
|
|
-
|
|
|
- // Wait for async cache write and delete attempt
|
|
|
- await new Promise(process.nextTick)
|
|
|
- await new Promise(process.nextTick)
|
|
|
-
|
|
|
- expect(mockCreateCache).toHaveBeenCalledTimes(1) // Create still happens
|
|
|
- expect(mockCacheSet).toHaveBeenCalledTimes(1) // Set still happens
|
|
|
- expect(mockDeleteCache).toHaveBeenCalledTimes(1) // Delete was attempted
|
|
|
-
|
|
|
- // Expect a single string argument containing both parts
|
|
|
- expect(consoleErrorSpy).toHaveBeenCalledWith(
|
|
|
- expect.stringContaining(
|
|
|
- `[GeminiHandler] failed to delete stale cache entry ${previousCacheName} -> ${deleteError.message}`,
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
- consoleErrorSpy.mockRestore()
|
|
|
- })
|
|
|
-})
|