Răsfoiți Sursa

Add native tool calling for deepinfra (#9691)

Matt Rubens 1 lună în urmă
părinte
comite
4591e960ed

+ 1 - 0
packages/types/src/providers/deepinfra.ts

@@ -8,6 +8,7 @@ export const deepInfraDefaultModelInfo: ModelInfo = {
 	contextWindow: 262144,
 	supportsImages: false,
 	supportsPromptCache: false,
+	supportsNativeTools: true,
 	inputPrice: 0.3,
 	outputPrice: 1.2,
 	description: "Qwen 3 Coder 480B A35B Instruct Turbo model, 256K context.",

+ 386 - 0
src/api/providers/__tests__/deepinfra.spec.ts

@@ -0,0 +1,386 @@
+// npx vitest api/providers/__tests__/deepinfra.spec.ts
+
+import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types"
+
+const mockCreate = vitest.fn()
+const mockWithResponse = vitest.fn()
+
+vitest.mock("openai", () => {
+	const mockConstructor = vitest.fn()
+
+	return {
+		__esModule: true,
+		default: mockConstructor.mockImplementation(() => ({
+			chat: {
+				completions: {
+					create: mockCreate.mockImplementation(() => ({
+						withResponse: mockWithResponse,
+					})),
+				},
+			},
+		})),
+	}
+})
+
+vitest.mock("../fetchers/modelCache", () => ({
+	getModels: vitest.fn().mockResolvedValue({
+		[deepInfraDefaultModelId]: deepInfraDefaultModelInfo,
+	}),
+}))
+
+import OpenAI from "openai"
+import { DeepInfraHandler } from "../deepinfra"
+
+describe("DeepInfraHandler", () => {
+	let handler: DeepInfraHandler
+
+	beforeEach(() => {
+		vi.clearAllMocks()
+		mockCreate.mockClear()
+		mockWithResponse.mockClear()
+
+		handler = new DeepInfraHandler({})
+	})
+
+	it("should use the correct DeepInfra base URL", () => {
+		expect(OpenAI).toHaveBeenCalledWith(
+			expect.objectContaining({
+				baseURL: "https://api.deepinfra.com/v1/openai",
+			}),
+		)
+	})
+
+	it("should use the provided API key", () => {
+		vi.clearAllMocks()
+
+		const deepInfraApiKey = "test-api-key"
+		new DeepInfraHandler({ deepInfraApiKey })
+
+		expect(OpenAI).toHaveBeenCalledWith(
+			expect.objectContaining({
+				apiKey: deepInfraApiKey,
+			}),
+		)
+	})
+
+	it("should return default model when no model is specified", () => {
+		const model = handler.getModel()
+		expect(model.id).toBe(deepInfraDefaultModelId)
+		expect(model.info).toEqual(deepInfraDefaultModelInfo)
+	})
+
+	it("createMessage should yield text content from stream", async () => {
+		const testContent = "This is test content"
+
+		mockWithResponse.mockResolvedValueOnce({
+			data: {
+				[Symbol.asyncIterator]: () => ({
+					next: vi
+						.fn()
+						.mockResolvedValueOnce({
+							done: false,
+							value: {
+								choices: [{ delta: { content: testContent } }],
+							},
+						})
+						.mockResolvedValueOnce({ done: true }),
+				}),
+			},
+		})
+
+		const stream = handler.createMessage("system prompt", [])
+		const firstChunk = await stream.next()
+
+		expect(firstChunk.done).toBe(false)
+		expect(firstChunk.value).toEqual({
+			type: "text",
+			text: testContent,
+		})
+	})
+
+	it("createMessage should yield reasoning content from stream", async () => {
+		const testReasoning = "Test reasoning content"
+
+		mockWithResponse.mockResolvedValueOnce({
+			data: {
+				[Symbol.asyncIterator]: () => ({
+					next: vi
+						.fn()
+						.mockResolvedValueOnce({
+							done: false,
+							value: {
+								choices: [{ delta: { reasoning_content: testReasoning } }],
+							},
+						})
+						.mockResolvedValueOnce({ done: true }),
+				}),
+			},
+		})
+
+		const stream = handler.createMessage("system prompt", [])
+		const firstChunk = await stream.next()
+
+		expect(firstChunk.done).toBe(false)
+		expect(firstChunk.value).toEqual({
+			type: "reasoning",
+			text: testReasoning,
+		})
+	})
+
+	it("createMessage should yield usage data from stream", async () => {
+		mockWithResponse.mockResolvedValueOnce({
+			data: {
+				[Symbol.asyncIterator]: () => ({
+					next: vi
+						.fn()
+						.mockResolvedValueOnce({
+							done: false,
+							value: {
+								choices: [{ delta: {} }],
+								usage: {
+									prompt_tokens: 10,
+									completion_tokens: 20,
+									prompt_tokens_details: {
+										cache_write_tokens: 15,
+										cached_tokens: 5,
+									},
+								},
+							},
+						})
+						.mockResolvedValueOnce({ done: true }),
+				}),
+			},
+		})
+
+		const stream = handler.createMessage("system prompt", [])
+		const firstChunk = await stream.next()
+
+		expect(firstChunk.done).toBe(false)
+		expect(firstChunk.value).toEqual({
+			type: "usage",
+			inputTokens: 10,
+			outputTokens: 20,
+			cacheWriteTokens: 15,
+			cacheReadTokens: 5,
+			totalCost: expect.any(Number),
+		})
+	})
+
+	describe("Native Tool Calling", () => {
+		const testTools = [
+			{
+				type: "function" as const,
+				function: {
+					name: "test_tool",
+					description: "A test tool",
+					parameters: {
+						type: "object",
+						properties: {
+							arg1: { type: "string", description: "First argument" },
+						},
+						required: ["arg1"],
+					},
+				},
+			},
+		]
+
+		it("should include tools in request when model supports native tools and tools are provided", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tools: expect.arrayContaining([
+						expect.objectContaining({
+							type: "function",
+							function: expect.objectContaining({
+								name: "test_tool",
+							}),
+						}),
+					]),
+					parallel_tool_calls: false,
+				}),
+			)
+		})
+
+		it("should include tool_choice when provided", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+				tool_choice: "auto",
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tool_choice: "auto",
+				}),
+			)
+		})
+
+		it("should not include tools when toolProtocol is xml", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "xml",
+			})
+			await messageGenerator.next()
+
+			const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
+			expect(callArgs).not.toHaveProperty("tools")
+			expect(callArgs).not.toHaveProperty("tool_choice")
+		})
+
+		it("should yield tool_call_partial chunks during streaming", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														id: "call_123",
+														function: {
+															name: "test_tool",
+															arguments: '{"arg1":',
+														},
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														function: {
+															arguments: '"value"}',
+														},
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				},
+			})
+
+			const stream = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+			})
+
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks).toContainEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: "call_123",
+				name: "test_tool",
+				arguments: '{"arg1":',
+			})
+
+			expect(chunks).toContainEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: undefined,
+				name: undefined,
+				arguments: '"value"}',
+			})
+		})
+
+		it("should set parallel_tool_calls based on metadata", async () => {
+			mockWithResponse.mockResolvedValueOnce({
+				data: {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				},
+			})
+
+			const messageGenerator = handler.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+				parallelToolCalls: true,
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					parallel_tool_calls: true,
+				}),
+			)
+		})
+	})
+
+	describe("completePrompt", () => {
+		it("should return text from API", async () => {
+			const expectedResponse = "This is a test response"
+			mockCreate.mockResolvedValueOnce({
+				choices: [{ message: { content: expectedResponse } }],
+			})
+
+			const result = await handler.completePrompt("test prompt")
+			expect(result).toBe(expectedResponse)
+		})
+	})
+})

