Просмотр исходного кода

Handle <think> tags in the base OpenAI-compatible provider (#8989)

Matt Rubens 1 месяц назад
Родитель
Сommit
613255c09a

+ 286 - 0
src/api/providers/__tests__/base-openai-compatible-provider.spec.ts

@@ -0,0 +1,286 @@
+// npx vitest run api/providers/__tests__/base-openai-compatible-provider.spec.ts
+
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
+
+import type { ModelInfo } from "@roo-code/types"
+
+import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider"
+
+// Create mock functions
+const mockCreate = vi.fn()
+
+// Mock OpenAI module
+vi.mock("openai", () => ({
+	default: vi.fn(() => ({
+		chat: {
+			completions: {
+				create: mockCreate,
+			},
+		},
+	})),
+}))
+
+// Create a concrete test implementation of the abstract base class
+class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model"> {
+	constructor(apiKey: string) {
+		const testModels: Record<"test-model", ModelInfo> = {
+			"test-model": {
+				maxTokens: 4096,
+				contextWindow: 128000,
+				supportsImages: false,
+				supportsPromptCache: false,
+				inputPrice: 0.5,
+				outputPrice: 1.5,
+			},
+		}
+
+		super({
+			providerName: "TestProvider",
+			baseURL: "https://test.example.com/v1",
+			defaultProviderModelId: "test-model",
+			providerModels: testModels,
+			apiKey,
+		})
+	}
+}
+
+describe("BaseOpenAiCompatibleProvider", () => {
+	let handler: TestOpenAiCompatibleProvider
+
+	beforeEach(() => {
+		vi.clearAllMocks()
+		handler = new TestOpenAiCompatibleProvider("test-api-key")
+	})
+
+	afterEach(() => {
+		vi.restoreAllMocks()
+	})
+
+	describe("XmlMatcher reasoning tags", () => {
+		it("should handle reasoning tags (<think>) from stream", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: "<think>Let me think" } }] },
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: " about this</think>" } }] },
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: "The answer is 42" } }] },
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// XmlMatcher yields chunks as they're processed
+			expect(chunks).toEqual([
+				{ type: "reasoning", text: "Let me think" },
+				{ type: "reasoning", text: " about this" },
+				{ type: "text", text: "The answer is 42" },
+			])
+		})
+
+		it("should handle complete <think> tag in a single chunk", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: "Regular text before " } }] },
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: "<think>Complete thought</think>" } }] },
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: " regular text after" } }] },
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// When a complete tag arrives in one chunk, XmlMatcher may not parse it
+			// This test documents the actual behavior
+			expect(chunks.length).toBeGreaterThan(0)
+			expect(chunks[0]).toEqual({ type: "text", text: "Regular text before " })
+		})
+
+		it("should handle incomplete <think> tag at end of stream", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: "<think>Incomplete thought" } }] },
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// XmlMatcher should handle incomplete tags and flush remaining content
+			expect(chunks.length).toBeGreaterThan(0)
+			expect(
+				chunks.some(
+					(c) => (c.type === "text" || c.type === "reasoning") && c.text.includes("Incomplete thought"),
+				),
+			).toBe(true)
+		})
+
+		it("should handle text without any <think> tags", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: "Just regular text" } }] },
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: " without reasoning" } }] },
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks).toEqual([
+				{ type: "text", text: "Just regular text" },
+				{ type: "text", text: " without reasoning" },
+			])
+		})
+
+		it("should handle <think> tags that start at beginning of stream", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: "<think>reasoning" } }] },
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: " content</think>" } }] },
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: { choices: [{ delta: { content: " normal text" } }] },
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks).toEqual([
+				{ type: "reasoning", text: "reasoning" },
+				{ type: "reasoning", text: " content" },
+				{ type: "text", text: " normal text" },
+			])
+		})
+	})
+
+	describe("Basic functionality", () => {
+		it("should create stream with correct parameters", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				}
+			})
+
+			const systemPrompt = "Test system prompt"
+			const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
+
+			const messageGenerator = handler.createMessage(systemPrompt, messages)
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					model: "test-model",
+					temperature: 0,
+					messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
+					stream: true,
+					stream_options: { include_usage: true },
+				}),
+				undefined,
+			)
+		})
+
+		it("should yield usage data from stream", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [{ delta: {} }],
+									usage: { prompt_tokens: 100, completion_tokens: 50 },
+								},
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const firstChunk = await stream.next()
+
+			expect(firstChunk.done).toBe(false)
+			expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 100, outputTokens: 50 })
+		})
+	})
+})

