Просмотр исходного кода

feat: Add DeepInfra as a model provider in Roo Code (#7677)

Thach Nguyen 4 месяцев назад
Родитель
Сommit
72502d8f1a

+ 6 - 0
.changeset/petite-rats-admire.md

@@ -0,0 +1,6 @@
+---
+"roo-cline": minor
+"@roo-code/types": patch
+---
+
+Added DeepInfra provider with dynamic model fetching and prompt caching

+ 1 - 0
packages/types/src/global-settings.ts

@@ -192,6 +192,7 @@ export const SECRET_STATE_KEYS = [
 	"groqApiKey",
 	"chutesApiKey",
 	"litellmApiKey",
+	"deepInfraApiKey",
 	"codeIndexOpenAiKey",
 	"codeIndexQdrantApiKey",
 	"codebaseIndexOpenAiCompatibleApiKey",

+ 12 - 0
packages/types/src/provider-settings.ts

@@ -48,6 +48,7 @@ export const providerNames = [
 	"mistral",
 	"moonshot",
 	"deepseek",
+	"deepinfra",
 	"doubao",
 	"qwen-code",
 	"unbound",
@@ -236,6 +237,12 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({
 	deepSeekApiKey: z.string().optional(),
 })
 
+const deepInfraSchema = apiModelIdProviderModelSchema.extend({
+	deepInfraBaseUrl: z.string().optional(),
+	deepInfraApiKey: z.string().optional(),
+	deepInfraModelId: z.string().optional(),
+})
+
 const doubaoSchema = apiModelIdProviderModelSchema.extend({
 	doubaoBaseUrl: z.string().optional(),
 	doubaoApiKey: z.string().optional(),
@@ -349,6 +356,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
 	openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })),
 	mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })),
 	deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })),
+	deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })),
 	doubaoSchema.merge(z.object({ apiProvider: z.literal("doubao") })),
 	moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
 	unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
@@ -389,6 +397,7 @@ export const providerSettingsSchema = z.object({
 	...openAiNativeSchema.shape,
 	...mistralSchema.shape,
 	...deepSeekSchema.shape,
+	...deepInfraSchema.shape,
 	...doubaoSchema.shape,
 	...moonshotSchema.shape,
 	...unboundSchema.shape,
@@ -438,6 +447,7 @@ export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
 	"huggingFaceModelId",
 	"ioIntelligenceModelId",
 	"vercelAiGatewayModelId",
+	"deepInfraModelId",
 ]
 
 export const getModelId = (settings: ProviderSettings): string | undefined => {
@@ -559,6 +569,7 @@ export const MODELS_BY_PROVIDER: Record<
 	openrouter: { id: "openrouter", label: "OpenRouter", models: [] },
 	requesty: { id: "requesty", label: "Requesty", models: [] },
 	unbound: { id: "unbound", label: "Unbound", models: [] },
+	deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] },
 	"vercel-ai-gateway": { id: "vercel-ai-gateway", label: "Vercel AI Gateway", models: [] },
 }
 
@@ -569,6 +580,7 @@ export const dynamicProviders = [
 	"openrouter",
 	"requesty",
 	"unbound",
+	"deepinfra",
 	"vercel-ai-gateway",
 ] as const satisfies readonly ProviderName[]
 

+ 14 - 0
packages/types/src/providers/deepinfra.ts

@@ -0,0 +1,14 @@
+import type { ModelInfo } from "../model.js"
+
+// Default fallback values for DeepInfra when model metadata is not yet loaded.
+export const deepInfraDefaultModelId = "Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo"
+
+export const deepInfraDefaultModelInfo: ModelInfo = {
+	maxTokens: 16384,
+	contextWindow: 262144,
+	supportsImages: false,
+	supportsPromptCache: false,
+	inputPrice: 0.3,
+	outputPrice: 1.2,
+	description: "Qwen 3 Coder 480B A35B Instruct Turbo model, 256K context.",
+}

+ 1 - 0
packages/types/src/providers/index.ts

