Просмотр исходного кода

Improvements to base openai compatible (#9462)

Co-authored-by: Roo Code <[email protected]>
Matt Rubens 1 месяц назад
Родитель
Сommit
8472bbb43b

+ 1 - 1
src/api/providers/__tests__/base-openai-compatible-provider.spec.ts

@@ -380,7 +380,7 @@ describe("BaseOpenAiCompatibleProvider", () => {
 			const firstChunk = await stream.next()
 
 			expect(firstChunk.done).toBe(false)
-			expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 100, outputTokens: 50 })
+			expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 50 })
 		})
 	})
 })

+ 6 - 10
src/api/providers/__tests__/featherless.spec.ts

@@ -123,11 +123,9 @@ describe("FeatherlessHandler", () => {
 			chunks.push(chunk)
 		}
 
-		expect(chunks).toEqual([
-			{ type: "reasoning", text: "Thinking..." },
-			{ type: "text", text: "Hello" },
-			{ type: "usage", inputTokens: 10, outputTokens: 5 },
-		])
+		expect(chunks[0]).toEqual({ type: "reasoning", text: "Thinking..." })
+		expect(chunks[1]).toEqual({ type: "text", text: "Hello" })
+		expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5 })
 	})
 
 	it("should fall back to base provider for non-DeepSeek models", async () => {
@@ -145,10 +143,8 @@ describe("FeatherlessHandler", () => {
 			chunks.push(chunk)
 		}
 
-		expect(chunks).toEqual([
-			{ type: "text", text: "Test response" },
-			{ type: "usage", inputTokens: 10, outputTokens: 5 },
-		])
+		expect(chunks[0]).toEqual({ type: "text", text: "Test response" })
+		expect(chunks[1]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5 })
 	})
 
 	it("should return default model when no model is specified", () => {
@@ -226,7 +222,7 @@ describe("FeatherlessHandler", () => {
 		const firstChunk = await stream.next()
 
 		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
+		expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
 	})
 
 	it("createMessage should pass correct parameters to Featherless client for DeepSeek R1", async () => {

+ 4 - 6
src/api/providers/__tests__/fireworks.spec.ts

@@ -384,7 +384,7 @@ describe("FireworksHandler", () => {
 		const firstChunk = await stream.next()
 
 		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
+		expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
 	})
 
 	it("createMessage should pass correct parameters to Fireworks client", async () => {
@@ -494,10 +494,8 @@ describe("FireworksHandler", () => {
 			chunks.push(chunk)
 		}
 
-		expect(chunks).toEqual([
-			{ type: "text", text: "Hello" },
-			{ type: "text", text: " world" },
-			{ type: "usage", inputTokens: 5, outputTokens: 10 },
-		])
+		expect(chunks[0]).toEqual({ type: "text", text: "Hello" })
+		expect(chunks[1]).toEqual({ type: "text", text: " world" })
+		expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 })
 	})
 })

+ 5 - 3
src/api/providers/__tests__/groq.spec.ts

@@ -112,9 +112,10 @@ describe("GroqHandler", () => {
 			type: "usage",
 			inputTokens: 10,
 			outputTokens: 20,
-			cacheWriteTokens: 0,
-			cacheReadTokens: 0,
 		})
+		// cacheWriteTokens and cacheReadTokens will be undefined when 0
+		expect(firstChunk.value.cacheWriteTokens).toBeUndefined()
+		expect(firstChunk.value.cacheReadTokens).toBeUndefined()
 		// Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts)
 		expect(typeof firstChunk.value.totalCost).toBe("number")
 	})
@@ -151,9 +152,10 @@ describe("GroqHandler", () => {
 			type: "usage",
 			inputTokens: 100,
 			outputTokens: 50,
-			cacheWriteTokens: 0,
 			cacheReadTokens: 30,
 		})
+		// cacheWriteTokens will be undefined when 0
+		expect(firstChunk.value.cacheWriteTokens).toBeUndefined()
 		expect(typeof firstChunk.value.totalCost).toBe("number")
 	})
 

+ 2 - 2
src/api/providers/__tests__/io-intelligence.spec.ts

