소스 검색

Fix openai cache tracking and cost estimates (#2616)

* fix(api): update cacheReadsPrice for OpenAI GPT-4.1 models (#2887)

Set correct cacheReadsPrice (cached input price) for gpt-4.1, gpt-4.1 mini, and gpt-4.1 nano based on official OpenAI pricing. No changes to cacheWritesPrice as per current OpenAI documentation. This ensures prompt caching costs are accurately reflected for these models in cost calculations.

* Update more OpenAI cache prices

* Track cache tokens and cost correctly for OpenAI

* Update tests

---------

Co-authored-by: monotykamary <[email protected]>
Matt Rubens 8 달 전
부모
커밋
a64cab92dc
3개의 변경된 파일96개의 추가작업 그리고 44개의 파일을 삭제
  1. 28 11
      src/api/providers/__tests__/openai-native.test.ts
  2. 56 33
      src/api/providers/openai-native.ts
  3. 12 0
      src/shared/api.ts

+ 28 - 11
src/api/providers/__tests__/openai-native.test.ts

@@ -153,7 +153,12 @@ describe("OpenAiNativeHandler", () => {
 				results.push(result)
 			}
 
-			expect(results).toEqual([{ type: "usage", inputTokens: 0, outputTokens: 0 }])
+			// Verify essential fields directly
+			expect(results.length).toBe(1)
+			expect(results[0].type).toBe("usage")
+			// Use type assertion to avoid TypeScript errors
+			expect((results[0] as any).inputTokens).toBe(0)
+			expect((results[0] as any).outputTokens).toBe(0)
 
 			// Verify developer role is used for system prompt with o1 model
 			expect(mockCreate).toHaveBeenCalledWith({
@@ -221,12 +226,18 @@ describe("OpenAiNativeHandler", () => {
 				results.push(result)
 			}
 
-			expect(results).toEqual([
-				{ type: "text", text: "Hello" },
-				{ type: "text", text: " there" },
-				{ type: "text", text: "!" },
-				{ type: "usage", inputTokens: 10, outputTokens: 5 },
-			])
+			// Verify text responses individually
+			expect(results.length).toBe(4)
+			expect(results[0]).toMatchObject({ type: "text", text: "Hello" })
+			expect(results[1]).toMatchObject({ type: "text", text: " there" })
+			expect(results[2]).toMatchObject({ type: "text", text: "!" })
+
+			// Check usage data fields but use toBeCloseTo for floating point comparison
+			expect(results[3].type).toBe("usage")
+			// Use type assertion to avoid TypeScript errors
+			expect((results[3] as any).inputTokens).toBe(10)
+			expect((results[3] as any).outputTokens).toBe(5)
+			expect((results[3] as any).totalCost).toBeCloseTo(0.00006, 6)
 
 			expect(mockCreate).toHaveBeenCalledWith({
 				model: "gpt-4.1",
@@ -261,10 +272,16 @@ describe("OpenAiNativeHandler", () => {
 				results.push(result)
 			}
 
-			expect(results).toEqual([
-				{ type: "text", text: "Hello" },
-				{ type: "usage", inputTokens: 10, outputTokens: 5 },
-			])
+			// Verify responses individually
+			expect(results.length).toBe(2)
+			expect(results[0]).toMatchObject({ type: "text", text: "Hello" })
+
+			// Check usage data fields but use toBeCloseTo for floating point comparison
+			expect(results[1].type).toBe("usage")
+			// Use type assertion to avoid TypeScript errors
+			expect((results[1] as any).inputTokens).toBe(10)
+			expect((results[1] as any).outputTokens).toBe(5)
+			expect((results[1] as any).totalCost).toBeCloseTo(0.00006, 6)
 		})
 	})
 

+ 56 - 33
src/api/providers/openai-native.ts

@@ -11,9 +11,16 @@ import {
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream } from "../transform/stream"
 import { BaseProvider } from "./base-provider"
+import { calculateApiCostOpenAI } from "../../utils/cost"
 
 const OPENAI_NATIVE_DEFAULT_TEMPERATURE = 0
 
+// Define a type for the model object returned by getModel
+export type OpenAiNativeModel = {
+	id: OpenAiNativeModelId
+	info: ModelInfo
+}
+
 export class OpenAiNativeHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: ApiHandlerOptions
 	private client: OpenAI
@@ -26,31 +33,31 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 	}
 
 	override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		const modelId = this.getModel().id
+		const model = this.getModel()
 
-		if (modelId.startsWith("o1")) {
-			yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
+		if (model.id.startsWith("o1")) {
+			yield* this.handleO1FamilyMessage(model, systemPrompt, messages)
 			return
 		}
 
-		if (modelId.startsWith("o3-mini")) {
-			yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
+		if (model.id.startsWith("o3-mini")) {
+			yield* this.handleO3FamilyMessage(model, systemPrompt, messages)
 			return
 		}
 
-		yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
+		yield* this.handleDefaultModelMessage(model, systemPrompt, messages)
 	}
 
 	private async *handleO1FamilyMessage(
-		modelId: string,
+		model: OpenAiNativeModel,
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 	): ApiStream {
 		// o1 supports developer prompt with formatting
 		// o1-preview and o1-mini only support user messages
-		const isOriginalO1 = modelId === "o1"
+		const isOriginalO1 = model.id === "o1"
 		const response = await this.client.chat.completions.create({
-			model: modelId,
+			model: model.id,
 			messages: [
 				{
 					role: isOriginalO1 ? "developer" : "user",
@@ -62,11 +69,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 			stream_options: { include_usage: true },
 		})
 
-		yield* this.handleStreamResponse(response)
+		yield* this.handleStreamResponse(response, model)
 	}
 
 	private async *handleO3FamilyMessage(
-		modelId: string,
+		model: OpenAiNativeModel,
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 	): ApiStream {
@@ -84,23 +91,23 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 			reasoning_effort: this.getModel().info.reasoningEffort,
 		})
 
-		yield* this.handleStreamResponse(stream)
+		yield* this.handleStreamResponse(stream, model)
 	}
 
 	private async *handleDefaultModelMessage(
-		modelId: string,
+		model: OpenAiNativeModel,
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 	): ApiStream {
 		const stream = await this.client.chat.completions.create({
-			model: modelId,
+			model: model.id,
 			temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
 			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
 			stream: true,
 			stream_options: { include_usage: true },
 		})
 
-		yield* this.handleStreamResponse(stream)
+		yield* this.handleStreamResponse(stream, model)
 	}
 
 	private async *yieldResponseData(response: OpenAI.Chat.Completions.ChatCompletion): ApiStream {
@@ -115,7 +122,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 		}
 	}
 
-	private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
+	private async *handleStreamResponse(
+		stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>,
+		model: OpenAiNativeModel,
+	): ApiStream {
 		for await (const chunk of stream) {
 			const delta = chunk.choices[0]?.delta
 			if (delta?.content) {
@@ -126,16 +136,29 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 			}
 
 			if (chunk.usage) {
-				yield {
-					type: "usage",
-					inputTokens: chunk.usage.prompt_tokens || 0,
-					outputTokens: chunk.usage.completion_tokens || 0,
-				}
+				yield* this.yieldUsage(model.info, chunk.usage)
 			}
 		}
 	}
 
-	override getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
+	private async *yieldUsage(info: ModelInfo, usage: OpenAI.Completions.CompletionUsage | undefined): ApiStream {
+		const inputTokens = usage?.prompt_tokens || 0 // sum of cache hits and misses
+		const outputTokens = usage?.completion_tokens || 0
+		const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
+		const cacheWriteTokens = 0
+		const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
+		const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)
+		yield {
+			type: "usage",
+			inputTokens: nonCachedInputTokens,
+			outputTokens: outputTokens,
+			cacheWriteTokens: cacheWriteTokens,
+			cacheReadTokens: cacheReadTokens,
+			totalCost: totalCost,
+		}
+	}
+
+	override getModel(): OpenAiNativeModel {
 		const modelId = this.options.apiModelId
 		if (modelId && modelId in openAiNativeModels) {
 			const id = modelId as OpenAiNativeModelId
@@ -146,15 +169,15 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 
 	async completePrompt(prompt: string): Promise<string> {
 		try {
-			const modelId = this.getModel().id
+			const model = this.getModel()
 			let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
 
-			if (modelId.startsWith("o1")) {
-				requestOptions = this.getO1CompletionOptions(modelId, prompt)
-			} else if (modelId.startsWith("o3-mini")) {
-				requestOptions = this.getO3CompletionOptions(modelId, prompt)
+			if (model.id.startsWith("o1")) {
+				requestOptions = this.getO1CompletionOptions(model, prompt)
+			} else if (model.id.startsWith("o3-mini")) {
+				requestOptions = this.getO3CompletionOptions(model, prompt)
 			} else {
-				requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
+				requestOptions = this.getDefaultCompletionOptions(model, prompt)
 			}
 
 			const response = await this.client.chat.completions.create(requestOptions)
@@ -168,17 +191,17 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 	}
 
 	private getO1CompletionOptions(
-		modelId: string,
+		model: OpenAiNativeModel,
 		prompt: string,
 	): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
 		return {
-			model: modelId,
+			model: model.id,
 			messages: [{ role: "user", content: prompt }],
 		}
 	}
 
 	private getO3CompletionOptions(
-		modelId: string,
+		model: OpenAiNativeModel,
 		prompt: string,
 	): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
 		return {
@@ -189,11 +212,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 	}
 
 	private getDefaultCompletionOptions(
-		modelId: string,
+		model: OpenAiNativeModel,
 		prompt: string,
 	): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
 		return {
-			model: modelId,
+			model: model.id,
 			messages: [{ role: "user", content: prompt }],
 			temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
 		}

+ 12 - 0
src/shared/api.ts

@@ -754,6 +754,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 2,
 		outputPrice: 8,
+		cacheReadsPrice: 0.5,
 	},
 	"gpt-4.1-mini": {
 		maxTokens: 32_768,
@@ -762,6 +763,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 0.4,
 		outputPrice: 1.6,
+		cacheReadsPrice: 0.1,
 	},
 	"gpt-4.1-nano": {
 		maxTokens: 32_768,
@@ -770,6 +772,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 0.1,
 		outputPrice: 0.4,
+		cacheReadsPrice: 0.025,
 	},
 	"o3-mini": {
 		maxTokens: 100_000,
@@ -778,6 +781,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 1.1,
 		outputPrice: 4.4,
+		cacheReadsPrice: 0.55,
 		reasoningEffort: "medium",
 	},
 	"o3-mini-high": {
@@ -787,6 +791,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 1.1,
 		outputPrice: 4.4,
+		cacheReadsPrice: 0.55,
 		reasoningEffort: "high",
 	},
 	"o3-mini-low": {
@@ -796,6 +801,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 1.1,
 		outputPrice: 4.4,
+		cacheReadsPrice: 0.55,
 		reasoningEffort: "low",
 	},
 	o1: {
@@ -805,6 +811,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 15,
 		outputPrice: 60,
+		cacheReadsPrice: 7.5,
 	},
 	"o1-preview": {
 		maxTokens: 32_768,
@@ -813,6 +820,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 15,
 		outputPrice: 60,
+		cacheReadsPrice: 7.5,
 	},
 	"o1-mini": {
 		maxTokens: 65_536,
@@ -821,6 +829,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 1.1,
 		outputPrice: 4.4,
+		cacheReadsPrice: 0.55,
 	},
 	"gpt-4.5-preview": {
 		maxTokens: 16_384,
@@ -829,6 +838,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 75,
 		outputPrice: 150,
+		cacheReadsPrice: 37.5,
 	},
 	"gpt-4o": {
 		maxTokens: 16_384,
@@ -837,6 +847,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 2.5,
 		outputPrice: 10,
+		cacheReadsPrice: 1.25,
 	},
 	"gpt-4o-mini": {
 		maxTokens: 16_384,
@@ -845,6 +856,7 @@ export const openAiNativeModels = {
 		supportsPromptCache: true,
 		inputPrice: 0.15,
 		outputPrice: 0.6,
+		cacheReadsPrice: 0.075,
 	},
 } as const satisfies Record<string, ModelInfo>