@@ -29,3 +29,4 @@ export * from "./vscode-llm.js"
 export * from "./xai.js"
 export * from "./vercel-ai-gateway.js"
 export * from "./zai.js"
+export * from "./deepinfra.js"

+ 3 - 0
src/api/index.ts

@@ -39,6 +39,7 @@ import {
 	RooHandler,
 	FeatherlessHandler,
 	VercelAiGatewayHandler,
+	DeepInfraHandler,
 } from "./providers"
 import { NativeOllamaHandler } from "./providers/native-ollama"
 
@@ -138,6 +139,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
 			return new XAIHandler(options)
 		case "groq":
 			return new GroqHandler(options)
+		case "deepinfra":
+			return new DeepInfraHandler(options)
 		case "huggingface":
 			return new HuggingFaceHandler(options)
 		case "chutes":

+ 147 - 0
src/api/providers/deepinfra.ts

@@ -0,0 +1,147 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
+
+import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types"
+
+import type { ApiHandlerOptions } from "../../shared/api"
+import { calculateApiCostOpenAI } from "../../shared/cost"
+
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+import { RouterProvider } from "./router-provider"
+import { getModelParams } from "../transform/model-params"
+import { getModels } from "./fetchers/modelCache"
+
+export class DeepInfraHandler extends RouterProvider implements SingleCompletionHandler {
+	constructor(options: ApiHandlerOptions) {
+		super({
+			options: {
+				...options,
+				openAiHeaders: {
+					"X-Deepinfra-Source": "roo-code",
+					"X-Deepinfra-Version": `2025-08-25`,
+				},
+			},
+			name: "deepinfra",
+			baseURL: `${options.deepInfraBaseUrl || "https://api.deepinfra.com/v1/openai"}`,
+			apiKey: options.deepInfraApiKey || "not-provided",
+			modelId: options.deepInfraModelId,
+			defaultModelId: deepInfraDefaultModelId,
+			defaultModelInfo: deepInfraDefaultModelInfo,
+		})
+	}
+
+	public override async fetchModel() {
+		this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL })
+		return this.getModel()
+	}
+
+	override getModel() {
+		const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId
+		const info = this.models[id] ?? deepInfraDefaultModelInfo
+
+		const params = getModelParams({
+			format: "openai",
+			modelId: id,
+			model: info,
+			settings: this.options,
+		})
+
+		return { id, info, ...params }
+	}
+
+	override async *createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		_metadata?: ApiHandlerCreateMessageMetadata,
+	): ApiStream {
+		// Ensure we have up-to-date model metadata
+		await this.fetchModel()
+		const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel()
+		let prompt_cache_key = undefined
+		if (info.supportsPromptCache && _metadata?.taskId) {
+			prompt_cache_key = _metadata.taskId
+		}
+
+		const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
+			model: modelId,
+			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
+			stream: true,
+			stream_options: { include_usage: true },
+			reasoning_effort,
+			prompt_cache_key,
+		} as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
+
+		if (this.supportsTemperature(modelId)) {
+			requestOptions.temperature = this.options.modelTemperature ?? 0
+		}
+
+		if (this.options.includeMaxTokens === true && info.maxTokens) {
+			;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
+		}
+
+		const { data: stream } = await this.client.chat.completions.create(requestOptions).withResponse()
+
+		let lastUsage: OpenAI.CompletionUsage | undefined
+		for await (const chunk of stream) {
+			const delta = chunk.choices[0]?.delta
+
+			if (delta?.content) {
+				yield { type: "text", text: delta.content }
+			}
+
+			if (delta && "reasoning_content" in delta && delta.reasoning_content) {
+				yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
+			}
+
+			if (chunk.usage) {
+				lastUsage = chunk.usage
+			}
+		}
+
+		if (lastUsage) {
+			yield this.processUsageMetrics(lastUsage, info)
+		}
+	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		await this.fetchModel()
+		const { id: modelId, info } = this.getModel()
+
+		const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+			model: modelId,
+			messages: [{ role: "user", content: prompt }],
+		}
+		if (this.supportsTemperature(modelId)) {
+			requestOptions.temperature = this.options.modelTemperature ?? 0
+		}
+		if (this.options.includeMaxTokens === true && info.maxTokens) {
+			;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
+		}
+
+		const resp = await this.client.chat.completions.create(requestOptions)
+		return resp.choices[0]?.message?.content || ""
+	}
+
+	protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
+		const inputTokens = usage?.prompt_tokens || 0
+		const outputTokens = usage?.completion_tokens || 0
+		const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
+		const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
+
+		const totalCost = modelInfo
+			? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
+			: 0
+
+		return {
+			type: "usage",
+			inputTokens,
+			outputTokens,
+			cacheWriteTokens: cacheWriteTokens || undefined,
+			cacheReadTokens: cacheReadTokens || undefined,
+			totalCost,
+		}
+	}
+}