@@ -178,7 +178,7 @@ describe("IOIntelligenceHandler", () => {
 		expect(results).toHaveLength(3)
 		expect(results[0]).toEqual({ type: "text", text: "Hello" })
 		expect(results[1]).toEqual({ type: "text", text: " world" })
-		expect(results[2]).toEqual({
+		expect(results[2]).toMatchObject({
 			type: "usage",
 			inputTokens: 10,
 			outputTokens: 5,
@@ -243,7 +243,7 @@ describe("IOIntelligenceHandler", () => {
 		const firstChunk = await stream.next()
 
 		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
+		expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
 	})
 
 	it("should return model info from cache when available", () => {

+ 1 - 1
src/api/providers/__tests__/sambanova.spec.ts

@@ -113,7 +113,7 @@ describe("SambaNovaHandler", () => {
 		const firstChunk = await stream.next()
 
 		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
+		expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
 	})
 
 	it("createMessage should pass correct parameters to SambaNova client", async () => {

+ 1 - 1
src/api/providers/__tests__/zai.spec.ts

@@ -252,7 +252,7 @@ describe("ZAiHandler", () => {
 			const firstChunk = await stream.next()
 
 			expect(firstChunk.done).toBe(false)
-			expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
+			expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 })
 		})
 
 		it("createMessage should pass correct parameters to Z AI client", async () => {

+ 55 - 15
src/api/providers/base-openai-compatible-provider.ts

@@ -5,13 +5,14 @@ import type { ModelInfo } from "@roo-code/types"
 
 import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api"
 import { XmlMatcher } from "../../utils/xml-matcher"
-import { ApiStream } from "../transform/stream"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 import { DEFAULT_HEADERS } from "./constants"
 import { BaseProvider } from "./base-provider"
 import { handleOpenAIError } from "./utils/openai-error-handler"
+import { calculateApiCostOpenAI } from "../../shared/cost"
 
 type BaseOpenAiCompatibleProviderOptions<ModelName extends string> = ApiHandlerOptions & {
 	providerName: string
@@ -94,6 +95,11 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 		}
 
+		// Add thinking parameter if reasoning is enabled and model supports it
+		if (this.options.enableReasoningEffort && info.supportsReasoningBinary) {
+			;(params as any).thinking = { type: "enabled" }
+		}
+
 		try {
 			return this.client.chat.completions.create(params, requestOptions)
 		} catch (error) {
@@ -119,6 +125,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 
 		const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()
 
+		let lastUsage: OpenAI.CompletionUsage | undefined
+
 		for await (const chunk of stream) {
 			// Check for provider-specific error responses (e.g., MiniMax base_resp)
 			const chunkAny = chunk as any
@@ -137,10 +145,15 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 				}
 			}
 
-			if (delta && "reasoning_content" in delta) {
-				const reasoning_content = (delta.reasoning_content as string | undefined) || ""
-				if (reasoning_content?.trim()) {
-					yield { type: "reasoning", text: reasoning_content }
+			if (delta) {
+				for (const key of ["reasoning_content", "reasoning"] as const) {
+					if (key in delta) {
+						const reasoning_content = ((delta as any)[key] as string | undefined) || ""
+						if (reasoning_content?.trim()) {
+							yield { type: "reasoning", text: reasoning_content }
+						}
+						break
+					}
 				}
 			}
 
@@ -176,11 +189,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			}
 
 			if (chunk.usage) {
-				yield {
-					type: "usage",
-					inputTokens: chunk.usage.prompt_tokens || 0,
-					outputTokens: chunk.usage.completion_tokens || 0,
-				}
+				lastUsage = chunk.usage
 			}
 		}
 
@@ -198,20 +207,51 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			toolCallAccumulator.clear()
 		}
 
+		if (lastUsage) {
+			yield this.processUsageMetrics(lastUsage, this.getModel().info)
+		}
+
 		// Process any remaining content
 		for (const processedChunk of matcher.final()) {
 			yield processedChunk
 		}
 	}
 
+	protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
+		const inputTokens = usage?.prompt_tokens || 0
+		const outputTokens = usage?.completion_tokens || 0
+		const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
+		const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
+
+		const { totalCost } = modelInfo
+			? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
+			: { totalCost: 0 }
+
+		return {
+			type: "usage",
+			inputTokens,
+			outputTokens,
+			cacheWriteTokens: cacheWriteTokens || undefined,
+			cacheReadTokens: cacheReadTokens || undefined,
+			totalCost,
+		}
+	}
+
 	async completePrompt(prompt: string): Promise<string> {
-		const { id: modelId } = this.getModel()
+		const { id: modelId, info: modelInfo } = this.getModel()
+
+		const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = {
+			model: modelId,
+			messages: [{ role: "user", content: prompt }],
+		}
+
+		// Add thinking parameter if reasoning is enabled and model supports it
+		if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) {
+			;(params as any).thinking = { type: "enabled" }
+		}
 
 		try {
-			const response = await this.client.chat.completions.create({
-				model: modelId,
-				messages: [{ role: "user", content: prompt }],
-			})
+			const response = await this.client.chat.completions.create(params)
 
 			// Check for provider-specific error responses (e.g., MiniMax base_resp)
 			const responseAny = response as any

+ 0 - 59
src/api/providers/groq.ts

@@ -1,22 +1,9 @@
 import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
-import { Anthropic } from "@anthropic-ai/sdk"
-import OpenAI from "openai"
 
 import type { ApiHandlerOptions } from "../../shared/api"
-import type { ApiHandlerCreateMessageMetadata } from "../index"
-import { ApiStream } from "../transform/stream"
-import { convertToOpenAiMessages } from "../transform/openai-format"
-import { calculateApiCostOpenAI } from "../../shared/cost"
 
 import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
 
-// Enhanced usage interface to support Groq's cached token fields
-interface GroqUsage extends OpenAI.CompletionUsage {
-	prompt_tokens_details?: {
-		cached_tokens?: number
-	}
-}
-
 export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
 	constructor(options: ApiHandlerOptions) {
 		super({
@@ -29,50 +16,4 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
 			defaultTemperature: 0.5,
 		})
 	}
-
-	override async *createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		metadata?: ApiHandlerCreateMessageMetadata,
-	): ApiStream {
-		const stream = await this.createStream(systemPrompt, messages, metadata)
-
-		for await (const chunk of stream) {
-			const delta = chunk.choices[0]?.delta
-
-			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
-				}
-			}
-
-			if (chunk.usage) {
-				yield* this.yieldUsage(chunk.usage as GroqUsage)
-			}
-		}
-	}
-
-	private async *yieldUsage(usage: GroqUsage | undefined): ApiStream {
-		const { info } = this.getModel()
-		const inputTokens = usage?.prompt_tokens || 0
-		const outputTokens = usage?.completion_tokens || 0
-
-		const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
-
-		// Groq does not track cache writes
-		const cacheWriteTokens = 0
-
-		// Calculate cost using OpenAI-compatible cost calculation
-		const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
-
-		yield {
-			type: "usage",
-			inputTokens,
-			outputTokens,
-			cacheWriteTokens,
-			cacheReadTokens,
-			totalCost,
-		}
-	}
 }

+ 0 - 72
src/api/providers/zai.ts

@@ -3,21 +3,12 @@ import {
 	mainlandZAiModels,
 	internationalZAiDefaultModelId,
 	mainlandZAiDefaultModelId,
-	type InternationalZAiModelId,
-	type MainlandZAiModelId,
 	type ModelInfo,
 	ZAI_DEFAULT_TEMPERATURE,
 	zaiApiLineConfigs,
 } from "@roo-code/types"
 
-import { Anthropic } from "@anthropic-ai/sdk"
-import OpenAI from "openai"
-
 import type { ApiHandlerOptions } from "../../shared/api"
-import { getModelMaxOutputTokens } from "../../shared/api"
-import { convertToOpenAiMessages } from "../transform/openai-format"
-import type { ApiHandlerCreateMessageMetadata } from "../index"
-import { handleOpenAIError } from "./utils/openai-error-handler"
 
 import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
 
@@ -37,67 +28,4 @@ export class ZAiHandler extends BaseOpenAiCompatibleProvider<string> {
 			defaultTemperature: ZAI_DEFAULT_TEMPERATURE,
 		})
 	}
-
-	protected override createStream(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		metadata?: ApiHandlerCreateMessageMetadata,
-		requestOptions?: OpenAI.RequestOptions,
-	) {
-		const { id: model, info } = this.getModel()
-
-		// Centralized cap: clamp to 20% of the context window (unless provider-specific exceptions apply)
-		const max_tokens =
-			getModelMaxOutputTokens({
-				modelId: model,
-				model: info,
-				settings: this.options,
-				format: "openai",
-			}) ?? undefined
-
-		const temperature = this.options.modelTemperature ?? this.defaultTemperature
-
-		const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
-			model,
-			max_tokens,
-			temperature,
-			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
-			stream: true,
-			stream_options: { include_usage: true },
-		}
-
-		// Add thinking parameter if reasoning is enabled and model supports it
-		const { id: modelId, info: modelInfo } = this.getModel()
-		if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) {
-			;(params as any).thinking = { type: "enabled" }
-		}
-
-		try {
-			return this.client.chat.completions.create(params, requestOptions)
-		} catch (error) {
-			throw handleOpenAIError(error, this.providerName)
-		}
-	}
-
-	override async completePrompt(prompt: string): Promise<string> {
-		const { id: modelId } = this.getModel()
-
-		const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = {
-			model: modelId,
-			messages: [{ role: "user", content: prompt }],
-		}
-
-		// Add thinking parameter if reasoning is enabled and model supports it
-		const { info: modelInfo } = this.getModel()
-		if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) {
-			;(params as any).thinking = { type: "enabled" }
-		}
-
-		try {
-			const response = await this.client.chat.completions.create(params)
-			return response.choices[0]?.message.content || ""
-		} catch (error) {
-			throw handleOpenAIError(error, this.providerName)
-		}
-	}
 }