Browse Source

feat: enable native tool calling for openai-native provider (#9348)

Co-authored-by: daniel-lxs <[email protected]>
Hannes Rudolph 1 month ago
parent
commit
f8d6e12aa7
2 changed files with 240 additions and 36 deletions
  1. 29 0
      packages/types/src/providers/openai.ts
  2. 211 36
      src/api/providers/openai-native.ts

+ 29 - 0
packages/types/src/providers/openai.ts

@@ -9,6 +9,7 @@ export const openAiNativeModels = {
 	"gpt-5.1": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		promptCacheRetention: "24h",
@@ -28,6 +29,7 @@ export const openAiNativeModels = {
 	"gpt-5.1-codex": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		promptCacheRetention: "24h",
@@ -43,6 +45,7 @@ export const openAiNativeModels = {
 	"gpt-5.1-codex-mini": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		promptCacheRetention: "24h",
@@ -57,6 +60,7 @@ export const openAiNativeModels = {
 	"gpt-5": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		supportsReasoningEffort: ["minimal", "low", "medium", "high"],
@@ -75,6 +79,7 @@ export const openAiNativeModels = {
 	"gpt-5-mini": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		supportsReasoningEffort: ["minimal", "low", "medium", "high"],
@@ -93,6 +98,7 @@ export const openAiNativeModels = {
 	"gpt-5-codex": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		supportsReasoningEffort: ["low", "medium", "high"],
@@ -107,6 +113,7 @@ export const openAiNativeModels = {
 	"gpt-5-nano": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		supportsReasoningEffort: ["minimal", "low", "medium", "high"],
@@ -122,6 +129,7 @@ export const openAiNativeModels = {
 	"gpt-5-chat-latest": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 1.25,
@@ -132,6 +140,7 @@ export const openAiNativeModels = {
 	"gpt-4.1": {
 		maxTokens: 32_768,
 		contextWindow: 1_047_576,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 2,
@@ -145,6 +154,7 @@ export const openAiNativeModels = {
 	"gpt-4.1-mini": {
 		maxTokens: 32_768,
 		contextWindow: 1_047_576,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 0.4,
@@ -158,6 +168,7 @@ export const openAiNativeModels = {
 	"gpt-4.1-nano": {
 		maxTokens: 32_768,
 		contextWindow: 1_047_576,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 0.1,
@@ -171,6 +182,7 @@ export const openAiNativeModels = {
 	o3: {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 2.0,
@@ -187,6 +199,7 @@ export const openAiNativeModels = {
 	"o3-high": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 2.0,
@@ -198,6 +211,7 @@ export const openAiNativeModels = {
 	"o3-low": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 2.0,
@@ -209,6 +223,7 @@ export const openAiNativeModels = {
 	"o4-mini": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 1.1,
@@ -225,6 +240,7 @@ export const openAiNativeModels = {
 	"o4-mini-high": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 1.1,
@@ -236,6 +252,7 @@ export const openAiNativeModels = {
 	"o4-mini-low": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 1.1,
@@ -247,6 +264,7 @@ export const openAiNativeModels = {
 	"o3-mini": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: false,
 		supportsPromptCache: true,
 		inputPrice: 1.1,
@@ -259,6 +277,7 @@ export const openAiNativeModels = {
 	"o3-mini-high": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: false,
 		supportsPromptCache: true,
 		inputPrice: 1.1,
@@ -270,6 +289,7 @@ export const openAiNativeModels = {
 	"o3-mini-low": {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: false,
 		supportsPromptCache: true,
 		inputPrice: 1.1,
@@ -281,6 +301,7 @@ export const openAiNativeModels = {
 	o1: {
 		maxTokens: 100_000,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 15,
@@ -291,6 +312,7 @@ export const openAiNativeModels = {
 	"o1-preview": {
 		maxTokens: 32_768,
 		contextWindow: 128_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 15,
@@ -301,6 +323,7 @@ export const openAiNativeModels = {
 	"o1-mini": {
 		maxTokens: 65_536,
 		contextWindow: 128_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 1.1,
@@ -311,6 +334,7 @@ export const openAiNativeModels = {
 	"gpt-4o": {
 		maxTokens: 16_384,
 		contextWindow: 128_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 2.5,
@@ -324,6 +348,7 @@ export const openAiNativeModels = {
 	"gpt-4o-mini": {
 		maxTokens: 16_384,
 		contextWindow: 128_000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		inputPrice: 0.15,
@@ -337,6 +362,7 @@ export const openAiNativeModels = {
 	"codex-mini-latest": {
 		maxTokens: 16_384,
 		contextWindow: 200_000,
+		supportsNativeTools: true,
 		supportsImages: false,
 		supportsPromptCache: false,
 		inputPrice: 1.5,
@@ -350,6 +376,7 @@ export const openAiNativeModels = {
 	"gpt-5-2025-08-07": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		supportsReasoningEffort: ["minimal", "low", "medium", "high"],
@@ -368,6 +395,7 @@ export const openAiNativeModels = {
 	"gpt-5-mini-2025-08-07": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		supportsReasoningEffort: ["minimal", "low", "medium", "high"],
@@ -386,6 +414,7 @@ export const openAiNativeModels = {
 	"gpt-5-nano-2025-08-07": {
 		maxTokens: 128000,
 		contextWindow: 400000,
+		supportsNativeTools: true,
 		supportsImages: true,
 		supportsPromptCache: true,
 		supportsReasoningEffort: ["minimal", "low", "medium", "high"],

+ 211 - 36
src/api/providers/openai-native.ts

@@ -34,6 +34,8 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 	private lastResponseOutput: any[] | undefined
 	// Last top-level response id from Responses API (for troubleshooting)
 	private lastResponseId: string | undefined
+	// Accumulate partial tool calls: call_id -> { name, arguments }
+	private currentToolCalls: Map<string, { name: string; arguments: string }> = new Map()
 	// Abort controller for cancelling ongoing requests
 	private abortController?: AbortController
 
@@ -49,6 +51,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 		"response.output_item.added",
 		"response.done",
 		"response.completed",
+		"response.tool_call_arguments.delta",
+		"response.function_call_arguments.delta",
+		"response.tool_call_arguments.done",
+		"response.function_call_arguments.done",
 	])
 
 	constructor(options: ApiHandlerOptions) {
@@ -147,6 +153,8 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 		this.lastResponseOutput = undefined
 		// Reset last response id for this request
 		this.lastResponseId = undefined
+		// Reset tool call accumulator
+		this.currentToolCalls.clear()
 
 		// Use Responses API for ALL models
 		const { verbosity, reasoning } = this.getModel()
@@ -179,6 +187,38 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 		reasoningEffort: ReasoningEffortExtended | undefined,
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): any {
+		// Ensure all properties are in the required array for OpenAI's strict mode
+		// This recursively processes nested objects and array items
+		const ensureAllRequired = (schema: any): any => {
+			if (!schema || typeof schema !== "object" || schema.type !== "object") {
+				return schema
+			}
+
+			const result = { ...schema }
+
+			if (result.properties) {
+				const allKeys = Object.keys(result.properties)
+				result.required = allKeys
+
+				// Recursively process nested objects
+				const newProps = { ...result.properties }
+				for (const key of allKeys) {
+					const prop = newProps[key]
+					if (prop.type === "object") {
+						newProps[key] = ensureAllRequired(prop)
+					} else if (prop.type === "array" && prop.items?.type === "object") {
+						newProps[key] = {
+							...prop,
+							items: ensureAllRequired(prop.items),
+						}
+					}
+				}
+				result.properties = newProps
+			}
+
+			return result
+		}
+
 		// Build a request body for the OpenAI Responses API.
 		// Ensure we explicitly pass max_output_tokens based on Roo's reserved model response calculation
 		// so requests do not default to very large limits (e.g., 120k).
@@ -196,6 +236,14 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 			include?: string[]
 			/** Prompt cache retention policy: "in_memory" (default) or "24h" for extended caching */
 			prompt_cache_retention?: "in_memory" | "24h"
+			tools?: Array<{
+				type: "function"
+				name: string
+				description?: string
+				parameters?: any
+				strict?: boolean
+			}>
+			tool_choice?: any
 		}
 
 		// Validate requested tier against model support; if not supported, omit.
@@ -240,6 +288,18 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 			// Enable extended prompt cache retention for models that support it.
 			// This uses the OpenAI Responses API `prompt_cache_retention` parameter.
 			...(promptCacheRetention ? { prompt_cache_retention: promptCacheRetention } : {}),
+			...(metadata?.tools && {
+				tools: metadata.tools
+					.filter((tool) => tool.type === "function")
+					.map((tool) => ({
+						type: "function",
+						name: tool.function.name,
+						description: tool.function.description,
+						parameters: ensureAllRequired(tool.function.parameters),
+						strict: true,
+					})),
+			}),
+			...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 		}
 
 		// Include text.verbosity only when the model explicitly supports it
@@ -292,9 +352,8 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 
 	private formatFullConversation(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): any {
 		// Format the entire conversation history for the Responses API using structured format
-		// This supports both text and images
-		// Messages already include reasoning items from API history, so we just need to format them
-		const formattedMessages: any[] = []
+		// The Responses API (like Realtime API) accepts a list of items, which can be messages, function calls, or function call outputs.
+		const formattedInput: any[] = []
 
 		// Do NOT embed the system prompt as a developer message in the Responses API input.
 		// The Responses API treats roles as free-form; use the top-level `instructions` field instead.
@@ -304,45 +363,83 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 			// Check if this is a reasoning item (already formatted in API history)
 			if ((message as any).type === "reasoning") {
 				// Pass through reasoning items as-is
-				formattedMessages.push(message)
+				formattedInput.push(message)
 				continue
 			}
 
-			const role = message.role === "user" ? "user" : "assistant"
-			const content: any[] = []
+			if (message.role === "user") {
+				const content: any[] = []
+				const toolResults: any[] = []
 
-			if (typeof message.content === "string") {
-				// For user messages, use input_text; for assistant messages, use output_text
-				if (role === "user") {
+				if (typeof message.content === "string") {
 					content.push({ type: "input_text", text: message.content })
-				} else {
-					content.push({ type: "output_text", text: message.content })
+				} else if (Array.isArray(message.content)) {
+					for (const block of message.content) {
+						if (block.type === "text") {
+							content.push({ type: "input_text", text: block.text })
+						} else if (block.type === "image") {
+							const image = block as Anthropic.Messages.ImageBlockParam
+							const imageUrl = `data:${image.source.media_type};base64,${image.source.data}`
+							content.push({ type: "input_image", image_url: imageUrl })
+						} else if (block.type === "tool_result") {
+							// Map Anthropic tool_result to Responses API function_call_output item
+							const result =
+								typeof block.content === "string"
+									? block.content
+									: block.content?.map((c) => (c.type === "text" ? c.text : "")).join("") || ""
+							toolResults.push({
+								type: "function_call_output",
+								call_id: block.tool_use_id,
+								output: result,
+							})
+						}
+					}
 				}
-			} else if (Array.isArray(message.content)) {
-				// For array content with potential images, format properly
-				for (const block of message.content) {
-					if (block.type === "text") {
-						// For user messages, use input_text; for assistant messages, use output_text
-						if (role === "user") {
-							content.push({ type: "input_text", text: (block as any).text })
-						} else {
-							content.push({ type: "output_text", text: (block as any).text })
+
+				// Add user message first
+				if (content.length > 0) {
+					formattedInput.push({ role: "user", content })
+				}
+
+				// Add tool results as separate items
+				if (toolResults.length > 0) {
+					formattedInput.push(...toolResults)
+				}
+			} else if (message.role === "assistant") {
+				const content: any[] = []
+				const toolCalls: any[] = []
+
+				if (typeof message.content === "string") {
+					content.push({ type: "output_text", text: message.content })
+				} else if (Array.isArray(message.content)) {
+					for (const block of message.content) {
+						if (block.type === "text") {
+							content.push({ type: "output_text", text: block.text })
+						} else if (block.type === "tool_use") {
+							// Map Anthropic tool_use to Responses API function_call item
+							toolCalls.push({
+								type: "function_call",
+								call_id: block.id,
+								name: block.name,
+								arguments: JSON.stringify(block.input),
+							})
 						}
-					} else if (block.type === "image") {
-						const image = block as Anthropic.Messages.ImageBlockParam
-						// Format image with proper data URL - images are always input_image
-						const imageUrl = `data:${image.source.media_type};base64,${image.source.data}`
-						content.push({ type: "input_image", image_url: imageUrl })
 					}
 				}
-			}
 
-			if (content.length > 0) {
-				formattedMessages.push({ role, content })
+				// Add assistant message if it has content
+				if (content.length > 0) {
+					formattedInput.push({ role: "assistant", content })
+				}
+
+				// Add tool calls as separate items
+				if (toolCalls.length > 0) {
+					formattedInput.push(...toolCalls)
+				}
 			}
 		}
 
-		return formattedMessages
+		return formattedInput
 	}
 
 	private async *makeResponsesApiRequest(
@@ -676,11 +773,16 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 								// Output item completed
 							}
 							// Handle function/tool call events
-							else if (parsed.type === "response.function_call_arguments.delta") {
-								// Function call arguments streaming
-								// We could yield this as a special type if needed for tool usage
-							} else if (parsed.type === "response.function_call_arguments.done") {
-								// Function call completed
+							else if (
+								parsed.type === "response.function_call_arguments.delta" ||
+								parsed.type === "response.tool_call_arguments.delta" ||
+								parsed.type === "response.function_call_arguments.done" ||
+								parsed.type === "response.tool_call_arguments.done"
+							) {
+								// Delegated to processEvent (handles accumulation and completion)
+								for await (const outChunk of this.processEvent(parsed, model)) {
+									yield outChunk
+								}
 							}
 							// Handle MCP (Model Context Protocol) tool events
 							else if (parsed.type === "response.mcp_call_arguments.delta") {
@@ -961,8 +1063,53 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 			return
 		}
 
-		// Handle output item additions (SDK or Responses API alternative format)
-		if (event?.type === "response.output_item.added") {
+		// Handle tool/function call deltas and completion
+		if (
+			event?.type === "response.tool_call_arguments.delta" ||
+			event?.type === "response.function_call_arguments.delta"
+		) {
+			const callId = event.call_id || event.tool_call_id || event.id
+			if (callId) {
+				if (!this.currentToolCalls.has(callId)) {
+					this.currentToolCalls.set(callId, { name: "", arguments: "" })
+				}
+				const toolCall = this.currentToolCalls.get(callId)!
+
+				// Update name if present (usually in the first delta)
+				if (event.name || event.function_name) {
+					toolCall.name = event.name || event.function_name
+				}
+
+				// Append arguments delta
+				if (event.delta || event.arguments) {
+					toolCall.arguments += event.delta || event.arguments
+				}
+			}
+			return
+		}
+
+		if (
+			event?.type === "response.tool_call_arguments.done" ||
+			event?.type === "response.function_call_arguments.done"
+		) {
+			const callId = event.call_id || event.tool_call_id || event.id
+			if (callId && this.currentToolCalls.has(callId)) {
+				const toolCall = this.currentToolCalls.get(callId)!
+				// Yield the complete tool call
+				yield {
+					type: "tool_call",
+					id: callId,
+					name: toolCall.name,
+					arguments: toolCall.arguments,
+				}
+				// Remove from accumulator
+				this.currentToolCalls.delete(callId)
+			}
+			return
+		}
+
+		// Handle output item additions/completions (SDK or Responses API alternative format)
+		if (event?.type === "response.output_item.added" || event?.type === "response.output_item.done") {
 			const item = event?.item
 			if (item) {
 				if (item.type === "text" && item.text) {
@@ -976,6 +1123,21 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 							yield { type: "text", text: content.text }
 						}
 					}
+				} else if (
+					(item.type === "function_call" || item.type === "tool_call") &&
+					event.type === "response.output_item.done" // Only handle done events for tool calls to ensure arguments are complete
+				) {
+					// Handle complete tool/function call item
+					const callId = item.call_id || item.tool_call_id || item.id
+					if (callId && !this.currentToolCalls.has(callId)) {
+						const args = item.arguments || item.function?.arguments || item.function_arguments
+						yield {
+							type: "tool_call",
+							id: callId,
+							name: item.name || item.function?.name || item.function_name || "",
+							arguments: typeof args === "string" ? args : "{}",
+						}
+					}
 				}
 			}
 			return
@@ -983,6 +1145,19 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 
 		// Completion events that may carry usage
 		if (event?.type === "response.done" || event?.type === "response.completed") {
+			// Yield any pending tool calls that didn't get a 'done' event (fallback)
+			if (this.currentToolCalls.size > 0) {
+				for (const [callId, toolCall] of this.currentToolCalls) {
+					yield {
+						type: "tool_call",
+						id: callId,
+						name: toolCall.name,
+						arguments: toolCall.arguments || "{}",
+					}
+				}
+				this.currentToolCalls.clear()
+			}
+
 			const usage = event?.response?.usage || event?.usage || undefined
 			const usageData = this.normalizeUsage(usage, model)
 			if (usageData) {