Просмотр исходного кода

feat(chutes): detect native tool support from API supported_features (#9715)

Co-authored-by: Matt Rubens <[email protected]>
Daniel 1 месяц назад
Родитель
Сommit
34c524f78e

+ 77 - 0
src/api/providers/__tests__/chutes.spec.ts

@@ -233,6 +233,83 @@ describe("ChutesHandler", () => {
 		expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
 	})
 
+	it("createMessage should yield tool_call_partial from stream", async () => {
+		mockCreate.mockImplementationOnce(() => {
+			return {
+				[Symbol.asyncIterator]: () => ({
+					next: vi
+						.fn()
+						.mockResolvedValueOnce({
+							done: false,
+							value: {
+								choices: [
+									{
+										delta: {
+											tool_calls: [
+												{
+													index: 0,
+													id: "call_123",
+													function: { name: "test_tool", arguments: '{"arg":"value"}' },
+												},
+											],
+										},
+									},
+								],
+							},
+						})
+						.mockResolvedValueOnce({ done: true }),
+				}),
+			}
+		})
+
+		const stream = handler.createMessage("system prompt", [])
+		const firstChunk = await stream.next()
+
+		expect(firstChunk.done).toBe(false)
+		expect(firstChunk.value).toEqual({
+			type: "tool_call_partial",
+			index: 0,
+			id: "call_123",
+			name: "test_tool",
+			arguments: '{"arg":"value"}',
+		})
+	})
+
+	it("createMessage should pass tools and tool_choice to API", async () => {
+		const tools = [
+			{
+				type: "function" as const,
+				function: {
+					name: "test_tool",
+					description: "A test tool",
+					parameters: { type: "object", properties: {} },
+				},
+			},
+		]
+		const tool_choice = "auto" as const
+
+		mockCreate.mockImplementationOnce(() => {
+			return {
+				[Symbol.asyncIterator]: () => ({
+					next: vi.fn().mockResolvedValueOnce({ done: true }),
+				}),
+			}
+		})
+
+		const stream = handler.createMessage("system prompt", [], { tools, tool_choice, taskId: "test-task-id" })
+		// Consume stream
+		for await (const _ of stream) {
+			// noop
+		}
+
+		expect(mockCreate).toHaveBeenCalledWith(
+			expect.objectContaining({
+				tools,
+				tool_choice,
+			}),
+		)
+	})
+
 	it("should apply DeepSeek default temperature for R1 models", () => {
 		const testModelId = "deepseek-ai/DeepSeek-R1"
 		const handlerWithModel = new ChutesHandler({

+ 34 - 2
src/api/providers/chutes.ts

@@ -28,6 +28,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
 	private getCompletionParams(
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
+		metadata?: ApiHandlerCreateMessageMetadata,
 	): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
 		const { id: model, info } = this.getModel()
 
@@ -46,6 +47,8 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
 			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
 			stream: true,
 			stream_options: { include_usage: true },
+			...(metadata?.tools && { tools: metadata.tools }),
+			...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
 		}
 
 		// Only add temperature if model supports it
@@ -65,7 +68,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
 
 		if (model.id.includes("DeepSeek-R1")) {
 			const stream = await this.client.chat.completions.create({
-				...this.getCompletionParams(systemPrompt, messages),
+				...this.getCompletionParams(systemPrompt, messages, metadata),
 				messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
 			})
 
@@ -87,6 +90,19 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
 					}
 				}
 
+				// Emit raw tool call chunks - NativeToolCallParser handles state management
+				if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) {
+					for (const toolCall of delta.tool_calls) {
+						yield {
+							type: "tool_call_partial",
+							index: toolCall.index,
+							id: toolCall.id,
+							name: toolCall.function?.name,
+							arguments: toolCall.function?.arguments,
+						}
+					}
+				}
+
 				if (chunk.usage) {
 					yield {
 						type: "usage",
@@ -102,7 +118,9 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
 			}
 		} else {
 			// For non-DeepSeek-R1 models, use standard OpenAI streaming
-			const stream = await this.client.chat.completions.create(this.getCompletionParams(systemPrompt, messages))
+			const stream = await this.client.chat.completions.create(
+				this.getCompletionParams(systemPrompt, messages, metadata),
+			)
 
 			for await (const chunk of stream) {
 				const delta = chunk.choices[0]?.delta
@@ -115,6 +133,19 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
 					yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
 				}
 
+				// Emit raw tool call chunks - NativeToolCallParser handles state management
+				if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) {
+					for (const toolCall of delta.tool_calls) {
+						yield {
+							type: "tool_call_partial",
+							index: toolCall.index,
+							id: toolCall.id,
+							name: toolCall.function?.name,
+							arguments: toolCall.function?.arguments,
+						}
+					}
+				}
+
 				if (chunk.usage) {
 					yield {
 						type: "usage",
@@ -166,6 +197,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan
 	override getModel() {
 		const model = super.getModel()
 		const isDeepSeekR1 = model.id.includes("DeepSeek-R1")
+
 		return {
 			...model,
 			info: {

+ 215 - 0
src/api/providers/fetchers/__tests__/chutes.spec.ts

@@ -0,0 +1,215 @@
+// Mocks must come first, before imports
+vi.mock("axios")
+
+import type { Mock } from "vitest"
+import type { ModelInfo } from "@roo-code/types"
+import axios from "axios"
+import { getChutesModels } from "../chutes"
+import { chutesModels } from "@roo-code/types"
+
+const mockedAxios = axios as typeof axios & {
+	get: Mock
+}
+
+describe("getChutesModels", () => {
+	beforeEach(() => {
+		vi.clearAllMocks()
+	})
+
+	it("should fetch and parse models successfully", async () => {
+		const mockResponse = {
+			data: {
+				data: [
+					{
+						id: "test/new-model",
+						object: "model",
+						owned_by: "test",
+						created: 1234567890,
+						context_length: 128000,
+						max_model_len: 8192,
+						input_modalities: ["text"],
+					},
+				],
+			},
+		}
+
+		mockedAxios.get.mockResolvedValue(mockResponse)
+
+		const models = await getChutesModels("test-api-key")
+
+		expect(mockedAxios.get).toHaveBeenCalledWith(
+			"https://llm.chutes.ai/v1/models",
+			expect.objectContaining({
+				headers: expect.objectContaining({
+					Authorization: "Bearer test-api-key",
+				}),
+			}),
+		)
+
+		expect(models["test/new-model"]).toEqual({
+			maxTokens: 8192,
+			contextWindow: 128000,
+			supportsImages: false,
+			supportsPromptCache: false,
+			supportsNativeTools: false,
+			inputPrice: 0,
+			outputPrice: 0,
+			description: "Chutes AI model: test/new-model",
+		})
+	})
+
+	it("should override hardcoded models with dynamic API data", async () => {
+		// Find any hardcoded model
+		const [modelId] = Object.entries(chutesModels)[0]
+
+		const mockResponse = {
+			data: {
+				data: [
+					{
+						id: modelId,
+						object: "model",
+						owned_by: "test",
+						created: 1234567890,
+						context_length: 200000, // Different from hardcoded
+						max_model_len: 10000, // Different from hardcoded
+						input_modalities: ["text", "image"],
+					},
+				],
+			},
+		}
+
+		mockedAxios.get.mockResolvedValue(mockResponse)
+
+		const models = await getChutesModels("test-api-key")
+
+		// Dynamic values should override hardcoded
+		expect(models[modelId]).toBeDefined()
+		expect(models[modelId].contextWindow).toBe(200000)
+		expect(models[modelId].maxTokens).toBe(10000)
+		expect(models[modelId].supportsImages).toBe(true)
+	})
+
+	it("should return hardcoded models when API returns empty", async () => {
+		const mockResponse = {
+			data: {
+				data: [],
+			},
+		}
+
+		mockedAxios.get.mockResolvedValue(mockResponse)
+
+		const models = await getChutesModels("test-api-key")
+
+		// Should still have hardcoded models
+		expect(Object.keys(models).length).toBeGreaterThan(0)
+		expect(models).toEqual(expect.objectContaining(chutesModels))
+	})
+
+	it("should return hardcoded models on API error", async () => {
+		mockedAxios.get.mockRejectedValue(new Error("Network error"))
+
+		const models = await getChutesModels("test-api-key")
+
+		// Should still have hardcoded models
+		expect(Object.keys(models).length).toBeGreaterThan(0)
+		expect(models).toEqual(chutesModels)
+	})
+
+	it("should work without API key", async () => {
+		const mockResponse = {
+			data: {
+				data: [],
+			},
+		}
+
+		mockedAxios.get.mockResolvedValue(mockResponse)
+
+		const models = await getChutesModels()
+
+		expect(mockedAxios.get).toHaveBeenCalledWith(
+			"https://llm.chutes.ai/v1/models",
+			expect.objectContaining({
+				headers: expect.not.objectContaining({
+					Authorization: expect.anything(),
+				}),
+			}),
+		)
+
+		expect(Object.keys(models).length).toBeGreaterThan(0)
+	})
+
+	it("should detect image support from input_modalities", async () => {
+		const mockResponse = {
+			data: {
+				data: [
+					{
+						id: "test/image-model",
+						object: "model",
+						owned_by: "test",
+						created: 1234567890,
+						context_length: 128000,
+						max_model_len: 8192,
+						input_modalities: ["text", "image"],
+					},
+				],
+			},
+		}
+
+		mockedAxios.get.mockResolvedValue(mockResponse)
+
+		const models = await getChutesModels("test-api-key")
+
+		expect(models["test/image-model"].supportsImages).toBe(true)
+	})
+
+	it("should detect native tool support from supported_features", async () => {
+		const mockResponse = {
+			data: {
+				data: [
+					{
+						id: "test/tools-model",
+						object: "model",
+						owned_by: "test",
+						created: 1234567890,
+						context_length: 128000,
+						max_model_len: 8192,
+						input_modalities: ["text"],
+						supported_features: ["json_mode", "tools", "reasoning"],
+					},
+				],
+			},
+		}
+
+		mockedAxios.get.mockResolvedValue(mockResponse)
+
+		const models = await getChutesModels("test-api-key")
+
+		expect(models["test/tools-model"].supportsNativeTools).toBe(true)
+	})
+
+	it("should not enable native tool support when tools is not in supported_features", async () => {
+		const mockResponse = {
+			data: {
+				data: [
+					{
+						id: "test/no-tools-model",
+						object: "model",
+						owned_by: "test",
+						created: 1234567890,
+						context_length: 128000,
+						max_model_len: 8192,
+						input_modalities: ["text"],
+						supported_features: ["json_mode", "reasoning"],
+					},
+				],
+			},
+		}
+
+		mockedAxios.get.mockResolvedValue(mockResponse)
+
+		const models = await getChutesModels("test-api-key")
+
+		expect(models["test/no-tools-model"].supportsNativeTools).toBe(false)
+		expect(models["test/no-tools-model"].defaultToolProtocol).toBeUndefined()
+	})
+})

+ 4 - 1
src/api/providers/fetchers/chutes.ts

@@ -1,7 +1,7 @@
 import axios from "axios"
 import { z } from "zod"
 
-import { type ModelInfo, chutesModels } from "@roo-code/types"
+import { type ModelInfo, TOOL_PROTOCOL, chutesModels } from "@roo-code/types"
 
 import { DEFAULT_HEADERS } from "../constants"
 
@@ -14,6 +14,7 @@ const ChutesModelSchema = z.object({
 	context_length: z.number(),
 	max_model_len: z.number(),
 	input_modalities: z.array(z.string()),
+	supported_features: z.array(z.string()).optional(),
 })
 
 const ChutesModelsResponseSchema = z.object({ data: z.array(ChutesModelSchema) })
@@ -37,12 +38,14 @@ export async function getChutesModels(apiKey?: string): Promise<Record<string, M
 			const contextWindow = m.context_length
 			const maxTokens = m.max_model_len
 			const supportsImages = m.input_modalities.includes("image")
+			const supportsNativeTools = m.supported_features?.includes("tools") ?? false
 
 			const info: ModelInfo = {
 				maxTokens,
 				contextWindow,
 				supportsImages,
 				supportsPromptCache: false,
+				supportsNativeTools,
 				inputPrice: 0,
 				outputPrice: 0,
 				description: `Chutes AI model: ${m.id}`,

+ 1 - 0
src/api/providers/fetchers/modelCache.ts

@@ -254,6 +254,7 @@ export async function initializeModelCacheRefresh(): Promise<void> {
 			{ provider: "openrouter", options: { provider: "openrouter" } },
 			{ provider: "glama", options: { provider: "glama" } },
 			{ provider: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } },
+			{ provider: "chutes", options: { provider: "chutes" } },
 		]
 
 		// Refresh each provider in background (fire and forget)