فهرست منبع

Add native tools support to Unbound (#9699)

Co-authored-by: Roo Code <[email protected]>
Matt Rubens 1 ماه پیش
والد
کامیت
eb120543e1

+ 1 - 0
packages/types/src/providers/unbound.ts

@@ -7,6 +7,7 @@ export const unboundDefaultModelInfo: ModelInfo = {
 	contextWindow: 200_000,
 	contextWindow: 200_000,
 	supportsImages: true,
 	supportsImages: true,
 	supportsPromptCache: true,
 	supportsPromptCache: true,
+	supportsNativeTools: true,
 	inputPrice: 3.0,
 	inputPrice: 3.0,
 	outputPrice: 15.0,
 	outputPrice: 15.0,
 	cacheWritesPrice: 3.75,
 	cacheWritesPrice: 3.75,

+ 226 - 0
src/api/providers/__tests__/unbound.spec.ts

@@ -15,6 +15,7 @@ vitest.mock("../fetchers/modelCache", () => ({
 				contextWindow: 200000,
 				contextWindow: 200000,
 				supportsImages: true,
 				supportsImages: true,
 				supportsPromptCache: true,
 				supportsPromptCache: true,
+				supportsNativeTools: true,
 				inputPrice: 3,
 				inputPrice: 3,
 				outputPrice: 15,
 				outputPrice: 15,
 				cacheWritesPrice: 3.75,
 				cacheWritesPrice: 3.75,
@@ -27,6 +28,7 @@ vitest.mock("../fetchers/modelCache", () => ({
 				contextWindow: 200000,
 				contextWindow: 200000,
 				supportsImages: true,
 				supportsImages: true,
 				supportsPromptCache: true,
 				supportsPromptCache: true,
+				supportsNativeTools: true,
 				inputPrice: 3,
 				inputPrice: 3,
 				outputPrice: 15,
 				outputPrice: 15,
 				cacheWritesPrice: 3.75,
 				cacheWritesPrice: 3.75,
@@ -39,6 +41,7 @@ vitest.mock("../fetchers/modelCache", () => ({
 				contextWindow: 200000,
 				contextWindow: 200000,
 				supportsImages: true,
 				supportsImages: true,
 				supportsPromptCache: true,
 				supportsPromptCache: true,
+				supportsNativeTools: true,
 				inputPrice: 3,
 				inputPrice: 3,
 				outputPrice: 15,
 				outputPrice: 15,
 				cacheWritesPrice: 3.75,
 				cacheWritesPrice: 3.75,
@@ -51,6 +54,7 @@ vitest.mock("../fetchers/modelCache", () => ({
 				contextWindow: 128000,
 				contextWindow: 128000,
 				supportsImages: true,
 				supportsImages: true,
 				supportsPromptCache: false,
 				supportsPromptCache: false,
+				supportsNativeTools: true,
 				inputPrice: 5,
 				inputPrice: 5,
 				outputPrice: 15,
 				outputPrice: 15,
 				description: "GPT-4o",
 				description: "GPT-4o",
@@ -60,6 +64,7 @@ vitest.mock("../fetchers/modelCache", () => ({
 				contextWindow: 128000,
 				contextWindow: 128000,
 				supportsImages: true,
 				supportsImages: true,
 				supportsPromptCache: false,
 				supportsPromptCache: false,
+				supportsNativeTools: true,
 				inputPrice: 1,
 				inputPrice: 1,
 				outputPrice: 3,
 				outputPrice: 3,
 				description: "O3 Mini",
 				description: "O3 Mini",
@@ -328,4 +333,225 @@ describe("UnboundHandler", () => {
 			expect(modelInfo.info).toBeDefined()
 			expect(modelInfo.info).toBeDefined()
 		})
 		})
 	})
 	})
+
+	describe("Native Tool Calling", () => {
+		const testTools = [
+			{
+				type: "function" as const,
+				function: {
+					name: "test_tool",
+					description: "A test tool",
+					parameters: {
+						type: "object",
+						properties: {
+							arg1: { type: "string", description: "First argument" },
+						},
+						required: ["arg1"],
+					},
+				},
+			},
+		]
+
+		it("should include tools in request when model supports native tools and tools are provided", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tools: expect.arrayContaining([
+						expect.objectContaining({
+							type: "function",
+							function: expect.objectContaining({
+								name: "test_tool",
+							}),
+						}),
+					]),
+					parallel_tool_calls: false,
+				}),
+				expect.objectContaining({
+					headers: {
+						"X-Unbound-Metadata": expect.stringContaining("roo-code"),
+					},
+				}),
+			)
+		})
+
+		it("should include tool_choice when provided", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+				tool_choice: "auto",
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tool_choice: "auto",
+				}),
+				expect.objectContaining({
+					headers: {
+						"X-Unbound-Metadata": expect.stringContaining("roo-code"),
+					},
+				}),
+			)
+		})
+
+		it("should not include tools when toolProtocol is xml", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "xml",
+			})
+			await messageGenerator.next()
+
+			const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
+			expect(callArgs).not.toHaveProperty("tools")
+			expect(callArgs).not.toHaveProperty("tool_choice")
+		})
+
+		it("should yield tool_call_partial chunks during streaming", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														id: "call_123",
+														function: {
+															name: "test_tool",
+															arguments: '{"arg1":',
+														},
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														function: {
+															arguments: '"value"}',
+														},
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				},
+			})
+
+			const stream = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+			})
+
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks).toContainEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: "call_123",
+				name: "test_tool",
+				arguments: '{"arg1":',
+			})
+
+			expect(chunks).toContainEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: undefined,
+				name: undefined,
+				arguments: '"value"}',
+			})
+		})
+
+		it("should set parallel_tool_calls based on metadata", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+				parallelToolCalls: true,
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					parallel_tool_calls: true,
+				}),
+				expect.objectContaining({
+					headers: {
+						"X-Unbound-Metadata": expect.stringContaining("roo-code"),
+					},
+				}),
+			)
+		})
+	})
 })
 })

