Просмотр исходного кода

Merge pull request #1550 from Smartsheet-JB-Brown/jbbrown/aws_custom_arn_for_intelligent_prompt_routing

Users need the ability to use custom ARNs (Amazon Resource Names) with AWS Bedrock for intelligent prompt routing.
Matt Rubens 9 месяцев назад
Родитель
Сommit
9eab941d5c

+ 75 - 0
src/api/providers/__tests__/bedrock-custom-arn.test.ts

@@ -0,0 +1,75 @@
+import { AwsBedrockHandler } from "../bedrock"
+import { ApiHandlerOptions } from "../../../shared/api"
+
+// Mock the AWS SDK
+jest.mock("@aws-sdk/client-bedrock-runtime", () => {
+	const mockSend = jest.fn().mockImplementation(() => {
+		return Promise.resolve({
+			output: new TextEncoder().encode(JSON.stringify({ content: "Test response" })),
+		})
+	})
+
+	return {
+		BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
+			send: mockSend,
+			config: {
+				region: "us-east-1",
+			},
+		})),
+		ConverseCommand: jest.fn(),
+		ConverseStreamCommand: jest.fn(),
+	}
+})
+
+describe("AwsBedrockHandler with custom ARN", () => {
+	const mockOptions: ApiHandlerOptions = {
+		apiModelId: "custom-arn",
+		awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
+		awsRegion: "us-east-1",
+	}
+
+	it("should use the custom ARN as the model ID", async () => {
+		const handler = new AwsBedrockHandler(mockOptions)
+		const model = handler.getModel()
+
+		expect(model.id).toBe(mockOptions.awsCustomArn)
+		expect(model.info).toHaveProperty("maxTokens")
+		expect(model.info).toHaveProperty("contextWindow")
+		expect(model.info).toHaveProperty("supportsPromptCache")
+	})
+
+	it("should extract region from ARN and use it for client configuration", () => {
+		// Test with matching region
+		const handler1 = new AwsBedrockHandler(mockOptions)
+		expect((handler1 as any).client.config.region).toBe("us-east-1")
+
+		// Test with mismatched region
+		const mismatchOptions = {
+			...mockOptions,
+			awsRegion: "us-west-2",
+		}
+		const handler2 = new AwsBedrockHandler(mismatchOptions)
+		// Should use the ARN region, not the provided region
+		expect((handler2 as any).client.config.region).toBe("us-east-1")
+	})
+
+	it("should validate ARN format", async () => {
+		// Invalid ARN format
+		const invalidOptions = {
+			...mockOptions,
+			awsCustomArn: "invalid-arn-format",
+		}
+
+		const handler = new AwsBedrockHandler(invalidOptions)
+
+		// completePrompt should throw an error for invalid ARN
+		await expect(handler.completePrompt("test")).rejects.toThrow("Invalid ARN format")
+	})
+
+	it("should complete a prompt successfully with valid ARN", async () => {
+		const handler = new AwsBedrockHandler(mockOptions)
+		const response = await handler.completePrompt("test prompt")
+
+		expect(response).toBe("Test response")
+	})
+})

+ 29 - 0
src/api/providers/__tests__/bedrock.test.ts

@@ -315,5 +315,34 @@ describe("AwsBedrockHandler", () => {
 			expect(modelInfo.info.maxTokens).toBe(5000)
 			expect(modelInfo.info.maxTokens).toBe(5000)
 			expect(modelInfo.info.contextWindow).toBe(128_000)
 			expect(modelInfo.info.contextWindow).toBe(128_000)
 		})
 		})
+
+		it("should use custom ARN when provided", () => {
+			const customArnHandler = new AwsBedrockHandler({
+				apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
+				awsAccessKey: "test-access-key",
+				awsSecretKey: "test-secret-key",
+				awsRegion: "us-east-1",
+				awsCustomArn: "arn:aws:bedrock:us-east-1::foundation-model/custom-model",
+			})
+			const modelInfo = customArnHandler.getModel()
+			expect(modelInfo.id).toBe("arn:aws:bedrock:us-east-1::foundation-model/custom-model")
+			expect(modelInfo.info.maxTokens).toBe(4096)
+			expect(modelInfo.info.contextWindow).toBe(128_000)
+			expect(modelInfo.info.supportsPromptCache).toBe(false)
+		})
+
+		it("should use default model when custom-arn is selected but no ARN is provided", () => {
+			const customArnHandler = new AwsBedrockHandler({
+				apiModelId: "custom-arn",
+				awsAccessKey: "test-access-key",
+				awsSecretKey: "test-secret-key",
+				awsRegion: "us-east-1",
+				// No awsCustomArn provided
+			})
+			const modelInfo = customArnHandler.getModel()
+			// Should fall back to default model
+			expect(modelInfo.id).not.toBe("custom-arn")
+			expect(modelInfo.info).toBeDefined()
+		})
 	})
 	})
 })
 })