+ 21 - 0
src/api/providers/deepinfra.ts

@@ -65,6 +65,11 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
 			prompt_cache_key = _metadata.taskId
 		}
 
+		// Check if model supports native tools and tools are provided with native protocol
+		const supportsNativeTools = info.supportsNativeTools ?? false
+		const useNativeTools =
+			supportsNativeTools && _metadata?.tools && _metadata.tools.length > 0 && _metadata?.toolProtocol !== "xml"
+
 		const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
 			model: modelId,
 			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
@@ -72,6 +77,9 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
 			stream_options: { include_usage: true },
 			reasoning_effort,
 			prompt_cache_key,
+			...(useNativeTools && { tools: this.convertToolsForOpenAI(_metadata.tools) }),
+			...(useNativeTools && _metadata.tool_choice && { tool_choice: _metadata.tool_choice }),
+			...(useNativeTools && { parallel_tool_calls: _metadata?.parallelToolCalls ?? false }),
 		} as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
 
 		if (this.supportsTemperature(modelId)) {
@@ -96,6 +104,19 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
 				yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
 			}
 
+			// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
+			if (delta?.tool_calls) {
+				for (const toolCall of delta.tool_calls) {
+					yield {
+						type: "tool_call_partial",
+						index: toolCall.index,
+						id: toolCall.id,
+						name: toolCall.function?.name,
+						arguments: toolCall.function?.arguments,
+					}
+				}
+			}
+
 			if (chunk.usage) {
 				lastUsage = chunk.usage
 			}

+ 1 - 0
src/api/providers/fetchers/deepinfra.ts

@@ -58,6 +58,7 @@ export async function getDeepInfraModels(
 			contextWindow,
 			supportsImages: tags.includes("vision"),
 			supportsPromptCache: tags.includes("prompt_cache"),
+			supportsNativeTools: true,
 			inputPrice: meta.pricing?.input_tokens,
 			outputPrice: meta.pricing?.output_tokens,
 			cacheReadsPrice: meta.pricing?.cache_read_tokens,