Przeglądaj źródła

fix: ambiguous model id error (#4306)

Dicha Zelianivan Arkana 6 miesięcy temu
rodzic
commit
b8505fe175

+ 1 - 1
webview-ui/src/components/settings/ApiErrorMessage.tsx

@@ -6,7 +6,7 @@ interface ApiErrorMessageProps {
 }
 
 export const ApiErrorMessage = ({ errorMessage, children }: ApiErrorMessageProps) => (
-	<div className="flex flex-col gap-2 text-vscode-errorForeground text-sm">
+	<div className="flex flex-col gap-2 text-vscode-errorForeground text-sm" data-testid="api-error-message">
 		<div className="flex flex-row items-center gap-1">
 			<div className="codicon codicon-close" />
 			<div>{errorMessage}</div>

+ 84 - 38
webview-ui/src/components/settings/ApiOptions.tsx

@@ -11,10 +11,20 @@ import {
 	glamaDefaultModelId,
 	unboundDefaultModelId,
 	litellmDefaultModelId,
+	openAiNativeDefaultModelId,
+	anthropicDefaultModelId,
+	geminiDefaultModelId,
+	deepSeekDefaultModelId,
+	mistralDefaultModelId,
+	xaiDefaultModelId,
+	groqDefaultModelId,
+	chutesDefaultModelId,
+	bedrockDefaultModelId,
+	vertexDefaultModelId,
 } from "@roo-code/types"
 
 import { vscode } from "@src/utils/vscode"
-import { validateApiConfiguration } from "@src/utils/validate"
+import { validateApiConfigurationExcludingModelErrors, getModelValidationError } from "@src/utils/validate"
 import { useAppTranslation } from "@src/i18n/TranslationContext"
 import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
 import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel"
@@ -176,8 +186,11 @@ const ApiOptions = ({
 	)
 
 	useEffect(() => {
-		const apiValidationResult = validateApiConfiguration(apiConfiguration, routerModels, organizationAllowList)
-
+		const apiValidationResult = validateApiConfigurationExcludingModelErrors(
+			apiConfiguration,
+			routerModels,
+			organizationAllowList,
+		)
 		setErrorMessage(apiValidationResult)
 	}, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage])
 
@@ -187,16 +200,20 @@ const ApiOptions = ({
 
 		const filteredModels = filterModels(models, selectedProvider, organizationAllowList)
 
-		return filteredModels
+		const modelOptions = filteredModels
 			? Object.keys(filteredModels).map((modelId) => ({
 					value: modelId,
 					label: modelId,
 				}))
 			: []
+
+		return modelOptions
 	}, [selectedProvider, organizationAllowList])
 
 	const onProviderChange = useCallback(
 		(value: ProviderName) => {
+			setApiConfigurationField("apiProvider", value)
+
 			// It would be much easier to have a single attribute that stores
 			// the modelId, but we have a separate attribute for each of
 			// OpenRouter, Glama, Unbound, and Requesty.
@@ -204,46 +221,69 @@ const ApiOptions = ({
 			// modelId is not set then you immediately end up in an error state.
 			// To address that we set the modelId to the default value for th
 			// provider if it's not already set.
-			switch (value) {
-				case "openrouter":
-					if (!apiConfiguration.openRouterModelId) {
-						setApiConfigurationField("openRouterModelId", openRouterDefaultModelId)
-					}
-					break
-				case "glama":
-					if (!apiConfiguration.glamaModelId) {
-						setApiConfigurationField("glamaModelId", glamaDefaultModelId)
-					}
-					break
-				case "unbound":
-					if (!apiConfiguration.unboundModelId) {
-						setApiConfigurationField("unboundModelId", unboundDefaultModelId)
-					}
-					break
-				case "requesty":
-					if (!apiConfiguration.requestyModelId) {
-						setApiConfigurationField("requestyModelId", requestyDefaultModelId)
-					}
-					break
-				case "litellm":
-					if (!apiConfiguration.litellmModelId) {
-						setApiConfigurationField("litellmModelId", litellmDefaultModelId)
+			const validateAndResetModel = (
+				modelId: string | undefined,
+				field: keyof ProviderSettings,
+				defaultValue?: string,
+			) => {
+				// in case we haven't set a default value for a provider
+				if (!defaultValue) return
+
+				// only set default if no model is set, but don't reset invalid models
+				// let users see and decide what to do with invalid model selections
+				const shouldSetDefault = !modelId
+
+				if (shouldSetDefault) {
+					setApiConfigurationField(field, defaultValue)
+				}
+			}
+
+			// Define a mapping object that associates each provider with its model configuration
+			const PROVIDER_MODEL_CONFIG: Partial<
+				Record<
+					ProviderName,
+					{
+						field: keyof ProviderSettings
+						default?: string
 					}
-					break
+				>
+			> = {
+				openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId },
+				glama: { field: "glamaModelId", default: glamaDefaultModelId },
+				unbound: { field: "unboundModelId", default: unboundDefaultModelId },
+				requesty: { field: "requestyModelId", default: requestyDefaultModelId },
+				litellm: { field: "litellmModelId", default: litellmDefaultModelId },
+				anthropic: { field: "apiModelId", default: anthropicDefaultModelId },
+				"openai-native": { field: "apiModelId", default: openAiNativeDefaultModelId },
+				gemini: { field: "apiModelId", default: geminiDefaultModelId },
+				deepseek: { field: "apiModelId", default: deepSeekDefaultModelId },
+				mistral: { field: "apiModelId", default: mistralDefaultModelId },
+				xai: { field: "apiModelId", default: xaiDefaultModelId },
+				groq: { field: "apiModelId", default: groqDefaultModelId },
+				chutes: { field: "apiModelId", default: chutesDefaultModelId },
+				bedrock: { field: "apiModelId", default: bedrockDefaultModelId },
+				vertex: { field: "apiModelId", default: vertexDefaultModelId },
+				openai: { field: "openAiModelId" },
+				ollama: { field: "ollamaModelId" },
+				lmstudio: { field: "lmStudioModelId" },
 			}
 
-			setApiConfigurationField("apiProvider", value)
+			const config = PROVIDER_MODEL_CONFIG[value]
+			if (config) {
+				validateAndResetModel(
+					apiConfiguration[config.field] as string | undefined,
+					config.field,
+					config.default,
+				)
+			}
 		},
-		[
-			setApiConfigurationField,
-			apiConfiguration.openRouterModelId,
-			apiConfiguration.glamaModelId,
-			apiConfiguration.unboundModelId,
-			apiConfiguration.requestyModelId,
-			apiConfiguration.litellmModelId,
-		],
+		[setApiConfigurationField, apiConfiguration],
 	)
 
+	const modelValidationError = useMemo(() => {
+		return getModelValidationError(apiConfiguration, routerModels, organizationAllowList)
+	}, [apiConfiguration, routerModels, organizationAllowList])
+
 	const docs = useMemo(() => {
 		const provider = PROVIDERS.find(({ value }) => value === selectedProvider)
 		const name = provider?.label
@@ -303,6 +343,7 @@ const ApiOptions = ({
 					uriScheme={uriScheme}
 					fromWelcomeView={fromWelcomeView}
 					organizationAllowList={organizationAllowList}
+					modelValidationError={modelValidationError}
 				/>
 			)}
 
@@ -313,6 +354,7 @@ const ApiOptions = ({
 					routerModels={routerModels}
 					refetchRouterModels={refetchRouterModels}
 					organizationAllowList={organizationAllowList}
+					modelValidationError={modelValidationError}
 				/>
 			)}
 
@@ -323,6 +365,7 @@ const ApiOptions = ({
 					routerModels={routerModels}
 					uriScheme={uriScheme}
 					organizationAllowList={organizationAllowList}
+					modelValidationError={modelValidationError}
 				/>
 			)}
 
@@ -332,6 +375,7 @@ const ApiOptions = ({
 					setApiConfigurationField={setApiConfigurationField}
 					routerModels={routerModels}
 					organizationAllowList={organizationAllowList}
+					modelValidationError={modelValidationError}
 				/>
 			)}
 
@@ -368,6 +412,7 @@ const ApiOptions = ({
 					apiConfiguration={apiConfiguration}
 					setApiConfigurationField={setApiConfigurationField}
 					organizationAllowList={organizationAllowList}
+					modelValidationError={modelValidationError}
 				/>
 			)}
 
@@ -404,6 +449,7 @@ const ApiOptions = ({
 					apiConfiguration={apiConfiguration}
 					setApiConfigurationField={setApiConfigurationField}
 					organizationAllowList={organizationAllowList}
+					modelValidationError={modelValidationError}
 				/>
 			)}
 

+ 11 - 2
webview-ui/src/components/settings/ModelPicker.tsx

@@ -23,6 +23,7 @@ import {
 } from "@src/components/ui"
 
 import { ModelInfoView } from "./ModelInfoView"
+import { ApiErrorMessage } from "./ApiErrorMessage"
 
 type ModelIdKey = keyof Pick<
 	ProviderSettings,
@@ -38,6 +39,7 @@ interface ModelPickerProps {
 	apiConfiguration: ProviderSettings
 	setApiConfigurationField: <K extends keyof ProviderSettings>(field: K, value: ProviderSettings[K]) => void
 	organizationAllowList: OrganizationAllowList
+	errorMessage?: string
 }
 
 export const ModelPicker = ({
@@ -49,6 +51,7 @@ export const ModelPicker = ({
 	apiConfiguration,
 	setApiConfigurationField,
 	organizationAllowList,
+	errorMessage,
 }: ModelPickerProps) => {
 	const { t } = useAppTranslation()
 
@@ -119,7 +122,8 @@ export const ModelPicker = ({
 							variant="combobox"
 							role="combobox"
 							aria-expanded={open}
-							className="w-full justify-between">
+							className="w-full justify-between"
+							data-testid="model-picker-button">
 							<div>{selectedModelId ?? t("settings:common.select")}</div>
 							<ChevronsUpDown className="opacity-50" />
 						</Button>
@@ -154,7 +158,11 @@ export const ModelPicker = ({
 								</CommandEmpty>
 								<CommandGroup>
 									{modelIds.map((model) => (
-										<CommandItem key={model} value={model} onSelect={onSelect}>
+										<CommandItem
+											key={model}
+											value={model}
+											onSelect={onSelect}
+											data-testid={`model-option-${model}`}>
 											{model}
 											<Check
 												className={cn(
@@ -177,6 +185,7 @@ export const ModelPicker = ({
 					</PopoverContent>
 				</Popover>
 			</div>
+			{errorMessage && <ApiErrorMessage errorMessage={errorMessage} />}
 			{selectedModelId && selectedModelInfo && (
 				<ModelInfoView
 					apiProvider={apiConfiguration.apiProvider}

+ 110 - 3
webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx

@@ -73,7 +73,7 @@ describe("ModelPicker", () => {
 
 		await act(async () => {
 			// Open the popover by clicking the button.
-			const button = screen.getByRole("combobox")
+			const button = screen.getByTestId("model-picker-button")
 			fireEvent.click(button)
 		})
 
@@ -91,7 +91,7 @@ describe("ModelPicker", () => {
 		// Need to find and click the CommandItem to trigger onSelect
 		await act(async () => {
 			// Find the CommandItem for model2 and click it
-			const modelItem = screen.getByText("model2")
+			const modelItem = screen.getByTestId("model-option-model2")
 			fireEvent.click(modelItem)
 		})
 
@@ -104,7 +104,7 @@ describe("ModelPicker", () => {
 
 		await act(async () => {
 			// Open the popover by clicking the button.
-			const button = screen.getByRole("combobox")
+			const button = screen.getByTestId("model-picker-button")
 			fireEvent.click(button)
 		})
 
@@ -136,4 +136,111 @@ describe("ModelPicker", () => {
 		// Verify the API config was updated with the custom model ID
 		expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, customModelId)
 	})
+
+	describe("Error Message Display", () => {
+		it("displays error message when errorMessage prop is provided", async () => {
+			const errorMessage = "Model not available for your organization"
+			const propsWithError = {
+				...defaultProps,
+				errorMessage,
+			}
+
+			await act(async () => {
+				render(
+					<QueryClientProvider client={queryClient}>
+						<ModelPicker {...propsWithError} />
+					</QueryClientProvider>,
+				)
+			})
+
+			// Check that the error message is displayed
+			expect(screen.getByTestId("api-error-message")).toBeInTheDocument()
+			expect(screen.getByText(errorMessage)).toBeInTheDocument()
+		})
+
+		it("does not display error message when errorMessage prop is undefined", async () => {
+			await act(async () => renderModelPicker())
+
+			// Check that no error message is displayed
+			expect(screen.queryByTestId("api-error-message")).not.toBeInTheDocument()
+		})
+
+		it("displays error message below the model selector", async () => {
+			const errorMessage = "Invalid model selected"
+			const propsWithError = {
+				...defaultProps,
+				errorMessage,
+			}
+
+			await act(async () => {
+				render(
+					<QueryClientProvider client={queryClient}>
+						<ModelPicker {...propsWithError} />
+					</QueryClientProvider>,
+				)
+			})
+
+			// Check that both the model selector and error message are present
+			const modelSelector = screen.getByTestId("model-picker-button")
+			const errorContainer = screen.getByTestId("api-error-message")
+			const errorElement = screen.getByText(errorMessage)
+
+			expect(modelSelector).toBeInTheDocument()
+			expect(errorContainer).toBeInTheDocument()
+			expect(errorElement).toBeInTheDocument()
+			expect(errorElement).toBeVisible()
+		})
+
+		it("updates error message when errorMessage prop changes", async () => {
+			const initialError = "Initial error"
+			const updatedError = "Updated error"
+
+			const { rerender } = render(
+				<QueryClientProvider client={queryClient}>
+					<ModelPicker {...defaultProps} errorMessage={initialError} />
+				</QueryClientProvider>,
+			)
+
+			// Check initial error is displayed
+			expect(screen.getByTestId("api-error-message")).toBeInTheDocument()
+			expect(screen.getByText(initialError)).toBeInTheDocument()
+
+			// Update the error message
+			rerender(
+				<QueryClientProvider client={queryClient}>
+					<ModelPicker {...defaultProps} errorMessage={updatedError} />
+				</QueryClientProvider>,
+			)
+
+			// Check that the error message has been updated
+			expect(screen.getByTestId("api-error-message")).toBeInTheDocument()
+			expect(screen.queryByText(initialError)).not.toBeInTheDocument()
+			expect(screen.getByText(updatedError)).toBeInTheDocument()
+		})
+
+		it("removes error message when errorMessage prop becomes undefined", async () => {
+			const errorMessage = "Temporary error"
+
+			const { rerender } = render(
+				<QueryClientProvider client={queryClient}>
+					<ModelPicker {...defaultProps} errorMessage={errorMessage} />
+				</QueryClientProvider>,
+			)
+
+			// Check error is initially displayed
+			expect(screen.getByTestId("api-error-message")).toBeInTheDocument()
+			expect(screen.getByText(errorMessage)).toBeInTheDocument()
+
+			// Remove the error message
+			rerender(
+				<QueryClientProvider client={queryClient}>
+					<ModelPicker {...defaultProps} errorMessage={undefined} />
+				</QueryClientProvider>,
+			)
+
+			// Check that the error message has been removed
+			expect(screen.queryByTestId("api-error-message")).not.toBeInTheDocument()
+			expect(screen.queryByText(errorMessage)).not.toBeInTheDocument()
+		})
+	})
 })

+ 3 - 0
webview-ui/src/components/settings/providers/Glama.tsx

@@ -18,6 +18,7 @@ type GlamaProps = {
 	routerModels?: RouterModels
 	uriScheme?: string
 	organizationAllowList: OrganizationAllowList
+	modelValidationError?: string
 }
 
 export const Glama = ({
@@ -26,6 +27,7 @@ export const Glama = ({
 	routerModels,
 	uriScheme,
 	organizationAllowList,
+	modelValidationError,
 }: GlamaProps) => {
 	const { t } = useAppTranslation()
 
@@ -67,6 +69,7 @@ export const Glama = ({
 				serviceName="Glama"
 				serviceUrl="https://glama.ai/models"
 				organizationAllowList={organizationAllowList}
+				errorMessage={modelValidationError}
 			/>
 		</>
 	)

+ 8 - 1
webview-ui/src/components/settings/providers/LiteLLM.tsx

@@ -18,9 +18,15 @@ type LiteLLMProps = {
 	apiConfiguration: ProviderSettings
 	setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void
 	organizationAllowList: OrganizationAllowList
+	modelValidationError?: string
 }
 
-export const LiteLLM = ({ apiConfiguration, setApiConfigurationField, organizationAllowList }: LiteLLMProps) => {
+export const LiteLLM = ({
+	apiConfiguration,
+	setApiConfigurationField,
+	organizationAllowList,
+	modelValidationError,
+}: LiteLLMProps) => {
 	const { t } = useAppTranslation()
 	const { routerModels } = useExtensionState()
 	const [refreshStatus, setRefreshStatus] = useState<"idle" | "loading" | "success" | "error">("idle")
@@ -143,6 +149,7 @@ export const LiteLLM = ({ apiConfiguration, setApiConfigurationField, organizati
 				serviceUrl="https://docs.litellm.ai/"
 				setApiConfigurationField={setApiConfigurationField}
 				organizationAllowList={organizationAllowList}
+				errorMessage={modelValidationError}
 			/>
 		</>
 	)

+ 3 - 0
webview-ui/src/components/settings/providers/OpenAICompatible.tsx

@@ -27,12 +27,14 @@ type OpenAICompatibleProps = {
 	apiConfiguration: ProviderSettings
 	setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void
 	organizationAllowList: OrganizationAllowList
+	modelValidationError?: string
 }
 
 export const OpenAICompatible = ({
 	apiConfiguration,
 	setApiConfigurationField,
 	organizationAllowList,
+	modelValidationError,
 }: OpenAICompatibleProps) => {
 	const { t } = useAppTranslation()
 
@@ -144,6 +146,7 @@ export const OpenAICompatible = ({
 				serviceName="OpenAI"
 				serviceUrl="https://platform.openai.com"
 				organizationAllowList={organizationAllowList}
+				errorMessage={modelValidationError}
 			/>
 			<R1FormatSetting
 				onChange={handleInputChange("openAiR1FormatEnabled", noTransform)}

+ 3 - 0
webview-ui/src/components/settings/providers/OpenRouter.tsx

@@ -30,6 +30,7 @@ type OpenRouterProps = {
 	uriScheme: string | undefined
 	fromWelcomeView?: boolean
 	organizationAllowList: OrganizationAllowList
+	modelValidationError?: string
 }
 
 export const OpenRouter = ({
@@ -40,6 +41,7 @@ export const OpenRouter = ({
 	uriScheme,
 	fromWelcomeView,
 	organizationAllowList,
+	modelValidationError,
 }: OpenRouterProps) => {
 	const { t } = useAppTranslation()
 
@@ -135,6 +137,7 @@ export const OpenRouter = ({
 				serviceName="OpenRouter"
 				serviceUrl="https://openrouter.ai/models"
 				organizationAllowList={organizationAllowList}
+				errorMessage={modelValidationError}
 			/>
 			{openRouterModelProviders && Object.keys(openRouterModelProviders).length > 0 && (
 				<div>

+ 3 - 0
webview-ui/src/components/settings/providers/Requesty.tsx

@@ -20,6 +20,7 @@ type RequestyProps = {
 	routerModels?: RouterModels
 	refetchRouterModels: () => void
 	organizationAllowList: OrganizationAllowList
+	modelValidationError?: string
 }
 
 export const Requesty = ({
@@ -28,6 +29,7 @@ export const Requesty = ({
 	routerModels,
 	refetchRouterModels,
 	organizationAllowList,
+	modelValidationError,
 }: RequestyProps) => {
 	const { t } = useAppTranslation()
 
@@ -96,6 +98,7 @@ export const Requesty = ({
 				serviceName="Requesty"
 				serviceUrl="https://requesty.ai"
 				organizationAllowList={organizationAllowList}
+				errorMessage={modelValidationError}
 			/>
 		</>
 	)

+ 3 - 0
webview-ui/src/components/settings/providers/Unbound.tsx

@@ -19,6 +19,7 @@ type UnboundProps = {
 	setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void
 	routerModels?: RouterModels
 	organizationAllowList: OrganizationAllowList
+	modelValidationError?: string
 }
 
 export const Unbound = ({
@@ -26,6 +27,7 @@ export const Unbound = ({
 	setApiConfigurationField,
 	routerModels,
 	organizationAllowList,
+	modelValidationError,
 }: UnboundProps) => {
 	const { t } = useAppTranslation()
 	const [didRefetch, setDidRefetch] = useState<boolean>()
@@ -176,6 +178,7 @@ export const Unbound = ({
 				serviceUrl="https://api.getunbound.ai/models"
 				setApiConfigurationField={setApiConfigurationField}
 				organizationAllowList={organizationAllowList}
+				errorMessage={modelValidationError}
 			/>
 		</>
 	)

+ 2 - 12
webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.test.ts

@@ -284,18 +284,8 @@ describe("useSelectedModel", () => {
 			const wrapper = createWrapper()
 			const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper })
 
-			expect(result.current.id).toBe("anthropic/claude-sonnet-4")
-			expect(result.current.info).toEqual({
-				maxTokens: 8192,
-				contextWindow: 200_000,
-				supportsImages: true,
-				supportsComputerUse: true,
-				supportsPromptCache: true,
-				inputPrice: 3.0,
-				outputPrice: 15.0,
-				cacheWritesPrice: 3.75,
-				cacheReadsPrice: 0.3,
-			})
+			expect(result.current.id).toBe("non-existent-model")
+			expect(result.current.info).toBeUndefined()
 		})
 	})
 

+ 19 - 26
webview-ui/src/components/ui/hooks/useSelectedModel.ts

@@ -75,7 +75,10 @@ function getSelectedModel({
 	apiConfiguration: ProviderSettings
 	routerModels: RouterModels
 	openRouterModelProviders: Record<string, ModelInfo>
-}): { id: string; info: ModelInfo } {
+}): { id: string; info: ModelInfo | undefined } {
+	// the `undefined` case are used to show the invalid selection to prevent
+	// users from seeing the default model if their selection is invalid
+	// this gives a better UX than showing the default model
 	switch (provider) {
 		case "openrouter": {
 			const id = apiConfiguration.openRouterModelId ?? openRouterDefaultModelId
@@ -91,50 +94,42 @@ function getSelectedModel({
 					: openRouterModelProviders[specificProvider]
 			}
 
-			return info
-				? { id, info }
-				: { id: openRouterDefaultModelId, info: routerModels.openrouter[openRouterDefaultModelId] }
+			return { id, info }
 		}
 		case "requesty": {
 			const id = apiConfiguration.requestyModelId ?? requestyDefaultModelId
 			const info = routerModels.requesty[id]
-			return info
-				? { id, info }
-				: { id: requestyDefaultModelId, info: routerModels.requesty[requestyDefaultModelId] }
+			return { id, info }
 		}
 		case "glama": {
 			const id = apiConfiguration.glamaModelId ?? glamaDefaultModelId
 			const info = routerModels.glama[id]
-			return info ? { id, info } : { id: glamaDefaultModelId, info: routerModels.glama[glamaDefaultModelId] }
+			return { id, info }
 		}
 		case "unbound": {
 			const id = apiConfiguration.unboundModelId ?? unboundDefaultModelId
 			const info = routerModels.unbound[id]
-			return info
-				? { id, info }
-				: { id: unboundDefaultModelId, info: routerModels.unbound[unboundDefaultModelId] }
+			return { id, info }
 		}
 		case "litellm": {
 			const id = apiConfiguration.litellmModelId ?? litellmDefaultModelId
 			const info = routerModels.litellm[id]
-			return info
-				? { id, info }
-				: { id: litellmDefaultModelId, info: routerModels.litellm[litellmDefaultModelId] }
+			return { id, info }
 		}
 		case "xai": {
 			const id = apiConfiguration.apiModelId ?? xaiDefaultModelId
 			const info = xaiModels[id as keyof typeof xaiModels]
-			return info ? { id, info } : { id: xaiDefaultModelId, info: xaiModels[xaiDefaultModelId] }
+			return info ? { id, info } : { id, info: undefined }
 		}
 		case "groq": {
 			const id = apiConfiguration.apiModelId ?? groqDefaultModelId
 			const info = groqModels[id as keyof typeof groqModels]
-			return info ? { id, info } : { id: groqDefaultModelId, info: groqModels[groqDefaultModelId] }
+			return { id, info }
 		}
 		case "chutes": {
 			const id = apiConfiguration.apiModelId ?? chutesDefaultModelId
 			const info = chutesModels[id as keyof typeof chutesModels]
-			return info ? { id, info } : { id: chutesDefaultModelId, info: chutesModels[chutesDefaultModelId] }
+			return { id, info }
 		}
 		case "bedrock": {
 			const id = apiConfiguration.apiModelId ?? bedrockDefaultModelId
@@ -148,34 +143,32 @@ function getSelectedModel({
 				}
 			}
 
-			return info ? { id, info } : { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
+			return { id, info }
 		}
 		case "vertex": {
 			const id = apiConfiguration.apiModelId ?? vertexDefaultModelId
 			const info = vertexModels[id as keyof typeof vertexModels]
-			return info ? { id, info } : { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
+			return { id, info }
 		}
 		case "gemini": {
 			const id = apiConfiguration.apiModelId ?? geminiDefaultModelId
 			const info = geminiModels[id as keyof typeof geminiModels]
-			return info ? { id, info } : { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
+			return { id, info }
 		}
 		case "deepseek": {
 			const id = apiConfiguration.apiModelId ?? deepSeekDefaultModelId
 			const info = deepSeekModels[id as keyof typeof deepSeekModels]
-			return info ? { id, info } : { id: deepSeekDefaultModelId, info: deepSeekModels[deepSeekDefaultModelId] }
+			return { id, info }
 		}
 		case "openai-native": {
 			const id = apiConfiguration.apiModelId ?? openAiNativeDefaultModelId
 			const info = openAiNativeModels[id as keyof typeof openAiNativeModels]
-			return info
-				? { id, info }
-				: { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
+			return { id, info }
 		}
 		case "mistral": {
 			const id = apiConfiguration.apiModelId ?? mistralDefaultModelId
 			const info = mistralModels[id as keyof typeof mistralModels]
-			return info ? { id, info } : { id: mistralDefaultModelId, info: mistralModels[mistralDefaultModelId] }
+			return { id, info }
 		}
 		case "openai": {
 			const id = apiConfiguration.openAiModelId ?? ""
@@ -206,7 +199,7 @@ function getSelectedModel({
 		default: {
 			const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId
 			const info = anthropicModels[id as keyof typeof anthropicModels]
-			return info ? { id, info } : { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
+			return { id, info }
 		}
 	}
 }

+ 187 - 0
webview-ui/src/utils/__tests__/validate.test.ts

@@ -0,0 +1,187 @@
+import { ProviderSettings, OrganizationAllowList } from "@roo-code/types"
+import { RouterModels } from "@roo/api"
+
+import { getModelValidationError, validateApiConfigurationExcludingModelErrors } from "../validate"
+
+describe("Model Validation Functions", () => {
+	const mockRouterModels: RouterModels = {
+		openrouter: {
+			"valid-model": {
+				maxTokens: 8192,
+				contextWindow: 200000,
+				supportsImages: true,
+				supportsPromptCache: false,
+				inputPrice: 3.0,
+				outputPrice: 15.0,
+			},
+			"another-valid-model": {
+				maxTokens: 4096,
+				contextWindow: 100000,
+				supportsImages: false,
+				supportsPromptCache: false,
+				inputPrice: 1.0,
+				outputPrice: 5.0,
+			},
+		},
+		glama: {
+			"valid-model": {
+				maxTokens: 8192,
+				contextWindow: 200000,
+				supportsImages: true,
+				supportsPromptCache: false,
+				inputPrice: 3.0,
+				outputPrice: 15.0,
+			},
+		},
+		requesty: {},
+		unbound: {},
+		litellm: {},
+	}
+
+	const allowAllOrganization: OrganizationAllowList = {
+		allowAll: true,
+		providers: {},
+	}
+
+	const restrictiveOrganization: OrganizationAllowList = {
+		allowAll: false,
+		providers: {
+			openrouter: {
+				allowAll: false,
+				models: ["valid-model"],
+			},
+		},
+	}
+
+	describe("getModelValidationError", () => {
+		it("returns undefined for valid OpenRouter model", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterModelId: "valid-model",
+			}
+
+			const result = getModelValidationError(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBeUndefined()
+		})
+
+		it("returns error for invalid OpenRouter model", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterModelId: "invalid-model",
+			}
+
+			const result = getModelValidationError(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBe("validation.modelAvailability")
+		})
+
+		it("returns error for model not allowed by organization", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterModelId: "another-valid-model",
+			}
+
+			const result = getModelValidationError(config, mockRouterModels, restrictiveOrganization)
+			expect(result).toContain("model")
+		})
+
+		it("returns undefined for valid Glama model", () => {
+			const config: ProviderSettings = {
+				apiProvider: "glama",
+				glamaModelId: "valid-model",
+			}
+
+			const result = getModelValidationError(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBeUndefined()
+		})
+
+		it("returns error for invalid Glama model", () => {
+			const config: ProviderSettings = {
+				apiProvider: "glama",
+				glamaModelId: "invalid-model",
+			}
+
+			const result = getModelValidationError(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBeUndefined()
+		})
+
+		it("returns undefined for OpenAI models when no router models provided", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openai",
+				openAiModelId: "gpt-4",
+			}
+
+			const result = getModelValidationError(config, undefined, allowAllOrganization)
+			expect(result).toBeUndefined()
+		})
+
+		it("handles empty model IDs gracefully", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterModelId: "",
+			}
+
+			const result = getModelValidationError(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBe("validation.modelId")
+		})
+
+		it("handles undefined model IDs gracefully", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				// openRouterModelId is undefined
+			}
+
+			const result = getModelValidationError(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBe("validation.modelId")
+		})
+	})
+
+	describe("validateApiConfigurationExcludingModelErrors", () => {
+		it("returns undefined when configuration is valid", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterApiKey: "valid-key",
+				openRouterModelId: "valid-model",
+			}
+
+			const result = validateApiConfigurationExcludingModelErrors(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBeUndefined()
+		})
+
+		it("returns error for missing API key", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterModelId: "valid-model",
+				// Missing openRouterApiKey
+			}
+
+			const result = validateApiConfigurationExcludingModelErrors(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBe("validation.apiKey")
+		})
+
+		it("excludes model-specific errors", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterApiKey: "valid-key",
+				openRouterModelId: "invalid-model", // This should be ignored
+			}
+
+			const result = validateApiConfigurationExcludingModelErrors(config, mockRouterModels, allowAllOrganization)
+			expect(result).toBeUndefined() // Should not return model validation error
+		})
+
+		it("excludes model-specific organization errors", () => {
+			const config: ProviderSettings = {
+				apiProvider: "openrouter",
+				openRouterApiKey: "valid-key",
+				openRouterModelId: "another-valid-model", // Not allowed by restrictive org
+			}
+
+			const result = validateApiConfigurationExcludingModelErrors(
+				config,
+				mockRouterModels,
+				restrictiveOrganization,
+			)
+			expect(result).toBeUndefined() // Should exclude model-specific org errors
+		})
+	})
+})

+ 72 - 9
webview-ui/src/utils/validate.ts

@@ -14,12 +14,12 @@ export function validateApiConfiguration(
 		return keysAndIdsPresentErrorMessage
 	}
 
-	const organizationAllowListErrorMessage = validateProviderAgainstOrganizationSettings(
+	const organizationAllowListError = validateProviderAgainstOrganizationSettings(
 		apiConfiguration,
 		organizationAllowList,
 	)
-	if (organizationAllowListErrorMessage) {
-		return organizationAllowListErrorMessage
+	if (organizationAllowListError) {
+		return organizationAllowListError.message
 	}
 
 	return validateModelId(apiConfiguration, routerModels)
@@ -107,17 +107,25 @@ function validateModelsAndKeysProvided(apiConfiguration: ProviderSettings): stri
 	return undefined
 }
 
+type ValidationError = {
+	message: string
+	code: 'PROVIDER_NOT_ALLOWED' | 'MODEL_NOT_ALLOWED'
+}
+
 function validateProviderAgainstOrganizationSettings(
 	apiConfiguration: ProviderSettings,
 	organizationAllowList?: OrganizationAllowList,
-): string | undefined {
+): ValidationError | undefined {
 	if (organizationAllowList && !organizationAllowList.allowAll) {
 		const provider = apiConfiguration.apiProvider
 		if (!provider) return undefined
 
 		const providerConfig = organizationAllowList.providers[provider]
 		if (!providerConfig) {
-			return i18next.t("settings:validation.providerNotAllowed", { provider })
+			return {
+				message: i18next.t("settings:validation.providerNotAllowed", { provider }),
+				code: 'PROVIDER_NOT_ALLOWED'
+			}
 		}
 
 		if (!providerConfig.allowAll) {
@@ -125,10 +133,13 @@ function validateProviderAgainstOrganizationSettings(
 			const allowedModels = providerConfig.models || []
 
 			if (modelId && !allowedModels.includes(modelId)) {
-				return i18next.t("settings:validation.modelNotAllowed", {
-					model: modelId,
-					provider,
-				})
+				return {
+					message: i18next.t("settings:validation.modelNotAllowed", {
+						model: modelId,
+						provider,
+					}),
+					code: 'MODEL_NOT_ALLOWED'
+				}
 			}
 		}
 	}
@@ -233,3 +244,55 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels
 
 	return undefined
 }
+
+/**
+ * Extracts model-specific validation errors from the API configuration
+ * This is used to show model errors specifically in the model selector components
+ */
+export function getModelValidationError(
+	apiConfiguration: ProviderSettings,
+	routerModels?: RouterModels,
+	organizationAllowList?: OrganizationAllowList,
+): string | undefined {
+	const modelId = getModelIdForProvider(apiConfiguration, apiConfiguration.apiProvider || "")
+	const configWithModelId = {
+		...apiConfiguration,
+		apiModelId: modelId || "",
+	}
+
+	const orgError = validateProviderAgainstOrganizationSettings(configWithModelId, organizationAllowList)
+	if (orgError && orgError.code === 'MODEL_NOT_ALLOWED') {
+		return orgError.message
+	}
+
+	return validateModelId(configWithModelId, routerModels)
+}
+
+/**
+ * Validates API configuration but excludes model-specific errors
+ * This is used for the general API error display to prevent duplication
+ * when model errors are shown in the model selector
+ */
+export function validateApiConfigurationExcludingModelErrors(
+	apiConfiguration: ProviderSettings,
+	_routerModels?: RouterModels, // keeping this for compatibility with the old function
+	organizationAllowList?: OrganizationAllowList,
+): string | undefined {
+	const keysAndIdsPresentErrorMessage = validateModelsAndKeysProvided(apiConfiguration)
+	if (keysAndIdsPresentErrorMessage) {
+		return keysAndIdsPresentErrorMessage
+	}
+
+	const organizationAllowListError = validateProviderAgainstOrganizationSettings(
+		apiConfiguration,
+		organizationAllowList,
+	)
+
+	// only return organization errors if they're not model-specific
+	if (organizationAllowListError && organizationAllowListError.code === 'PROVIDER_NOT_ALLOWED') {
+		return organizationAllowListError.message
+	}
+
+	// skip model validation errors as they'll be shown in the model selector
+	return undefined
+}