|
|
@@ -2,62 +2,33 @@ import {
|
|
|
type ProviderName,
|
|
|
type ProviderSettings,
|
|
|
type ModelInfo,
|
|
|
- anthropicDefaultModelId,
|
|
|
anthropicModels,
|
|
|
- bedrockDefaultModelId,
|
|
|
bedrockModels,
|
|
|
- cerebrasDefaultModelId,
|
|
|
cerebrasModels,
|
|
|
- deepSeekDefaultModelId,
|
|
|
deepSeekModels,
|
|
|
- moonshotDefaultModelId,
|
|
|
moonshotModels,
|
|
|
- minimaxDefaultModelId,
|
|
|
minimaxModels,
|
|
|
- geminiDefaultModelId,
|
|
|
geminiModels,
|
|
|
- mistralDefaultModelId,
|
|
|
mistralModels,
|
|
|
openAiModelInfoSaneDefaults,
|
|
|
- openAiNativeDefaultModelId,
|
|
|
openAiNativeModels,
|
|
|
- vertexDefaultModelId,
|
|
|
vertexModels,
|
|
|
- xaiDefaultModelId,
|
|
|
xaiModels,
|
|
|
groqModels,
|
|
|
- groqDefaultModelId,
|
|
|
- chutesDefaultModelId,
|
|
|
vscodeLlmModels,
|
|
|
vscodeLlmDefaultModelId,
|
|
|
- openRouterDefaultModelId,
|
|
|
- requestyDefaultModelId,
|
|
|
- glamaDefaultModelId,
|
|
|
- unboundDefaultModelId,
|
|
|
- litellmDefaultModelId,
|
|
|
- claudeCodeDefaultModelId,
|
|
|
claudeCodeModels,
|
|
|
sambaNovaModels,
|
|
|
- sambaNovaDefaultModelId,
|
|
|
doubaoModels,
|
|
|
- doubaoDefaultModelId,
|
|
|
- internationalZAiDefaultModelId,
|
|
|
- mainlandZAiDefaultModelId,
|
|
|
internationalZAiModels,
|
|
|
mainlandZAiModels,
|
|
|
fireworksModels,
|
|
|
- fireworksDefaultModelId,
|
|
|
featherlessModels,
|
|
|
- featherlessDefaultModelId,
|
|
|
- ioIntelligenceDefaultModelId,
|
|
|
ioIntelligenceModels,
|
|
|
- rooDefaultModelId,
|
|
|
- qwenCodeDefaultModelId,
|
|
|
qwenCodeModels,
|
|
|
- vercelAiGatewayDefaultModelId,
|
|
|
BEDROCK_1M_CONTEXT_MODEL_IDS,
|
|
|
- deepInfraDefaultModelId,
|
|
|
isDynamicProvider,
|
|
|
+ getProviderDefaultModelId,
|
|
|
} from "@roo-code/types"
|
|
|
|
|
|
import type { ModelRecord, RouterModels } from "@roo/api"
|
|
|
@@ -67,6 +38,18 @@ import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
|
|
|
import { useLmStudioModels } from "./useLmStudioModels"
|
|
|
import { useOllamaModels } from "./useOllamaModels"
|
|
|
|
|
|
+/**
|
|
|
+ * Helper to get a validated model ID for dynamic providers.
|
|
|
+ * Returns the configured model ID if it exists in the available models, otherwise returns the default.
|
|
|
+ */
|
|
|
+function getValidatedModelId(
|
|
|
+ configuredId: string | undefined,
|
|
|
+ availableModels: ModelRecord | undefined,
|
|
|
+ defaultModelId: string,
|
|
|
+): string {
|
|
|
+ return configuredId && availableModels?.[configuredId] ? configuredId : defaultModelId
|
|
|
+}
|
|
|
+
|
|
|
export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
|
|
|
const provider = apiConfiguration?.apiProvider || "anthropic"
|
|
|
const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined
|
|
|
@@ -90,10 +73,17 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
|
|
|
const needLmStudio = typeof lmStudioModelId !== "undefined"
|
|
|
const needOllama = typeof ollamaModelId !== "undefined"
|
|
|
|
|
|
+ const hasValidRouterData = needRouterModels
|
|
|
+ ? routerModels.data &&
|
|
|
+ routerModels.data[provider] !== undefined &&
|
|
|
+ typeof routerModels.data[provider] === "object" &&
|
|
|
+ !routerModels.isLoading
|
|
|
+ : true
|
|
|
+
|
|
|
const isReady =
|
|
|
(!needLmStudio || typeof lmStudioModels.data !== "undefined") &&
|
|
|
(!needOllama || typeof ollamaModels.data !== "undefined") &&
|
|
|
- (!needRouterModels || typeof routerModels.data !== "undefined") &&
|
|
|
+ hasValidRouterData &&
|
|
|
(!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined")
|
|
|
|
|
|
const { id, info } =
|
|
|
@@ -106,7 +96,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
|
|
|
lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined,
|
|
|
ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined,
|
|
|
})
|
|
|
- : { id: anthropicDefaultModelId, info: undefined }
|
|
|
+ : { id: getProviderDefaultModelId(provider), info: undefined }
|
|
|
|
|
|
return {
|
|
|
provider,
|
|
|
@@ -143,10 +133,11 @@ function getSelectedModel({
|
|
|
// the `undefined` case are used to show the invalid selection to prevent
|
|
|
// users from seeing the default model if their selection is invalid
|
|
|
// this gives a better UX than showing the default model
|
|
|
+ const defaultModelId = getProviderDefaultModelId(provider)
|
|
|
switch (provider) {
|
|
|
case "openrouter": {
|
|
|
- const id = apiConfiguration.openRouterModelId ?? openRouterDefaultModelId
|
|
|
- let info = routerModels.openrouter[id]
|
|
|
+ const id = getValidatedModelId(apiConfiguration.openRouterModelId, routerModels.openrouter, defaultModelId)
|
|
|
+ let info = routerModels.openrouter?.[id]
|
|
|
const specificProvider = apiConfiguration.openRouterSpecificProvider
|
|
|
|
|
|
if (specificProvider && openRouterModelProviders[specificProvider]) {
|
|
|
@@ -161,32 +152,32 @@ function getSelectedModel({
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "requesty": {
|
|
|
- const id = apiConfiguration.requestyModelId ?? requestyDefaultModelId
|
|
|
- const info = routerModels.requesty[id]
|
|
|
+ const id = getValidatedModelId(apiConfiguration.requestyModelId, routerModels.requesty, defaultModelId)
|
|
|
+ const info = routerModels.requesty?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "glama": {
|
|
|
- const id = apiConfiguration.glamaModelId ?? glamaDefaultModelId
|
|
|
- const info = routerModels.glama[id]
|
|
|
+ const id = getValidatedModelId(apiConfiguration.glamaModelId, routerModels.glama, defaultModelId)
|
|
|
+ const info = routerModels.glama?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "unbound": {
|
|
|
- const id = apiConfiguration.unboundModelId ?? unboundDefaultModelId
|
|
|
- const info = routerModels.unbound[id]
|
|
|
+ const id = getValidatedModelId(apiConfiguration.unboundModelId, routerModels.unbound, defaultModelId)
|
|
|
+ const info = routerModels.unbound?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "litellm": {
|
|
|
- const id = apiConfiguration.litellmModelId ?? litellmDefaultModelId
|
|
|
- const info = routerModels.litellm[id]
|
|
|
+ const id = getValidatedModelId(apiConfiguration.litellmModelId, routerModels.litellm, defaultModelId)
|
|
|
+ const info = routerModels.litellm?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "xai": {
|
|
|
- const id = apiConfiguration.apiModelId ?? xaiDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = xaiModels[id as keyof typeof xaiModels]
|
|
|
return info ? { id, info } : { id, info: undefined }
|
|
|
}
|
|
|
case "groq": {
|
|
|
- const id = apiConfiguration.apiModelId ?? groqDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = groqModels[id as keyof typeof groqModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
@@ -201,12 +192,12 @@ function getSelectedModel({
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "chutes": {
|
|
|
- const id = apiConfiguration.apiModelId ?? chutesDefaultModelId
|
|
|
- const info = routerModels.chutes[id]
|
|
|
+ const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.chutes, defaultModelId)
|
|
|
+ const info = routerModels.chutes?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "bedrock": {
|
|
|
- const id = apiConfiguration.apiModelId ?? bedrockDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const baseInfo = bedrockModels[id as keyof typeof bedrockModels]
|
|
|
|
|
|
// Special case for custom ARN.
|
|
|
@@ -230,50 +221,50 @@ function getSelectedModel({
|
|
|
return { id, info: baseInfo }
|
|
|
}
|
|
|
case "vertex": {
|
|
|
- const id = apiConfiguration.apiModelId ?? vertexDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = vertexModels[id as keyof typeof vertexModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "gemini": {
|
|
|
- const id = apiConfiguration.apiModelId ?? geminiDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = geminiModels[id as keyof typeof geminiModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "deepseek": {
|
|
|
- const id = apiConfiguration.apiModelId ?? deepSeekDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = deepSeekModels[id as keyof typeof deepSeekModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "doubao": {
|
|
|
- const id = apiConfiguration.apiModelId ?? doubaoDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = doubaoModels[id as keyof typeof doubaoModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "moonshot": {
|
|
|
- const id = apiConfiguration.apiModelId ?? moonshotDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = moonshotModels[id as keyof typeof moonshotModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "minimax": {
|
|
|
- const id = apiConfiguration.apiModelId ?? minimaxDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = minimaxModels[id as keyof typeof minimaxModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "zai": {
|
|
|
const isChina = apiConfiguration.zaiApiLine === "china_coding"
|
|
|
const models = isChina ? mainlandZAiModels : internationalZAiModels
|
|
|
- const defaultModelId = isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId
|
|
|
+ const defaultModelId = getProviderDefaultModelId(provider, { isChina })
|
|
|
const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = models[id as keyof typeof models]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "openai-native": {
|
|
|
- const id = apiConfiguration.apiModelId ?? openAiNativeDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = openAiNativeModels[id as keyof typeof openAiNativeModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "mistral": {
|
|
|
- const id = apiConfiguration.apiModelId ?? mistralDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = mistralModels[id as keyof typeof mistralModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
@@ -307,7 +298,7 @@ function getSelectedModel({
|
|
|
}
|
|
|
}
|
|
|
case "deepinfra": {
|
|
|
- const id = apiConfiguration.deepInfraModelId ?? deepInfraDefaultModelId
|
|
|
+ const id = getValidatedModelId(apiConfiguration.deepInfraModelId, routerModels.deepinfra, defaultModelId)
|
|
|
const info = routerModels.deepinfra?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
@@ -321,49 +312,56 @@ function getSelectedModel({
|
|
|
}
|
|
|
case "claude-code": {
|
|
|
// Claude Code models extend anthropic models but with images and prompt caching disabled
|
|
|
- const id = apiConfiguration.apiModelId ?? claudeCodeDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = claudeCodeModels[id as keyof typeof claudeCodeModels]
|
|
|
return { id, info: { ...openAiModelInfoSaneDefaults, ...info } }
|
|
|
}
|
|
|
case "cerebras": {
|
|
|
- const id = apiConfiguration.apiModelId ?? cerebrasDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = cerebrasModels[id as keyof typeof cerebrasModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "sambanova": {
|
|
|
- const id = apiConfiguration.apiModelId ?? sambaNovaDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = sambaNovaModels[id as keyof typeof sambaNovaModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "fireworks": {
|
|
|
- const id = apiConfiguration.apiModelId ?? fireworksDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = fireworksModels[id as keyof typeof fireworksModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "featherless": {
|
|
|
- const id = apiConfiguration.apiModelId ?? featherlessDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = featherlessModels[id as keyof typeof featherlessModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "io-intelligence": {
|
|
|
- const id = apiConfiguration.ioIntelligenceModelId ?? ioIntelligenceDefaultModelId
|
|
|
+ const id = getValidatedModelId(
|
|
|
+ apiConfiguration.ioIntelligenceModelId,
|
|
|
+ routerModels["io-intelligence"],
|
|
|
+ defaultModelId,
|
|
|
+ )
|
|
|
const info =
|
|
|
routerModels["io-intelligence"]?.[id] ?? ioIntelligenceModels[id as keyof typeof ioIntelligenceModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "roo": {
|
|
|
- // Roo is a dynamic provider - models are loaded from API
|
|
|
- const id = apiConfiguration.apiModelId ?? rooDefaultModelId
|
|
|
- const info = routerModels.roo[id]
|
|
|
+ const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.roo, defaultModelId)
|
|
|
+ const info = routerModels.roo?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "qwen-code": {
|
|
|
- const id = apiConfiguration.apiModelId ?? qwenCodeDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const info = qwenCodeModels[id as keyof typeof qwenCodeModels]
|
|
|
return { id, info }
|
|
|
}
|
|
|
case "vercel-ai-gateway": {
|
|
|
- const id = apiConfiguration.vercelAiGatewayModelId ?? vercelAiGatewayDefaultModelId
|
|
|
+ const id = getValidatedModelId(
|
|
|
+ apiConfiguration.vercelAiGatewayModelId,
|
|
|
+ routerModels["vercel-ai-gateway"],
|
|
|
+ defaultModelId,
|
|
|
+ )
|
|
|
const info = routerModels["vercel-ai-gateway"]?.[id]
|
|
|
return { id, info }
|
|
|
}
|
|
|
@@ -372,7 +370,7 @@ function getSelectedModel({
|
|
|
// case "fake-ai":
|
|
|
default: {
|
|
|
provider satisfies "anthropic" | "gemini-cli" | "qwen-code" | "human-relay" | "fake-ai"
|
|
|
- const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId
|
|
|
+ const id = apiConfiguration.apiModelId ?? defaultModelId
|
|
|
const baseInfo = anthropicModels[id as keyof typeof anthropicModels]
|
|
|
|
|
|
// Apply 1M context beta tier pricing for Claude Sonnet 4
|