Browse Source

Add native tool support for vercel ai gateway (#9697)

Co-authored-by: Roo Code <[email protected]>
Matt Rubens 1 month ago
parent
commit
b5acebccf5

+ 211 - 0
src/api/providers/__tests__/vercel-ai-gateway.spec.ts

@@ -274,6 +274,217 @@ describe("VercelAiGatewayHandler", () => {
 				totalCost: 0.005,
 			})
 		})
+
+		describe("native tool calling", () => {
+			const testTools = [
+				{
+					type: "function" as const,
+					function: {
+						name: "test_tool",
+						description: "A test tool",
+						parameters: {
+							type: "object",
+							properties: {
+								arg1: { type: "string" },
+							},
+							required: ["arg1"],
+						},
+					},
+				},
+			]
+
+			beforeEach(() => {
+				mockCreate.mockImplementation(async () => ({
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							choices: [
+								{
+									delta: {},
+									index: 0,
+								},
+							],
+						}
+					},
+				}))
+			})
+
+			it("should include tools when provided", async () => {
+				const handler = new VercelAiGatewayHandler(mockOptions)
+
+				const messageGenerator = handler.createMessage("test prompt", [], {
+					taskId: "test-task-id",
+					tools: testTools,
+					toolProtocol: "native",
+				})
+				await messageGenerator.next()
+
+				expect(mockCreate).toHaveBeenCalledWith(
+					expect.objectContaining({
+						tools: expect.arrayContaining([
+							expect.objectContaining({
+								type: "function",
+								function: expect.objectContaining({
+									name: "test_tool",
+								}),
+							}),
+						]),
+					}),
+				)
+			})
+
+			it("should include tool_choice when provided", async () => {
+				const handler = new VercelAiGatewayHandler(mockOptions)
+
+				const messageGenerator = handler.createMessage("test prompt", [], {
+					taskId: "test-task-id",
+					tools: testTools,
+					toolProtocol: "native",
+					tool_choice: "auto",
+				})
+				await messageGenerator.next()
+
+				expect(mockCreate).toHaveBeenCalledWith(
+					expect.objectContaining({
+						tool_choice: "auto",
+					}),
+				)
+			})
+
+			it("should set parallel_tool_calls when toolProtocol is native", async () => {
+				const handler = new VercelAiGatewayHandler(mockOptions)
+
+				const messageGenerator = handler.createMessage("test prompt", [], {
+					taskId: "test-task-id",
+					tools: testTools,
+					toolProtocol: "native",
+					parallelToolCalls: true,
+				})
+				await messageGenerator.next()
+
+				expect(mockCreate).toHaveBeenCalledWith(
+					expect.objectContaining({
+						parallel_tool_calls: true,
+					}),
+				)
+			})
+
+			it("should default parallel_tool_calls to false", async () => {
+				const handler = new VercelAiGatewayHandler(mockOptions)
+
+				const messageGenerator = handler.createMessage("test prompt", [], {
+					taskId: "test-task-id",
+					tools: testTools,
+					toolProtocol: "native",
+				})
+				await messageGenerator.next()
+
+				expect(mockCreate).toHaveBeenCalledWith(
+					expect.objectContaining({
+						parallel_tool_calls: false,
+					}),
+				)
+			})
+
+			it("should yield tool_call_partial chunks when streaming tool calls", async () => {
+				mockCreate.mockImplementation(async () => ({
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [
+											{
+												index: 0,
+												id: "call_123",
+												function: {
+													name: "test_tool",
+													arguments: '{"arg1":',
+												},
+											},
+										],
+									},
+									index: 0,
+								},
+							],
+						}
+						yield {
+							choices: [
+								{
+									delta: {
+										tool_calls: [
+											{
+												index: 0,
+												function: {
+													arguments: '"value"}',
+												},
+											},
+										],
+									},
+									index: 0,
+								},
+							],
+						}
+						yield {
+							choices: [
+								{
+									delta: {},
+									index: 0,
+								},
+							],
+							usage: {
+								prompt_tokens: 10,
+								completion_tokens: 5,
+							},
+						}
+					},
+				}))
+
+				const handler = new VercelAiGatewayHandler(mockOptions)
+
+				const stream = handler.createMessage("test prompt", [], {
+					taskId: "test-task-id",
+					tools: testTools,
+					toolProtocol: "native",
+				})
+
+				const chunks = []
+				for await (const chunk of stream) {
+					chunks.push(chunk)
+				}
+
+				const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
+				expect(toolCallChunks).toHaveLength(2)
+				expect(toolCallChunks[0]).toEqual({
+					type: "tool_call_partial",
+					index: 0,
+					id: "call_123",
+					name: "test_tool",
+					arguments: '{"arg1":',
+				})
+				expect(toolCallChunks[1]).toEqual({
+					type: "tool_call_partial",
+					index: 0,
+					id: undefined,
+					name: undefined,
+					arguments: '"value"}',
+				})
+			})
+
+			it("should include stream_options with include_usage", async () => {
+				const handler = new VercelAiGatewayHandler(mockOptions)
+
+				const messageGenerator = handler.createMessage("test prompt", [], {
+					taskId: "test-task-id",
+				})
+				await messageGenerator.next()
+
+				expect(mockCreate).toHaveBeenCalledWith(
+					expect.objectContaining({
+						stream_options: { include_usage: true },
+					}),
+				)
+			})
+		})
 	})
 
 	describe("completePrompt", () => {

+ 1 - 0
src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts

@@ -176,6 +176,7 @@ describe("Vercel AI Gateway Fetchers", () => {
 				maxTokens: 8000,
 				contextWindow: 100000,
 				supportsImages: false,
+				supportsNativeTools: true,
 				supportsPromptCache: false,
 				inputPrice: 2500000,
 				outputPrice: 10000000,

+ 1 - 0
src/api/providers/fetchers/vercel-ai-gateway.ts

@@ -108,6 +108,7 @@ export const parseVercelAiGatewayModel = ({ id, model }: { id: string; model: Ve
 		contextWindow: model.context_window,
 		supportsImages,
 		supportsPromptCache,
+		supportsNativeTools: true,
 		inputPrice: parseApiPrice(model.pricing?.input),
 		outputPrice: parseApiPrice(model.pricing?.output),
 		cacheWritesPrice,

+ 19 - 0
src/api/providers/vercel-ai-gateway.ts

@@ -60,6 +60,12 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp
 				: undefined,
 			max_completion_tokens: info.maxTokens,
 			stream: true,
+			stream_options: { include_usage: true },
+			...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+			...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
+			...(metadata?.toolProtocol === "native" && {
+				parallel_tool_calls: metadata.parallelToolCalls ?? false,
+			}),
 		}
 
 		const completion = await this.client.chat.completions.create(body)
@@ -73,6 +79,19 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp
 				}
 			}
 
+			// Emit raw tool call chunks - NativeToolCallParser handles state management
+			if (delta?.tool_calls) {
+				for (const toolCall of delta.tool_calls) {
+					yield {
+						type: "tool_call_partial",
+						index: toolCall.index,
+						id: toolCall.id,
+						name: toolCall.function?.name,
+						arguments: toolCall.function?.arguments,
+					}
+				}
+			}
+
 			if (chunk.usage) {
 				const usage = chunk.usage as VercelAiGatewayUsage
 				yield {