+ 460 - 29
src/api/providers/bedrock.ts

@@ -11,6 +11,47 @@ import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, be
 import { ApiStream } from "../transform/stream"
 import { ApiStream } from "../transform/stream"
 import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
 import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
 import { BaseProvider } from "./base-provider"
 import { BaseProvider } from "./base-provider"
+import { logger } from "../../utils/logging"
+
+/**
+ * Validates an AWS Bedrock ARN format and optionally checks if the region in the ARN matches the provided region
+ * @param arn The ARN string to validate
+ * @param region Optional region to check against the ARN's region
+ * @returns An object with validation results: { isValid, arnRegion, errorMessage }
+ */
+function validateBedrockArn(arn: string, region?: string) {
+	// Validate ARN format
+	const arnRegex = /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router)\/(.+)$/
+	const match = arn.match(arnRegex)
+
+	if (!match) {
+		return {
+			isValid: false,
+			arnRegion: undefined,
+			errorMessage:
+				"Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name",
+		}
+	}
+
+	// Extract region from ARN
+	const arnRegion = match[1]
+
+	// Check if region in ARN matches provided region (if specified)
+	if (region && arnRegion !== region) {
+		return {
+			isValid: true,
+			arnRegion,
+			errorMessage: `Warning: The region in your ARN (${arnRegion}) does not match your selected region (${region}). This may cause access issues. The provider will use the region from the ARN.`,
+		}
+	}
+
+	// ARN is valid and region matches (or no region was provided to check against)
+	return {
+		isValid: true,
+		arnRegion,
+		errorMessage: undefined,
+	}
+}
 
 
 const BEDROCK_DEFAULT_TEMPERATURE = 0.3
 const BEDROCK_DEFAULT_TEMPERATURE = 0.3
 
 
@@ -55,8 +96,31 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 		super()
 		super()
 		this.options = options
 		this.options = options
 
 
