Explorar o código

Fix usage tracking for SiliconFlow etc

Matt Rubens hai 9 meses
pai
achega
f306461276

+ 5 - 0
.changeset/tidy-queens-pay.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": patch
+---
+
+Fix usage tracking for SiliconFlow etc

+ 235 - 0
src/api/providers/__tests__/openai-usage-tracking.test.ts

@@ -0,0 +1,235 @@
+import { OpenAiHandler } from "../openai"
+import { ApiHandlerOptions } from "../../../shared/api"
+import { Anthropic } from "@anthropic-ai/sdk"
+
+// Mock OpenAI client with multiple chunks that contain usage data
+const mockCreate = jest.fn()
+jest.mock("openai", () => {
+	return {
+		__esModule: true,
+		default: jest.fn().mockImplementation(() => ({
+			chat: {
+				completions: {
+					create: mockCreate.mockImplementation(async (options) => {
+						if (!options.stream) {
+							return {
+								id: "test-completion",
+								choices: [
+									{
+										message: { role: "assistant", content: "Test response", refusal: null },
+										finish_reason: "stop",
+										index: 0,
+									},
+								],
+								usage: {
+									prompt_tokens: 10,
+									completion_tokens: 5,
+									total_tokens: 15,
+								},
+							}
+						}
+
+						// Return a stream with multiple chunks that include usage metrics
+						return {
+							[Symbol.asyncIterator]: async function* () {
+								// First chunk with partial usage
+								yield {
+									choices: [
+										{
+											delta: { content: "Test " },
+											index: 0,
+										},
+									],
+									usage: {
+										prompt_tokens: 10,
+										completion_tokens: 2,
+										total_tokens: 12,
+									},
+								}
+
+								// Second chunk with updated usage
+								yield {
+									choices: [
+										{
+											delta: { content: "response" },
+											index: 0,
+										},
+									],
+									usage: {
+										prompt_tokens: 10,
+										completion_tokens: 4,
+										total_tokens: 14,
+									},
+								}
+
+								// Final chunk with complete usage
+								yield {
+									choices: [
+										{
+											delta: {},
+											index: 0,
+										},
+									],
+									usage: {
+										prompt_tokens: 10,
+										completion_tokens: 5,
+										total_tokens: 15,
+									},
+								}
+							},
+						}
+					}),
+				},
+			},
+		})),
+	}
+})
+
+describe("OpenAiHandler with usage tracking fix", () => {
+	let handler: OpenAiHandler
+	let mockOptions: ApiHandlerOptions
+
+	beforeEach(() => {
+		mockOptions = {
+			openAiApiKey: "test-api-key",
+			openAiModelId: "gpt-4",
+			openAiBaseUrl: "https://api.openai.com/v1",
+		}
+		handler = new OpenAiHandler(mockOptions)
+		mockCreate.mockClear()
+	})
+
+	describe("usage metrics with streaming", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [
+					{
+						type: "text" as const,
+						text: "Hello!",
+					},
+				],
+			},
+		]
+
+		it("should only yield usage metrics once at the end of the stream", async () => {
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Check we have text chunks
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(2)
+			expect(textChunks[0].text).toBe("Test ")
+			expect(textChunks[1].text).toBe("response")
+
+			// Check we only have one usage chunk and it's the last one
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks).toHaveLength(1)
+			expect(usageChunks[0]).toEqual({
+				type: "usage",
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			// Check the usage chunk is the last one reported from the API
+			const lastChunk = chunks[chunks.length - 1]
+			expect(lastChunk.type).toBe("usage")
+			expect(lastChunk.inputTokens).toBe(10)
+			expect(lastChunk.outputTokens).toBe(5)
+		})
+
+		it("should handle case where usage is only in the final chunk", async () => {
+			// Override the mock for this specific test
+			mockCreate.mockImplementationOnce(async (options) => {
+				if (!options.stream) {
+					return {
+						id: "test-completion",
+						choices: [{ message: { role: "assistant", content: "Test response" } }],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				}
+
+				return {
+					[Symbol.asyncIterator]: async function* () {
+						// First chunk with no usage
+						yield {
+							choices: [{ delta: { content: "Test " }, index: 0 }],
+							usage: null,
+						}
+
+						// Second chunk with no usage
+						yield {
+							choices: [{ delta: { content: "response" }, index: 0 }],
+							usage: null,
+						}
+
+						// Final chunk with usage data
+						yield {
+							choices: [{ delta: {}, index: 0 }],
+							usage: {
+								prompt_tokens: 10,
+								completion_tokens: 5,
+								total_tokens: 15,
+							},
+						}
+					},
+				}
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Check usage metrics
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks).toHaveLength(1)
+			expect(usageChunks[0]).toEqual({
+				type: "usage",
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+		})
+
+		it("should handle case where no usage is provided", async () => {
+			// Override the mock for this specific test
+			mockCreate.mockImplementationOnce(async (options) => {
+				if (!options.stream) {
+					return {
+						id: "test-completion",
+						choices: [{ message: { role: "assistant", content: "Test response" } }],
+						usage: null,
+					}
+				}
+
+				return {
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							choices: [{ delta: { content: "Test response" }, index: 0 }],
+							usage: null,
+						}
+						yield {
+							choices: [{ delta: {}, index: 0 }],
+							usage: null,
+						}
+					},
+				}
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Check we don't have any usage chunks
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks).toHaveLength(0)
+		})
+	})
+})

+ 7 - 1
src/api/providers/openai.ts

@@ -99,6 +99,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 
 			const stream = await this.client.chat.completions.create(requestOptions)
 
+			let lastUsage
+
 			for await (const chunk of stream) {
 				const delta = chunk.choices[0]?.delta ?? {}
 
@@ -116,9 +118,13 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 					}
 				}
 				if (chunk.usage) {
-					yield this.processUsageMetrics(chunk.usage, modelInfo)
+					lastUsage = chunk.usage
 				}
 			}
+
+			if (lastUsage) {
+				yield this.processUsageMetrics(lastUsage, modelInfo)
+			}
 		} else {
 			// o1 for instance doesnt support streaming, non-1 temp, or system prompt
 			const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {