Просмотр исходного кода

fix: add fallback to yield tool calls regardless of finish_reason (#9476)

Daniel 1 месяц назад
Родитель
Сommit
a65e7807b1

+ 102 - 0
src/api/providers/__tests__/openai.spec.ts

@@ -279,6 +279,56 @@ describe("OpenAiHandler", () => {
 			})
 		})
 
+		it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => {
+			mockCreate.mockImplementation(async (options) => {
+				return {
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [
+											{
+												index: 0,
+												id: "call_fallback",
+												function: { name: "fallback_tool", arguments: '{"test":"fallback"}' },
+											},
+										],
+									},
+									finish_reason: null,
+								},
+							],
+						}
+						// Stream ends without finish_reason being set to "tool_calls"
+						yield {
+							choices: [
+								{
+									delta: {},
+									finish_reason: "stop", // Different finish reason
+								},
+							],
+						}
+					},
+				}
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Tool calls should still be yielded via the fallback mechanism
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0]).toEqual({
+				type: "tool_call",
+				id: "call_fallback",
+				name: "fallback_tool",
+				arguments: '{"test":"fallback"}',
+			})
+		})
+
 		it("should include reasoning_effort when reasoning effort is enabled", async () => {
 			const reasoningOptions: ApiHandlerOptions = {
 				...mockOptions,
@@ -779,6 +829,58 @@ describe("OpenAiHandler", () => {
 			})
 		})
 
+		it("should yield tool calls for O3 model even when finish_reason is not set (fallback behavior)", async () => {
+			const o3Handler = new OpenAiHandler(o3Options)
+
+			mockCreate.mockImplementation(async (options) => {
+				return {
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [
+											{
+												index: 0,
+												id: "call_o3_fallback",
+												function: { name: "o3_fallback_tool", arguments: '{"o3":"test"}' },
+											},
+										],
+									},
+									finish_reason: null,
+								},
+							],
+						}
+						// Stream ends with different finish reason
+						yield {
+							choices: [
+								{
+									delta: {},
+									finish_reason: "length", // Different finish reason
+								},
+							],
+						}
+					},
+				}
+			})
+
+			const stream = o3Handler.createMessage("system", [])
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Tool calls should still be yielded via the fallback mechanism
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0]).toEqual({
+				type: "tool_call",
+				id: "call_o3_fallback",
+				name: "o3_fallback_tool",
+				arguments: '{"o3":"test"}',
+			})
+		})
+
 		it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
 			const o3Handler = new OpenAiHandler({
 				...o3Options,

+ 276 - 0
src/api/providers/__tests__/roo.spec.ts

@@ -630,4 +630,280 @@ describe("RooHandler", () => {
 			)
 		})
 	})
+
+	describe("tool calls handling", () => {
+		beforeEach(() => {
+			handler = new RooHandler(mockOptions)
+		})
+
+		it("should yield tool calls when finish_reason is tool_calls", async () => {
+			mockCreate.mockResolvedValueOnce({
+				[Symbol.asyncIterator]: async function* () {
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											id: "call_123",
+											function: { name: "read_file", arguments: '{"path":"' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											function: { arguments: 'test.ts"}' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {},
+								finish_reason: "tool_calls",
+								index: 0,
+							},
+						],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				},
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0].id).toBe("call_123")
+			expect(toolCallChunks[0].name).toBe("read_file")
+			expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts"}')
+		})
+
+		it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => {
+			mockCreate.mockResolvedValueOnce({
+				[Symbol.asyncIterator]: async function* () {
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											id: "call_456",
+											function: {
+												name: "write_to_file",
+												arguments: '{"path":"test.ts","content":"hello"}',
+											},
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					// Stream ends without finish_reason being set to "tool_calls"
+					yield {
+						choices: [
+							{
+								delta: {},
+								finish_reason: "stop", // Different finish reason
+								index: 0,
+							},
+						],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				},
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Tool calls should still be yielded via the fallback mechanism
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0].id).toBe("call_456")
+			expect(toolCallChunks[0].name).toBe("write_to_file")
+			expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts","content":"hello"}')
+		})
+
+		it("should handle multiple tool calls", async () => {
+			mockCreate.mockResolvedValueOnce({
+				[Symbol.asyncIterator]: async function* () {
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											id: "call_1",
+											function: { name: "read_file", arguments: '{"path":"file1.ts"}' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 1,
+											id: "call_2",
+											function: { name: "read_file", arguments: '{"path":"file2.ts"}' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {},
+								finish_reason: "tool_calls",
+								index: 0,
+							},
+						],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				},
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(2)
+			expect(toolCallChunks[0].id).toBe("call_1")
+			expect(toolCallChunks[0].name).toBe("read_file")
+			expect(toolCallChunks[1].id).toBe("call_2")
+			expect(toolCallChunks[1].name).toBe("read_file")
+		})
+
+		it("should accumulate tool call arguments across multiple chunks", async () => {
+			mockCreate.mockResolvedValueOnce({
+				[Symbol.asyncIterator]: async function* () {
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											id: "call_789",
+											function: { name: "execute_command", arguments: '{"command":"' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											function: { arguments: "npm install" },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {
+									tool_calls: [
+										{
+											index: 0,
+											function: { arguments: '"}' },
+										},
+									],
+								},
+								index: 0,
+							},
+						],
+					}
+					yield {
+						choices: [
+							{
+								delta: {},
+								finish_reason: "tool_calls",
+								index: 0,
+							},
+						],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				},
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0].id).toBe("call_789")
+			expect(toolCallChunks[0].name).toBe("execute_command")
+			expect(toolCallChunks[0].arguments).toBe('{"command":"npm install"}')
+		})
+
+		it("should not yield empty tool calls when no tool calls present", async () => {
+			mockCreate.mockResolvedValueOnce({
+				[Symbol.asyncIterator]: async function* () {
+					yield {
+						choices: [{ delta: { content: "Regular text response" }, index: 0 }],
+					}
+					yield {
+						choices: [{ delta: {}, finish_reason: "stop", index: 0 }],
+						usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
+					}
+				},
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(0)
+		})
+	})
 })

+ 14 - 0
src/api/providers/base-openai-compatible-provider.ts

@@ -184,6 +184,20 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			}
 		}
 
