Przeglądaj źródła

Add native tool calling support to OpenAI-compatible (#9369)

* Add native tool calling support to OpenAI-compatible

* Fix OpenAI strict mode schema validation by adding converter methods to BaseProvider

- Add convertToolsForOpenAI() and convertToolSchemaForOpenAI() methods to BaseProvider
- These methods ensure all properties are in required array and convert nullable types
- Remove line_ranges from required array in read_file tool (converter handles it)
- Update OpenAiHandler and BaseOpenAiCompatibleProvider to use helper methods
- Eliminates code duplication across multiple tool usage sites
- Fixes: OpenAI completion error: 400 Invalid schema for function 'read_file'

---------

Co-authored-by: daniel-lxs <[email protected]>
Matt Rubens 1 miesiąc temu
rodzic
commit
bc6fad1f9d

+ 1 - 0
packages/types/src/providers/openai.ts

@@ -436,6 +436,7 @@ export const openAiModelInfoSaneDefaults: ModelInfo = {
 	supportsPromptCache: false,
 	inputPrice: 0,
 	outputPrice: 0,
+	supportsNativeTools: true,
 }
 
 // https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation

+ 210 - 0
src/api/providers/__tests__/openai.spec.ts

@@ -157,6 +157,55 @@ describe("OpenAiHandler", () => {
 			expect(usageChunk?.outputTokens).toBe(5)
 		})
 
+		it("should handle tool calls in non-streaming mode", async () => {
+			mockCreate.mockResolvedValueOnce({
+				choices: [
+					{
+						message: {
+							role: "assistant",
+							content: null,
+							tool_calls: [
+								{
+									id: "call_1",
+									type: "function",
+									function: {
+										name: "test_tool",
+										arguments: '{"arg":"value"}',
+									},
+								},
+							],
+						},
+						finish_reason: "tool_calls",
+					},
+				],
+				usage: {
+					prompt_tokens: 10,
+					completion_tokens: 5,
+					total_tokens: 15,
+				},
+			})
+
+			const handler = new OpenAiHandler({
+				...mockOptions,
+				openAiStreamingEnabled: false,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0]).toEqual({
+				type: "tool_call",
+				id: "call_1",
+				name: "test_tool",
+				arguments: '{"arg":"value"}',
+			})
+		})
+
 		it("should handle streaming responses", async () => {
 			const stream = handler.createMessage(systemPrompt, messages)
 			const chunks: any[] = []
@@ -170,6 +219,66 @@ describe("OpenAiHandler", () => {
 			expect(textChunks[0].text).toBe("Test response")
 		})
 
+		it("should handle tool calls in streaming responses", async () => {
+			mockCreate.mockImplementation(async (options) => {
+				return {
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [
+											{
+												index: 0,
+												id: "call_1",
+												function: { name: "test_tool", arguments: "" },
+											},
+										],
+									},
+									finish_reason: null,
+								},
+							],
+						}
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [{ index: 0, function: { arguments: '{"arg":' } }],
+									},
+									finish_reason: null,
+								},
+							],
+						}
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [{ index: 0, function: { arguments: '"value"}' } }],
+									},
+									finish_reason: "tool_calls",
+								},
+							],
+						}
+					},
+				}
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0]).toEqual({
+				type: "tool_call",
+				id: "call_1",
+				name: "test_tool",
+				arguments: '{"arg":"value"}',
+			})
+		})
+
 		it("should include reasoning_effort when reasoning effort is enabled", async () => {
 			const reasoningOptions: ApiHandlerOptions = {
 				...mockOptions,
@@ -618,6 +727,58 @@ describe("OpenAiHandler", () => {
 			)
 		})
 
+		it("should handle tool calls with O3 model in streaming mode", async () => {
+			const o3Handler = new OpenAiHandler(o3Options)
+
+			mockCreate.mockImplementation(async (options) => {
+				return {
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [
+											{
+												index: 0,
+												id: "call_1",
+												function: { name: "test_tool", arguments: "" },
+											},
+										],
+									},
+									finish_reason: null,
+								},
+							],
+						}
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [{ index: 0, function: { arguments: "{}" } }],
+									},
+									finish_reason: "tool_calls",
+								},
+							],
+						}
+					},
+				}
+			})
+
+			const stream = o3Handler.createMessage("system", [])
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0]).toEqual({
+				type: "tool_call",
+				id: "call_1",
+				name: "test_tool",
+				arguments: "{}",
+			})
+		})
+
 		it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
 			const o3Handler = new OpenAiHandler({
 				...o3Options,
@@ -705,6 +866,55 @@ describe("OpenAiHandler", () => {
 			expect(callArgs).not.toHaveProperty("stream")
 		})
 
+		it("should handle tool calls with O3 model in non-streaming mode", async () => {
+			const o3Handler = new OpenAiHandler({
+				...o3Options,
+				openAiStreamingEnabled: false,
+			})
+
+			mockCreate.mockResolvedValueOnce({
+				choices: [
+					{
+						message: {
+							role: "assistant",
+							content: null,
+							tool_calls: [
+								{
+									id: "call_1",
+									type: "function",
+									function: {
+										name: "test_tool",
+										arguments: "{}",
+									},
+								},
+							],
+						},
+						finish_reason: "tool_calls",
+					},
+				],
+				usage: {
+					prompt_tokens: 10,
+					completion_tokens: 5,
+					total_tokens: 15,
+				},
+			})
+
+			const stream = o3Handler.createMessage("system", [])
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
+			expect(toolCallChunks).toHaveLength(1)
+			expect(toolCallChunks[0]).toEqual({
+				type: "tool_call",
+				id: "call_1",
+				name: "test_tool",
+				arguments: "{}",
+			})
+		})
+
 		it("should use default temperature of 0 when not specified for O3 models", async () => {
 			const o3Handler = new OpenAiHandler({
 				...o3Options,

+ 36 - 0
src/api/providers/base-openai-compatible-provider.ts

@@ -90,6 +90,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
 			stream: true,
 			stream_options: { include_usage: true },
+			...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+			...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 		}
 
 		try {
@@ -115,6 +117,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 				}) as const,
 		)
 
+		const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()
+
 		for await (const chunk of stream) {
 			// Check for provider-specific error responses (e.g., MiniMax base_resp)
 			const chunkAny = chunk as any
@@ -125,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 			}
 
 			const delta = chunk.choices?.[0]?.delta
+			const finishReason = chunk.choices?.[0]?.finish_reason
 
 			if (delta?.content) {
 				for (const processedChunk of matcher.update(delta.content)) {
@@ -139,6 +144,37 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
 				}
 			}
 
+			if (delta?.tool_calls) {
+				for (const toolCall of delta.tool_calls) {
+					const index = toolCall.index
+					const existing = toolCallAccumulator.get(index)
+
+					if (existing) {
+						if (toolCall.function?.arguments) {
+							existing.arguments += toolCall.function.arguments
+						}
+					} else {
+						toolCallAccumulator.set(index, {
+							id: toolCall.id || "",
+							name: toolCall.function?.name || "",
+							arguments: toolCall.function?.arguments || "",
+						})
+					}
+				}
+			}
+
+			if (finishReason === "tool_calls") {
+				for (const toolCall of toolCallAccumulator.values()) {
+					yield {
+						type: "tool_call",
+						id: toolCall.id,
+						name: toolCall.name,
+						arguments: toolCall.arguments,
+					}
+				}
+				toolCallAccumulator.clear()
+			}
+
 			if (chunk.usage) {
 				yield {
 					type: "usage",

+ 69 - 0
src/api/providers/base-provider.ts

@@ -18,6 +18,75 @@ export abstract class BaseProvider implements ApiHandler {
 
 	abstract getModel(): { id: string; info: ModelInfo }
 
+	/**
+	 * Converts an array of tools to be compatible with OpenAI's strict mode.
+	 * Filters for function tools and applies schema conversion to their parameters.
+	 */
+	protected convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined {
+		if (!tools) {
+			return undefined
+		}
+
+		return tools.map((tool) =>
+			tool.type === "function"
+				? {
+						...tool,
+						function: {
+							...tool.function,
+							parameters: this.convertToolSchemaForOpenAI(tool.function.parameters),
+						},
+					}
+				: tool,
+		)
+	}
+
+	/**
+	 * Converts tool schemas to be compatible with OpenAI's strict mode by:
+	 * - Ensuring all properties are in the required array (strict mode requirement)
+	 * - Converting nullable types (["type", "null"]) to non-nullable ("type")
+	 * - Recursively processing nested objects and arrays
+	 *
+	 * This matches the behavior of ensureAllRequired in openai-native.ts
+	 */
+	protected convertToolSchemaForOpenAI(schema: any): any {
+		if (!schema || typeof schema !== "object" || schema.type !== "object") {
+			return schema
+		}
+
+		const result = { ...schema }
+
+		if (result.properties) {
+			const allKeys = Object.keys(result.properties)
+			// OpenAI strict mode requires ALL properties to be in required array
+			result.required = allKeys
+
+			// Recursively process nested objects and convert nullable types
+			const newProps = { ...result.properties }
+			for (const key of allKeys) {
+				const prop = newProps[key]
+
+				// Handle nullable types by removing null
+				if (prop && Array.isArray(prop.type) && prop.type.includes("null")) {
+					const nonNullTypes = prop.type.filter((t: string) => t !== "null")
+					prop.type = nonNullTypes.length === 1 ? nonNullTypes[0] : nonNullTypes
+				}
+
+				// Recursively process nested objects
+				if (prop && prop.type === "object") {
+					newProps[key] = this.convertToolSchemaForOpenAI(prop)
+				} else if (prop && prop.type === "array" && prop.items?.type === "object") {
+					newProps[key] = {
+						...prop,
+						items: this.convertToolSchemaForOpenAI(prop.items),
+					}
+				}
+			}
+			result.properties = newProps
+		}
+
+		return result
+	}
+
 	/**
 	 * Default token counting implementation using tiktoken.
 	 * Providers can override this to use their native token counting endpoints.

+ 116 - 7
src/api/providers/openai.ts

@@ -95,7 +95,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 		const ark = modelUrl.includes(".volces.com")
 
 		if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
-			yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
+			yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, metadata)
 			return
 		}
 
@@ -164,6 +164,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				stream: true as const,
 				...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
 				...(reasoning && reasoning),
+				...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+				...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 			}
 
 			// Add max_tokens if needed
@@ -189,9 +191,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 			)
 
 			let lastUsage
+			const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()
 
 			for await (const chunk of stream) {
 				const delta = chunk.choices?.[0]?.delta ?? {}
+				const finishReason = chunk.choices?.[0]?.finish_reason
 
 				if (delta.content) {
 					for (const chunk of matcher.update(delta.content)) {
@@ -205,6 +209,38 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 						text: (delta.reasoning_content as string | undefined) || "",
 					}
 				}
+
+				if (delta.tool_calls) {
+					for (const toolCall of delta.tool_calls) {
+						const index = toolCall.index
+						const existing = toolCallAccumulator.get(index)
+
+						if (existing) {
+							if (toolCall.function?.arguments) {
+								existing.arguments += toolCall.function.arguments
+							}
+						} else {
+							toolCallAccumulator.set(index, {
+								id: toolCall.id || "",
+								name: toolCall.function?.name || "",
+								arguments: toolCall.function?.arguments || "",
+							})
+						}
+					}
+				}
+
+				if (finishReason === "tool_calls") {
+					for (const toolCall of toolCallAccumulator.values()) {
+						yield {
+							type: "tool_call",
+							id: toolCall.id,
+							name: toolCall.name,
+							arguments: toolCall.arguments,
+						}
+					}
+					toolCallAccumulator.clear()
+				}
+
 				if (chunk.usage) {
 					lastUsage = chunk.usage
 				}
@@ -225,6 +261,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 					: enabledLegacyFormat
 						? [systemMessage, ...convertToSimpleMessages(messages)]
 						: [systemMessage, ...convertToOpenAiMessages(messages)],
+				...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+				...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 			}
 
 			// Add max_tokens if needed
@@ -240,9 +278,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				throw handleOpenAIError(error, this.providerName)
 			}
 
+			const message = response.choices?.[0]?.message
+
+			if (message?.tool_calls) {
+				for (const toolCall of message.tool_calls) {
+					if (toolCall.type === "function") {
+						yield {
+							type: "tool_call",
+							id: toolCall.id,
+							name: toolCall.function.name,
+							arguments: toolCall.function.arguments,
+						}
+					}
+				}
+			}
+
 			yield {
 				type: "text",
-				text: response.choices?.[0]?.message.content || "",
+				text: message?.content || "",
 			}
 
 			yield this.processUsageMetrics(response.usage, modelInfo)
@@ -304,6 +357,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 		modelId: string,
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
+		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
 		const modelInfo = this.getModel().info
 		const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
@@ -324,6 +378,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
 				reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined,
 				temperature: undefined,
+				...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+				...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 			}
 
 			// O3 family models do not support the deprecated max_tokens parameter
@@ -354,6 +410,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				],
 				reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined,
 				temperature: undefined,
+				...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+				...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 			}
 
 			// O3 family models do not support the deprecated max_tokens parameter
@@ -371,22 +429,73 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				throw handleOpenAIError(error, this.providerName)
 			}
 
+			const message = response.choices?.[0]?.message
+			if (message?.tool_calls) {
+				for (const toolCall of message.tool_calls) {
+					if (toolCall.type === "function") {
+						yield {
+							type: "tool_call",
+							id: toolCall.id,
+							name: toolCall.function.name,
+							arguments: toolCall.function.arguments,
+						}
+					}
+				}
+			}
+
 			yield {
 				type: "text",
-				text: response.choices?.[0]?.message.content || "",
+				text: message?.content || "",
 			}
 			yield this.processUsageMetrics(response.usage)
 		}
 	}
 
 	private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
+		const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()
+
 		for await (const chunk of stream) {
 			const delta = chunk.choices?.[0]?.delta
-			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
+			const finishReason = chunk.choices?.[0]?.finish_reason
+
+			if (delta) {
+				if (delta.content) {
+					yield {
+						type: "text",
+						text: delta.content,
+					}
+				}
+
+				if (delta.tool_calls) {
+					for (const toolCall of delta.tool_calls) {
+						const index = toolCall.index
+						const existing = toolCallAccumulator.get(index)
+
+						if (existing) {
+							if (toolCall.function?.arguments) {
+								existing.arguments += toolCall.function.arguments
+							}
+						} else {
+							toolCallAccumulator.set(index, {
+								id: toolCall.id || "",
+								name: toolCall.function?.name || "",
+								arguments: toolCall.function?.arguments || "",
+							})
+						}
+					}
+				}
+			}
+
+			if (finishReason === "tool_calls") {
+				for (const toolCall of toolCallAccumulator.values()) {
+					yield {
+						type: "tool_call",
+						id: toolCall.id,
+						name: toolCall.name,
+						arguments: toolCall.arguments,
+					}
 				}
+				toolCallAccumulator.clear()
 			}
 
 			if (chunk.usage) {

+ 2 - 0
src/api/transform/model-params.ts

@@ -42,6 +42,7 @@ type BaseModelParams = {
 	reasoningEffort: ReasoningEffortExtended | undefined
 	reasoningBudget: number | undefined
 	verbosity: VerbosityLevel | undefined
+	tools?: boolean
 }
 
 type AnthropicModelParams = {
@@ -160,6 +161,7 @@ export function getModelParams({
 			format,
 			...params,
 			reasoning: getOpenAiReasoning({ model, reasoningBudget, reasoningEffort, settings }),
+			tools: model.supportsNativeTools,
 		}
 	} else if (format === "gemini") {
 		return {