فهرست منبع

Don't immediately show an model ID error when changing API providers (#2888)

Chris Estreich 8 ماه پیش
والد
کامیت
1543713c32

+ 13 - 8
webview-ui/src/components/chat/ChatView.tsx

@@ -4,6 +4,9 @@ import { forwardRef, useCallback, useEffect, useImperativeHandle, useMemo, useRe
 import { useDeepCompareEffect, useEvent, useMount } from "react-use"
 import { Virtuoso, type VirtuosoHandle } from "react-virtuoso"
 import styled from "styled-components"
+import removeMd from "remove-markdown"
+import { Trans } from "react-i18next"
+
 import {
 	ClineAsk,
 	ClineMessage,
@@ -16,11 +19,19 @@ import { findLast } from "@roo/shared/array"
 import { combineApiRequests } from "@roo/shared/combineApiRequests"
 import { combineCommandSequences } from "@roo/shared/combineCommandSequences"
 import { getApiMetrics } from "@roo/shared/getApiMetrics"
+import { AudioType } from "@roo/shared/WebviewMessage"
+import { getAllModes } from "@roo/shared/modes"
+
 import { useExtensionState } from "@src/context/ExtensionStateContext"
 import { vscode } from "@src/utils/vscode"
+import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
+import { validateCommand } from "@src/utils/command-validation"
+import { useAppTranslation } from "@src/i18n/TranslationContext"
+
+import TelemetryBanner from "../common/TelemetryBanner"
 import HistoryPreview from "../history/HistoryPreview"
 import RooHero from "../welcome/RooHero"
-import { normalizeApiConfiguration } from "../settings/ApiOptions"
+
 import Announcement from "./Announcement"
 import BrowserSessionRow from "./BrowserSessionRow"
 import ChatRow from "./ChatRow"
@@ -28,13 +39,7 @@ import ChatTextArea from "./ChatTextArea"
 import TaskHeader from "./TaskHeader"
 import AutoApproveMenu from "./AutoApproveMenu"
 import SystemPromptWarning from "./SystemPromptWarning"
-import { AudioType } from "@roo/shared/WebviewMessage"
-import { validateCommand } from "@src/utils/command-validation"
-import { getAllModes } from "@roo/shared/modes"
-import TelemetryBanner from "../common/TelemetryBanner"
-import { useAppTranslation } from "@/i18n/TranslationContext"
-import removeMd from "remove-markdown"
-import { Trans } from "react-i18next"
+
 interface ChatViewProps {
 	isHidden: boolean
 	showAnnouncement: boolean

+ 5 - 5
webview-ui/src/components/chat/TaskHeader.tsx

@@ -6,14 +6,14 @@ import { CloudUpload, CloudDownload } from "lucide-react"
 
 import { ClineMessage } from "@roo/shared/ExtensionMessage"
 
-import { getMaxTokensForModel } from "@/utils/model-utils"
-import { formatLargeNumber } from "@/utils/format"
-import { cn } from "@/lib/utils"
-import { Button } from "@/components/ui"
+import { getMaxTokensForModel } from "@src/utils/model-utils"
+import { formatLargeNumber } from "@src/utils/format"
+import { cn } from "@src/lib/utils"
+import { Button } from "@src/components/ui"
 import { useExtensionState } from "@src/context/ExtensionStateContext"
+import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
 
 import Thumbnails from "../common/Thumbnails"
-import { normalizeApiConfiguration } from "../settings/ApiOptions"
 
 import { TaskActions } from "./TaskActions"
 import { ContextWindowProgress } from "./ContextWindowProgress"

+ 55 - 138
webview-ui/src/components/settings/ApiOptions.tsx

@@ -11,48 +11,34 @@ import { ExternalLinkIcon } from "@radix-ui/react-icons"
 import {
 	ApiConfiguration,
 	ModelInfo,
-	anthropicDefaultModelId,
-	anthropicModels,
 	azureOpenAiDefaultApiVersion,
-	bedrockDefaultModelId,
-	bedrockModels,
-	deepSeekDefaultModelId,
-	deepSeekModels,
-	geminiDefaultModelId,
-	geminiModels,
 	glamaDefaultModelId,
 	glamaDefaultModelInfo,
 	mistralDefaultModelId,
-	mistralModels,
 	openAiModelInfoSaneDefaults,
-	openAiNativeDefaultModelId,
-	openAiNativeModels,
 	openRouterDefaultModelId,
 	openRouterDefaultModelInfo,
-	vertexDefaultModelId,
-	vertexModels,
 	unboundDefaultModelId,
 	unboundDefaultModelInfo,
 	requestyDefaultModelId,
 	requestyDefaultModelInfo,
-	xaiDefaultModelId,
-	xaiModels,
 	ApiProvider,
-	vscodeLlmModels,
-	vscodeLlmDefaultModelId,
 } from "@roo/shared/api"
 import { ExtensionMessage } from "@roo/shared/ExtensionMessage"
+import { AWS_REGIONS } from "@roo/shared/aws_regions"
 
-import { vscode } from "@/utils/vscode"
-import { validateApiConfiguration, validateModelId, validateBedrockArn } from "@/utils/validate"
+import { vscode } from "@src/utils/vscode"
+import { validateApiConfiguration, validateModelId, validateBedrockArn } from "@src/utils/validate"
+import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
 import {
 	useOpenRouterModelProviders,
 	OPENROUTER_DEFAULT_PROVIDER_NAME,
-} from "@/components/ui/hooks/useOpenRouterModelProviders"
-import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, SelectSeparator, Button } from "@/components/ui"
-import { MODELS_BY_PROVIDER, PROVIDERS, VERTEX_REGIONS, REASONING_MODELS } from "./constants"
-import { AWS_REGIONS } from "@roo/shared/aws_regions"
+} from "@src/components/ui/hooks/useOpenRouterModelProviders"
+import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, Button } from "@src/components/ui"
+
 import { VSCodeButtonLink } from "../common/VSCodeButtonLink"
