فهرست منبع

fix: exclude cache tokens from context window calculation (#5603)

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
Daniel 5 ماه پیش
والد
کامیت
0ecae9d4f1

+ 1 - 0
packages/types/src/message.ts

@@ -155,6 +155,7 @@ export const clineMessageSchema = z.object({
 	progressStatus: toolProgressStatusSchema.optional(),
 	contextCondense: contextCondenseSchema.optional(),
 	isProtected: z.boolean().optional(),
+	apiProtocol: z.union([z.literal("openai"), z.literal("anthropic")]).optional(),
 })
 
 export type ClineMessage = z.infer<typeof clineMessageSchema>

+ 8 - 0
packages/types/src/provider-settings.ts

@@ -292,3 +292,11 @@ export const getModelId = (settings: ProviderSettings): string | undefined => {
 	const modelIdKey = MODEL_ID_KEYS.find((key) => settings[key])
 	return modelIdKey ? (settings[modelIdKey] as string) : undefined
 }
+
+// Providers that use Anthropic-style API protocol
+export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code"]
+
+// Helper function to determine API protocol for a provider
+export const getApiProtocol = (provider: ProviderName | undefined): "anthropic" | "openai" => {
+	return provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider) ? "anthropic" : "openai"
+}

+ 9 - 1
src/core/task/Task.ts

@@ -21,6 +21,7 @@ import {
 	type HistoryItem,
 	TelemetryEventName,
 	TodoItem,
+	getApiProtocol,
 } from "@roo-code/types"
 import { TelemetryService } from "@roo-code/telemetry"
 import { CloudService } from "@roo-code/cloud"
@@ -1207,11 +1208,16 @@ export class Task extends EventEmitter<ClineEvents> {
 		// top-down build file structure of project which for large projects can
 		// take a few seconds. For the best UX we show a placeholder api_req_started
 		// message with a loading spinner as this happens.
+
+		// Determine API protocol based on provider
+		const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider)
+
 		await this.say(
 			"api_req_started",
 			JSON.stringify({
 				request:
 					userContent.map((block) => formatContentBlockToMarkdown(block)).join("\n\n") + "\n\nLoading...",
+				apiProtocol,
 			}),
 		)
 
