Просмотр исходного кода

fix: emit tool_call_end events in BaseOpenAiCompatibleProvider (#10293)

Hannes Rudolph 6 дней назад
Родитель
Сommit
a8ac2ced02

+ 162 - 0
src/api/providers/__tests__/base-openai-compatible-provider.spec.ts

@@ -383,4 +383,166 @@ describe("BaseOpenAiCompatibleProvider", () => {
 			expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 50 })
 		})
 	})
+
+	describe("Tool call handling", () => {
+		it("should yield tool_call_end events when finish_reason is tool_calls", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														id: "call_123",
+														function: { name: "test_tool", arguments: '{"arg":' },
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														function: { arguments: '"value"}' },
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {},
+											finish_reason: "tool_calls",
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Should have tool_call_partial and tool_call_end
+			const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
+			const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
+
+			expect(partialChunks).toHaveLength(2)
+			expect(endChunks).toHaveLength(1)
+			expect(endChunks[0]).toEqual({ type: "tool_call_end", id: "call_123" })
+		})
+
+		it("should yield multiple tool_call_end events for parallel tool calls", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														id: "call_001",
+														function: { name: "tool_a", arguments: "{}" },
+													},
+													{
+														index: 1,
+														id: "call_002",
+														function: { name: "tool_b", arguments: "{}" },
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {},
+											finish_reason: "tool_calls",
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
+			expect(endChunks).toHaveLength(2)
+			expect(endChunks.map((c: any) => c.id).sort()).toEqual(["call_001", "call_002"])
+		})
+
+		it("should not yield tool_call_end when finish_reason is not tool_calls", async () => {
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: { content: "Some text response" },
+											finish_reason: "stop",
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handler.createMessage("system prompt", [])
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
+			expect(endChunks).toHaveLength(0)
+		})
+	})
 })

+ 14 - 0
src/api/providers/base-openai-compatible-provider.ts

@@ -129,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 		)
 
 		let lastUsage: OpenAI.CompletionUsage | undefined
+		const activeToolCallIds = new Set<string>()
 
 		for await (const chunk of stream) {
 			// Check for provider-specific error responses (e.g., MiniMax base_resp)
@@ -140,6 +141,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			}
 
 			const delta = chunk.choices?.[0]?.delta
+			const finishReason = chunk.choices?.[0]?.finish_reason
 
 			if (delta?.content) {
 				for (const processedChunk of matcher.update(delta.content)) {
@@ -162,6 +164,9 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			// Emit raw tool call chunks - NativeToolCallParser handles state management
 			if (delta?.tool_calls) {
 				for (const toolCall of delta.tool_calls) {
+					if (toolCall.id) {
+						activeToolCallIds.add(toolCall.id)
+					}
 					yield {
 						type: "tool_call_partial",
 						index: toolCall.index,
@@ -172,6 +177,15 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 				}
 			}
 
+			// Emit tool_call_end events when finish_reason is "tool_calls"
+			// This ensures tool calls are finalized even if the stream doesn't properly close
+			if (finishReason === "tool_calls" && activeToolCallIds.size > 0) {
+				for (const id of activeToolCallIds) {
+					yield { type: "tool_call_end", id }
+				}
+				activeToolCallIds.clear()
+			}
+
 			if (chunk.usage) {
 				lastUsage = chunk.usage
 			}