+		// Fallback: If stream ends with accumulated tool calls that weren't yielded
+		// (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls')
+		if (toolCallAccumulator.size > 0) {
+			for (const toolCall of toolCallAccumulator.values()) {
+				yield {
+					type: "tool_call",
+					id: toolCall.id,
+					name: toolCall.name,
+					arguments: toolCall.arguments,
+				}
+			}
+			toolCallAccumulator.clear()
+		}
+
 		// Process any remaining content
 		for (const processedChunk of matcher.final()) {
 			yield processedChunk

+ 28 - 0
src/api/providers/openai.ts

@@ -246,6 +246,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				}
 			}
 
+			// Fallback: If stream ends with accumulated tool calls that weren't yielded
+			// (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls')
+			if (toolCallAccumulator.size > 0) {
+				for (const toolCall of toolCallAccumulator.values()) {
+					yield {
+						type: "tool_call",
+						id: toolCall.id,
+						name: toolCall.name,
+						arguments: toolCall.arguments,
+					}
+				}
+				toolCallAccumulator.clear()
+			}
+
 			for (const chunk of matcher.final()) {
 				yield chunk
 			}
@@ -506,6 +520,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				}
 			}
 		}
+
+		// Fallback: If stream ends with accumulated tool calls that weren't yielded
+		// (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls')
+		if (toolCallAccumulator.size > 0) {
+			for (const toolCall of toolCallAccumulator.values()) {
+				yield {
+					type: "tool_call",
+					id: toolCall.id,
+					name: toolCall.name,
+					arguments: toolCall.arguments,
+				}
+			}
+			toolCallAccumulator.clear()
+		}
 	}
 
 	private _getUrlHost(baseUrl?: string): string {

+ 14 - 0
src/api/providers/openrouter.ts

@@ -265,6 +265,20 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
 			}
 		}
 
+		// Fallback: If stream ends with accumulated tool calls that weren't yielded
+		// (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls')
+		if (toolCallAccumulator.size > 0) {
+			for (const toolCall of toolCallAccumulator.values()) {
+				yield {
+					type: "tool_call",
+					id: toolCall.id,
+					name: toolCall.name,
+					arguments: toolCall.arguments,
+				}
+			}
+			toolCallAccumulator.clear()
+		}
+
 		if (lastUsage) {
 			yield {
 				type: "usage",

+ 14 - 0
src/api/providers/roo.ts

@@ -199,6 +199,20 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
 				}
 			}
 
+			// Fallback: If stream ends with accumulated tool calls that weren't yielded
+			// (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls')
+			if (toolCallAccumulator.size > 0) {
+				for (const [index, toolCall] of toolCallAccumulator.entries()) {
+					yield {
+						type: "tool_call",
+						id: toolCall.id,
+						name: toolCall.name,
+						arguments: toolCall.arguments,
+					}
+				}
+				toolCallAccumulator.clear()
+			}
+
 			if (lastUsage) {
 				// Check if the current model is marked as free
 				const model = this.getModel()