Browse Source

Improves model info detection for custom Bedrock ARNs (#3799)

* Improves model info detection for custom Bedrock ARNs

Adds heuristics to better estimate model capabilities when using unknown or custom model ARNs, including context window and max tokens. Allows user overrides for key model parameters via provider settings, improving flexibility and reliability for non-standard model integrations.

Fixes #3712

* Improves JSON syntax error handling in import flow

Provides more informative error messages for JSON syntax
errors by extracting the error position and formatting it
for clarity during import. Enhances user feedback when
invalid JSON is encountered.

* Fixed failing tests

* Delete pnpm-lock.yaml

* Added Rory's cache fix from PR #3099

PR #3009 has an important fix, alerted to me by @JBBrown

It was a one liner so I pulled it in.

This brings up a question can we merge PR's in the GH UI?

* Add Claude 4 and Opus 4 to modelID's

Kept previous parameters, did not see any changes in those.

* Fixed types being moved and me breaking the merge.

* Fix merge

---------

Co-authored-by: Matt Rubens <[email protected]>
Adam Hill 🦿 9 months ago
parent
commit
c4dab9e9b2

+ 2 - 0
packages/types/src/provider-settings.ts

@@ -100,6 +100,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({
 	awsProfile: z.string().optional(),
 	awsUseProfile: z.boolean().optional(),
 	awsCustomArn: z.string().optional(),
+	awsModelContextWindow: z.number().optional(),
 	awsBedrockEndpointEnabled: z.boolean().optional(),
 	awsBedrockEndpoint: z.string().optional(),
 })
@@ -285,6 +286,7 @@ export const PROVIDER_SETTINGS_KEYS = keysOf<ProviderSettings>()([
 	"awsProfile",
 	"awsUseProfile",
 	"awsCustomArn",
+	"awsModelContextWindow",
 	"awsBedrockEndpointEnabled",
 	"awsBedrockEndpoint",
 	// Google Vertex

+ 2 - 0
packages/types/src/providers/bedrock.ts

@@ -355,6 +355,8 @@ export const BEDROCK_DEFAULT_TEMPERATURE = 0.3
 
 export const BEDROCK_MAX_TOKENS = 4096
 
+export const BEDROCK_DEFAULT_CONTEXT = 128_000
+
 export const BEDROCK_REGION_INFO: Record<
 	string,
 	{

+ 71 - 4
src/api/providers/bedrock.ts

@@ -19,6 +19,7 @@ import {
 	bedrockDefaultPromptRouterModelId,
 	BEDROCK_DEFAULT_TEMPERATURE,
 	BEDROCK_MAX_TOKENS,
+	BEDROCK_DEFAULT_CONTEXT,
 	BEDROCK_REGION_INFO,
 } from "@roo-code/types"
 
@@ -192,6 +193,65 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 		this.client = new BedrockRuntimeClient(clientConfig)
 	}
 
+	// Helper to guess model info from custom modelId string if not in bedrockModels
+	private guessModelInfoFromId(modelId: string): Partial<ModelInfo> {
+		// Define a mapping for model ID patterns and their configurations
+		const modelConfigMap: Record<string, Partial<ModelInfo>> = {
+			"claude-4": {
+				maxTokens: 8192,
+				contextWindow: 200_000,
+				supportsImages: true,
+				supportsPromptCache: true,
+			},
+			"claude-3-7": {
+				maxTokens: 8192,
+				contextWindow: 200_000,
+				supportsImages: true,
+				supportsPromptCache: true,
+			},
+			"claude-3-5": {
+				maxTokens: 8192,
+				contextWindow: 200_000,
+				supportsImages: true,
+				supportsPromptCache: true,
+			},
+			"claude-4-opus": {
+				maxTokens: 4096,
+				contextWindow: 200_000,
+				supportsImages: true,
+				supportsPromptCache: true,
+			},
+			"claude-3-opus": {
+				maxTokens: 4096,
+				contextWindow: 200_000,
+				supportsImages: true,
+				supportsPromptCache: true,
+			},
+			"claude-3-haiku": {
+				maxTokens: 4096,
+				contextWindow: 200_000,
+				supportsImages: true,
+				supportsPromptCache: true,
+			},
+		}
+
+		// Match the model ID to a configuration
+		const id = modelId.toLowerCase()
+		for (const [pattern, config] of Object.entries(modelConfigMap)) {
+			if (id.includes(pattern)) {
+				return config
+			}
+		}
+
+		// Default fallback
+		return {
+			maxTokens: BEDROCK_MAX_TOKENS,
+			contextWindow: BEDROCK_DEFAULT_CONTEXT,
+			supportsImages: false,
+			supportsPromptCache: false,
+		}
+	}
+
 	override async *createMessage(
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
@@ -640,16 +700,24 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 				info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])),
 			}
 		} else {
+			// Use heuristics for model info, then allow overrides from ProviderSettings
+			const guessed = this.guessModelInfoFromId(modelId)
 			model = {
 				id: bedrockDefaultModelId,
-				info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
+				info: {
+					...JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
+					...guessed,
+				},
 			}
 		}
 
-		// If modelMaxTokens is explicitly set in options, override the default
+		// Always allow user to override detected/guessed maxTokens and contextWindow
 		if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
 			model.info.maxTokens = this.options.modelMaxTokens
 		}
+		if (this.options.awsModelContextWindow && this.options.awsModelContextWindow > 0) {
+			model.info.contextWindow = this.options.awsModelContextWindow
+		}
 
 		return model
 	}
@@ -684,8 +752,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 			}
 		}
 
-		modelConfig.info.maxTokens = modelConfig.info.maxTokens || BEDROCK_MAX_TOKENS
-
+		// Don't override maxTokens/contextWindow here; handled in getModelById (and includes user overrides)
 		return modelConfig as { id: BedrockModelId | string; info: ModelInfo }
 	}