@@ -1243,6 +1249,7 @@ export class Task extends EventEmitter<ClineEvents> {
 
 		this.clineMessages[lastApiReqIndex].text = JSON.stringify({
 			request: finalUserContent.map((block) => formatContentBlockToMarkdown(block)).join("\n\n"),
+			apiProtocol,
 		} satisfies ClineApiReqInfo)
 
 		await this.saveClineMessages()
@@ -1263,8 +1270,9 @@ export class Task extends EventEmitter<ClineEvents> {
 			// of prices in tasks from history (it's worth removing a few months
 			// from now).
 			const updateApiReqMsg = (cancelReason?: ClineApiReqCancelReason, streamingFailedMessage?: string) => {
+				const existingData = JSON.parse(this.clineMessages[lastApiReqIndex].text || "{}")
 				this.clineMessages[lastApiReqIndex].text = JSON.stringify({
-					...JSON.parse(this.clineMessages[lastApiReqIndex].text || "{}"),
+					...existingData,
 					tokensIn: inputTokens,
 					tokensOut: outputTokens,
 					cacheWrites: cacheWriteTokens,

+ 1 - 0
src/shared/ExtensionMessage.ts

@@ -379,6 +379,7 @@ export interface ClineApiReqInfo {
 	cost?: number
 	cancelReason?: ClineApiReqCancelReason
 	streamingFailedMessage?: string
+	apiProtocol?: "anthropic" | "openai"
 }
 
 export type ClineApiReqCancelReason = "streaming_failed" | "user_cancelled"

+ 11 - 11
src/shared/__tests__/getApiMetrics.spec.ts

@@ -61,7 +61,7 @@ describe("getApiMetrics", () => {
 			expect(result.totalCacheWrites).toBe(5)
 			expect(result.totalCacheReads).toBe(10)
 			expect(result.totalCost).toBe(0.005)
-			expect(result.contextTokens).toBe(315) // 100 + 200 + 5 + 10
+			expect(result.contextTokens).toBe(300) // 100 + 200 (OpenAI default, no cache tokens)
 		})
 
 		it("should calculate metrics from multiple api_req_started messages", () => {
@@ -83,7 +83,7 @@ describe("getApiMetrics", () => {
 			expect(result.totalCacheWrites).toBe(8) // 5 + 3
 			expect(result.totalCacheReads).toBe(17) // 10 + 7
 			expect(result.totalCost).toBe(0.008) // 0.005 + 0.003
-			expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7 (from the last message)
+			expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens)
 		})
 
 		it("should calculate metrics from condense_context messages", () => {
@@ -123,7 +123,7 @@ describe("getApiMetrics", () => {
 			expect(result.totalCacheWrites).toBe(8) // 5 + 3
 			expect(result.totalCacheReads).toBe(17) // 10 + 7
 			expect(result.totalCost).toBe(0.01) // 0.005 + 0.002 + 0.003
-			expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7 (from the last api_req_started message)
+			expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens)
 		})
 	})
 
@@ -242,9 +242,9 @@ describe("getApiMetrics", () => {
 			expect(result.totalCacheReads).toBe(10)
 			expect(result.totalCost).toBe(0.005)
 
-			// The implementation will use the last message with tokens for contextTokens
-			// In this case, it's the cacheReads message
-			expect(result.contextTokens).toBe(10)
+			// The implementation will use the last message that has any tokens
+			// In this case, it's the message with tokensOut:200 (since the last few messages have no tokensIn/Out)
+			expect(result.contextTokens).toBe(200) // 0 + 200 (from the tokensOut message)
 		})
 
 		it("should handle non-number values in api_req_started message", () => {
@@ -264,8 +264,8 @@ describe("getApiMetrics", () => {
 			expect(result.totalCacheReads).toBeUndefined()
 			expect(result.totalCost).toBe(0)
 
-			// The implementation concatenates string values for contextTokens
-			expect(result.contextTokens).toBe("not-a-numbernot-a-numbernot-a-numbernot-a-number")
+			// The implementation concatenates all token values including cache tokens
+			expect(result.contextTokens).toBe("not-a-numbernot-a-number") // tokensIn + tokensOut (OpenAI default)
 		})
 	})
 
@@ -279,7 +279,7 @@ describe("getApiMetrics", () => {
 			const result = getApiMetrics(messages)
 
 			// Should use the values from the last api_req_started message
-			expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7
+			expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens)
 		})
 
 		it("should calculate contextTokens from the last condense_context message", () => {
@@ -305,7 +305,7 @@ describe("getApiMetrics", () => {
 			const result = getApiMetrics(messages)
 
 			// Should use the values from the last api_req_started message
-			expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7
+			expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens)
 		})
 
 		it("should handle missing values when calculating contextTokens", () => {
@@ -320,7 +320,7 @@ describe("getApiMetrics", () => {
 			const result = getApiMetrics(messages)
 
 			// Should handle missing or invalid values
-			expect(result.contextTokens).toBe(15) // 0 + 0 + 5 + 10
+			expect(result.contextTokens).toBe(0) // 0 + 0 (OpenAI default, no cache tokens)
 
 			// Restore console.error
 			console.error = originalConsoleError

+ 10 - 2
src/shared/getApiMetrics.ts

@@ -6,6 +6,7 @@ export type ParsedApiReqStartedTextType = {
 	cacheWrites: number
 	cacheReads: number
 	cost?: number // Only present if combineApiRequests has been called
+	apiProtocol?: "anthropic" | "openai"
 }
 
 /**
@@ -72,8 +73,15 @@ export function getApiMetrics(messages: ClineMessage[]) {
 		if (message.type === "say" && message.say === "api_req_started" && message.text) {
 			try {
 				const parsedText: ParsedApiReqStartedTextType = JSON.parse(message.text)
-				const { tokensIn, tokensOut, cacheWrites, cacheReads } = parsedText
-				result.contextTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0)
+				const { tokensIn, tokensOut, cacheWrites, cacheReads, apiProtocol } = parsedText
+
+				// Calculate context tokens based on API protocol
+				if (apiProtocol === "anthropic") {
+					result.contextTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0)
+				} else {
+					// For OpenAI (or when protocol is not specified)
+					result.contextTokens = (tokensIn || 0) + (tokensOut || 0)
+				}
 			} catch (error) {
 				console.error("Error parsing JSON:", error)
 				continue