Browse Source

Merge pull request #642 from websentry-ai/vs/fix-anthropic-cache-tokens

Fixes the cached token count for unbound provider models
Matt Rubens 11 months ago
parent
commit
449b9ef610
2 changed files with 60 additions and 15 deletions
  1. 40 10
      src/api/providers/__tests__/unbound.test.ts
  2. 20 5
      src/api/providers/unbound.ts

+ 40 - 10
src/api/providers/__tests__/unbound.test.ts

@@ -1,6 +1,5 @@
 import { UnboundHandler } from "../unbound"
 import { ApiHandlerOptions } from "../../../shared/api"
-import OpenAI from "openai"
 import { Anthropic } from "@anthropic-ai/sdk"
 
 // Mock OpenAI client
@@ -16,6 +15,7 @@ jest.mock("openai", () => {
 					create: (...args: any[]) => {
 						const stream = {
 							[Symbol.asyncIterator]: async function* () {
+								// First chunk with content
 								yield {
 									choices: [
 										{
@@ -24,13 +24,25 @@ jest.mock("openai", () => {
 										},
 									],
 								}
+								// Second chunk with usage data
 								yield {
-									choices: [
-										{
-											delta: {},
-											index: 0,
-										},
-									],
+									choices: [{ delta: {}, index: 0 }],
+									usage: {
+										prompt_tokens: 10,
+										completion_tokens: 5,
+										total_tokens: 15,
+									},
+								}
+								// Third chunk with cache usage data
+								yield {
+									choices: [{ delta: {}, index: 0 }],
+									usage: {
+										prompt_tokens: 8,
+										completion_tokens: 4,
+										total_tokens: 12,
+										cache_creation_input_tokens: 3,
+										cache_read_input_tokens: 2,
+									},
 								}
 							},
 						}
@@ -95,19 +107,37 @@ describe("UnboundHandler", () => {
 			},
 		]
 
-		it("should handle streaming responses", async () => {
+		it("should handle streaming responses with text and usage data", async () => {
 			const stream = handler.createMessage(systemPrompt, messages)
-			const chunks: any[] = []
+			const chunks: Array<{ type: string } & Record<string, any>> = []
 			for await (const chunk of stream) {
 				chunks.push(chunk)
 			}
 
-			expect(chunks.length).toBe(1)
+			expect(chunks.length).toBe(3)
+
+			// Verify text chunk
 			expect(chunks[0]).toEqual({
 				type: "text",
 				text: "Test response",
 			})
 
+			// Verify regular usage data
+			expect(chunks[1]).toEqual({
+				type: "usage",
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			// Verify usage data with cache information
+			expect(chunks[2]).toEqual({
+				type: "usage",
+				inputTokens: 8,
+				outputTokens: 4,
+				cacheWriteTokens: 3,
+				cacheReadTokens: 2,
+			})
+
 			expect(mockCreate).toHaveBeenCalledWith(
 				expect.objectContaining({
 					model: "claude-3-5-sonnet-20241022",

+ 20 - 5
src/api/providers/unbound.ts

@@ -3,7 +3,12 @@ import OpenAI from "openai"
 import { ApiHandler, SingleCompletionHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api"
 import { convertToOpenAiMessages } from "../transform/openai-format"
-import { ApiStream } from "../transform/stream"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+
+interface UnboundUsage extends OpenAI.CompletionUsage {
+	cache_creation_input_tokens?: number
+	cache_read_input_tokens?: number
+}
 
 export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
@@ -96,7 +101,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
 
 		for await (const chunk of completion) {
 			const delta = chunk.choices[0]?.delta
-			const usage = chunk.usage
+			const usage = chunk.usage as UnboundUsage
 
 			if (delta?.content) {
 				yield {
@@ -106,11 +111,21 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
 			}
 
 			if (usage) {
-				yield {
+				const usageData: ApiStreamUsageChunk = {
 					type: "usage",
-					inputTokens: usage?.prompt_tokens || 0,
-					outputTokens: usage?.completion_tokens || 0,
+					inputTokens: usage.prompt_tokens || 0,
+					outputTokens: usage.completion_tokens || 0,
 				}
+
+				// Only add cache tokens if they exist
+				if (usage.cache_creation_input_tokens) {
+					usageData.cacheWriteTokens = usage.cache_creation_input_tokens
+				}
+				if (usage.cache_read_input_tokens) {
+					usageData.cacheReadTokens = usage.cache_read_input_tokens
+				}
+
+				yield usageData
 			}
 		}
 	}