Просмотр исходного кода

Fix dynamic provider model validation to prevent cross-contamination (#9054)

Daniel 1 месяц назад
Родитель
Сommit
d4aeca4205

+ 118 - 0
packages/types/src/providers/index.ts

@@ -31,3 +31,121 @@ export * from "./vercel-ai-gateway.js"
 export * from "./zai.js"
 export * from "./deepinfra.js"
 export * from "./minimax.js"
+
+import { anthropicDefaultModelId } from "./anthropic.js"
+import { bedrockDefaultModelId } from "./bedrock.js"
+import { cerebrasDefaultModelId } from "./cerebras.js"
+import { chutesDefaultModelId } from "./chutes.js"
+import { claudeCodeDefaultModelId } from "./claude-code.js"
+import { deepSeekDefaultModelId } from "./deepseek.js"
+import { doubaoDefaultModelId } from "./doubao.js"
+import { featherlessDefaultModelId } from "./featherless.js"
+import { fireworksDefaultModelId } from "./fireworks.js"
+import { geminiDefaultModelId } from "./gemini.js"
+import { glamaDefaultModelId } from "./glama.js"
+import { groqDefaultModelId } from "./groq.js"
+import { ioIntelligenceDefaultModelId } from "./io-intelligence.js"
+import { litellmDefaultModelId } from "./lite-llm.js"
+import { mistralDefaultModelId } from "./mistral.js"
+import { moonshotDefaultModelId } from "./moonshot.js"
+import { openRouterDefaultModelId } from "./openrouter.js"
+import { qwenCodeDefaultModelId } from "./qwen-code.js"
+import { requestyDefaultModelId } from "./requesty.js"
+import { rooDefaultModelId } from "./roo.js"
+import { sambaNovaDefaultModelId } from "./sambanova.js"
+import { unboundDefaultModelId } from "./unbound.js"
+import { vertexDefaultModelId } from "./vertex.js"
+import { vscodeLlmDefaultModelId } from "./vscode-llm.js"
+import { xaiDefaultModelId } from "./xai.js"
+import { vercelAiGatewayDefaultModelId } from "./vercel-ai-gateway.js"
+import { internationalZAiDefaultModelId, mainlandZAiDefaultModelId } from "./zai.js"
+import { deepInfraDefaultModelId } from "./deepinfra.js"
+import { minimaxDefaultModelId } from "./minimax.js"
+
+// Import the ProviderName type from provider-settings to avoid duplication
+import type { ProviderName } from "../provider-settings.js"
+
+/**
+ * Get the default model ID for a given provider.
+ * This function returns only the provider's default model ID, without considering user configuration.
+ * Used as a fallback when provider models are still loading.
+ */
+export function getProviderDefaultModelId(
+	provider: ProviderName,
+	options: { isChina?: boolean } = { isChina: false },
+): string {
+	switch (provider) {
+		case "openrouter":
+			return openRouterDefaultModelId
+		case "requesty":
+			return requestyDefaultModelId
+		case "glama":
+			return glamaDefaultModelId
+		case "unbound":
+			return unboundDefaultModelId
+		case "litellm":
+			return litellmDefaultModelId
+		case "xai":
+			return xaiDefaultModelId
+		case "groq":
+			return groqDefaultModelId
+		case "huggingface":
+			return "meta-llama/Llama-3.3-70B-Instruct"
+		case "chutes":
+			return chutesDefaultModelId
+		case "bedrock":
+			return bedrockDefaultModelId
+		case "vertex":
+			return vertexDefaultModelId
+		case "gemini":
+			return geminiDefaultModelId
+		case "deepseek":
+			return deepSeekDefaultModelId
+		case "doubao":
+			return doubaoDefaultModelId
+		case "moonshot":
+			return moonshotDefaultModelId
+		case "minimax":
+			return minimaxDefaultModelId
+		case "zai":
+			return options?.isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId
+		case "openai-native":
+			return "gpt-4o" // Based on openai-native patterns
+		case "mistral":
+			return mistralDefaultModelId
+		case "openai":
+			return "" // OpenAI provider uses custom model configuration
+		case "ollama":
+			return "" // Ollama uses dynamic model selection
+		case "lmstudio":
+			return "" // LMStudio uses dynamic model selection
+		case "deepinfra":
+			return deepInfraDefaultModelId
+		case "vscode-lm":
+			return vscodeLlmDefaultModelId
+		case "claude-code":
+			return claudeCodeDefaultModelId
+		case "cerebras":
+			return cerebrasDefaultModelId
+		case "sambanova":
+			return sambaNovaDefaultModelId
+		case "fireworks":
+			return fireworksDefaultModelId
+		case "featherless":
+			return featherlessDefaultModelId
+		case "io-intelligence":
+			return ioIntelligenceDefaultModelId
+		case "roo":
+			return rooDefaultModelId
+		case "qwen-code":
+			return qwenCodeDefaultModelId
+		case "vercel-ai-gateway":
+			return vercelAiGatewayDefaultModelId
+		case "anthropic":
+		case "gemini-cli":
+		case "human-relay":
+		case "fake-ai":
+		default:
+			return anthropicDefaultModelId
+	}
+}

+ 46 - 10
webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts

@@ -93,7 +93,7 @@ describe("useSelectedModel", () => {
 			})
 		})
 