+ 1 - 0
src/api/providers/fetchers/unbound.ts

@@ -23,6 +23,7 @@ export async function getUnboundModels(apiKey?: string | null): Promise<Record<s
 					contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0,
 					contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0,
 					supportsImages: model?.supportsImages ?? false,
 					supportsImages: model?.supportsImages ?? false,
 					supportsPromptCache: model?.supportsPromptCaching ?? false,
 					supportsPromptCache: model?.supportsPromptCaching ?? false,
+					supportsNativeTools: true,
 					inputPrice: model?.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined,
 					inputPrice: model?.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined,
 					outputPrice: model?.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined,
 					outputPrice: model?.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined,
 					cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined,
 					cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined,

+ 48 - 1
src/api/providers/unbound.ts

@@ -13,6 +13,8 @@ import { addCacheBreakpoints as addVertexCacheBreakpoints } from "../transform/c
 
 
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 import { RouterProvider } from "./router-provider"
 import { RouterProvider } from "./router-provider"
+import { getModelParams } from "../transform/model-params"
+import { getModels } from "./fetchers/modelCache"
 
 
 const ORIGIN_APP = "roo-code"
 const ORIGIN_APP = "roo-code"
 
 
@@ -52,12 +54,35 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
 		})
 		})
 	}
 	}
 
 
+	public override async fetchModel() {
+		this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL })
+		return this.getModel()
+	}
+
+	override getModel() {
+		const requestedId = this.options.unboundModelId ?? unboundDefaultModelId
+		const modelExists = this.models[requestedId]
+		const id = modelExists ? requestedId : unboundDefaultModelId
+		const info = modelExists ? this.models[requestedId] : unboundDefaultModelInfo
+
+		const params = getModelParams({
+			format: "openai",
+			modelId: id,
+			model: info,
+			settings: this.options,
+		})
+
+		return { id, info, ...params }
+	}
+
 	override async *createMessage(
 	override async *createMessage(
 		systemPrompt: string,
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
 	): ApiStream {
-		const { id: modelId, info } = await this.fetchModel()
+		// Ensure we have up-to-date model metadata
+		await this.fetchModel()
+		const { id: modelId, info } = this.getModel()
 
 
 		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
 		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
 			{ role: "system", content: systemPrompt },
 			{ role: "system", content: systemPrompt },
@@ -83,16 +108,25 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
 			maxTokens = info.maxTokens ?? undefined
 			maxTokens = info.maxTokens ?? undefined
 		}
 		}
 
 
+		// Check if model supports native tools and tools are provided with native protocol
+		const supportsNativeTools = info.supportsNativeTools ?? false
+		const useNativeTools =
+			supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"
+
 		const requestOptions: UnboundChatCompletionCreateParamsStreaming = {
 		const requestOptions: UnboundChatCompletionCreateParamsStreaming = {
 			model: modelId.split("/")[1],
 			model: modelId.split("/")[1],
 			max_tokens: maxTokens,
 			max_tokens: maxTokens,
 			messages: openAiMessages,
 			messages: openAiMessages,
 			stream: true,
 			stream: true,
+			stream_options: { include_usage: true },
 			unbound_metadata: {
 			unbound_metadata: {
 				originApp: ORIGIN_APP,
 				originApp: ORIGIN_APP,
 				taskId: metadata?.taskId,
 				taskId: metadata?.taskId,
 				mode: metadata?.mode,
 				mode: metadata?.mode,
 			},
 			},
+			...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+			...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
+			...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
 		}
 		}
 
 
 		if (this.supportsTemperature(modelId)) {
 		if (this.supportsTemperature(modelId)) {
@@ -111,6 +145,19 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
 				yield { type: "text", text: delta.content }
 				yield { type: "text", text: delta.content }
 			}
 			}
 
 
+			// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
+			if (delta?.tool_calls) {
+				for (const toolCall of delta.tool_calls) {
+					yield {
+						type: "tool_call_partial",
+						index: toolCall.index,
+						id: toolCall.id,
+						name: toolCall.function?.name,
+						arguments: toolCall.function?.arguments,
+					}
+				}
+			}
+
 			if (usage) {
 			if (usage) {
 				const usageData: ApiStreamUsageChunk = {
 				const usageData: ApiStreamUsageChunk = {
 					type: "usage",
 					type: "usage",