+ 71 - 0
src/api/providers/fetchers/deepinfra.ts

@@ -0,0 +1,71 @@
+import axios from "axios"
+import { z } from "zod"
+
+import { type ModelInfo } from "@roo-code/types"
+
+import { DEFAULT_HEADERS } from "../constants"
+
+// DeepInfra models endpoint follows OpenAI /models shape with an added metadata object.
+
+const DeepInfraModelSchema = z.object({
+	id: z.string(),
+	object: z.literal("model").optional(),
+	owned_by: z.string().optional(),
+	created: z.number().optional(),
+	root: z.string().optional(),
+	metadata: z
+		.object({
+			description: z.string().optional(),
+			context_length: z.number().optional(),
+			max_tokens: z.number().optional(),
+			tags: z.array(z.string()).optional(), // e.g., ["vision", "prompt_cache"]
+			pricing: z
+				.object({
+					input_tokens: z.number().optional(),
+					output_tokens: z.number().optional(),
+					cache_read_tokens: z.number().optional(),
+				})
+				.optional(),
+		})
+		.optional(),
+})
+
+const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSchema) })
+
+export async function getDeepInfraModels(
+	apiKey?: string,
+	baseUrl: string = "https://api.deepinfra.com/v1/openai",
+): Promise<Record<string, ModelInfo>> {
+	const headers: Record<string, string> = { ...DEFAULT_HEADERS }
+	if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`
+
+	const url = `${baseUrl.replace(/\/$/, "")}/models`
+	const models: Record<string, ModelInfo> = {}
+
+	const response = await axios.get(url, { headers })
+	const parsed = DeepInfraModelsResponseSchema.safeParse(response.data)
+	const data = parsed.success ? parsed.data.data : response.data?.data || []
+
+	for (const m of data as Array<z.infer<typeof DeepInfraModelSchema>>) {
+		const meta = m.metadata || {}
+		const tags = meta.tags || []
+
+		const contextWindow = typeof meta.context_length === "number" ? meta.context_length : 8192
+		const maxTokens = typeof meta.max_tokens === "number" ? meta.max_tokens : Math.ceil(contextWindow * 0.2)
+
+		const info: ModelInfo = {
+			maxTokens,
+			contextWindow,
+			supportsImages: tags.includes("vision"),
+			supportsPromptCache: tags.includes("prompt_cache"),
+			inputPrice: meta.pricing?.input_tokens,
+			outputPrice: meta.pricing?.output_tokens,
+			cacheReadsPrice: meta.pricing?.cache_read_tokens,
+			description: meta.description,
+		}
+
+		models[m.id] = info
+	}
+
+	return models
+}

+ 4 - 0
src/api/providers/fetchers/modelCache.ts

@@ -19,6 +19,7 @@ import { GetModelsOptions } from "../../../shared/api"
 import { getOllamaModels } from "./ollama"
 import { getLMStudioModels } from "./lmstudio"
 import { getIOIntelligenceModels } from "./io-intelligence"
+import { getDeepInfraModels } from "./deepinfra"
 const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
 
 async function writeModels(router: RouterName, data: ModelRecord) {
@@ -79,6 +80,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
 			case "lmstudio":
 				models = await getLMStudioModels(options.baseUrl)
 				break
+			case "deepinfra":
+				models = await getDeepInfraModels(options.apiKey, options.baseUrl)
+				break
 			case "io-intelligence":
 				models = await getIOIntelligenceModels(options.apiKey)
 				break

+ 1 - 0
src/api/providers/index.ts

@@ -33,3 +33,4 @@ export { FireworksHandler } from "./fireworks"
 export { RooHandler } from "./roo"
 export { FeatherlessHandler } from "./featherless"
 export { VercelAiGatewayHandler } from "./vercel-ai-gateway"
+export { DeepInfraHandler } from "./deepinfra"

+ 4 - 0
src/core/webview/__tests__/ClineProvider.spec.ts

@@ -2680,6 +2680,7 @@ describe("ClineProvider - Router Models", () => {
 		expect(mockPostMessage).toHaveBeenCalledWith({
 			type: "routerModels",
 			routerModels: {
+				deepinfra: mockModels,
 				openrouter: mockModels,
 				requesty: mockModels,
 				glama: mockModels,
@@ -2719,6 +2720,7 @@ describe("ClineProvider - Router Models", () => {
 			.mockResolvedValueOnce(mockModels) // glama success
 			.mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail
 			.mockResolvedValueOnce(mockModels) // vercel-ai-gateway success
+			.mockResolvedValueOnce(mockModels) // deepinfra success
 			.mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail
 
 		await messageHandler({ type: "requestRouterModels" })
@@ -2727,6 +2729,7 @@ describe("ClineProvider - Router Models", () => {
 		expect(mockPostMessage).toHaveBeenCalledWith({
 			type: "routerModels",
 			routerModels: {
+				deepinfra: mockModels,
 				openrouter: mockModels,
 				requesty: {},
 				glama: mockModels,
@@ -2838,6 +2841,7 @@ describe("ClineProvider - Router Models", () => {
 		expect(mockPostMessage).toHaveBeenCalledWith({
 			type: "routerModels",
 			routerModels: {
+				deepinfra: mockModels,
 				openrouter: mockModels,
 				requesty: mockModels,
 				glama: mockModels,

+ 13 - 0
src/core/webview/__tests__/webviewMessageHandler.spec.ts

@@ -174,6 +174,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
 		})
 
 		// Verify getModels was called for each provider
+		expect(mockGetModels).toHaveBeenCalledWith({ provider: "deepinfra" })
 		expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" })
 		expect(mockGetModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" })
 		expect(mockGetModels).toHaveBeenCalledWith({ provider: "glama" })
@@ -189,6 +190,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
 		expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
 			type: "routerModels",
 			routerModels: {
+				deepinfra: mockModels,
 				openrouter: mockModels,
 				requesty: mockModels,
 				glama: mockModels,
@@ -277,6 +279,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
 		expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
 			type: "routerModels",
 			routerModels: {
+				deepinfra: mockModels,
 				openrouter: mockModels,
 				requesty: mockModels,
 				glama: mockModels,
@@ -306,6 +309,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
 			.mockResolvedValueOnce(mockModels) // glama
 			.mockRejectedValueOnce(new Error("Unbound API error")) // unbound
 			.mockResolvedValueOnce(mockModels) // vercel-ai-gateway
+			.mockResolvedValueOnce(mockModels) // deepinfra
 			.mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm
 
 		await webviewMessageHandler(mockClineProvider, {
@@ -316,6 +320,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
 		expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
 			type: "routerModels",
 			routerModels: {
+				deepinfra: mockModels,
 				openrouter: mockModels,
 				requesty: {},
 				glama: mockModels,
@@ -358,6 +363,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
 			.mockRejectedValueOnce(new Error("Glama API error")) // glama
 			.mockRejectedValueOnce(new Error("Unbound API error")) // unbound
 			.mockRejectedValueOnce(new Error("Vercel AI Gateway error")) // vercel-ai-gateway
+			.mockRejectedValueOnce(new Error("DeepInfra API error")) // deepinfra
 			.mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm
 
 		await webviewMessageHandler(mockClineProvider, {
@@ -393,6 +399,13 @@ describe("webviewMessageHandler - requestRouterModels", () => {
 			values: { provider: "unbound" },
 		})
 
+		expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
+			type: "singleRouterModelFetchResponse",
+			success: false,
+			error: "DeepInfra API error",
+			values: { provider: "deepinfra" },
+		})
+
 		expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
 			type: "singleRouterModelFetchResponse",
 			success: false,

+ 9 - 0
src/core/webview/webviewMessageHandler.ts

@@ -550,6 +550,7 @@ export const webviewMessageHandler = async (
 				litellm: {},
 				ollama: {},
 				lmstudio: {},
+				deepinfra: {},
 			}
 
 			const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
@@ -577,6 +578,14 @@ export const webviewMessageHandler = async (
 				{ key: "glama", options: { provider: "glama" } },
 				{ key: "unbound", options: { provider: "unbound", apiKey: apiConfiguration.unboundApiKey } },
 				{ key: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } },
+				{
+					key: "deepinfra",
+					options: {
+						provider: "deepinfra",
+						apiKey: apiConfiguration.deepInfraApiKey,
+						baseUrl: apiConfiguration.deepInfraBaseUrl,
+					},
+				},
 			]
 
 			// Add IO Intelligence if API key is provided

+ 2 - 0
src/shared/ProfileValidator.ts

@@ -90,6 +90,8 @@ export class ProfileValidator {
 				return profile.requestyModelId
 			case "io-intelligence":
 				return profile.ioIntelligenceModelId
+			case "deepinfra":
+				return profile.deepInfraModelId
 			case "human-relay":
 			case "fake-ai":
 			default:

+ 2 - 0
src/shared/api.ts

@@ -27,6 +27,7 @@ const routerNames = [
 	"ollama",
 	"lmstudio",
 	"io-intelligence",
+	"deepinfra",
 	"vercel-ai-gateway",
 ] as const
 
@@ -151,5 +152,6 @@ export type GetModelsOptions =
 	| { provider: "litellm"; apiKey: string; baseUrl: string }
 	| { provider: "ollama"; baseUrl?: string }
 	| { provider: "lmstudio"; baseUrl?: string }
+	| { provider: "deepinfra"; apiKey?: string; baseUrl?: string }
 	| { provider: "io-intelligence"; apiKey: string }
 	| { provider: "vercel-ai-gateway" }

+ 18 - 0
webview-ui/src/components/settings/ApiOptions.tsx

@@ -36,6 +36,7 @@ import {
 	ioIntelligenceDefaultModelId,
 	rooDefaultModelId,
 	vercelAiGatewayDefaultModelId,
+	deepInfraDefaultModelId,
 } from "@roo-code/types"
 
 import { vscode } from "@src/utils/vscode"
@@ -93,6 +94,7 @@ import {
 	Fireworks,
 	Featherless,
 	VercelAiGateway,
+	DeepInfra,
 } from "./providers"
 
 import { MODELS_BY_PROVIDER, PROVIDERS } from "./constants"
@@ -226,6 +228,8 @@ const ApiOptions = ({
 				vscode.postMessage({ type: "requestVsCodeLmModels" })
 			} else if (selectedProvider === "litellm") {
 				vscode.postMessage({ type: "requestRouterModels" })
+			} else if (selectedProvider === "deepinfra") {
+				vscode.postMessage({ type: "requestRouterModels" })
 			}
 		},
 		250,
@@ -238,6 +242,8 @@ const ApiOptions = ({
 			apiConfiguration?.lmStudioBaseUrl,
 			apiConfiguration?.litellmBaseUrl,
 			apiConfiguration?.litellmApiKey,
+			apiConfiguration?.deepInfraApiKey,
+			apiConfiguration?.deepInfraBaseUrl,
 			customHeaders,
 		],
 	)
@@ -305,6 +311,7 @@ const ApiOptions = ({
 					}
 				>
 			> = {
+				deepinfra: { field: "deepInfraModelId", default: deepInfraDefaultModelId },
 				openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId },
 				glama: { field: "glamaModelId", default: glamaDefaultModelId },
 				unbound: { field: "unboundModelId", default: unboundDefaultModelId },
@@ -487,6 +494,17 @@ const ApiOptions = ({
 				/>
 			)}
 
+			{selectedProvider === "deepinfra" && (
+				<DeepInfra
+					apiConfiguration={apiConfiguration}
+					setApiConfigurationField={setApiConfigurationField}
+					routerModels={routerModels}
+					refetchRouterModels={refetchRouterModels}
+					organizationAllowList={organizationAllowList}
+					modelValidationError={modelValidationError}
+				/>
+			)}
+
 			{selectedProvider === "anthropic" && (
 				<Anthropic apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} />
 			)}

+ 1 - 0
webview-ui/src/components/settings/ModelPicker.tsx

@@ -34,6 +34,7 @@ type ModelIdKey = keyof Pick<
 	| "requestyModelId"
 	| "openAiModelId"
 	| "litellmModelId"
+	| "deepInfraModelId"
 	| "ioIntelligenceModelId"
 	| "vercelAiGatewayModelId"
 >

+ 1 - 0
webview-ui/src/components/settings/constants.ts

@@ -48,6 +48,7 @@ export const MODELS_BY_PROVIDER: Partial<Record<ProviderName, Record<string, Mod
 
 export const PROVIDERS = [
 	{ value: "openrouter", label: "OpenRouter" },
+	{ value: "deepinfra", label: "DeepInfra" },
 	{ value: "anthropic", label: "Anthropic" },
 	{ value: "claude-code", label: "Claude Code" },
 	{ value: "cerebras", label: "Cerebras" },

+ 94 - 0
webview-ui/src/components/settings/providers/DeepInfra.tsx

@@ -0,0 +1,94 @@
+import { useCallback, useEffect, useState } from "react"
+import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
+
+import { OrganizationAllowList, type ProviderSettings, deepInfraDefaultModelId } from "@roo-code/types"
+
+import type { RouterModels } from "@roo/api"
+
+import { vscode } from "@src/utils/vscode"
+import { useAppTranslation } from "@src/i18n/TranslationContext"
+import { Button } from "@src/components/ui"
+
+import { inputEventTransform } from "../transforms"
+import { ModelPicker } from "../ModelPicker"
+
+type DeepInfraProps = {
+	apiConfiguration: ProviderSettings
+	setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void
+	routerModels?: RouterModels
+	refetchRouterModels: () => void
+	organizationAllowList: OrganizationAllowList
+	modelValidationError?: string
+}
+
+export const DeepInfra = ({
+	apiConfiguration,
+	setApiConfigurationField,
+	routerModels,
+	refetchRouterModels,
+	organizationAllowList,
+	modelValidationError,
+}: DeepInfraProps) => {
+	const { t } = useAppTranslation()
+
+	const [didRefetch, setDidRefetch] = useState<boolean>()
+
+	const handleInputChange = useCallback(
+		<K extends keyof ProviderSettings, E>(
+			field: K,
+			transform: (event: E) => ProviderSettings[K] = inputEventTransform,
+		) =>
+			(event: E | Event) => {
+				setApiConfigurationField(field, transform(event as E))
+			},
+		[setApiConfigurationField],
+	)
+
+	useEffect(() => {
+		// When base URL or API key changes, trigger a silent refresh of models
+		// The outer ApiOptions debounces and sends requestRouterModels; this keeps UI responsive
+	}, [apiConfiguration.deepInfraBaseUrl, apiConfiguration.deepInfraApiKey])
+
+	return (
+		<>
+			<VSCodeTextField
+				value={apiConfiguration?.deepInfraApiKey || ""}
+				type="password"
+				onInput={handleInputChange("deepInfraApiKey")}
+				placeholder={t("settings:placeholders.apiKey")}
+				className="w-full">
+				<label className="block font-medium mb-1">{t("settings:providers.apiKey")}</label>
+			</VSCodeTextField>
+
+			<Button
+				variant="outline"
+				onClick={() => {
+					vscode.postMessage({ type: "flushRouterModels", text: "deepinfra" })
+					refetchRouterModels()
+					setDidRefetch(true)
+				}}>
+				<div className="flex items-center gap-2">
+					<span className="codicon codicon-refresh" />
+					{t("settings:providers.refreshModels.label")}
+				</div>
+			</Button>
+			{didRefetch && (
+				<div className="flex items-center text-vscode-errorForeground">
+					{t("settings:providers.refreshModels.hint")}
+				</div>
+			)}
+
+			<ModelPicker
+				apiConfiguration={apiConfiguration}
+				setApiConfigurationField={setApiConfigurationField}
+				defaultModelId={deepInfraDefaultModelId}
+				models={routerModels?.deepinfra ?? {}}
+				modelIdKey="deepInfraModelId"
+				serviceName="Deep Infra"
+				serviceUrl="https://deepinfra.com/models"
+				organizationAllowList={organizationAllowList}
+				errorMessage={modelValidationError}
+			/>
+		</>
+	)
+}

+ 1 - 0
webview-ui/src/components/settings/providers/index.ts

@@ -29,3 +29,4 @@ export { LiteLLM } from "./LiteLLM"
 export { Fireworks } from "./Fireworks"
 export { Featherless } from "./Featherless"
 export { VercelAiGateway } from "./VercelAiGateway"
+export { DeepInfra } from "./DeepInfra"

+ 6 - 0
webview-ui/src/components/ui/hooks/useSelectedModel.ts

@@ -56,6 +56,7 @@ import {
 	qwenCodeModels,
 	vercelAiGatewayDefaultModelId,
 	BEDROCK_CLAUDE_SONNET_4_MODEL_ID,
+	deepInfraDefaultModelId,
 } from "@roo-code/types"
 
 import type { ModelRecord, RouterModels } from "@roo/api"
@@ -268,6 +269,11 @@ function getSelectedModel({
 				info: info || undefined,
 			}
 		}
+		case "deepinfra": {
+			const id = apiConfiguration.deepInfraModelId ?? deepInfraDefaultModelId
+			const info = routerModels.deepinfra?.[id]
+			return { id, info }
+		}
 		case "vscode-lm": {
 			const id = apiConfiguration?.vsCodeLmModelSelector
 				? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}`

+ 1 - 0
webview-ui/src/utils/__tests__/validate.test.ts

@@ -39,6 +39,7 @@ describe("Model Validation Functions", () => {
 		litellm: {},
 		ollama: {},
 		lmstudio: {},
+		deepinfra: {},
 		"io-intelligence": {},
 		"vercel-ai-gateway": {},
 	}

+ 10 - 0
webview-ui/src/utils/validate.ts

@@ -47,6 +47,11 @@ function validateModelsAndKeysProvided(apiConfiguration: ProviderSettings): stri
 				return i18next.t("settings:validation.apiKey")
 			}
 			break
+		case "deepinfra":
+			if (!apiConfiguration.deepInfraApiKey) {
+				return i18next.t("settings:validation.apiKey")
+			}
+			break
 		case "litellm":
 			if (!apiConfiguration.litellmApiKey) {
 				return i18next.t("settings:validation.apiKey")
@@ -193,6 +198,8 @@ function getModelIdForProvider(apiConfiguration: ProviderSettings, provider: str
 			return apiConfiguration.unboundModelId
 		case "requesty":
 			return apiConfiguration.requestyModelId
+		case "deepinfra":
+			return apiConfiguration.deepInfraModelId
 		case "litellm":
 			return apiConfiguration.litellmModelId
 		case "openai":
@@ -271,6 +278,9 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels
 		case "requesty":
 			modelId = apiConfiguration.requestyModelId
 			break
+		case "deepinfra":
+			modelId = apiConfiguration.deepInfraModelId
+			break
 		case "ollama":
 			modelId = apiConfiguration.ollamaModelId
 			break