+ 0 - 37
src/api/providers/__tests__/minimax.spec.ts

@@ -178,43 +178,6 @@ describe("MiniMaxHandler", () => {
 			expect(firstChunk.value).toEqual({ type: "text", text: testContent })
 		})
 
-		it("should handle reasoning tags (<think>) from stream", async () => {
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						next: vitest
-							.fn()
-							.mockResolvedValueOnce({
-								done: false,
-								value: { choices: [{ delta: { content: "<think>Let me think" } }] },
-							})
-							.mockResolvedValueOnce({
-								done: false,
-								value: { choices: [{ delta: { content: " about this</think>" } }] },
-							})
-							.mockResolvedValueOnce({
-								done: false,
-								value: { choices: [{ delta: { content: "The answer is 42" } }] },
-							})
-							.mockResolvedValueOnce({ done: true }),
-					}),
-				}
-			})
-
-			const stream = handler.createMessage("system prompt", [])
-			const chunks = []
-			for await (const chunk of stream) {
-				chunks.push(chunk)
-			}
-
-			// XmlMatcher yields chunks as they're processed
-			expect(chunks).toEqual([
-				{ type: "reasoning", text: "Let me think" },
-				{ type: "reasoning", text: " about this" },
-				{ type: "text", text: "The answer is 42" },
-			])
-		})
-
 		it("createMessage should yield usage data from stream", async () => {
 			mockCreate.mockImplementationOnce(() => {
 				return {

+ 17 - 3
src/api/providers/base-openai-compatible-provider.ts

@@ -4,6 +4,7 @@ import OpenAI from "openai"
 import type { ModelInfo } from "@roo-code/types"
 
 import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api"
+import { XmlMatcher } from "../../utils/xml-matcher"
 import { ApiStream } from "../transform/stream"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 
@@ -105,13 +106,21 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 	): ApiStream {
 		const stream = await this.createStream(systemPrompt, messages, metadata)
 
+		const matcher = new XmlMatcher(
+			"think",
+			(chunk) =>
+				({
+					type: chunk.matched ? "reasoning" : "text",
+					text: chunk.data,
+				}) as const,
+		)
+
 		for await (const chunk of stream) {
 			const delta = chunk.choices[0]?.delta
 
 			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
+				for (const processedChunk of matcher.update(delta.content)) {
+					yield processedChunk
 				}
 			}
 
@@ -127,6 +136,11 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 				}
 			}
 		}
+
+		// Process any remaining content
+		for (const processedChunk of matcher.final()) {
+			yield processedChunk
+		}
 	}
 
 	async completePrompt(prompt: string): Promise<string> {

+ 0 - 43
src/api/providers/minimax.ts

@@ -1,10 +1,6 @@
-import { Anthropic } from "@anthropic-ai/sdk"
 import { type MinimaxModelId, minimaxDefaultModelId, minimaxModels } from "@roo-code/types"
 
 import type { ApiHandlerOptions } from "../../shared/api"
-import { XmlMatcher } from "../../utils/xml-matcher"
-import { ApiStream } from "../transform/stream"
-import type { ApiHandlerCreateMessageMetadata } from "../index"
 
 import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
 
@@ -20,43 +16,4 @@ export class MiniMaxHandler extends BaseOpenAiCompatibleProvider<MinimaxModelId>
 			defaultTemperature: 1.0,
 		})
 	}
-
-	override async *createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		metadata?: ApiHandlerCreateMessageMetadata,
-	): ApiStream {
-		const stream = await this.createStream(systemPrompt, messages, metadata)
-
-		const matcher = new XmlMatcher(
-			"think",
-			(chunk) =>
-				({
-					type: chunk.matched ? "reasoning" : "text",
-					text: chunk.data,
-				}) as const,
-		)
-
-		for await (const chunk of stream) {
-			const delta = chunk.choices[0]?.delta
-
-			if (delta?.content) {
-				for (const matcherChunk of matcher.update(delta.content)) {
-					yield matcherChunk
-				}
-			}
-
-			if (chunk.usage) {
-				yield {
-					type: "usage",
-					inputTokens: chunk.usage.prompt_tokens || 0,
-					outputTokens: chunk.usage.completion_tokens || 0,
-				}
-			}
-		}
-
-		for (const chunk of matcher.final()) {
-			yield chunk
-		}
-	}
 }