+		// Extract region from custom ARN if provided
+		let region = this.options.awsRegion || "us-east-1"
+
+		// If using custom ARN, extract region from the ARN
+		if (this.options.awsCustomArn) {
+			const validation = validateBedrockArn(this.options.awsCustomArn, region)
+
+			if (validation.isValid && validation.arnRegion) {
+				// If there's a region mismatch warning, log it and use the ARN region
+				if (validation.errorMessage) {
+					logger.info(
+						`Region mismatch: Selected region is ${region}, but ARN region is ${validation.arnRegion}. Using ARN region.`,
+						{
+							ctx: "bedrock",
+							selectedRegion: region,
+							arnRegion: validation.arnRegion,
+						},
+					)
+					region = validation.arnRegion
+				}
+			}
+		}
+
 		const clientConfig: BedrockRuntimeClientConfig = {
 		const clientConfig: BedrockRuntimeClientConfig = {
-			region: this.options.awsRegion || "us-east-1",
+			region: region,
 		}
 		}
 
 
 		if (this.options.awsUseProfile && this.options.awsProfile) {
 		if (this.options.awsUseProfile && this.options.awsProfile) {
@@ -81,7 +145,41 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 
 
 		// Handle cross-region inference
 		// Handle cross-region inference
 		let modelId: string
 		let modelId: string
-		if (this.options.awsUseCrossRegionInference) {
+
+		// For custom ARNs, use the ARN directly without modification
+		if (this.options.awsCustomArn) {
+			modelId = modelConfig.id
+
+			// Validate ARN format and check region match
+			const clientRegion = this.client.config.region as string
+			const validation = validateBedrockArn(modelId, clientRegion)
+
+			if (!validation.isValid) {
+				logger.error("Invalid ARN format", {
+					ctx: "bedrock",
+					modelId,
+					errorMessage: validation.errorMessage,
+				})
+				yield {
+					type: "text",
+					text: `Error: ${validation.errorMessage}`,
+				}
+				yield { type: "usage", inputTokens: 0, outputTokens: 0 }
+				throw new Error("Invalid ARN format")
+			}
+
+			// Extract region from ARN
+			const arnRegion = validation.arnRegion!
+
+			// Log warning if there's a region mismatch
+			if (validation.errorMessage) {
+				logger.warn(validation.errorMessage, {
+					ctx: "bedrock",
+					arnRegion,
+					clientRegion,
+				})
+			}
+		} else if (this.options.awsUseCrossRegionInference) {
 			let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
 			let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
 			switch (regionPrefix) {
 			switch (regionPrefix) {
 				case "us-":
 				case "us-":
@@ -107,7 +205,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 			messages: formattedMessages,
 			messages: formattedMessages,
 			system: [{ text: systemPrompt }],
 			system: [{ text: systemPrompt }],
 			inferenceConfig: {
 			inferenceConfig: {
-				maxTokens: modelConfig.info.maxTokens || 5000,
+				maxTokens: modelConfig.info.maxTokens || 4096,
 				temperature: this.options.modelTemperature ?? BEDROCK_DEFAULT_TEMPERATURE,
 				temperature: this.options.modelTemperature ?? BEDROCK_DEFAULT_TEMPERATURE,
 				topP: 0.1,
 				topP: 0.1,
 				...(this.options.awsUsePromptCache
 				...(this.options.awsUsePromptCache
@@ -121,6 +219,16 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 		}
 		}
 
 
 		try {
 		try {
+			// Log the payload for debugging custom ARN issues
+			if (this.options.awsCustomArn) {
+				logger.debug("Using custom ARN for Bedrock request", {
+					ctx: "bedrock",
+					customArn: this.options.awsCustomArn,
+					clientRegion: this.client.config.region,
+					payload: JSON.stringify(payload, null, 2),
+				})
+			}
+
 			const command = new ConverseStreamCommand(payload)
 			const command = new ConverseStreamCommand(payload)
 			const response = await this.client.send(command)
 			const response = await this.client.send(command)
 
 
@@ -134,7 +242,11 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 				try {
 				try {
 					streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
 					streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent)
 				} catch (e) {
 				} catch (e) {
-					console.error("Failed to parse stream event:", e)
+					logger.error("Failed to parse stream event", {
+						ctx: "bedrock",
+						error: e instanceof Error ? e : String(e),
+						chunk: typeof chunk === "string" ? chunk : "binary data",
+					})
 					continue
 					continue
 				}
 				}
 
 
@@ -177,39 +289,257 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 				}
 				}
 			}
 			}
 		} catch (error: unknown) {
 		} catch (error: unknown) {
-			console.error("Bedrock Runtime API Error:", error)
-			// Only access stack if error is an Error object
-			if (error instanceof Error) {
-				console.error("Error stack:", error.stack)
-				yield {
-					type: "text",
-					text: `Error: ${error.message}`,
+			logger.error("Bedrock Runtime API Error", {
+				ctx: "bedrock",
+				error: error instanceof Error ? error : String(error),
+			})
+
+			// Enhanced error handling for custom ARN issues
+			if (this.options.awsCustomArn) {
+				logger.error("Error occurred with custom ARN", {
+					ctx: "bedrock",
+					customArn: this.options.awsCustomArn,
+				})
+
+				// Check for common ARN-related errors
+				if (error instanceof Error) {
+					const errorMessage = error.message.toLowerCase()
+
+					// Access denied errors
+					if (
+						errorMessage.includes("access") &&
+						(errorMessage.includes("model") || errorMessage.includes("denied"))
+					) {
+						logger.error("Permissions issue with custom ARN", {
+							ctx: "bedrock",
+							customArn: this.options.awsCustomArn,
+							errorType: "access_denied",
+							clientRegion: this.client.config.region,
+						})
+						yield {
+							type: "text",
+							text: `Error: You don't have access to the model with the specified ARN. Please verify:
+
+1. The ARN is correct and points to a valid model
+2. Your AWS credentials have permission to access this model (check IAM policies)
+3. The region in the ARN (${this.client.config.region}) matches the region where the model is deployed
+4. If using a provisioned model, ensure it's active and not in a failed state
+5. If using a custom model, ensure your account has been granted access to it`,
+						}
+					}
+					// Model not found errors
+					else if (errorMessage.includes("not found") || errorMessage.includes("does not exist")) {
+						logger.error("Invalid ARN or non-existent model", {
+							ctx: "bedrock",
+							customArn: this.options.awsCustomArn,
+							errorType: "not_found",
+						})
+						yield {
+							type: "text",
+							text: `Error: The specified ARN does not exist or is invalid. Please check:
+
+1. The ARN format is correct (arn:aws:bedrock:region:account-id:resource-type/resource-name)
+2. The model exists in the specified region
+3. The account ID in the ARN is correct
+4. The resource type is one of: foundation-model, provisioned-model, or default-prompt-router`,
+						}
+					}
+					// Throttling errors
+					else if (
+						errorMessage.includes("throttl") ||
+						errorMessage.includes("rate") ||
+						errorMessage.includes("limit")
+					) {
+						logger.error("Throttling or rate limit issue with Bedrock", {
+							ctx: "bedrock",
+							customArn: this.options.awsCustomArn,
+							errorType: "throttling",
+						})
+						yield {
+							type: "text",
+							text: `Error: Request was throttled or rate limited. Please try:
+
+1. Reducing the frequency of requests
+2. If using a provisioned model, check its throughput settings
+3. Contact AWS support to request a quota increase if needed`,
+						}
+					}
+					// Other errors
+					else {
+						logger.error("Unspecified error with custom ARN", {
+							ctx: "bedrock",
+							customArn: this.options.awsCustomArn,
+							errorStack: error.stack,
+							errorMessage: error.message,
+						})
+						yield {
+							type: "text",
+							text: `Error with custom ARN: ${error.message}
+
+Please check:
+1. Your AWS credentials are valid and have the necessary permissions
+2. The ARN format is correct
+3. The region in the ARN matches the region where you're making the request`,
+						}
+					}
+				} else {
+					yield {
+						type: "text",
+						text: `Unknown error occurred with custom ARN. Please check your AWS credentials and ARN format.`,
+					}
 				}
 				}
-				yield {
-					type: "usage",
-					inputTokens: 0,
-					outputTokens: 0,
+			} else {
+				// Standard error handling for non-ARN cases
+				if (error instanceof Error) {
+					logger.error("Standard Bedrock error", {
+						ctx: "bedrock",
+						errorStack: error.stack,
+						errorMessage: error.message,
+					})
+					yield {
+						type: "text",
+						text: `Error: ${error.message}`,
+					}
+				} else {
+					logger.error("Unknown Bedrock error", {
+						ctx: "bedrock",
+						error: String(error),
+					})
+					yield {
+						type: "text",
+						text: "An unknown error occurred",
+					}
 				}
 				}
+			}
+
+			// Always yield usage info
+			yield {
+				type: "usage",
+				inputTokens: 0,
+				outputTokens: 0,
+			}
+
+			// Re-throw the error
+			if (error instanceof Error) {
 				throw error
 				throw error
 			} else {
 			} else {
-				const unknownError = new Error("An unknown error occurred")
-				yield {
-					type: "text",
-					text: unknownError.message,
-				}
-				yield {
-					type: "usage",
-					inputTokens: 0,
-					outputTokens: 0,
-				}
-				throw unknownError
+				throw new Error("An unknown error occurred")
 			}
 			}
 		}
 		}
 	}
 	}
 
 
 	override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
 	override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
+		// If custom ARN is provided, use it
+		if (this.options.awsCustomArn) {
+			// Custom ARNs should not be modified with region prefixes
+			// as they already contain the full resource path
+
+			// Check if the ARN contains information about the model type
+			// This helps set appropriate token limits for models behind prompt routers
+			const arnLower = this.options.awsCustomArn.toLowerCase()
+
+			// Determine model info based on ARN content
+			let modelInfo: ModelInfo
+
+			if (arnLower.includes("claude-3-7-sonnet") || arnLower.includes("claude-3.7-sonnet")) {
+				// Claude 3.7 Sonnet has 8192 tokens in Bedrock
+				modelInfo = {
+					maxTokens: 8192,
+					contextWindow: 200_000,
+					supportsPromptCache: false,
+					supportsImages: true,
+					supportsComputerUse: true,
+				}
+			} else if (arnLower.includes("claude-3-5-sonnet") || arnLower.includes("claude-3.5-sonnet")) {
+				// Claude 3.5 Sonnet has 8192 tokens in Bedrock
+				modelInfo = {
+					maxTokens: 8192,
+					contextWindow: 200_000,
+					supportsPromptCache: false,
+					supportsImages: true,
+					supportsComputerUse: true,
+				}
+			} else if (arnLower.includes("claude-3-opus") || arnLower.includes("claude-3.0-opus")) {
+				// Claude 3 Opus has 4096 tokens in Bedrock
+				modelInfo = {
+					maxTokens: 4096,
+					contextWindow: 200_000,
+					supportsPromptCache: false,
+					supportsImages: true,
+				}
+			} else if (arnLower.includes("claude-3-haiku") || arnLower.includes("claude-3.0-haiku")) {
+				// Claude 3 Haiku has 4096 tokens in Bedrock
+				modelInfo = {
+					maxTokens: 4096,
+					contextWindow: 200_000,
+					supportsPromptCache: false,
+					supportsImages: true,
+				}
+			} else if (arnLower.includes("claude-3-5-haiku") || arnLower.includes("claude-3.5-haiku")) {
+				// Claude 3.5 Haiku has 8192 tokens in Bedrock
+				modelInfo = {
+					maxTokens: 8192,
+					contextWindow: 200_000,
+					supportsPromptCache: false,
+					supportsImages: false,
+				}
+			} else if (arnLower.includes("claude")) {
+				// Generic Claude model with conservative token limit
+				modelInfo = {
+					maxTokens: 4096,
+					contextWindow: 128_000,
+					supportsPromptCache: false,
+					supportsImages: true,
+				}
+			} else if (arnLower.includes("llama3") || arnLower.includes("llama-3")) {
+				// Llama 3 models typically have 8192 tokens in Bedrock
+				modelInfo = {
+					maxTokens: 8192,
+					contextWindow: 128_000,
+					supportsPromptCache: false,
+					supportsImages: arnLower.includes("90b") || arnLower.includes("11b"),
+				}
+			} else if (arnLower.includes("nova-pro")) {
+				// Amazon Nova Pro
+				modelInfo = {
+					maxTokens: 5000,
+					contextWindow: 300_000,
+					supportsPromptCache: false,
+					supportsImages: true,
+				}
+			} else {
+				// Default for unknown models or prompt routers
+				modelInfo = {
+					maxTokens: 4096,
+					contextWindow: 128_000,
+					supportsPromptCache: false,
+					supportsImages: true,
+				}
+			}
+
+			// If modelMaxTokens is explicitly set in options, override the default
+			if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
+				modelInfo.maxTokens = this.options.modelMaxTokens
+			}
+
+			return {
+				id: this.options.awsCustomArn,
+				info: modelInfo,
+			}
+		}
+
 		const modelId = this.options.apiModelId
 		const modelId = this.options.apiModelId
 		if (modelId) {
 		if (modelId) {
+			// Special case for custom ARN option
+			if (modelId === "custom-arn") {
+				// This should not happen as we should have awsCustomArn set
+				// but just in case, return a default model
+				return {
+					id: bedrockDefaultModelId,
+					info: bedrockModels[bedrockDefaultModelId],
+				}
+			}
+
 			// For tests, allow any model ID
 			// For tests, allow any model ID
 			if (process.env.NODE_ENV === "test") {
 			if (process.env.NODE_ENV === "test") {
 				return {
 				return {
@@ -239,7 +569,43 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 
 
 			// Handle cross-region inference
 			// Handle cross-region inference
 			let modelId: string
 			let modelId: string
-			if (this.options.awsUseCrossRegionInference) {
+
+			// For custom ARNs, use the ARN directly without modification
+			if (this.options.awsCustomArn) {
+				modelId = modelConfig.id
+				logger.debug("Using custom ARN in completePrompt", {
+					ctx: "bedrock",
+					customArn: this.options.awsCustomArn,
+				})
+
+				// Validate ARN format and check region match
+				const clientRegion = this.client.config.region as string
+				const validation = validateBedrockArn(modelId, clientRegion)
+
+				if (!validation.isValid) {
+					logger.error("Invalid ARN format in completePrompt", {
+						ctx: "bedrock",
+						modelId,
+						errorMessage: validation.errorMessage,
+					})
+					throw new Error(
+						validation.errorMessage ||
+							"Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name",
+					)
+				}
+
+				// Extract region from ARN
+				const arnRegion = validation.arnRegion!
+
+				// Log warning if there's a region mismatch
+				if (validation.errorMessage) {
+					logger.warn(validation.errorMessage, {
+						ctx: "bedrock",
+						arnRegion,
+						clientRegion,
+					})
+				}
+			} else if (this.options.awsUseCrossRegionInference) {
 				let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
 				let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
 				switch (regionPrefix) {
 				switch (regionPrefix) {
 					case "us-":
 					case "us-":
@@ -265,12 +631,21 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 					},
 					},
 				]),
 				]),
 				inferenceConfig: {
 				inferenceConfig: {
-					maxTokens: modelConfig.info.maxTokens || 5000,
+					maxTokens: modelConfig.info.maxTokens || 4096,
 					temperature: this.options.modelTemperature ?? BEDROCK_DEFAULT_TEMPERATURE,
 					temperature: this.options.modelTemperature ?? BEDROCK_DEFAULT_TEMPERATURE,
 					topP: 0.1,
 					topP: 0.1,
 				},
 				},
 			}
 			}
 
 
+			// Log the payload for debugging custom ARN issues
+			if (this.options.awsCustomArn) {
+				logger.debug("Bedrock completePrompt request details", {
+					ctx: "bedrock",
+					clientRegion: this.client.config.region,
+					payload: JSON.stringify(payload, null, 2),
+				})
+			}
+
 			const command = new ConverseCommand(payload)
 			const command = new ConverseCommand(payload)
 			const response = await this.client.send(command)
 			const response = await this.client.send(command)
 
 
@@ -282,11 +657,67 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 						return output.content
 						return output.content
 					}
 					}
 				} catch (parseError) {
 				} catch (parseError) {
-					console.error("Failed to parse Bedrock response:", parseError)
+					logger.error("Failed to parse Bedrock response", {
+						ctx: "bedrock",
+						error: parseError instanceof Error ? parseError : String(parseError),
+					})
 				}
 				}
 			}
 			}
 			return ""
 			return ""
 		} catch (error) {
 		} catch (error) {
+			// Enhanced error handling for custom ARN issues
+			if (this.options.awsCustomArn) {
+				logger.error("Error occurred with custom ARN in completePrompt", {
+					ctx: "bedrock",
+					customArn: this.options.awsCustomArn,
+					error: error instanceof Error ? error : String(error),
+				})
+
+				if (error instanceof Error) {
+					const errorMessage = error.message.toLowerCase()
+
+					// Access denied errors
+					if (
+						errorMessage.includes("access") &&
+						(errorMessage.includes("model") || errorMessage.includes("denied"))
+					) {
+						throw new Error(
+							`Bedrock custom ARN error: You don't have access to the model with the specified ARN. Please verify:
+1. The ARN is correct and points to a valid model
+2. Your AWS credentials have permission to access this model (check IAM policies)
+3. The region in the ARN matches the region where the model is deployed
+4. If using a provisioned model, ensure it's active and not in a failed state`,
+						)
+					}
+					// Model not found errors
+					else if (errorMessage.includes("not found") || errorMessage.includes("does not exist")) {
+						throw new Error(
+							`Bedrock custom ARN error: The specified ARN does not exist or is invalid. Please check:
+1. The ARN format is correct (arn:aws:bedrock:region:account-id:resource-type/resource-name)
+2. The model exists in the specified region
+3. The account ID in the ARN is correct
+4. The resource type is one of: foundation-model, provisioned-model, or default-prompt-router`,
+						)
+					}
+					// Throttling errors
+					else if (
+						errorMessage.includes("throttl") ||
+						errorMessage.includes("rate") ||
+						errorMessage.includes("limit")
+					) {
+						throw new Error(
+							`Bedrock custom ARN error: Request was throttled or rate limited. Please try:
+1. Reducing the frequency of requests
+2. If using a provisioned model, check its throughput settings
+3. Contact AWS support to request a quota increase if needed`,
+						)
+					} else {
+						throw new Error(`Bedrock custom ARN error: ${error.message}`)
+					}
+				}
+			}
+
+			// Standard error handling
 			if (error instanceof Error) {
 			if (error instanceof Error) {
 				throw new Error(`Bedrock completion error: ${error.message}`)
 				throw new Error(`Bedrock completion error: ${error.message}`)
 			}
 			}

+ 2 - 0
src/shared/api.ts

@@ -39,6 +39,7 @@ export interface ApiHandlerOptions {
 	awspromptCacheId?: string
 	awspromptCacheId?: string
 	awsProfile?: string
 	awsProfile?: string
 	awsUseProfile?: boolean
 	awsUseProfile?: boolean
+	awsCustomArn?: string
 	vertexKeyFile?: string
 	vertexKeyFile?: string
 	vertexJsonCredentials?: string
 	vertexJsonCredentials?: string
 	vertexProjectId?: string
 	vertexProjectId?: string
@@ -99,6 +100,7 @@ export const API_CONFIG_KEYS: GlobalStateKey[] = [
 	// "awspromptCacheId", // NOT exist on GlobalStateKey
 	// "awspromptCacheId", // NOT exist on GlobalStateKey
 	"awsProfile",
 	"awsProfile",
 	"awsUseProfile",
 	"awsUseProfile",
+	"awsCustomArn",
 	"vertexKeyFile",
 	"vertexKeyFile",
 	"vertexJsonCredentials",
 	"vertexJsonCredentials",
 	"vertexProjectId",
 	"vertexProjectId",

+ 1 - 0
src/shared/globalState.ts

@@ -28,6 +28,7 @@ export const GLOBAL_STATE_KEYS = [
 	"awsUseCrossRegionInference",
 	"awsUseCrossRegionInference",
 	"awsProfile",
 	"awsProfile",
 	"awsUseProfile",
 	"awsUseProfile",
+	"awsCustomArn",
 	"vertexKeyFile",
 	"vertexKeyFile",
 	"vertexJsonCredentials",
 	"vertexJsonCredentials",
 	"vertexProjectId",
 	"vertexProjectId",

+ 85 - 4
webview-ui/src/components/settings/ApiOptions.tsx

@@ -41,7 +41,7 @@ import { VSCodeButtonLink } from "../common/VSCodeButtonLink"
 import { ModelInfoView } from "./ModelInfoView"
 import { ModelInfoView } from "./ModelInfoView"
 import { ModelPicker } from "./ModelPicker"
 import { ModelPicker } from "./ModelPicker"
 import { TemperatureControl } from "./TemperatureControl"
 import { TemperatureControl } from "./TemperatureControl"
-import { validateApiConfiguration, validateModelId } from "@/utils/validate"
+import { validateApiConfiguration, validateModelId, validateBedrockArn } from "@/utils/validate"
 import { ApiErrorMessage } from "./ApiErrorMessage"
 import { ApiErrorMessage } from "./ApiErrorMessage"
 import { ThinkingBudget } from "./ThinkingBudget"
 import { ThinkingBudget } from "./ThinkingBudget"
 
 
@@ -1267,14 +1267,82 @@ const ApiOptions = ({
 						</label>
 						</label>
 						<Dropdown
 						<Dropdown
 							id="model-id"
 							id="model-id"
-							value={selectedModelId}
+							value={selectedModelId === "custom-arn" ? "custom-arn" : selectedModelId}
 							onChange={(value) => {
 							onChange={(value) => {
-								setApiConfigurationField("apiModelId", typeof value == "string" ? value : value?.value)
+								const modelValue = typeof value == "string" ? value : value?.value
+								setApiConfigurationField("apiModelId", modelValue)
+
+								// Clear custom ARN if not using custom ARN option
+								if (modelValue !== "custom-arn" && selectedProvider === "bedrock") {
+									setApiConfigurationField("awsCustomArn", "")
+								}
 							}}
 							}}
-							options={selectedProviderModelOptions}
+							options={[
+								...selectedProviderModelOptions,
+								...(selectedProvider === "bedrock"
+									? [{ value: "custom-arn", label: "Use custom ARN..." }]
+									: []),
+							]}
 							className="w-full"
 							className="w-full"
 						/>
 						/>
 					</div>
 					</div>
+
+					{selectedProvider === "bedrock" && selectedModelId === "custom-arn" && (
+						<>
+							<VSCodeTextField
+								value={apiConfiguration?.awsCustomArn || ""}
+								onInput={(e) => {
+									const value = (e.target as HTMLInputElement).value
+									setApiConfigurationField("awsCustomArn", value)
+								}}
+								placeholder="Enter ARN (e.g. arn:aws:bedrock:us-east-1:123456789012:foundation-model/my-model)"
+								className="w-full">
+								<span className="font-medium">Custom ARN</span>
+							</VSCodeTextField>
+							<div className="text-sm text-vscode-descriptionForeground -mt-2">
+								Enter a valid AWS Bedrock ARN for the model you want to use. Format examples:
+								<ul className="list-disc pl-5 mt-1">
+									<li>
+										arn:aws:bedrock:us-east-1:123456789012:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0
+									</li>
+									<li>
+										arn:aws:bedrock:us-west-2:123456789012:provisioned-model/my-provisioned-model
+									</li>
+									<li>
+										arn:aws:bedrock:us-east-1:123456789012:default-prompt-router/anthropic.claude:1
+									</li>
+								</ul>
+								Make sure the region in the ARN matches your selected AWS Region above.
+							</div>
+							{apiConfiguration?.awsCustomArn &&
+								(() => {
+									const validation = validateBedrockArn(
+										apiConfiguration.awsCustomArn,
+										apiConfiguration.awsRegion,
+									)
+
+									if (!validation.isValid) {
+										return (
+											<div className="text-sm text-vscode-errorForeground mt-2">
+												{validation.errorMessage ||
+													"Invalid ARN format. Please check the examples above."}
+											</div>
+										)
+									}
+
+									if (validation.errorMessage) {
+										return (
+											<div className="text-sm text-vscode-errorForeground mt-2">
+												{validation.errorMessage}
+											</div>
+										)
+									}
+
+									return null
+								})()}
+							=======
+						</>
+					)}
 					<ModelInfoView
 					<ModelInfoView
 						selectedModelId={selectedModelId}
 						selectedModelId={selectedModelId}
 						modelInfo={selectedModelInfo}
 						modelInfo={selectedModelInfo}
@@ -1333,6 +1401,19 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 		case "anthropic":
 		case "anthropic":
 			return getProviderData(anthropicModels, anthropicDefaultModelId)
 			return getProviderData(anthropicModels, anthropicDefaultModelId)
 		case "bedrock":
 		case "bedrock":
+			// Special case for custom ARN
+			if (modelId === "custom-arn") {
+				return {
+					selectedProvider: provider,
+					selectedModelId: "custom-arn",
+					selectedModelInfo: {
+						maxTokens: 5000,
+						contextWindow: 128_000,
+						supportsPromptCache: false,
+						supportsImages: true,
+					},
+				}
+			}
 			return getProviderData(bedrockModels, bedrockDefaultModelId)
 			return getProviderData(bedrockModels, bedrockDefaultModelId)
 		case "vertex":
 		case "vertex":
 			return getProviderData(vertexModels, vertexDefaultModelId)
 			return getProviderData(vertexModels, vertexDefaultModelId)

+ 38 - 0
webview-ui/src/utils/validate.ts

@@ -80,6 +80,44 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
 
 
 	return undefined
 	return undefined
 }
 }
+/**
+ * Validates an AWS Bedrock ARN format and optionally checks if the region in the ARN matches the provided region
+ * @param arn The ARN string to validate
+ * @param region Optional region to check against the ARN's region
+ * @returns An object with validation results: { isValid, arnRegion, errorMessage }
+ */
+export function validateBedrockArn(arn: string, region?: string) {
+	// Validate ARN format
+	const arnRegex = /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router)\/(.+)$/
+	const match = arn.match(arnRegex)
+
+	if (!match) {
+		return {
+			isValid: false,
+			arnRegion: undefined,
+			errorMessage: "Invalid ARN format. Please check the format requirements.",
+		}
+	}
+
+	// Extract region from ARN
+	const arnRegion = match[1]
+
+	// Check if region in ARN matches provided region (if specified)
+	if (region && arnRegion !== region) {
+		return {
+			isValid: true,
+			arnRegion,
+			errorMessage: `Warning: The region in your ARN (${arnRegion}) does not match your selected region (${region}). This may cause access issues. The provider will use the region from the ARN.`,
+		}
+	}
+
+	// ARN is valid and region matches (or no region was provided to check against)
+	return {
+		isValid: true,
+		arnRegion,
+		errorMessage: undefined,
+	}
+}
 
 
 export function validateModelId(
 export function validateModelId(
 	apiConfiguration?: ApiConfiguration,
 	apiConfiguration?: ApiConfiguration,