|
|
@@ -11,10 +11,20 @@ import {
|
|
|
glamaDefaultModelId,
|
|
|
unboundDefaultModelId,
|
|
|
litellmDefaultModelId,
|
|
|
+ openAiNativeDefaultModelId,
|
|
|
+ anthropicDefaultModelId,
|
|
|
+ geminiDefaultModelId,
|
|
|
+ deepSeekDefaultModelId,
|
|
|
+ mistralDefaultModelId,
|
|
|
+ xaiDefaultModelId,
|
|
|
+ groqDefaultModelId,
|
|
|
+ chutesDefaultModelId,
|
|
|
+ bedrockDefaultModelId,
|
|
|
+ vertexDefaultModelId,
|
|
|
} from "@roo-code/types"
|
|
|
|
|
|
import { vscode } from "@src/utils/vscode"
|
|
|
-import { validateApiConfiguration } from "@src/utils/validate"
|
|
|
+import { validateApiConfigurationExcludingModelErrors, getModelValidationError } from "@src/utils/validate"
|
|
|
import { useAppTranslation } from "@src/i18n/TranslationContext"
|
|
|
import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
|
|
|
import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel"
|
|
|
@@ -176,8 +186,11 @@ const ApiOptions = ({
|
|
|
)
|
|
|
|
|
|
useEffect(() => {
|
|
|
- const apiValidationResult = validateApiConfiguration(apiConfiguration, routerModels, organizationAllowList)
|
|
|
-
|
|
|
+ const apiValidationResult = validateApiConfigurationExcludingModelErrors(
|
|
|
+ apiConfiguration,
|
|
|
+ routerModels,
|
|
|
+ organizationAllowList,
|
|
|
+ )
|
|
|
setErrorMessage(apiValidationResult)
|
|
|
}, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage])
|
|
|
|
|
|
@@ -187,16 +200,20 @@ const ApiOptions = ({
|
|
|
|
|
|
const filteredModels = filterModels(models, selectedProvider, organizationAllowList)
|
|
|
|
|
|
- return filteredModels
|
|
|
+ const modelOptions = filteredModels
|
|
|
? Object.keys(filteredModels).map((modelId) => ({
|
|
|
value: modelId,
|
|
|
label: modelId,
|
|
|
}))
|
|
|
: []
|
|
|
+
|
|
|
+ return modelOptions
|
|
|
}, [selectedProvider, organizationAllowList])
|
|
|
|
|
|
const onProviderChange = useCallback(
|
|
|
(value: ProviderName) => {
|
|
|
+ setApiConfigurationField("apiProvider", value)
|
|
|
+
|
|
|
// It would be much easier to have a single attribute that stores
|
|
|
// the modelId, but we have a separate attribute for each of
|
|
|
// OpenRouter, Glama, Unbound, and Requesty.
|
|
|
@@ -204,46 +221,69 @@ const ApiOptions = ({
|
|
|
// modelId is not set then you immediately end up in an error state.
|
|
|
// To address that we set the modelId to the default value for th
|
|
|
// provider if it's not already set.
|
|
|
- switch (value) {
|
|
|
- case "openrouter":
|
|
|
- if (!apiConfiguration.openRouterModelId) {
|
|
|
- setApiConfigurationField("openRouterModelId", openRouterDefaultModelId)
|
|
|
- }
|
|
|
- break
|
|
|
- case "glama":
|
|
|
- if (!apiConfiguration.glamaModelId) {
|
|
|
- setApiConfigurationField("glamaModelId", glamaDefaultModelId)
|
|
|
- }
|
|
|
- break
|
|
|
- case "unbound":
|
|
|
- if (!apiConfiguration.unboundModelId) {
|
|
|
- setApiConfigurationField("unboundModelId", unboundDefaultModelId)
|
|
|
- }
|
|
|
- break
|
|
|
- case "requesty":
|
|
|
- if (!apiConfiguration.requestyModelId) {
|
|
|
- setApiConfigurationField("requestyModelId", requestyDefaultModelId)
|
|
|
- }
|
|
|
- break
|
|
|
- case "litellm":
|
|
|
- if (!apiConfiguration.litellmModelId) {
|
|
|
- setApiConfigurationField("litellmModelId", litellmDefaultModelId)
|
|
|
+ const validateAndResetModel = (
|
|
|
+ modelId: string | undefined,
|
|
|
+ field: keyof ProviderSettings,
|
|
|
+ defaultValue?: string,
|
|
|
+ ) => {
|
|
|
+ // in case we haven't set a default value for a provider
|
|
|
+ if (!defaultValue) return
|
|
|
+
|
|
|
+ // only set default if no model is set, but don't reset invalid models
|
|
|
+ // let users see and decide what to do with invalid model selections
|
|
|
+ const shouldSetDefault = !modelId
|
|
|
+
|
|
|
+ if (shouldSetDefault) {
|
|
|
+ setApiConfigurationField(field, defaultValue)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Define a mapping object that associates each provider with its model configuration
|
|
|
+ const PROVIDER_MODEL_CONFIG: Partial<
|
|
|
+ Record<
|
|
|
+ ProviderName,
|
|
|
+ {
|
|
|
+ field: keyof ProviderSettings
|
|
|
+ default?: string
|
|
|
}
|
|
|
- break
|
|
|
+ >
|
|
|
+ > = {
|
|
|
+ openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId },
|
|
|
+ glama: { field: "glamaModelId", default: glamaDefaultModelId },
|
|
|
+ unbound: { field: "unboundModelId", default: unboundDefaultModelId },
|
|
|
+ requesty: { field: "requestyModelId", default: requestyDefaultModelId },
|
|
|
+ litellm: { field: "litellmModelId", default: litellmDefaultModelId },
|
|
|
+ anthropic: { field: "apiModelId", default: anthropicDefaultModelId },
|
|
|
+ "openai-native": { field: "apiModelId", default: openAiNativeDefaultModelId },
|
|
|
+ gemini: { field: "apiModelId", default: geminiDefaultModelId },
|
|
|
+ deepseek: { field: "apiModelId", default: deepSeekDefaultModelId },
|
|
|
+ mistral: { field: "apiModelId", default: mistralDefaultModelId },
|
|
|
+ xai: { field: "apiModelId", default: xaiDefaultModelId },
|
|
|
+ groq: { field: "apiModelId", default: groqDefaultModelId },
|
|
|
+ chutes: { field: "apiModelId", default: chutesDefaultModelId },
|
|
|
+ bedrock: { field: "apiModelId", default: bedrockDefaultModelId },
|
|
|
+ vertex: { field: "apiModelId", default: vertexDefaultModelId },
|
|
|
+ openai: { field: "openAiModelId" },
|
|
|
+ ollama: { field: "ollamaModelId" },
|
|
|
+ lmstudio: { field: "lmStudioModelId" },
|
|
|
}
|
|
|
|
|
|
- setApiConfigurationField("apiProvider", value)
|
|
|
+ const config = PROVIDER_MODEL_CONFIG[value]
|
|
|
+ if (config) {
|
|
|
+ validateAndResetModel(
|
|
|
+ apiConfiguration[config.field] as string | undefined,
|
|
|
+ config.field,
|
|
|
+ config.default,
|
|
|
+ )
|
|
|
+ }
|
|
|
},
|
|
|
- [
|
|
|
- setApiConfigurationField,
|
|
|
- apiConfiguration.openRouterModelId,
|
|
|
- apiConfiguration.glamaModelId,
|
|
|
- apiConfiguration.unboundModelId,
|
|
|
- apiConfiguration.requestyModelId,
|
|
|
- apiConfiguration.litellmModelId,
|
|
|
- ],
|
|
|
+ [setApiConfigurationField, apiConfiguration],
|
|
|
)
|
|
|
|
|
|
+ const modelValidationError = useMemo(() => {
|
|
|
+ return getModelValidationError(apiConfiguration, routerModels, organizationAllowList)
|
|
|
+ }, [apiConfiguration, routerModels, organizationAllowList])
|
|
|
+
|
|
|
const docs = useMemo(() => {
|
|
|
const provider = PROVIDERS.find(({ value }) => value === selectedProvider)
|
|
|
const name = provider?.label
|
|
|
@@ -303,6 +343,7 @@ const ApiOptions = ({
|
|
|
uriScheme={uriScheme}
|
|
|
fromWelcomeView={fromWelcomeView}
|
|
|
organizationAllowList={organizationAllowList}
|
|
|
+ modelValidationError={modelValidationError}
|
|
|
/>
|
|
|
)}
|
|
|
|
|
|
@@ -313,6 +354,7 @@ const ApiOptions = ({
|
|
|
routerModels={routerModels}
|
|
|
refetchRouterModels={refetchRouterModels}
|
|
|
organizationAllowList={organizationAllowList}
|
|
|
+ modelValidationError={modelValidationError}
|
|
|
/>
|
|
|
)}
|
|
|
|
|
|
@@ -323,6 +365,7 @@ const ApiOptions = ({
|
|
|
routerModels={routerModels}
|
|
|
uriScheme={uriScheme}
|
|
|
organizationAllowList={organizationAllowList}
|
|
|
+ modelValidationError={modelValidationError}
|
|
|
/>
|
|
|
)}
|
|
|
|
|
|
@@ -332,6 +375,7 @@ const ApiOptions = ({
|
|
|
setApiConfigurationField={setApiConfigurationField}
|
|
|
routerModels={routerModels}
|
|
|
organizationAllowList={organizationAllowList}
|
|
|
+ modelValidationError={modelValidationError}
|
|
|
/>
|
|
|
)}
|
|
|
|
|
|
@@ -368,6 +412,7 @@ const ApiOptions = ({
|
|
|
apiConfiguration={apiConfiguration}
|
|
|
setApiConfigurationField={setApiConfigurationField}
|
|
|
organizationAllowList={organizationAllowList}
|
|
|
+ modelValidationError={modelValidationError}
|
|
|
/>
|
|
|
)}
|
|
|
|
|
|
@@ -404,6 +449,7 @@ const ApiOptions = ({
|
|
|
apiConfiguration={apiConfiguration}
|
|
|
setApiConfigurationField={setApiConfigurationField}
|
|
|
organizationAllowList={organizationAllowList}
|
|
|
+ modelValidationError={modelValidationError}
|
|
|
/>
|
|
|
)}
|
|
|
|