|
|
@@ -1,8 +1,17 @@
|
|
|
import i18next from "i18next"
|
|
|
|
|
|
-import type { ProviderSettings, OrganizationAllowList } from "@roo-code/types"
|
|
|
-
|
|
|
-import { isRouterName, RouterModels } from "@roo/api"
|
|
|
+import {
|
|
|
+ type ProviderSettings,
|
|
|
+ type OrganizationAllowList,
|
|
|
+ type ProviderName,
|
|
|
+ modelIdKeysByProvider,
|
|
|
+ isProviderName,
|
|
|
+ isDynamicProvider,
|
|
|
+ isFauxProvider,
|
|
|
+ isCustomProvider,
|
|
|
+} from "@roo-code/types"
|
|
|
+
|
|
|
+import type { RouterModels } from "@roo/api"
|
|
|
|
|
|
export function validateApiConfiguration(
|
|
|
apiConfiguration: ProviderSettings,
|
|
|
@@ -10,6 +19,7 @@ export function validateApiConfiguration(
|
|
|
organizationAllowList?: OrganizationAllowList,
|
|
|
): string | undefined {
|
|
|
const keysAndIdsPresentErrorMessage = validateModelsAndKeysProvided(apiConfiguration)
|
|
|
+
|
|
|
if (keysAndIdsPresentErrorMessage) {
|
|
|
return keysAndIdsPresentErrorMessage
|
|
|
}
|
|
|
@@ -18,11 +28,12 @@ export function validateApiConfiguration(
|
|
|
apiConfiguration,
|
|
|
organizationAllowList,
|
|
|
)
|
|
|
+
|
|
|
if (organizationAllowListError) {
|
|
|
return organizationAllowListError.message
|
|
|
}
|
|
|
|
|
|
- return validateModelId(apiConfiguration, routerModels)
|
|
|
+ return validateDynamicProviderModelId(apiConfiguration, routerModels)
|
|
|
}
|
|
|
|
|
|
function validateModelsAndKeysProvided(apiConfiguration: ProviderSettings): string | undefined {
|
|
|
@@ -161,9 +172,13 @@ function validateProviderAgainstOrganizationSettings(
|
|
|
): ValidationError | undefined {
|
|
|
if (organizationAllowList && !organizationAllowList.allowAll) {
|
|
|
const provider = apiConfiguration.apiProvider
|
|
|
- if (!provider) return undefined
|
|
|
+
|
|
|
+ if (!provider) {
|
|
|
+ return undefined
|
|
|
+ }
|
|
|
|
|
|
const providerConfig = organizationAllowList.providers[provider]
|
|
|
+
|
|
|
if (!providerConfig) {
|
|
|
return {
|
|
|
message: i18next.t("settings:validation.providerNotAllowed", { provider }),
|
|
|
@@ -188,47 +203,28 @@ function validateProviderAgainstOrganizationSettings(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-function getModelIdForProvider(apiConfiguration: ProviderSettings, provider: string): string | undefined {
|
|
|
- switch (provider) {
|
|
|
- case "openrouter":
|
|
|
- return apiConfiguration.openRouterModelId
|
|
|
- case "glama":
|
|
|
- return apiConfiguration.glamaModelId
|
|
|
- case "unbound":
|
|
|
- return apiConfiguration.unboundModelId
|
|
|
- case "requesty":
|
|
|
- return apiConfiguration.requestyModelId
|
|
|
- case "deepinfra":
|
|
|
- return apiConfiguration.deepInfraModelId
|
|
|
- case "litellm":
|
|
|
- return apiConfiguration.litellmModelId
|
|
|
- case "openai":
|
|
|
- return apiConfiguration.openAiModelId
|
|
|
- case "ollama":
|
|
|
- return apiConfiguration.ollamaModelId
|
|
|
- case "lmstudio":
|
|
|
- return apiConfiguration.lmStudioModelId
|
|
|
- case "vscode-lm":
|
|
|
- // vsCodeLmModelSelector is an object, not a string
|
|
|
- return apiConfiguration.vsCodeLmModelSelector?.id
|
|
|
- case "huggingface":
|
|
|
- return apiConfiguration.huggingFaceModelId
|
|
|
- case "io-intelligence":
|
|
|
- return apiConfiguration.ioIntelligenceModelId
|
|
|
- case "vercel-ai-gateway":
|
|
|
- return apiConfiguration.vercelAiGatewayModelId
|
|
|
- default:
|
|
|
- return apiConfiguration.apiModelId
|
|
|
+function getModelIdForProvider(apiConfiguration: ProviderSettings, provider: ProviderName): string | undefined {
|
|
|
+ if (provider === "vscode-lm") {
|
|
|
+ return apiConfiguration.vsCodeLmModelSelector?.id
|
|
|
}
|
|
|
+
|
|
|
+ if (isCustomProvider(provider) || isFauxProvider(provider)) {
|
|
|
+ return apiConfiguration.apiModelId
|
|
|
+ }
|
|
|
+
|
|
|
+ return apiConfiguration[modelIdKeysByProvider[provider]]
|
|
|
}
|
|
|
+
|
|
|
/**
|
|
|
- * Validates an Amazon Bedrock ARN format and optionally checks if the region in the ARN matches the provided region
|
|
|
+ * Validates an Amazon Bedrock ARN format and optionally checks if the region in
|
|
|
+ * the ARN matches the provided region.
|
|
|
+ *
|
|
|
* @param arn The ARN string to validate
|
|
|
* @param region Optional region to check against the ARN's region
|
|
|
* @returns An object with validation results: { isValid, arnRegion, errorMessage }
|
|
|
*/
|
|
|
export function validateBedrockArn(arn: string, region?: string) {
|
|
|
- // Validate ARN format
|
|
|
+ // Validate ARN format.
|
|
|
const arnRegex = /^arn:aws:(?:bedrock|sagemaker):([^:]+):([^:]*):(?:([^/]+)\/([\w.\-:]+)|([^/]+))$/
|
|
|
const match = arn.match(arnRegex)
|
|
|
|
|
|
@@ -240,10 +236,10 @@ export function validateBedrockArn(arn: string, region?: string) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Extract region from ARN
|
|
|
+ // Extract region from ARN.
|
|
|
const arnRegion = match[1]
|
|
|
|
|
|
- // Check if region in ARN matches provided region (if specified)
|
|
|
+ // Check if region in ARN matches provided region (if specified).
|
|
|
if (region && arnRegion !== region) {
|
|
|
return {
|
|
|
isValid: true,
|
|
|
@@ -252,51 +248,22 @@ export function validateBedrockArn(arn: string, region?: string) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // ARN is valid and region matches (or no region was provided to check against)
|
|
|
+ // ARN is valid and region matches (or no region was provided to check against).
|
|
|
return { isValid: true, arnRegion, errorMessage: undefined }
|
|
|
}
|
|
|
|
|
|
-export function validateModelId(apiConfiguration: ProviderSettings, routerModels?: RouterModels): string | undefined {
|
|
|
+function validateDynamicProviderModelId(
|
|
|
+ apiConfiguration: ProviderSettings,
|
|
|
+ routerModels?: RouterModels,
|
|
|
+): string | undefined {
|
|
|
const provider = apiConfiguration.apiProvider ?? ""
|
|
|
|
|
|
- if (!isRouterName(provider)) {
|
|
|
+ // We only validate model ids from dynamic providers.
|
|
|
+ if (!isDynamicProvider(provider)) {
|
|
|
return undefined
|
|
|
}
|
|
|
|
|
|
- let modelId: string | undefined
|
|
|
-
|
|
|
- switch (provider) {
|
|
|
- case "openrouter":
|
|
|
- modelId = apiConfiguration.openRouterModelId
|
|
|
- break
|
|
|
- case "glama":
|
|
|
- modelId = apiConfiguration.glamaModelId
|
|
|
- break
|
|
|
- case "unbound":
|
|
|
- modelId = apiConfiguration.unboundModelId
|
|
|
- break
|
|
|
- case "requesty":
|
|
|
- modelId = apiConfiguration.requestyModelId
|
|
|
- break
|
|
|
- case "deepinfra":
|
|
|
- modelId = apiConfiguration.deepInfraModelId
|
|
|
- break
|
|
|
- case "ollama":
|
|
|
- modelId = apiConfiguration.ollamaModelId
|
|
|
- break
|
|
|
- case "lmstudio":
|
|
|
- modelId = apiConfiguration.lmStudioModelId
|
|
|
- break
|
|
|
- case "litellm":
|
|
|
- modelId = apiConfiguration.litellmModelId
|
|
|
- break
|
|
|
- case "io-intelligence":
|
|
|
- modelId = apiConfiguration.ioIntelligenceModelId
|
|
|
- break
|
|
|
- case "vercel-ai-gateway":
|
|
|
- modelId = apiConfiguration.vercelAiGatewayModelId
|
|
|
- break
|
|
|
- }
|
|
|
+ const modelId = getModelIdForProvider(apiConfiguration, provider)
|
|
|
|
|
|
if (!modelId) {
|
|
|
return i18next.t("settings:validation.modelId")
|
|
|
@@ -312,39 +279,44 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * Extracts model-specific validation errors from the API configuration
|
|
|
- * This is used to show model errors specifically in the model selector components
|
|
|
+ * Extracts model-specific validation errors from the API configuration.
|
|
|
+ * This is used to show model errors specifically in the model selector components.
|
|
|
*/
|
|
|
export function getModelValidationError(
|
|
|
apiConfiguration: ProviderSettings,
|
|
|
routerModels?: RouterModels,
|
|
|
organizationAllowList?: OrganizationAllowList,
|
|
|
): string | undefined {
|
|
|
- const modelId = getModelIdForProvider(apiConfiguration, apiConfiguration.apiProvider || "")
|
|
|
+ const modelId = isProviderName(apiConfiguration.apiProvider)
|
|
|
+ ? getModelIdForProvider(apiConfiguration, apiConfiguration.apiProvider)
|
|
|
+ : apiConfiguration.apiModelId
|
|
|
+
|
|
|
const configWithModelId = {
|
|
|
...apiConfiguration,
|
|
|
apiModelId: modelId || "",
|
|
|
}
|
|
|
|
|
|
const orgError = validateProviderAgainstOrganizationSettings(configWithModelId, organizationAllowList)
|
|
|
+
|
|
|
if (orgError && orgError.code === "MODEL_NOT_ALLOWED") {
|
|
|
return orgError.message
|
|
|
}
|
|
|
|
|
|
- return validateModelId(configWithModelId, routerModels)
|
|
|
+ return validateDynamicProviderModelId(configWithModelId, routerModels)
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * Validates API configuration but excludes model-specific errors
|
|
|
+ * Validates API configuration but excludes model-specific errors.
|
|
|
* This is used for the general API error display to prevent duplication
|
|
|
- * when model errors are shown in the model selector
|
|
|
+ * when model errors are shown in the model selector.
|
|
|
*/
|
|
|
export function validateApiConfigurationExcludingModelErrors(
|
|
|
apiConfiguration: ProviderSettings,
|
|
|
- _routerModels?: RouterModels, // keeping this for compatibility with the old function
|
|
|
+ _routerModels?: RouterModels, // Keeping this for compatibility with the old function.
|
|
|
organizationAllowList?: OrganizationAllowList,
|
|
|
): string | undefined {
|
|
|
const keysAndIdsPresentErrorMessage = validateModelsAndKeysProvided(apiConfiguration)
|
|
|
+
|
|
|
if (keysAndIdsPresentErrorMessage) {
|
|
|
return keysAndIdsPresentErrorMessage
|
|
|
}
|
|
|
@@ -354,11 +326,11 @@ export function validateApiConfigurationExcludingModelErrors(
|
|
|
organizationAllowList,
|
|
|
)
|
|
|
|
|
|
- // only return organization errors if they're not model-specific
|
|
|
+ // Inly return organization errors if they're not model-specific.
|
|
|
if (organizationAllowListError && organizationAllowListError.code === "PROVIDER_NOT_ALLOWED") {
|
|
|
return organizationAllowListError.message
|
|
|
}
|
|
|
|
|
|
- // skip model validation errors as they'll be shown in the model selector
|
|
|
+ // Skip model validation errors as they'll be shown in the model selector.
|
|
|
return undefined
|
|
|
}
|