-		it("should use only specific provider info when base model info is missing", () => {
+		it("should fall back to default when configured model doesn't exist in available models", () => {
 			const specificProviderInfo: ModelInfo = {
 				maxTokens: 8192,
 				contextWindow: 16384,
@@ -106,7 +106,18 @@ describe("useSelectedModel", () => {
 
 			mockUseRouterModels.mockReturnValue({
 				data: {
-					openrouter: {},
+					openrouter: {
+						"anthropic/claude-sonnet-4.5": {
+							maxTokens: 8192,
+							contextWindow: 200_000,
+							supportsImages: true,
+							supportsPromptCache: true,
+							inputPrice: 3.0,
+							outputPrice: 15.0,
+							cacheWritesPrice: 3.75,
+							cacheReadsPrice: 0.3,
+						},
+					},
 					requesty: {},
 					glama: {},
 					unbound: {},
@@ -127,15 +138,29 @@ describe("useSelectedModel", () => {
 
 			const apiConfiguration: ProviderSettings = {
 				apiProvider: "openrouter",
-				openRouterModelId: "test-model",
+				openRouterModelId: "test-model", // This model doesn't exist in available models
 				openRouterSpecificProvider: "test-provider",
 			}
 
 			const wrapper = createWrapper()
 			const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper })
 
-			expect(result.current.id).toBe("test-model")
-			expect(result.current.info).toEqual(specificProviderInfo)
+			// Should fall back to provider default since "test-model" doesn't exist
+			expect(result.current.id).toBe("anthropic/claude-sonnet-4.5")
+			// Should still use specific provider info for the default model if specified
+			expect(result.current.info).toEqual({
+				...{
+					maxTokens: 8192,
+					contextWindow: 200_000,
+					supportsImages: true,
+					supportsPromptCache: true,
+					inputPrice: 3.0,
+					outputPrice: 15.0,
+					cacheWritesPrice: 3.75,
+					cacheReadsPrice: 0.3,
+				},
+				...specificProviderInfo,
+			})
 		})
 
 		it("should demonstrate the merging behavior validates the comment about missing fields", () => {
@@ -244,12 +269,12 @@ describe("useSelectedModel", () => {
 			expect(result.current.info).toEqual(baseModelInfo)
 		})
 
-		it("should fall back to default when both base and specific provider info are missing", () => {
+		it("should fall back to default when configured model and provider don't exist", () => {
 			mockUseRouterModels.mockReturnValue({
 				data: {
 					openrouter: {
-						"anthropic/claude-sonnet-4": {
-							// Default model
+						"anthropic/claude-sonnet-4.5": {
+							// Default model - using correct default model name
 							maxTokens: 8192,
 							contextWindow: 200_000,
 							supportsImages: true,
@@ -285,8 +310,19 @@ describe("useSelectedModel", () => {
 			const wrapper = createWrapper()
 			const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper })
 
-			expect(result.current.id).toBe("non-existent-model")
-			expect(result.current.info).toBeUndefined()
+			// Should fall back to provider default since "non-existent-model" doesn't exist
+			expect(result.current.id).toBe("anthropic/claude-sonnet-4.5")
+			// Should use base model info since provider doesn't exist
+			expect(result.current.info).toEqual({
+				maxTokens: 8192,
+				contextWindow: 200_000,
+				supportsImages: true,
+				supportsPromptCache: true,
+				inputPrice: 3.0,
+				outputPrice: 15.0,
+				cacheWritesPrice: 3.75,
+				cacheReadsPrice: 0.3,
+			})
 		})
 	})
 

+ 67 - 69
webview-ui/src/components/ui/hooks/useSelectedModel.ts

@@ -2,62 +2,33 @@ import {
 	type ProviderName,
 	type ProviderSettings,
 	type ModelInfo,
-	anthropicDefaultModelId,
 	anthropicModels,
-	bedrockDefaultModelId,
 	bedrockModels,
-	cerebrasDefaultModelId,
 	cerebrasModels,
-	deepSeekDefaultModelId,
 	deepSeekModels,
-	moonshotDefaultModelId,
 	moonshotModels,
-	minimaxDefaultModelId,
 	minimaxModels,
-	geminiDefaultModelId,
 	geminiModels,
-	mistralDefaultModelId,
 	mistralModels,
 	openAiModelInfoSaneDefaults,
-	openAiNativeDefaultModelId,
 	openAiNativeModels,
-	vertexDefaultModelId,
 	vertexModels,
-	xaiDefaultModelId,
 	xaiModels,
 	groqModels,
-	groqDefaultModelId,
-	chutesDefaultModelId,
 	vscodeLlmModels,
 	vscodeLlmDefaultModelId,
-	openRouterDefaultModelId,
-	requestyDefaultModelId,
-	glamaDefaultModelId,
-	unboundDefaultModelId,
-	litellmDefaultModelId,
-	claudeCodeDefaultModelId,
 	claudeCodeModels,
 	sambaNovaModels,
-	sambaNovaDefaultModelId,
 	doubaoModels,
-	doubaoDefaultModelId,
-	internationalZAiDefaultModelId,
-	mainlandZAiDefaultModelId,
 	internationalZAiModels,
 	mainlandZAiModels,
 	fireworksModels,
-	fireworksDefaultModelId,
 	featherlessModels,
-	featherlessDefaultModelId,
-	ioIntelligenceDefaultModelId,
 	ioIntelligenceModels,
-	rooDefaultModelId,
-	qwenCodeDefaultModelId,
 	qwenCodeModels,
-	vercelAiGatewayDefaultModelId,
 	BEDROCK_1M_CONTEXT_MODEL_IDS,
-	deepInfraDefaultModelId,
 	isDynamicProvider,
+	getProviderDefaultModelId,
 } from "@roo-code/types"
 
 import type { ModelRecord, RouterModels } from "@roo/api"
@@ -67,6 +38,18 @@ import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
 import { useLmStudioModels } from "./useLmStudioModels"
 import { useOllamaModels } from "./useOllamaModels"
 
+/**
+ * Helper to get a validated model ID for dynamic providers.
+ * Returns the configured model ID if it exists in the available models, otherwise returns the default.
+ */
+function getValidatedModelId(
+	configuredId: string | undefined,
+	availableModels: ModelRecord | undefined,
+	defaultModelId: string,
+): string {
+	return configuredId && availableModels?.[configuredId] ? configuredId : defaultModelId
+}
+
 export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
 	const provider = apiConfiguration?.apiProvider || "anthropic"
 	const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined
@@ -90,10 +73,17 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
 	const needLmStudio = typeof lmStudioModelId !== "undefined"
 	const needOllama = typeof ollamaModelId !== "undefined"
 
+	const hasValidRouterData = needRouterModels
+		? routerModels.data &&
+			routerModels.data[provider] !== undefined &&
+			typeof routerModels.data[provider] === "object" &&
+			!routerModels.isLoading
+		: true
+
 	const isReady =
 		(!needLmStudio || typeof lmStudioModels.data !== "undefined") &&
 		(!needOllama || typeof ollamaModels.data !== "undefined") &&
-		(!needRouterModels || typeof routerModels.data !== "undefined") &&
+		hasValidRouterData &&
 		(!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined")
 
 	const { id, info } =
@@ -106,7 +96,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
 					lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined,
 					ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined,
 				})
-			: { id: anthropicDefaultModelId, info: undefined }
+			: { id: getProviderDefaultModelId(provider), info: undefined }
 
 	return {
 		provider,
@@ -143,10 +133,11 @@ function getSelectedModel({
 	// the `undefined` case are used to show the invalid selection to prevent
 	// users from seeing the default model if their selection is invalid
 	// this gives a better UX than showing the default model
+	const defaultModelId = getProviderDefaultModelId(provider)
 	switch (provider) {
 		case "openrouter": {
-			const id = apiConfiguration.openRouterModelId ?? openRouterDefaultModelId
-			let info = routerModels.openrouter[id]
+			const id = getValidatedModelId(apiConfiguration.openRouterModelId, routerModels.openrouter, defaultModelId)
+			let info = routerModels.openrouter?.[id]
 			const specificProvider = apiConfiguration.openRouterSpecificProvider
 
 			if (specificProvider && openRouterModelProviders[specificProvider]) {
@@ -161,32 +152,32 @@ function getSelectedModel({
 			return { id, info }
 		}
 		case "requesty": {
-			const id = apiConfiguration.requestyModelId ?? requestyDefaultModelId
-			const info = routerModels.requesty[id]
+			const id = getValidatedModelId(apiConfiguration.requestyModelId, routerModels.requesty, defaultModelId)
+			const info = routerModels.requesty?.[id]
 			return { id, info }
 		}
 		case "glama": {
-			const id = apiConfiguration.glamaModelId ?? glamaDefaultModelId
-			const info = routerModels.glama[id]
+			const id = getValidatedModelId(apiConfiguration.glamaModelId, routerModels.glama, defaultModelId)
+			const info = routerModels.glama?.[id]
 			return { id, info }
 		}
 		case "unbound": {
-			const id = apiConfiguration.unboundModelId ?? unboundDefaultModelId
-			const info = routerModels.unbound[id]
+			const id = getValidatedModelId(apiConfiguration.unboundModelId, routerModels.unbound, defaultModelId)
+			const info = routerModels.unbound?.[id]
 			return { id, info }
 		}
 		case "litellm": {
-			const id = apiConfiguration.litellmModelId ?? litellmDefaultModelId
-			const info = routerModels.litellm[id]
+			const id = getValidatedModelId(apiConfiguration.litellmModelId, routerModels.litellm, defaultModelId)
+			const info = routerModels.litellm?.[id]
 			return { id, info }
 		}
 		case "xai": {
-			const id = apiConfiguration.apiModelId ?? xaiDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = xaiModels[id as keyof typeof xaiModels]
 			return info ? { id, info } : { id, info: undefined }
 		}
 		case "groq": {
-			const id = apiConfiguration.apiModelId ?? groqDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = groqModels[id as keyof typeof groqModels]
 			return { id, info }
 		}
@@ -201,12 +192,12 @@ function getSelectedModel({
 			return { id, info }
 		}
 		case "chutes": {
-			const id = apiConfiguration.apiModelId ?? chutesDefaultModelId
-			const info = routerModels.chutes[id]
+			const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.chutes, defaultModelId)
+			const info = routerModels.chutes?.[id]
 			return { id, info }
 		}
 		case "bedrock": {
-			const id = apiConfiguration.apiModelId ?? bedrockDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const baseInfo = bedrockModels[id as keyof typeof bedrockModels]
 
 			// Special case for custom ARN.
@@ -230,50 +221,50 @@ function getSelectedModel({
 			return { id, info: baseInfo }
 		}
 		case "vertex": {
-			const id = apiConfiguration.apiModelId ?? vertexDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = vertexModels[id as keyof typeof vertexModels]
 			return { id, info }
 		}
 		case "gemini": {
-			const id = apiConfiguration.apiModelId ?? geminiDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = geminiModels[id as keyof typeof geminiModels]
 			return { id, info }
 		}
 		case "deepseek": {
-			const id = apiConfiguration.apiModelId ?? deepSeekDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = deepSeekModels[id as keyof typeof deepSeekModels]
 			return { id, info }
 		}
 		case "doubao": {
-			const id = apiConfiguration.apiModelId ?? doubaoDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = doubaoModels[id as keyof typeof doubaoModels]
 			return { id, info }
 		}
 		case "moonshot": {
-			const id = apiConfiguration.apiModelId ?? moonshotDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = moonshotModels[id as keyof typeof moonshotModels]
 			return { id, info }
 		}
 		case "minimax": {
-			const id = apiConfiguration.apiModelId ?? minimaxDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = minimaxModels[id as keyof typeof minimaxModels]
 			return { id, info }
 		}
 		case "zai": {
 			const isChina = apiConfiguration.zaiApiLine === "china_coding"
 			const models = isChina ? mainlandZAiModels : internationalZAiModels
-			const defaultModelId = isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId
+			const defaultModelId = getProviderDefaultModelId(provider, { isChina })
 			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = models[id as keyof typeof models]
 			return { id, info }
 		}
 		case "openai-native": {
-			const id = apiConfiguration.apiModelId ?? openAiNativeDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = openAiNativeModels[id as keyof typeof openAiNativeModels]
 			return { id, info }
 		}
 		case "mistral": {
-			const id = apiConfiguration.apiModelId ?? mistralDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = mistralModels[id as keyof typeof mistralModels]
 			return { id, info }
 		}
@@ -307,7 +298,7 @@ function getSelectedModel({
 			}
 		}
 		case "deepinfra": {
-			const id = apiConfiguration.deepInfraModelId ?? deepInfraDefaultModelId
+			const id = getValidatedModelId(apiConfiguration.deepInfraModelId, routerModels.deepinfra, defaultModelId)
 			const info = routerModels.deepinfra?.[id]
 			return { id, info }
 		}
@@ -321,49 +312,56 @@ function getSelectedModel({
 		}
 		case "claude-code": {
 			// Claude Code models extend anthropic models but with images and prompt caching disabled
-			const id = apiConfiguration.apiModelId ?? claudeCodeDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = claudeCodeModels[id as keyof typeof claudeCodeModels]
 			return { id, info: { ...openAiModelInfoSaneDefaults, ...info } }
 		}
 		case "cerebras": {
-			const id = apiConfiguration.apiModelId ?? cerebrasDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = cerebrasModels[id as keyof typeof cerebrasModels]
 			return { id, info }
 		}
 		case "sambanova": {
-			const id = apiConfiguration.apiModelId ?? sambaNovaDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = sambaNovaModels[id as keyof typeof sambaNovaModels]
 			return { id, info }
 		}
 		case "fireworks": {
-			const id = apiConfiguration.apiModelId ?? fireworksDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = fireworksModels[id as keyof typeof fireworksModels]
 			return { id, info }
 		}
 		case "featherless": {
-			const id = apiConfiguration.apiModelId ?? featherlessDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = featherlessModels[id as keyof typeof featherlessModels]
 			return { id, info }
 		}
 		case "io-intelligence": {
-			const id = apiConfiguration.ioIntelligenceModelId ?? ioIntelligenceDefaultModelId
+			const id = getValidatedModelId(
+				apiConfiguration.ioIntelligenceModelId,
+				routerModels["io-intelligence"],
+				defaultModelId,
+			)
 			const info =
 				routerModels["io-intelligence"]?.[id] ?? ioIntelligenceModels[id as keyof typeof ioIntelligenceModels]
 			return { id, info }
 		}
 		case "roo": {
-			// Roo is a dynamic provider - models are loaded from API
-			const id = apiConfiguration.apiModelId ?? rooDefaultModelId
-			const info = routerModels.roo[id]
+			const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.roo, defaultModelId)
+			const info = routerModels.roo?.[id]
 			return { id, info }
 		}
 		case "qwen-code": {
-			const id = apiConfiguration.apiModelId ?? qwenCodeDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const info = qwenCodeModels[id as keyof typeof qwenCodeModels]
 			return { id, info }
 		}
 		case "vercel-ai-gateway": {
-			const id = apiConfiguration.vercelAiGatewayModelId ?? vercelAiGatewayDefaultModelId
+			const id = getValidatedModelId(
+				apiConfiguration.vercelAiGatewayModelId,
+				routerModels["vercel-ai-gateway"],
+				defaultModelId,
+			)
 			const info = routerModels["vercel-ai-gateway"]?.[id]
 			return { id, info }
 		}
@@ -372,7 +370,7 @@ function getSelectedModel({
 		// case "fake-ai":
 		default: {
 			provider satisfies "anthropic" | "gemini-cli" | "qwen-code" | "human-relay" | "fake-ai"
-			const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId
+			const id = apiConfiguration.apiModelId ?? defaultModelId
 			const baseInfo = anthropicModels[id as keyof typeof anthropicModels]
 
 			// Apply 1M context beta tier pricing for Claude Sonnet 4