+
+import { MODELS_BY_PROVIDER, PROVIDERS, VERTEX_REGIONS, REASONING_MODELS } from "./constants"
 import { ModelInfoView } from "./ModelInfoView"
 import { ModelPicker } from "./ModelPicker"
 import { TemperatureControl } from "./TemperatureControl"
@@ -281,6 +267,7 @@ const ApiOptions = ({
 	// Helper function to get the documentation URL and name for the currently selected provider
 	const getSelectedProviderDocUrl = (): { url: string; name: string } | undefined => {
 		const displayName = getProviderDisplayName(selectedProvider)
+
 		if (!displayName) {
 			return undefined
 		}
@@ -294,6 +281,49 @@ const ApiOptions = ({
 		}
 	}
 
+	const onApiProviderChange = useCallback(
+		(value: ApiProvider) => {
+			// It would be much easier to have a single attribute that stores
+			// the modelId, but we have a separate attribute for each of
+			// OpenRouter, Glama, Unbound, and Requesty.
+			// If you switch to one of these providers and the corresponding
+			// modelId is not set then you immediately end up in an error state.
+			// To address that we set the modelId to the default value for th
+			// provider if it's not already set.
+			switch (value) {
+				case "openrouter":
+					if (!apiConfiguration.openRouterModelId) {
+						setApiConfigurationField("openRouterModelId", openRouterDefaultModelId)
+					}
+					break
+				case "glama":
+					if (!apiConfiguration.glamaModelId) {
+						setApiConfigurationField("glamaModelId", glamaDefaultModelId)
+					}
+					break
+				case "unbound":
+					if (!apiConfiguration.unboundModelId) {
+						setApiConfigurationField("unboundModelId", unboundDefaultModelId)
+					}
+					break
+				case "requesty":
+					if (!apiConfiguration.requestyModelId) {
+						setApiConfigurationField("requestyModelId", requestyDefaultModelId)
+					}
+					break
+			}
+
+			setApiConfigurationField("apiProvider", value)
+		},
+		[
+			setApiConfigurationField,
+			apiConfiguration.openRouterModelId,
+			apiConfiguration.glamaModelId,
+			apiConfiguration.unboundModelId,
+			apiConfiguration.requestyModelId,
+		],
+	)
+
 	return (
 		<div className="flex flex-col gap-3">
 			<div className="flex flex-col gap-1 relative">
@@ -312,16 +342,12 @@ const ApiOptions = ({
 						</div>
 					)}
 				</div>
-				<Select
-					value={selectedProvider}
-					onValueChange={(value) => setApiConfigurationField("apiProvider", value as ApiProvider)}>
+				<Select value={selectedProvider} onValueChange={(value) => onApiProviderChange(value as ApiProvider)}>
 					<SelectTrigger className="w-full">
 						<SelectValue placeholder={t("settings:common.select")} />
 					</SelectTrigger>
 					<SelectContent>
-						<SelectItem value="openrouter">OpenRouter</SelectItem>
-						<SelectSeparator />
-						{PROVIDERS.filter((p) => p.value !== "openrouter").map(({ value, label }) => (
+						{PROVIDERS.map(({ value, label }) => (
 							<SelectItem key={value} value={value}>
 								{label}
 							</SelectItem>
@@ -1738,113 +1764,4 @@ const ApiOptions = ({
 	)
 }
 
-export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
-	const provider = apiConfiguration?.apiProvider || "anthropic"
-	const modelId = apiConfiguration?.apiModelId
-	const getProviderData = (models: Record<string, ModelInfo>, defaultId: string) => {
-		let selectedModelId: string
-		let selectedModelInfo: ModelInfo
-
-		if (modelId && modelId in models) {
-			selectedModelId = modelId
-			selectedModelInfo = models[modelId]
-		} else {
-			selectedModelId = defaultId
-			selectedModelInfo = models[defaultId]
-		}
-
-		return { selectedProvider: provider, selectedModelId, selectedModelInfo }
-	}
-
-	switch (provider) {
-		case "anthropic":
-			return getProviderData(anthropicModels, anthropicDefaultModelId)
-		case "xai":
-			return getProviderData(xaiModels, xaiDefaultModelId)
-		case "bedrock":
-			// Special case for custom ARN
-			if (modelId === "custom-arn") {
-				return {
-					selectedProvider: provider,
-					selectedModelId: "custom-arn",
-					selectedModelInfo: {
-						maxTokens: 5000,
-						contextWindow: 128_000,
-						supportsPromptCache: false,
-						supportsImages: true,
-					},
-				}
-			}
-			return getProviderData(bedrockModels, bedrockDefaultModelId)
-		case "vertex":
-			return getProviderData(vertexModels, vertexDefaultModelId)
-		case "gemini":
-			return getProviderData(geminiModels, geminiDefaultModelId)
-		case "deepseek":
-			return getProviderData(deepSeekModels, deepSeekDefaultModelId)
-		case "openai-native":
-			return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
-		case "mistral":
-			return getProviderData(mistralModels, mistralDefaultModelId)
-		case "openrouter":
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId,
-				selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo,
-			}
-		case "glama":
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.glamaModelId || glamaDefaultModelId,
-				selectedModelInfo: apiConfiguration?.glamaModelInfo || glamaDefaultModelInfo,
-			}
-		case "unbound":
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.unboundModelId || unboundDefaultModelId,
-				selectedModelInfo: apiConfiguration?.unboundModelInfo || unboundDefaultModelInfo,
-			}
-		case "requesty":
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.requestyModelId || requestyDefaultModelId,
-				selectedModelInfo: apiConfiguration?.requestyModelInfo || requestyDefaultModelInfo,
-			}
-		case "openai":
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.openAiModelId || "",
-				selectedModelInfo: apiConfiguration?.openAiCustomModelInfo || openAiModelInfoSaneDefaults,
-			}
-		case "ollama":
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.ollamaModelId || "",
-				selectedModelInfo: openAiModelInfoSaneDefaults,
-			}
-		case "lmstudio":
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.lmStudioModelId || "",
-				selectedModelInfo: openAiModelInfoSaneDefaults,
-			}
-		case "vscode-lm":
-			const modelFamily = apiConfiguration?.vsCodeLmModelSelector?.family ?? vscodeLlmDefaultModelId
-			const modelInfo = {
-				...openAiModelInfoSaneDefaults,
-				...vscodeLlmModels[modelFamily as keyof typeof vscodeLlmModels],
-				supportsImages: false, // VSCode LM API currently doesn't support images.
-			}
-			return {
-				selectedProvider: provider,
-				selectedModelId: apiConfiguration?.vsCodeLmModelSelector
-					? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}`
-					: "",
-				selectedModelInfo: modelInfo,
-			}
-		default:
-			return getProviderData(anthropicModels, anthropicDefaultModelId)
-	}
-}
-
 export default memo(ApiOptions)

+ 5 - 8
webview-ui/src/components/settings/ModelPicker.tsx

@@ -5,8 +5,9 @@ import { ChevronsUpDown, Check, X } from "lucide-react"
 
 import { ProviderSettings, ModelInfo } from "@roo/schemas"
 
-import { useAppTranslation } from "@/i18n/TranslationContext"
-import { cn } from "@/lib/utils"
+import { useAppTranslation } from "@src/i18n/TranslationContext"
+import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
+import { cn } from "@src/lib/utils"
 import {
 	Command,
 	CommandEmpty,
@@ -18,9 +19,8 @@ import {
 	PopoverContent,
 	PopoverTrigger,
 	Button,
-} from "@/components/ui"
+} from "@src/components/ui"
 
-import { normalizeApiConfiguration } from "./ApiOptions"
 import { ThinkingBudget } from "./ThinkingBudget"
 import { ModelInfoView } from "./ModelInfoView"
 
@@ -205,10 +205,7 @@ export const ModelPicker = ({
 						serviceLink: <VSCodeLink href={serviceUrl} className="text-sm" />,
 						defaultModelLink: <VSCodeLink onClick={() => onSelect(defaultModelId)} className="text-sm" />,
 					}}
-					values={{
-						serviceName,
-						defaultModelId,
-					}}
+					values={{ serviceName, defaultModelId }}
 				/>
 			</div>
 		</>

+ 141 - 0
webview-ui/src/utils/normalizeApiConfiguration.ts

@@ -0,0 +1,141 @@
+import {
+	ApiConfiguration,
+	ModelInfo,
+	anthropicDefaultModelId,
+	anthropicModels,
+	bedrockDefaultModelId,
+	bedrockModels,
+	deepSeekDefaultModelId,
+	deepSeekModels,
+	geminiDefaultModelId,
+	geminiModels,
+	glamaDefaultModelId,
+	glamaDefaultModelInfo,
+	mistralDefaultModelId,
+	mistralModels,
+	openAiModelInfoSaneDefaults,
+	openAiNativeDefaultModelId,
+	openAiNativeModels,
+	openRouterDefaultModelId,
+	openRouterDefaultModelInfo,
+	vertexDefaultModelId,
+	vertexModels,
+	unboundDefaultModelId,
+	unboundDefaultModelInfo,
+	requestyDefaultModelId,
+	requestyDefaultModelInfo,
+	xaiDefaultModelId,
+	xaiModels,
+	vscodeLlmModels,
+	vscodeLlmDefaultModelId,
+} from "@roo/shared/api"
+
+export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
+	const provider = apiConfiguration?.apiProvider || "anthropic"
+	const modelId = apiConfiguration?.apiModelId
+
+	const getProviderData = (models: Record<string, ModelInfo>, defaultId: string) => {
+		let selectedModelId: string
+		let selectedModelInfo: ModelInfo
+
+		if (modelId && modelId in models) {
+			selectedModelId = modelId
+			selectedModelInfo = models[modelId]
+		} else {
+			selectedModelId = defaultId
+			selectedModelInfo = models[defaultId]
+		}
+
+		return { selectedProvider: provider, selectedModelId, selectedModelInfo }
+	}
+
+	switch (provider) {
+		case "anthropic":
+			return getProviderData(anthropicModels, anthropicDefaultModelId)
+		case "xai":
+			return getProviderData(xaiModels, xaiDefaultModelId)
+		case "bedrock":
+			// Special case for custom ARN
+			if (modelId === "custom-arn") {
+				return {
+					selectedProvider: provider,
+					selectedModelId: "custom-arn",
+					selectedModelInfo: {
+						maxTokens: 5000,
+						contextWindow: 128_000,
+						supportsPromptCache: false,
+						supportsImages: true,
+					},
+				}
+			}
+			return getProviderData(bedrockModels, bedrockDefaultModelId)
+		case "vertex":
+			return getProviderData(vertexModels, vertexDefaultModelId)
+		case "gemini":
+			return getProviderData(geminiModels, geminiDefaultModelId)
+		case "deepseek":
+			return getProviderData(deepSeekModels, deepSeekDefaultModelId)
+		case "openai-native":
+			return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
+		case "mistral":
+			return getProviderData(mistralModels, mistralDefaultModelId)
+		case "openrouter":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId,
+				selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo,
+			}
+		case "glama":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.glamaModelId || glamaDefaultModelId,
+				selectedModelInfo: apiConfiguration?.glamaModelInfo || glamaDefaultModelInfo,
+			}
+		case "unbound":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.unboundModelId || unboundDefaultModelId,
+				selectedModelInfo: apiConfiguration?.unboundModelInfo || unboundDefaultModelInfo,
+			}
+		case "requesty":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.requestyModelId || requestyDefaultModelId,
+				selectedModelInfo: apiConfiguration?.requestyModelInfo || requestyDefaultModelInfo,
+			}
+		case "openai":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.openAiModelId || "",
+				selectedModelInfo: apiConfiguration?.openAiCustomModelInfo || openAiModelInfoSaneDefaults,
+			}
+		case "ollama":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.ollamaModelId || "",
+				selectedModelInfo: openAiModelInfoSaneDefaults,
+			}
+		case "lmstudio":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.lmStudioModelId || "",
+				selectedModelInfo: openAiModelInfoSaneDefaults,
+			}
+		case "vscode-lm":
+			const modelFamily = apiConfiguration?.vsCodeLmModelSelector?.family ?? vscodeLlmDefaultModelId
+			const modelInfo = {
+				...openAiModelInfoSaneDefaults,
+				...vscodeLlmModels[modelFamily as keyof typeof vscodeLlmModels],
+				supportsImages: false, // VSCode LM API currently doesn't support images.
+			}
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.vsCodeLmModelSelector
+					? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}`
+					: "",
+				selectedModelInfo: modelInfo,
+			}
+		default:
+			return getProviderData(anthropicModels, anthropicDefaultModelId)
+	}
+}