Просмотр исходного кода

fix: emit tool_call_end events in OpenAI handler when streaming ends (#10280)

Daniel 1 неделя назад
Родитель
Сommit
9b06a98b85
2 измененных файлов с 55 добавлено и 23 удалено
  1. 8 0
      src/api/providers/__tests__/openai.spec.ts
  2. 47 23
      src/api/providers/openai.ts

+ 8 - 0
src/api/providers/__tests__/openai.spec.ts

@@ -295,6 +295,10 @@ describe("OpenAiHandler", () => {
 				name: undefined,
 				arguments: '"value"}',
 			})
+
+			// Verify tool_call_end event is emitted when finish_reason is "tool_calls"
+			const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
+			expect(toolCallEndChunks).toHaveLength(1)
 		})
 
 		it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => {
@@ -855,6 +859,10 @@ describe("OpenAiHandler", () => {
 				name: undefined,
 				arguments: "{}",
 			})
+
+			// Verify tool_call_end event is emitted when finish_reason is "tool_calls"
+			const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
+			expect(toolCallEndChunks).toHaveLength(1)
 		})
 
 		it("should yield tool calls for O3 model even when finish_reason is not set (fallback behavior)", async () => {

+ 47 - 23
src/api/providers/openai.ts

@@ -194,9 +194,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 			)
 
 			let lastUsage
+			const activeToolCallIds = new Set<string>()
 
 			for await (const chunk of stream) {
 				const delta = chunk.choices?.[0]?.delta ?? {}
+				const finishReason = chunk.choices?.[0]?.finish_reason
 
 				if (delta.content) {
 					for (const chunk of matcher.update(delta.content)) {
@@ -211,17 +213,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 					}
 				}
 
-				if (delta.tool_calls) {
-					for (const toolCall of delta.tool_calls) {
-						yield {
-							type: "tool_call_partial",
-							index: toolCall.index,
-							id: toolCall.id,
-							name: toolCall.function?.name,
-							arguments: toolCall.function?.arguments,
-						}
-					}
-				}
+				yield* this.processToolCalls(delta, finishReason, activeToolCallIds)
 
 				if (chunk.usage) {
 					lastUsage = chunk.usage
@@ -443,8 +435,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 	}
 
 	private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
+		const activeToolCallIds = new Set<string>()
+
 		for await (const chunk of stream) {
 			const delta = chunk.choices?.[0]?.delta
+			const finishReason = chunk.choices?.[0]?.finish_reason
 
 			if (delta) {
 				if (delta.content) {
@@ -454,18 +449,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 					}
 				}
 
-				// Emit raw tool call chunks - NativeToolCallParser handles state management
-				if (delta.tool_calls) {
-					for (const toolCall of delta.tool_calls) {
-						yield {
-							type: "tool_call_partial",
-							index: toolCall.index,
-							id: toolCall.id,
-							name: toolCall.function?.name,
-							arguments: toolCall.function?.arguments,
-						}
-					}
-				}
+				yield* this.processToolCalls(delta, finishReason, activeToolCallIds)
 			}
 
 			if (chunk.usage) {
@@ -478,6 +462,46 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 		}
 	}
 
+	/**
+	 * Helper generator to process tool calls from a stream chunk.
+	 * Tracks active tool call IDs and yields tool_call_partial and tool_call_end events.
+	 * @param delta - The delta object from the stream chunk
+	 * @param finishReason - The finish_reason from the stream chunk
+	 * @param activeToolCallIds - Set to track active tool call IDs (mutated in place)
+	 */
+	private *processToolCalls(
+		delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta | undefined,
+		finishReason: string | null | undefined,
+		activeToolCallIds: Set<string>,
+	): Generator<
+		| { type: "tool_call_partial"; index: number; id?: string; name?: string; arguments?: string }
+		| { type: "tool_call_end"; id: string }
+	> {
+		if (delta?.tool_calls) {
+			for (const toolCall of delta.tool_calls) {
+				if (toolCall.id) {
+					activeToolCallIds.add(toolCall.id)
+				}
+				yield {
+					type: "tool_call_partial",
+					index: toolCall.index,
+					id: toolCall.id,
+					name: toolCall.function?.name,
+					arguments: toolCall.function?.arguments,
+				}
+			}
+		}
+
+		// Emit tool_call_end events when finish_reason is "tool_calls"
+		// This ensures tool calls are finalized even if the stream doesn't properly close
+		if (finishReason === "tool_calls" && activeToolCallIds.size > 0) {
+			for (const id of activeToolCallIds) {
+				yield { type: "tool_call_end", id }
+			}
+			activeToolCallIds.clear()
+		}
+	}
+
 	protected _getUrlHost(baseUrl?: string): string {
 		try {
 			return new URL(baseUrl ?? "").host