Browse Source

fix: process finish_reason to emit tool_call_end events (#9927)

Daniel 1 month ago
parent
commit
88a0bed27f

+ 73 - 0
src/api/providers/__tests__/openrouter.spec.ts

@@ -283,6 +283,79 @@ describe("OpenRouterHandler", () => {
 			const generator = handler.createMessage("test", [])
 			await expect(generator.next()).rejects.toThrow("OpenRouter API Error 500: API Error")
 		})
+
+		it("yields tool_call_end events when finish_reason is tool_calls", async () => {
+			// Import NativeToolCallParser to set up state
+			const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser")
+
+			// Clear any previous state
+			NativeToolCallParser.clearRawChunkState()
+
+			const handler = new OpenRouterHandler(mockOptions)
+
+			const mockStream = {
+				async *[Symbol.asyncIterator]() {
+					yield {
+						id: "test-id",
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											id: "call_openrouter_test",
+											function: { name: "read_file", arguments: '{"path":"test.ts"}' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						id: "test-id",
+						choices: [
+							{
+								delta: {},
+								finish_reason: "tool_calls",
+								index: 0,
+							},
+						],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				},
+			}
+
+			const mockCreate = vitest.fn().mockResolvedValue(mockStream)
+			;(OpenAI as any).prototype.chat = {
+				completions: { create: mockCreate },
+			} as any
+
+			const generator = handler.createMessage("test", [])
+			const chunks = []
+
+			for await (const chunk of generator) {
+				// Simulate what Task.ts does: when we receive tool_call_partial,
+				// process it through NativeToolCallParser to populate rawChunkTracker
+				if (chunk.type === "tool_call_partial") {
+					NativeToolCallParser.processRawChunk({
+						index: chunk.index,
+						id: chunk.id,
+						name: chunk.name,
+						arguments: chunk.arguments,
+					})
+				}
+				chunks.push(chunk)
+			}
+
+			// Should have tool_call_partial and tool_call_end
+			const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
+			const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
+
+			expect(partialChunks).toHaveLength(1)
+			expect(endChunks).toHaveLength(1)
+			expect(endChunks[0].id).toBe("call_openrouter_test")
+		})
 	})
 
 	describe("completePrompt", () => {

+ 63 - 0
src/api/providers/__tests__/roo.spec.ts

@@ -1002,5 +1002,68 @@ describe("RooHandler", () => {
 			const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
 			expect(rawChunks).toHaveLength(0)
 		})
+
+		it("should yield tool_call_end events when finish_reason is tool_calls", async () => {
+			// Import NativeToolCallParser to set up state
+			const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser")
+
+			// Clear any previous state
+			NativeToolCallParser.clearRawChunkState()
+
+			mockCreate.mockResolvedValueOnce({
+				[Symbol.asyncIterator]: async function* () {
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											id: "call_finish_test",
+											function: { name: "read_file", arguments: '{"path":"test.ts"}' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {},
+								finish_reason: "tool_calls",
+								index: 0,
+							},
+						],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				},
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				// Simulate what Task.ts does: when we receive tool_call_partial,
+				// process it through NativeToolCallParser to populate rawChunkTracker
+				if (chunk.type === "tool_call_partial") {
+					NativeToolCallParser.processRawChunk({
+						index: chunk.index,
+						id: chunk.id,
+						name: chunk.name,
+						arguments: chunk.arguments,
+					})
+				}
+				chunks.push(chunk)
+			}
+
+			// Should have tool_call_partial and tool_call_end
+			const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
+			const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
+
+			expect(partialChunks).toHaveLength(1)
+			expect(endChunks).toHaveLength(1)
+			expect(endChunks[0].id).toBe("call_finish_test")
+		})
 	})
 })

+ 11 - 0
src/api/providers/openrouter.ts

@@ -9,6 +9,8 @@ import {
 	DEEP_SEEK_DEFAULT_TEMPERATURE,
 } from "@roo-code/types"
 
+import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser"
+
 import type { ApiHandlerOptions, ModelRecord } from "../../shared/api"
 
 import { convertToOpenAiMessages } from "../transform/openai-format"
@@ -341,6 +343,15 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
 				}
 			}
 
+			// Process finish_reason to emit tool_call_end events
+			// This ensures tool calls are finalized even if the stream doesn't properly close
+			if (finishReason) {
+				const endEvents = NativeToolCallParser.processFinishReason(finishReason)
+				for (const event of endEvents) {
+					yield event
+				}
+			}
+
 			if (chunk.usage) {
 				lastUsage = chunk.usage
 			}

+ 9 - 0
src/api/providers/roo.ts

@@ -2,6 +2,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
 
 import { rooDefaultModelId, getApiProtocol, type ImageGenerationApiMethod } from "@roo-code/types"
+import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser"
 import { CloudService } from "@roo-code/cloud"
 
 import { Package } from "../../shared/package"
@@ -157,6 +158,7 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
 
 			for await (const chunk of stream) {
 				const delta = chunk.choices[0]?.delta
+				const finishReason = chunk.choices[0]?.finish_reason
 
 				if (delta) {
 					// Handle reasoning_details array format (used by Gemini 3, Claude, OpenAI o-series, etc.)
@@ -258,6 +260,13 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
 					}
 				}
 
+				if (finishReason) {
+					const endEvents = NativeToolCallParser.processFinishReason(finishReason)
+					for (const event of endEvents) {
+						yield event
+					}
+				}
+
 				if (chunk.usage) {
 					lastUsage = chunk.usage as RooUsage
 				}