|
@@ -3,66 +3,42 @@ import {
|
|
|
ConverseStreamCommand,
|
|
ConverseStreamCommand,
|
|
|
ConverseCommand,
|
|
ConverseCommand,
|
|
|
BedrockRuntimeClientConfig,
|
|
BedrockRuntimeClientConfig,
|
|
|
- ConverseStreamCommandOutput,
|
|
|
|
|
} from "@aws-sdk/client-bedrock-runtime"
|
|
} from "@aws-sdk/client-bedrock-runtime"
|
|
|
import { fromIni } from "@aws-sdk/credential-providers"
|
|
import { fromIni } from "@aws-sdk/credential-providers"
|
|
|
import { Anthropic } from "@anthropic-ai/sdk"
|
|
import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
import { SingleCompletionHandler } from "../"
|
|
import { SingleCompletionHandler } from "../"
|
|
|
import {
|
|
import {
|
|
|
- ApiHandlerOptions,
|
|
|
|
|
BedrockModelId,
|
|
BedrockModelId,
|
|
|
- ModelInfo,
|
|
|
|
|
|
|
+ ModelInfo as SharedModelInfo,
|
|
|
bedrockDefaultModelId,
|
|
bedrockDefaultModelId,
|
|
|
bedrockModels,
|
|
bedrockModels,
|
|
|
bedrockDefaultPromptRouterModelId,
|
|
bedrockDefaultPromptRouterModelId,
|
|
|
} from "../../shared/api"
|
|
} from "../../shared/api"
|
|
|
|
|
+import { ProviderSettings } from "../../schemas"
|
|
|
import { ApiStream } from "../transform/stream"
|
|
import { ApiStream } from "../transform/stream"
|
|
|
-import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
|
|
|
|
|
import { BaseProvider } from "./base-provider"
|
|
import { BaseProvider } from "./base-provider"
|
|
|
import { logger } from "../../utils/logging"
|
|
import { logger } from "../../utils/logging"
|
|
|
-
|
|
|
|
|
-/**
|
|
|
|
|
- * Validates an AWS Bedrock ARN format and optionally checks if the region in the ARN matches the provided region
|
|
|
|
|
- * @param arn The ARN string to validate
|
|
|
|
|
- * @param region Optional region to check against the ARN's region
|
|
|
|
|
- * @returns An object with validation results: { isValid, arnRegion, errorMessage }
|
|
|
|
|
- */
|
|
|
|
|
-function validateBedrockArn(arn: string, region?: string) {
|
|
|
|
|
- // Validate ARN format
|
|
|
|
|
- const arnRegex =
|
|
|
|
|
- /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router|prompt-router|inference-profile|application-inference-profile)\/(.+)$/
|
|
|
|
|
- const match = arn.match(arnRegex)
|
|
|
|
|
-
|
|
|
|
|
- if (!match) {
|
|
|
|
|
- return {
|
|
|
|
|
- isValid: false,
|
|
|
|
|
- arnRegion: undefined,
|
|
|
|
|
- errorMessage:
|
|
|
|
|
- "Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name",
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Extract region from ARN
|
|
|
|
|
- const arnRegion = match[1]
|
|
|
|
|
-
|
|
|
|
|
- // Check if region in ARN matches provided region (if specified)
|
|
|
|
|
- if (region && arnRegion !== region) {
|
|
|
|
|
- return {
|
|
|
|
|
- isValid: true,
|
|
|
|
|
- arnRegion,
|
|
|
|
|
- errorMessage: `Warning: The region in your ARN (${arnRegion}) does not match your selected region (${region}). This may cause access issues. The provider will use the region from the ARN.`,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // ARN is valid and region matches (or no region was provided to check against)
|
|
|
|
|
- return {
|
|
|
|
|
- isValid: true,
|
|
|
|
|
- arnRegion,
|
|
|
|
|
- errorMessage: undefined,
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
|
|
+import { Message, SystemContentBlock } from "@aws-sdk/client-bedrock-runtime"
|
|
|
|
|
+// New cache-related imports
|
|
|
|
|
+import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy"
|
|
|
|
|
+import { ModelInfo as CacheModelInfo } from "../transform/cache-strategy/types"
|
|
|
|
|
+import { AWS_BEDROCK_REGION_INFO } from "../../shared/aws_regions"
|
|
|
|
|
|
|
|
const BEDROCK_DEFAULT_TEMPERATURE = 0.3
|
|
const BEDROCK_DEFAULT_TEMPERATURE = 0.3
|
|
|
|
|
+const BEDROCK_MAX_TOKENS = 4096
|
|
|
|
|
+
|
|
|
|
|
+/************************************************************************************
|
|
|
|
|
+ *
|
|
|
|
|
+ * TYPES
|
|
|
|
|
+ *
|
|
|
|
|
+ *************************************************************************************/
|
|
|
|
|
+
|
|
|
|
|
+// Define interface for Bedrock inference config
|
|
|
|
|
+interface BedrockInferenceConfig {
|
|
|
|
|
+ maxTokens: number
|
|
|
|
|
+ temperature: number
|
|
|
|
|
+ topP: number
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
// Define types for stream events based on AWS SDK
|
|
// Define types for stream events based on AWS SDK
|
|
|
export interface StreamEvent {
|
|
export interface StreamEvent {
|
|
@@ -90,11 +66,17 @@ export interface StreamEvent {
|
|
|
inputTokens: number
|
|
inputTokens: number
|
|
|
outputTokens: number
|
|
outputTokens: number
|
|
|
totalTokens?: number // Made optional since we don't use it
|
|
totalTokens?: number // Made optional since we don't use it
|
|
|
|
|
+ // New cache-related fields
|
|
|
|
|
+ cacheReadInputTokens?: number
|
|
|
|
|
+ cacheWriteInputTokens?: number
|
|
|
|
|
+ cacheReadInputTokenCount?: number
|
|
|
|
|
+ cacheWriteInputTokenCount?: number
|
|
|
}
|
|
}
|
|
|
metrics?: {
|
|
metrics?: {
|
|
|
latencyMs: number
|
|
latencyMs: number
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ // New trace field for prompt router
|
|
|
trace?: {
|
|
trace?: {
|
|
|
promptRouter?: {
|
|
promptRouter?: {
|
|
|
invokedModelId?: string
|
|
invokedModelId?: string
|
|
@@ -102,49 +84,85 @@ export interface StreamEvent {
|
|
|
inputTokens: number
|
|
inputTokens: number
|
|
|
outputTokens: number
|
|
outputTokens: number
|
|
|
totalTokens?: number // Made optional since we don't use it
|
|
totalTokens?: number // Made optional since we don't use it
|
|
|
|
|
+ // New cache-related fields
|
|
|
|
|
+ cacheReadTokens?: number
|
|
|
|
|
+ cacheWriteTokens?: number
|
|
|
|
|
+ cacheReadInputTokenCount?: number
|
|
|
|
|
+ cacheWriteInputTokenCount?: number
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// Type for usage information in stream events
|
|
|
|
|
+export type UsageType = {
|
|
|
|
|
+ inputTokens?: number
|
|
|
|
|
+ outputTokens?: number
|
|
|
|
|
+ cacheReadInputTokens?: number
|
|
|
|
|
+ cacheWriteInputTokens?: number
|
|
|
|
|
+ cacheReadInputTokenCount?: number
|
|
|
|
|
+ cacheWriteInputTokenCount?: number
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/************************************************************************************
|
|
|
|
|
+ *
|
|
|
|
|
+ * PROVIDER
|
|
|
|
|
+ *
|
|
|
|
|
+ *************************************************************************************/
|
|
|
|
|
+
|
|
|
export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler {
|
|
export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler {
|
|
|
- protected options: ApiHandlerOptions
|
|
|
|
|
|
|
+ protected options: ProviderSettings
|
|
|
private client: BedrockRuntimeClient
|
|
private client: BedrockRuntimeClient
|
|
|
|
|
+ private arnInfo: any
|
|
|
|
|
|
|
|
- private costModelConfig: { id: BedrockModelId | string; info: ModelInfo } = {
|
|
|
|
|
- id: "",
|
|
|
|
|
- info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false, supportsImages: false },
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- constructor(options: ApiHandlerOptions) {
|
|
|
|
|
|
|
+ constructor(options: ProviderSettings) {
|
|
|
super()
|
|
super()
|
|
|
this.options = options
|
|
this.options = options
|
|
|
|
|
+ let region = this.options.awsRegion
|
|
|
|
|
|
|
|
- // Extract region from custom ARN if provided
|
|
|
|
|
- let region = this.options.awsRegion || "us-east-1"
|
|
|
|
|
-
|
|
|
|
|
- // If using custom ARN, extract region from the ARN
|
|
|
|
|
|
|
+ // process the various user input options, be opinionated about the intent of the options
|
|
|
|
|
+ // and determine the model to use during inference and for cost caclulations
|
|
|
|
|
+ // There are variations on ARN strings that can be entered making the conditional logic
|
|
|
|
|
+ // more involved than the non-ARN branch of logic
|
|
|
if (this.options.awsCustomArn) {
|
|
if (this.options.awsCustomArn) {
|
|
|
- const validation = validateBedrockArn(this.options.awsCustomArn, region)
|
|
|
|
|
|
|
+ this.arnInfo = this.parseArn(this.options.awsCustomArn, region)
|
|
|
|
|
|
|
|
- if (validation.isValid && validation.arnRegion) {
|
|
|
|
|
- // If there's a region mismatch warning, log it and use the ARN region
|
|
|
|
|
- if (validation.errorMessage) {
|
|
|
|
|
- logger.info(
|
|
|
|
|
- `Region mismatch: Selected region is ${region}, but ARN region is ${validation.arnRegion}. Using ARN region.`,
|
|
|
|
|
- {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- selectedRegion: region,
|
|
|
|
|
- arnRegion: validation.arnRegion,
|
|
|
|
|
- },
|
|
|
|
|
- )
|
|
|
|
|
- region = validation.arnRegion
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if (!this.arnInfo.isValid) {
|
|
|
|
|
+ logger.error("Invalid ARN format", {
|
|
|
|
|
+ ctx: "bedrock",
|
|
|
|
|
+ errorMessage: this.arnInfo.errorMessage,
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ // Throw a consistent error with a prefix that can be detected by callers
|
|
|
|
|
+ const errorMessage =
|
|
|
|
|
+ this.arnInfo.errorMessage ||
|
|
|
|
|
+ "Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name"
|
|
|
|
|
+ throw new Error("INVALID_ARN_FORMAT:" + errorMessage)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (this.arnInfo.region && this.arnInfo.region !== this.options.awsRegion) {
|
|
|
|
|
+ // Log if there's a region mismatch between the ARN and the region selected by the user
|
|
|
|
|
+ // We will use the ARNs region, so execution can continue, but log an info statement.
|
|
|
|
|
+ // Log a warning if there's a region mismatch between the ARN and the region selected by the user
|
|
|
|
|
+ // We will use the ARNs region, so execution can continue, but log an info statement.
|
|
|
|
|
+ logger.info(this.arnInfo.errorMessage, {
|
|
|
|
|
+ ctx: "bedrock",
|
|
|
|
|
+ selectedRegion: this.options.awsRegion,
|
|
|
|
|
+ arnRegion: this.arnInfo.region,
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ this.options.awsRegion = this.arnInfo.region
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ this.options.apiModelId = this.arnInfo.modelId
|
|
|
|
|
+ if (this.arnInfo.awsUseCrossRegionInference) this.options.awsUseCrossRegionInference = true
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ this.options.modelTemperature ?? BEDROCK_DEFAULT_TEMPERATURE
|
|
|
|
|
+ this.costModelConfig = this.getModel()
|
|
|
|
|
+
|
|
|
const clientConfig: BedrockRuntimeClientConfig = {
|
|
const clientConfig: BedrockRuntimeClientConfig = {
|
|
|
- region: region,
|
|
|
|
|
|
|
+ region: this.options.awsRegion,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (this.options.awsUseProfile && this.options.awsProfile) {
|
|
if (this.options.awsUseProfile && this.options.awsProfile) {
|
|
@@ -167,98 +185,60 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
|
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
|
|
let modelConfig = this.getModel()
|
|
let modelConfig = this.getModel()
|
|
|
// Handle cross-region inference
|
|
// Handle cross-region inference
|
|
|
- let modelId: string
|
|
|
|
|
-
|
|
|
|
|
- // For custom ARNs, use the ARN directly without modification
|
|
|
|
|
- if (this.options.awsCustomArn) {
|
|
|
|
|
- modelId = modelConfig.id
|
|
|
|
|
-
|
|
|
|
|
- // Validate ARN format and check region match
|
|
|
|
|
- const clientRegion = this.client.config.region as string
|
|
|
|
|
- const validation = validateBedrockArn(modelId, clientRegion)
|
|
|
|
|
|
|
+ const usePromptCache = Boolean(this.options.awsUsePromptCache && this.supportsAwsPromptCache(modelConfig))
|
|
|
|
|
+
|
|
|
|
|
+ // Generate a conversation ID based on the first few messages to maintain cache consistency
|
|
|
|
|
+ const conversationId =
|
|
|
|
|
+ messages.length > 0
|
|
|
|
|
+ ? `conv_${messages[0].role}_${
|
|
|
|
|
+ typeof messages[0].content === "string"
|
|
|
|
|
+ ? messages[0].content.substring(0, 20)
|
|
|
|
|
+ : "complex_content"
|
|
|
|
|
+ }`
|
|
|
|
|
+ : "default_conversation"
|
|
|
|
|
+
|
|
|
|
|
+ // Convert messages to Bedrock format, passing the model info and conversation ID
|
|
|
|
|
+ const formatted = this.convertToBedrockConverseMessages(
|
|
|
|
|
+ messages,
|
|
|
|
|
+ systemPrompt,
|
|
|
|
|
+ usePromptCache,
|
|
|
|
|
+ modelConfig.info,
|
|
|
|
|
+ conversationId,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- if (!validation.isValid) {
|
|
|
|
|
- logger.error("Invalid ARN format", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- modelId,
|
|
|
|
|
- errorMessage: validation.errorMessage,
|
|
|
|
|
- })
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: `Error: ${validation.errorMessage}`,
|
|
|
|
|
- }
|
|
|
|
|
- yield { type: "usage", inputTokens: 0, outputTokens: 0 }
|
|
|
|
|
- throw new Error("Invalid ARN format")
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Extract region from ARN
|
|
|
|
|
- const arnRegion = validation.arnRegion!
|
|
|
|
|
-
|
|
|
|
|
- // Log warning if there's a region mismatch
|
|
|
|
|
- if (validation.errorMessage) {
|
|
|
|
|
- logger.warn(validation.errorMessage, {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- arnRegion,
|
|
|
|
|
- clientRegion,
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
- } else if (this.options.awsUseCrossRegionInference) {
|
|
|
|
|
- let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
|
|
|
|
- switch (regionPrefix) {
|
|
|
|
|
- case "us-":
|
|
|
|
|
- modelId = `us.${modelConfig.id}`
|
|
|
|
|
- break
|
|
|
|
|
- case "eu-":
|
|
|
|
|
- modelId = `eu.${modelConfig.id}`
|
|
|
|
|
- break
|
|
|
|
|
- case "ap-":
|
|
|
|
|
- modelId = `apac.${modelConfig.id}`
|
|
|
|
|
- break
|
|
|
|
|
- default:
|
|
|
|
|
- modelId = modelConfig.id
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- modelId = modelConfig.id
|
|
|
|
|
|
|
+ // Construct the payload
|
|
|
|
|
+ const inferenceConfig: BedrockInferenceConfig = {
|
|
|
|
|
+ maxTokens: modelConfig.info.maxTokens as number,
|
|
|
|
|
+ temperature: this.options.modelTemperature as number,
|
|
|
|
|
+ topP: 0.1,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Convert messages to Bedrock format
|
|
|
|
|
- const formattedMessages = convertToBedrockConverseMessages(messages)
|
|
|
|
|
-
|
|
|
|
|
- // Construct the payload
|
|
|
|
|
const payload = {
|
|
const payload = {
|
|
|
- modelId,
|
|
|
|
|
- messages: formattedMessages,
|
|
|
|
|
- system: [{ text: systemPrompt }],
|
|
|
|
|
- inferenceConfig: {
|
|
|
|
|
- maxTokens: modelConfig.info.maxTokens || 4096,
|
|
|
|
|
- temperature: this.options.modelTemperature ?? BEDROCK_DEFAULT_TEMPERATURE,
|
|
|
|
|
- topP: 0.1,
|
|
|
|
|
- ...(this.options.awsUsePromptCache
|
|
|
|
|
- ? {
|
|
|
|
|
- promptCache: {
|
|
|
|
|
- promptCacheId: this.options.awspromptCacheId || "",
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
- : {}),
|
|
|
|
|
- },
|
|
|
|
|
|
|
+ modelId: modelConfig.id,
|
|
|
|
|
+ messages: formatted.messages,
|
|
|
|
|
+ system: formatted.system,
|
|
|
|
|
+ inferenceConfig,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // Create AbortController with 10 minute timeout
|
|
|
|
|
+ const controller = new AbortController()
|
|
|
|
|
+ let timeoutId: NodeJS.Timeout | undefined
|
|
|
|
|
+
|
|
|
try {
|
|
try {
|
|
|
- // Log the payload for debugging custom ARN issues
|
|
|
|
|
- if (this.options.awsCustomArn) {
|
|
|
|
|
- logger.debug("Using custom ARN for Bedrock request", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
|
|
- clientRegion: this.client.config.region,
|
|
|
|
|
- payload: JSON.stringify(payload, null, 2),
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ timeoutId = setTimeout(
|
|
|
|
|
+ () => {
|
|
|
|
|
+ controller.abort()
|
|
|
|
|
+ },
|
|
|
|
|
+ 10 * 60 * 1000,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
const command = new ConverseStreamCommand(payload)
|
|
const command = new ConverseStreamCommand(payload)
|
|
|
- const response = await this.client.send(command)
|
|
|
|
|
|
|
+ const response = await this.client.send(command, {
|
|
|
|
|
+ abortSignal: controller.signal,
|
|
|
|
|
+ })
|
|
|
|
|
|
|
|
if (!response.stream) {
|
|
if (!response.stream) {
|
|
|
|
|
+ clearTimeout(timeoutId)
|
|
|
throw new Error("No stream available in the response")
|
|
throw new Error("No stream available in the response")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -276,54 +256,63 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
continue
|
|
continue
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Handle metadata events first.
|
|
|
|
|
- if (streamEvent?.metadata?.usage) {
|
|
|
|
|
|
|
+ // Handle metadata events first
|
|
|
|
|
+ if (streamEvent.metadata?.usage) {
|
|
|
|
|
+ const usage = (streamEvent.metadata?.usage || {}) as UsageType
|
|
|
|
|
+
|
|
|
|
|
+ // Check both field naming conventions for cache tokens
|
|
|
|
|
+ const cacheReadTokens = usage.cacheReadInputTokens || usage.cacheReadInputTokenCount || 0
|
|
|
|
|
+ const cacheWriteTokens = usage.cacheWriteInputTokens || usage.cacheWriteInputTokenCount || 0
|
|
|
|
|
+
|
|
|
|
|
+ // Always include all available token information
|
|
|
yield {
|
|
yield {
|
|
|
type: "usage",
|
|
type: "usage",
|
|
|
- inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
|
|
|
|
- outputTokens: streamEvent.metadata.usage.outputTokens || 0,
|
|
|
|
|
|
|
+ inputTokens: usage.inputTokens || 0,
|
|
|
|
|
+ outputTokens: usage.outputTokens || 0,
|
|
|
|
|
+ cacheReadTokens: cacheReadTokens,
|
|
|
|
|
+ cacheWriteTokens: cacheWriteTokens,
|
|
|
}
|
|
}
|
|
|
continue
|
|
continue
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (streamEvent?.trace?.promptRouter?.invokedModelId) {
|
|
if (streamEvent?.trace?.promptRouter?.invokedModelId) {
|
|
|
try {
|
|
try {
|
|
|
- const invokedModelId = streamEvent.trace.promptRouter.invokedModelId
|
|
|
|
|
- // Create a platform-independent regex that doesn't use forward slash as both delimiter and matcher
|
|
|
|
|
- const modelMatch = invokedModelId.match(new RegExp("[\\/\\\\]([^\\/\\\\]+)(?::|$)"))
|
|
|
|
|
- if (modelMatch && modelMatch[1]) {
|
|
|
|
|
- let modelName = modelMatch[1]
|
|
|
|
|
-
|
|
|
|
|
- // Get a new modelConfig from getModel() using invokedModelId.. remove the region first
|
|
|
|
|
- let region = modelName.slice(0, 3)
|
|
|
|
|
-
|
|
|
|
|
- // Check for all region prefixes (us., eu., and apac prefix)
|
|
|
|
|
- if (region === "us." || region === "eu.") modelName = modelName.slice(3)
|
|
|
|
|
- else if (modelName.startsWith("apac.")) modelName = modelName.slice(5)
|
|
|
|
|
- this.costModelConfig = this.getModelByName(modelName)
|
|
|
|
|
-
|
|
|
|
|
- // Log successful model extraction to help with debugging
|
|
|
|
|
- logger.debug("Successfully extracted model from invokedModelId", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- invokedModelId,
|
|
|
|
|
- extractedModelName: modelName,
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ //update the in-use model info to be based on the invoked Model Id for the router
|
|
|
|
|
+ //so that pricing, context window, caching etc have values that can be used
|
|
|
|
|
+ //However, we want to keep the id of the model to be the ID for the router for
|
|
|
|
|
+ //subsequent requests so they are sent back through the router
|
|
|
|
|
+ let invokedArnInfo = this.parseArn(streamEvent.trace.promptRouter.invokedModelId)
|
|
|
|
|
+ let invokedModel = this.getModelById(invokedArnInfo.modelId as string, invokedArnInfo.modelType)
|
|
|
|
|
+ if (invokedModel) {
|
|
|
|
|
+ invokedModel.id = modelConfig.id
|
|
|
|
|
+ this.costModelConfig = invokedModel
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Handle metadata events for the promptRouter.
|
|
// Handle metadata events for the promptRouter.
|
|
|
if (streamEvent?.trace?.promptRouter?.usage) {
|
|
if (streamEvent?.trace?.promptRouter?.usage) {
|
|
|
|
|
+ const routerUsage = streamEvent.trace.promptRouter.usage
|
|
|
|
|
+
|
|
|
|
|
+ // Check both field naming conventions for cache tokens
|
|
|
|
|
+ const cacheReadTokens =
|
|
|
|
|
+ routerUsage.cacheReadTokens || routerUsage.cacheReadInputTokenCount || 0
|
|
|
|
|
+ const cacheWriteTokens =
|
|
|
|
|
+ routerUsage.cacheWriteTokens || routerUsage.cacheWriteInputTokenCount || 0
|
|
|
|
|
+
|
|
|
yield {
|
|
yield {
|
|
|
type: "usage",
|
|
type: "usage",
|
|
|
- inputTokens: streamEvent?.trace?.promptRouter?.usage?.inputTokens || 0,
|
|
|
|
|
- outputTokens: streamEvent?.trace?.promptRouter?.usage?.outputTokens || 0,
|
|
|
|
|
|
|
+ inputTokens: routerUsage.inputTokens || 0,
|
|
|
|
|
+ outputTokens: routerUsage.outputTokens || 0,
|
|
|
|
|
+ cacheReadTokens: cacheReadTokens,
|
|
|
|
|
+ cacheWriteTokens: cacheWriteTokens,
|
|
|
}
|
|
}
|
|
|
- continue
|
|
|
|
|
}
|
|
}
|
|
|
} catch (error) {
|
|
} catch (error) {
|
|
|
logger.error("Error handling Bedrock invokedModelId", {
|
|
logger.error("Error handling Bedrock invokedModelId", {
|
|
|
ctx: "bedrock",
|
|
ctx: "bedrock",
|
|
|
error: error instanceof Error ? error : String(error),
|
|
error: error instanceof Error ? error : String(error),
|
|
|
})
|
|
})
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ continue
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -354,391 +343,610 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
continue
|
|
continue
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ // Clear timeout after stream completes
|
|
|
|
|
+ clearTimeout(timeoutId)
|
|
|
} catch (error: unknown) {
|
|
} catch (error: unknown) {
|
|
|
- logger.error("Bedrock Runtime API Error", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- error: error instanceof Error ? error : String(error),
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ // Clear timeout on error
|
|
|
|
|
+ clearTimeout(timeoutId)
|
|
|
|
|
+
|
|
|
|
|
+ // Use the extracted error handling method for all errors
|
|
|
|
|
+ const errorChunks = this.handleBedrockError(error, true) // true for streaming context
|
|
|
|
|
+ // Yield each chunk individually to ensure type compatibility
|
|
|
|
|
+ for (const chunk of errorChunks) {
|
|
|
|
|
+ yield chunk as any // Cast to any to bypass type checking since we know the structure is correct
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Enhanced error handling for custom ARN issues
|
|
|
|
|
- if (this.options.awsCustomArn) {
|
|
|
|
|
- logger.error("Error occurred with custom ARN", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ // Re-throw the error
|
|
|
|
|
+ if (error instanceof Error) {
|
|
|
|
|
+ throw error
|
|
|
|
|
+ } else {
|
|
|
|
|
+ throw new Error("An unknown error occurred")
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Check for common ARN-related errors
|
|
|
|
|
- if (error instanceof Error) {
|
|
|
|
|
- const errorMessage = error.message.toLowerCase()
|
|
|
|
|
|
|
+ async completePrompt(prompt: string): Promise<string> {
|
|
|
|
|
+ try {
|
|
|
|
|
+ const modelConfig = this.getModel()
|
|
|
|
|
|
|
|
- // Access denied errors
|
|
|
|
|
- if (
|
|
|
|
|
- errorMessage.includes("access") &&
|
|
|
|
|
- (errorMessage.includes("model") || errorMessage.includes("denied"))
|
|
|
|
|
- ) {
|
|
|
|
|
- logger.error("Permissions issue with custom ARN", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
|
|
- errorType: "access_denied",
|
|
|
|
|
- clientRegion: this.client.config.region,
|
|
|
|
|
- })
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: `Error: You don't have access to the model with the specified ARN. Please verify:
|
|
|
|
|
-
|
|
|
|
|
-1. The ARN is correct and points to a valid model
|
|
|
|
|
-2. Your AWS credentials have permission to access this model (check IAM policies)
|
|
|
|
|
-3. The region in the ARN (${this.client.config.region}) matches the region where the model is deployed
|
|
|
|
|
-4. If using a provisioned model, ensure it's active and not in a failed state
|
|
|
|
|
-5. If using a custom model, ensure your account has been granted access to it`,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- // Model not found errors
|
|
|
|
|
- else if (errorMessage.includes("not found") || errorMessage.includes("does not exist")) {
|
|
|
|
|
- logger.error("Invalid ARN or non-existent model", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
|
|
- errorType: "not_found",
|
|
|
|
|
- })
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: `Error: The specified ARN does not exist or is invalid. Please check:
|
|
|
|
|
|
|
+ const inferenceConfig: BedrockInferenceConfig = {
|
|
|
|
|
+ maxTokens: modelConfig.info.maxTokens as number,
|
|
|
|
|
+ temperature: this.options.modelTemperature as number,
|
|
|
|
|
+ topP: 0.1,
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
-1. The ARN format is correct (arn:aws:bedrock:region:account-id:resource-type/resource-name)
|
|
|
|
|
-2. The model exists in the specified region
|
|
|
|
|
-3. The account ID in the ARN is correct
|
|
|
|
|
-4. The resource type is one of: foundation-model, provisioned-model, or default-prompt-router`,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- // Throttling errors
|
|
|
|
|
- else if (
|
|
|
|
|
- errorMessage.includes("throttl") ||
|
|
|
|
|
- errorMessage.includes("rate") ||
|
|
|
|
|
- errorMessage.includes("limit")
|
|
|
|
|
- ) {
|
|
|
|
|
- logger.error("Throttling or rate limit issue with Bedrock", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
|
|
- errorType: "throttling",
|
|
|
|
|
- })
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: `Error: Request was throttled or rate limited. Please try:
|
|
|
|
|
|
|
+ // For completePrompt, use a unique conversation ID based on the prompt
|
|
|
|
|
+ const conversationId = `prompt_${prompt.substring(0, 20)}`
|
|
|
|
|
|
|
|
-1. Reducing the frequency of requests
|
|
|
|
|
-2. If using a provisioned model, check its throughput settings
|
|
|
|
|
-3. Contact AWS support to request a quota increase if needed`,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- // Other errors
|
|
|
|
|
- else {
|
|
|
|
|
- logger.error("Unspecified error with custom ARN", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
|
|
- errorStack: error.stack,
|
|
|
|
|
- errorMessage: error.message,
|
|
|
|
|
- })
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: `Error with custom ARN: ${error.message}
|
|
|
|
|
-
|
|
|
|
|
-Please check:
|
|
|
|
|
-1. Your AWS credentials are valid and have the necessary permissions
|
|
|
|
|
-2. The ARN format is correct
|
|
|
|
|
-3. The region in the ARN matches the region where you're making the request`,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: `Unknown error occurred with custom ARN. Please check your AWS credentials and ARN format.`,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- // Standard error handling for non-ARN cases
|
|
|
|
|
- if (error instanceof Error) {
|
|
|
|
|
- logger.error("Standard Bedrock error", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- errorStack: error.stack,
|
|
|
|
|
- errorMessage: error.message,
|
|
|
|
|
- })
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: `Error: ${error.message}`,
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- logger.error("Unknown Bedrock error", {
|
|
|
|
|
|
|
+ const payload = {
|
|
|
|
|
+ modelId: modelConfig.id,
|
|
|
|
|
+ messages: this.convertToBedrockConverseMessages(
|
|
|
|
|
+ [
|
|
|
|
|
+ {
|
|
|
|
|
+ role: "user",
|
|
|
|
|
+ content: prompt,
|
|
|
|
|
+ },
|
|
|
|
|
+ ],
|
|
|
|
|
+ undefined,
|
|
|
|
|
+ false,
|
|
|
|
|
+ modelConfig.info,
|
|
|
|
|
+ conversationId,
|
|
|
|
|
+ ).messages,
|
|
|
|
|
+ inferenceConfig,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const command = new ConverseCommand(payload)
|
|
|
|
|
+ const response = await this.client.send(command)
|
|
|
|
|
+
|
|
|
|
|
+ if (
|
|
|
|
|
+ response?.output?.message?.content &&
|
|
|
|
|
+ response.output.message.content.length > 0 &&
|
|
|
|
|
+ response.output.message.content[0].text &&
|
|
|
|
|
+ response.output.message.content[0].text.trim().length > 0
|
|
|
|
|
+ ) {
|
|
|
|
|
+ try {
|
|
|
|
|
+ return response.output.message.content[0].text
|
|
|
|
|
+ } catch (parseError) {
|
|
|
|
|
+ logger.error("Failed to parse Bedrock response", {
|
|
|
ctx: "bedrock",
|
|
ctx: "bedrock",
|
|
|
- error: String(error),
|
|
|
|
|
|
|
+ error: parseError instanceof Error ? parseError : String(parseError),
|
|
|
})
|
|
})
|
|
|
- yield {
|
|
|
|
|
- type: "text",
|
|
|
|
|
- text: "An unknown error occurred",
|
|
|
|
|
- }
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ return ""
|
|
|
|
|
+ } catch (error) {
|
|
|
|
|
+ // Use the extracted error handling method for all errors
|
|
|
|
|
+ const errorResult = this.handleBedrockError(error, false) // false for non-streaming context
|
|
|
|
|
+ // Since we're in a non-streaming context, we know the result is a string
|
|
|
|
|
+ const errorMessage = errorResult as string
|
|
|
|
|
+ throw new Error(errorMessage)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * Convert Anthropic messages to Bedrock Converse format
|
|
|
|
|
+ */
|
|
|
|
|
+ private convertToBedrockConverseMessages(
|
|
|
|
|
+ anthropicMessages: Anthropic.Messages.MessageParam[] | { role: string; content: string }[],
|
|
|
|
|
+ systemMessage?: string,
|
|
|
|
|
+ usePromptCache: boolean = false,
|
|
|
|
|
+ modelInfo?: any,
|
|
|
|
|
+ conversationId?: string, // Optional conversation ID to track cache points across messages
|
|
|
|
|
+ ): { system: SystemContentBlock[]; messages: Message[] } {
|
|
|
|
|
+ // Convert model info to expected format
|
|
|
|
|
+ const cacheModelInfo: CacheModelInfo = {
|
|
|
|
|
+ maxTokens: modelInfo?.maxTokens || 8192,
|
|
|
|
|
+ contextWindow: modelInfo?.contextWindow || 200_000,
|
|
|
|
|
+ supportsPromptCache: modelInfo?.supportsPromptCache || false,
|
|
|
|
|
+ maxCachePoints: modelInfo?.maxCachePoints || 0,
|
|
|
|
|
+ minTokensPerCachePoint: modelInfo?.minTokensPerCachePoint || 50,
|
|
|
|
|
+ cachableFields: modelInfo?.cachableFields || [],
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Always yield usage info
|
|
|
|
|
- yield {
|
|
|
|
|
- type: "usage",
|
|
|
|
|
- inputTokens: 0,
|
|
|
|
|
- outputTokens: 0,
|
|
|
|
|
|
|
+ // Clean messages by removing any existing cache points
|
|
|
|
|
+ const cleanedMessages = anthropicMessages.map((msg) => {
|
|
|
|
|
+ if (typeof msg.content === "string") {
|
|
|
|
|
+ return msg
|
|
|
|
|
+ }
|
|
|
|
|
+ const cleaned = {
|
|
|
|
|
+ ...msg,
|
|
|
|
|
+ content: this.removeCachePoints(msg.content),
|
|
|
}
|
|
}
|
|
|
|
|
+ return cleaned
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ // Get previous cache point placements for this conversation if available
|
|
|
|
|
+ const previousPlacements =
|
|
|
|
|
+ conversationId && this.previousCachePointPlacements[conversationId]
|
|
|
|
|
+ ? this.previousCachePointPlacements[conversationId]
|
|
|
|
|
+ : undefined
|
|
|
|
|
+
|
|
|
|
|
+ // Create config for cache strategy
|
|
|
|
|
+ const config = {
|
|
|
|
|
+ modelInfo: cacheModelInfo,
|
|
|
|
|
+ systemPrompt: systemMessage,
|
|
|
|
|
+ messages: cleanedMessages as Anthropic.Messages.MessageParam[],
|
|
|
|
|
+ usePromptCache,
|
|
|
|
|
+ previousCachePointPlacements: previousPlacements,
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Re-throw the error
|
|
|
|
|
- if (error instanceof Error) {
|
|
|
|
|
- throw error
|
|
|
|
|
- } else {
|
|
|
|
|
- throw new Error("An unknown error occurred")
|
|
|
|
|
|
|
+ // Determine optimal cache points
|
|
|
|
|
+ let strategy = new MultiPointStrategy(config)
|
|
|
|
|
+ const result = strategy.determineOptimalCachePoints()
|
|
|
|
|
+
|
|
|
|
|
+ // Store cache point placements for future use if conversation ID is provided
|
|
|
|
|
+ if (conversationId && result.messageCachePointPlacements) {
|
|
|
|
|
+ this.previousCachePointPlacements[conversationId] = result.messageCachePointPlacements
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return result
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /************************************************************************************
|
|
|
|
|
+ *
|
|
|
|
|
+ * MODEL IDENTIFICATION
|
|
|
|
|
+ *
|
|
|
|
|
+ *************************************************************************************/
|
|
|
|
|
+
|
|
|
|
|
+ private costModelConfig: { id: BedrockModelId | string; info: SharedModelInfo } = {
|
|
|
|
|
+ id: "",
|
|
|
|
|
+ info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false, supportsImages: false },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private parseArn(arn: string, region?: string) {
|
|
|
|
|
+ /*
|
|
|
|
|
+ * VIA Roo analysis: platform-independent Regex. It's designed to parse AWS Bedrock ARNs and doesn't rely on any platform-specific features
|
|
|
|
|
+ * like file path separators, line endings, or case sensitivity behaviors. The forward slashes in the regex are properly escaped and
|
|
|
|
|
+ * represent literal characters in the AWS ARN format, not filesystem paths. This regex will function consistently across Windows,
|
|
|
|
|
+ * macOS, Linux, and any other operating system where JavaScript runs.
|
|
|
|
|
+ *
|
|
|
|
|
+ * This matches ARNs like:
|
|
|
|
|
+ * - Foundation Model: arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-v2
|
|
|
|
|
+ * - Prompt Router: arn:aws:bedrock:us-west-2:123456789012:prompt-router/anthropic-claude
|
|
|
|
|
+ * - Inference Profile: arn:aws:bedrock:us-west-2:123456789012:inference-profile/anthropic.claude-v2
|
|
|
|
|
+ * - Cross Region Inference Profile: arn:aws:bedrock:us-west-2:123456789012:inference-profile/us.anthropic.claude-3-5-sonnet-20241022-v2:0
|
|
|
|
|
+ * - Custom Model (Provisioned Throughput): arn:aws:bedrock:us-west-2:123456789012:provisioned-model/my-custom-model
|
|
|
|
|
+ * - Imported Model: arn:aws:bedrock:us-west-2:123456789012:imported-model/my-imported-model
|
|
|
|
|
+ *
|
|
|
|
|
+ * match[0] - The entire matched string
|
|
|
|
|
+ * match[1] - The region (e.g., "us-east-1")
|
|
|
|
|
+ * match[2] - The account ID (can be empty string for AWS-managed resources)
|
|
|
|
|
+ * match[3] - The resource type (e.g., "foundation-model")
|
|
|
|
|
+ * match[4] - The resource ID (e.g., "anthropic.claude-3-sonnet-20240229-v1:0")
|
|
|
|
|
+ */
|
|
|
|
|
+
|
|
|
|
|
+ const arnRegex = /^arn:aws:bedrock:([^:]+):([^:]*):(?:([^\/]+)\/(.+)|([^\/]+))$/
|
|
|
|
|
+ let match = arn.match(arnRegex)
|
|
|
|
|
+
|
|
|
|
|
+ if (match && match[1] && match[3] && match[4]) {
|
|
|
|
|
+ // Create the result object
|
|
|
|
|
+ const result: {
|
|
|
|
|
+ isValid: boolean
|
|
|
|
|
+ region?: string
|
|
|
|
|
+ modelType?: string
|
|
|
|
|
+ modelId?: string
|
|
|
|
|
+ errorMessage?: string
|
|
|
|
|
+ crossRegionInference: boolean
|
|
|
|
|
+ } = {
|
|
|
|
|
+ isValid: true,
|
|
|
|
|
+ crossRegionInference: false, // Default to false
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ result.modelType = match[3]
|
|
|
|
|
+ const originalModelId = match[4]
|
|
|
|
|
+ result.modelId = this.parseBaseModelId(originalModelId)
|
|
|
|
|
+
|
|
|
|
|
+ // Extract the region from the first capture group
|
|
|
|
|
+ const arnRegion = match[1]
|
|
|
|
|
+ result.region = arnRegion
|
|
|
|
|
+
|
|
|
|
|
+ // Check if the original model ID had a region prefix
|
|
|
|
|
+ if (originalModelId && result.modelId !== originalModelId) {
|
|
|
|
|
+ // If the model ID changed after parsing, it had a region prefix
|
|
|
|
|
+ let prefix = originalModelId.replace(result.modelId, "")
|
|
|
|
|
+ result.crossRegionInference = AwsBedrockHandler.prefixIsMultiRegion(prefix)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check if region in ARN matches provided region (if specified)
|
|
|
|
|
+ if (region && arnRegion !== region) {
|
|
|
|
|
+ result.errorMessage = `Region mismatch: The region in your ARN (${arnRegion}) does not match your selected region (${region}). This may cause access issues. The provider will use the region from the ARN.`
|
|
|
|
|
+ result.region = arnRegion
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return result
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // If we get here, the regex didn't match
|
|
|
|
|
+ return {
|
|
|
|
|
+ isValid: false,
|
|
|
|
|
+ region: undefined,
|
|
|
|
|
+ modelType: undefined,
|
|
|
|
|
+ modelId: undefined,
|
|
|
|
|
+ errorMessage: "Invalid ARN format. ARN should follow the AWS Bedrock ARN pattern.",
|
|
|
|
|
+ crossRegionInference: false,
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- //Prompt Router responses come back in a different sequence and the yield calls are not resulting in costs getting updated
|
|
|
|
|
- getModelByName(modelName: string): { id: BedrockModelId | string; info: ModelInfo } {
|
|
|
|
|
|
|
+ //This strips any region prefix that used on cross-region model inference ARNs
|
|
|
|
|
+ private parseBaseModelId(modelId: string) {
|
|
|
|
|
+ if (!modelId) {
|
|
|
|
|
+ return modelId
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const knownRegionPrefixes = AwsBedrockHandler.getPrefixList()
|
|
|
|
|
+
|
|
|
|
|
+ // Find if the model ID starts with any known region prefix
|
|
|
|
|
+ const matchedPrefix = knownRegionPrefixes.find((prefix) => modelId.startsWith(prefix))
|
|
|
|
|
+
|
|
|
|
|
+ if (matchedPrefix) {
|
|
|
|
|
+ // Remove the region prefix from the model ID
|
|
|
|
|
+ return modelId.substring(matchedPrefix.length)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // If no known prefix was found, check for a generic pattern
|
|
|
|
|
+ // Look for a pattern where the first segment before a dot doesn't contain dots or colons
|
|
|
|
|
+ // and the remaining parts still contain at least one dot
|
|
|
|
|
+ const genericPrefixMatch = modelId.match(/^([^.:]+)\.(.+\..+)$/)
|
|
|
|
|
+ if (genericPrefixMatch) {
|
|
|
|
|
+ const genericPrefix = genericPrefixMatch[1] + "."
|
|
|
|
|
+ return genericPrefixMatch[2]
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return modelId
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ //Prompt Router responses come back in a different sequence and the model used is in the response and must be fetched by name
|
|
|
|
|
+ getModelById(modelId: string, modelType?: string): { id: BedrockModelId | string; info: SharedModelInfo } {
|
|
|
// Try to find the model in bedrockModels
|
|
// Try to find the model in bedrockModels
|
|
|
- if (modelName in bedrockModels) {
|
|
|
|
|
- const id = modelName as BedrockModelId
|
|
|
|
|
|
|
+ const baseModelId = this.parseBaseModelId(modelId) as BedrockModelId
|
|
|
|
|
|
|
|
|
|
+ let model
|
|
|
|
|
+ if (baseModelId in bedrockModels) {
|
|
|
//Do a deep copy of the model info so that later in the code the model id and maxTokens can be set.
|
|
//Do a deep copy of the model info so that later in the code the model id and maxTokens can be set.
|
|
|
// The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value
|
|
// The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value
|
|
|
// in a prompt router response isn't possible on the constant.
|
|
// in a prompt router response isn't possible on the constant.
|
|
|
- let model = JSON.parse(JSON.stringify(bedrockModels[id]))
|
|
|
|
|
-
|
|
|
|
|
- // If modelMaxTokens is explicitly set in options, override the default
|
|
|
|
|
- if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
|
|
|
|
|
- model.maxTokens = this.options.modelMaxTokens
|
|
|
|
|
|
|
+ model = { id: baseModelId, info: JSON.parse(JSON.stringify(bedrockModels[baseModelId])) }
|
|
|
|
|
+ } else if (modelType && modelType.includes("router")) {
|
|
|
|
|
+ model = {
|
|
|
|
|
+ id: bedrockDefaultPromptRouterModelId,
|
|
|
|
|
+ info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])),
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ model = {
|
|
|
|
|
+ id: bedrockDefaultModelId,
|
|
|
|
|
+ info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- return { id, info: model }
|
|
|
|
|
|
|
+ // If modelMaxTokens is explicitly set in options, override the default
|
|
|
|
|
+ if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
|
|
|
|
|
+ model.info.maxTokens = this.options.modelMaxTokens
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
|
|
|
|
|
|
|
+ return model
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
|
|
|
|
- if (this.costModelConfig.id.trim().length > 0) {
|
|
|
|
|
|
|
+ override getModel(): { id: BedrockModelId | string; info: SharedModelInfo } {
|
|
|
|
|
+ if (this.costModelConfig?.id?.trim().length > 0) {
|
|
|
return this.costModelConfig
|
|
return this.costModelConfig
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ let modelConfig = undefined
|
|
|
|
|
+
|
|
|
// If custom ARN is provided, use it
|
|
// If custom ARN is provided, use it
|
|
|
if (this.options.awsCustomArn) {
|
|
if (this.options.awsCustomArn) {
|
|
|
- // Extract the model name from the ARN using platform-independent regex
|
|
|
|
|
- const arnMatch = this.options.awsCustomArn.match(
|
|
|
|
|
- new RegExp(
|
|
|
|
|
- "^arn:aws:bedrock:([^:]+):(\\d+):(inference-profile|foundation-model|provisioned-model|default-prompt-router|prompt-router)[/\\\\](.+)$",
|
|
|
|
|
- ),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ modelConfig = this.getModelById(this.arnInfo.modelId, this.arnInfo.modelType)
|
|
|
|
|
|
|
|
- let modelName = arnMatch ? arnMatch[4] : ""
|
|
|
|
|
- if (modelName) {
|
|
|
|
|
- let region = modelName.slice(0, 3)
|
|
|
|
|
- // Check for all region prefixes (us., eu., and apac prefix)
|
|
|
|
|
- if (region === "us." || region === "eu.") modelName = modelName.slice(3)
|
|
|
|
|
- else if (modelName.startsWith("apac.")) modelName = modelName.slice(5)
|
|
|
|
|
|
|
+ //If the user entered an ARN for a foundation-model they've done the same thing as picking from our list of options.
|
|
|
|
|
+ //We leave the model data matching the same as if a drop-down input method was used by not overwriting the model ID with the user input ARN
|
|
|
|
|
+ //Otherwise the ARN is not a foundation-model resource type that ARN should be used as the identifier in Bedrock interactions
|
|
|
|
|
+ if (this.arnInfo.modelType !== "foundation-model") modelConfig.id = this.options.awsCustomArn
|
|
|
|
|
+ } else {
|
|
|
|
|
+ //a model was selected from the drop down
|
|
|
|
|
+ modelConfig = this.getModelById(this.options.apiModelId as string)
|
|
|
|
|
|
|
|
- let modelData = this.getModelByName(modelName)
|
|
|
|
|
- modelData.id = this.options.awsCustomArn
|
|
|
|
|
|
|
+ if (this.options.awsUseCrossRegionInference) {
|
|
|
|
|
+ // Get the current region
|
|
|
|
|
+ const region = this.options.awsRegion || ""
|
|
|
|
|
+ // Use the helper method to get the appropriate prefix for this region
|
|
|
|
|
+ const prefix = AwsBedrockHandler.getPrefixForRegion(region)
|
|
|
|
|
|
|
|
- if (modelData) {
|
|
|
|
|
- return modelData
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Apply the prefix if one was found, otherwise use the model ID as is
|
|
|
|
|
+ modelConfig.id = prefix ? `${prefix}${modelConfig.id}` : modelConfig.id
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // An ARN was used, but no model info match found, use default values based on common patterns
|
|
|
|
|
- let model = this.getModelByName(bedrockDefaultPromptRouterModelId)
|
|
|
|
|
|
|
+ modelConfig.info.maxTokens = modelConfig.info.maxTokens || BEDROCK_MAX_TOKENS
|
|
|
|
|
|
|
|
- // For custom ARNs, always return the specific values expected by tests
|
|
|
|
|
- return {
|
|
|
|
|
- id: this.options.awsCustomArn,
|
|
|
|
|
- info: model.info,
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ return modelConfig as { id: BedrockModelId | string; info: SharedModelInfo }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /************************************************************************************
|
|
|
|
|
+ *
|
|
|
|
|
+ * CACHE
|
|
|
|
|
+ *
|
|
|
|
|
+ *************************************************************************************/
|
|
|
|
|
+
|
|
|
|
|
+ // Store previous cache point placements for maintaining consistency across consecutive messages
|
|
|
|
|
+ private previousCachePointPlacements: { [conversationId: string]: any[] } = {}
|
|
|
|
|
+
|
|
|
|
|
+ private supportsAwsPromptCache(modelConfig: {
|
|
|
|
|
+ id: BedrockModelId | string
|
|
|
|
|
+ info: SharedModelInfo
|
|
|
|
|
+ }): boolean | undefined {
|
|
|
|
|
+ // Check if the model supports prompt cache
|
|
|
|
|
+ // The cachableFields property is not part of the ModelInfo type in schemas
|
|
|
|
|
+ // but it's used in the bedrockModels object in shared/api.ts
|
|
|
|
|
+ return (
|
|
|
|
|
+ modelConfig?.info?.supportsPromptCache &&
|
|
|
|
|
+ // Use optional chaining and type assertion to access cachableFields
|
|
|
|
|
+ (modelConfig?.info as any)?.cachableFields &&
|
|
|
|
|
+ (modelConfig?.info as any)?.cachableFields?.length > 0
|
|
|
|
|
+ )
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * Removes any existing cachePoint nodes from content blocks
|
|
|
|
|
+ */
|
|
|
|
|
+ private removeCachePoints(content: any): any {
|
|
|
|
|
+ if (Array.isArray(content)) {
|
|
|
|
|
+ return content.map((block) => {
|
|
|
|
|
+ // Use destructuring to remove cachePoint property
|
|
|
|
|
+ const { cachePoint, ...rest } = block
|
|
|
|
|
+ return rest
|
|
|
|
|
+ })
|
|
|
}
|
|
}
|
|
|
|
|
+ return content
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if (this.options.apiModelId) {
|
|
|
|
|
- // Special case for custom ARN option
|
|
|
|
|
- if (this.options.apiModelId === "custom-arn") {
|
|
|
|
|
- // This should not happen as we should have awsCustomArn set
|
|
|
|
|
- // but just in case, return a default model
|
|
|
|
|
- return this.getModelByName(bedrockDefaultModelId)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ /************************************************************************************
|
|
|
|
|
+ *
|
|
|
|
|
+ * AWS REGIONS
|
|
|
|
|
+ *
|
|
|
|
|
+ *************************************************************************************/
|
|
|
|
|
|
|
|
- // For tests, allow any model ID (but not custom ARNs, which are handled above)
|
|
|
|
|
- if (process.env.NODE_ENV === "test") {
|
|
|
|
|
- return {
|
|
|
|
|
- id: this.options.apiModelId,
|
|
|
|
|
- info: {
|
|
|
|
|
- maxTokens: 5000,
|
|
|
|
|
- contextWindow: 128_000,
|
|
|
|
|
- supportsPromptCache: false,
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ private static getPrefixList(): string[] {
|
|
|
|
|
+ return Object.keys(AWS_BEDROCK_REGION_INFO)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private static getPrefixForRegion(region: string): string | undefined {
|
|
|
|
|
+ for (const [prefix, info] of Object.entries(AWS_BEDROCK_REGION_INFO)) {
|
|
|
|
|
+ if (info.pattern && region.startsWith(info.pattern)) {
|
|
|
|
|
+ return prefix
|
|
|
}
|
|
}
|
|
|
- // For production, validate against known models
|
|
|
|
|
- return this.getModelByName(this.options.apiModelId)
|
|
|
|
|
}
|
|
}
|
|
|
- return this.getModelByName(bedrockDefaultModelId)
|
|
|
|
|
|
|
+ return undefined
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- async completePrompt(prompt: string): Promise<string> {
|
|
|
|
|
- try {
|
|
|
|
|
- const modelConfig = this.getModel()
|
|
|
|
|
|
|
+ private static prefixIsMultiRegion(arnPrefix: string): boolean {
|
|
|
|
|
+ for (const [prefix, info] of Object.entries(AWS_BEDROCK_REGION_INFO)) {
|
|
|
|
|
+ if (arnPrefix === prefix) {
|
|
|
|
|
+ if (info?.multiRegion) return info.multiRegion
|
|
|
|
|
+ else return false
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Handle cross-region inference
|
|
|
|
|
- let modelId: string
|
|
|
|
|
|
|
+ /************************************************************************************
|
|
|
|
|
+ *
|
|
|
|
|
+ * ERROR HANDLING
|
|
|
|
|
+ *
|
|
|
|
|
+ *************************************************************************************/
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * Error type definitions for Bedrock API errors
|
|
|
|
|
+ */
|
|
|
|
|
+ private static readonly ERROR_TYPES: Record<
|
|
|
|
|
+ string,
|
|
|
|
|
+ {
|
|
|
|
|
+ patterns: string[] // Strings to match in lowercase error message or name
|
|
|
|
|
+ messageTemplate: string // Template with placeholders like {region}, {modelId}, etc.
|
|
|
|
|
+ logLevel: "error" | "warn" | "info" // Log level for this error type
|
|
|
|
|
+ }
|
|
|
|
|
+ > = {
|
|
|
|
|
+ ACCESS_DENIED: {
|
|
|
|
|
+ patterns: ["access", "denied", "permission"],
|
|
|
|
|
+ messageTemplate: `You don't have access to the model specified.
|
|
|
|
|
+
|
|
|
|
|
+Please verify:
|
|
|
|
|
+1. Try cross-region inference if you're using a foundation model
|
|
|
|
|
+2. If using an ARN, verify the ARN is correct and points to a valid model
|
|
|
|
|
+3. Your AWS credentials have permission to access this model (check IAM policies)
|
|
|
|
|
+4. The region in the ARN matches the region where the model is deployed
|
|
|
|
|
+5. If using a provisioned model, ensure it's active and not in a failed state`,
|
|
|
|
|
+ logLevel: "error",
|
|
|
|
|
+ },
|
|
|
|
|
+ NOT_FOUND: {
|
|
|
|
|
+ patterns: ["not found", "does not exist"],
|
|
|
|
|
+ messageTemplate: `The specified ARN does not exist or is invalid. Please check:
|
|
|
|
|
|
|
|
- // For custom ARNs, use the ARN directly without modification
|
|
|
|
|
- if (this.options.awsCustomArn) {
|
|
|
|
|
- modelId = modelConfig.id
|
|
|
|
|
|
|
+1. The ARN format is correct (arn:aws:bedrock:region:account-id:resource-type/resource-name)
|
|
|
|
|
+2. The model exists in the specified region
|
|
|
|
|
+3. The account ID in the ARN is correct`,
|
|
|
|
|
+ logLevel: "error",
|
|
|
|
|
+ },
|
|
|
|
|
+ THROTTLING: {
|
|
|
|
|
+ patterns: ["throttl", "rate", "limit"],
|
|
|
|
|
+ messageTemplate: `Request was throttled or rate limited. Please try:
|
|
|
|
|
+1. Reducing the frequency of requests
|
|
|
|
|
+2. If using a provisioned model, check its throughput settings
|
|
|
|
|
+3. Contact AWS support to request a quota increase if needed
|
|
|
|
|
+
|
|
|
|
|
+{formattedErrorDetails}
|
|
|
|
|
+
|
|
|
|
|
+`,
|
|
|
|
|
+ logLevel: "error",
|
|
|
|
|
+ },
|
|
|
|
|
+ TOO_MANY_TOKENS: {
|
|
|
|
|
+ patterns: ["too many tokens"],
|
|
|
|
|
+ messageTemplate: `"Too many tokens" error detected.
|
|
|
|
|
+Possible Causes:
|
|
|
|
|
+1. Input exceeds model's context window limit
|
|
|
|
|
+2. Rate limiting (too many tokens per minute)
|
|
|
|
|
+3. Quota exceeded for token usage
|
|
|
|
|
+4. Other token-related service limitations
|
|
|
|
|
+
|
|
|
|
|
+Suggestions:
|
|
|
|
|
+1. Reduce the size of your input
|
|
|
|
|
+2. Split your request into smaller chunks
|
|
|
|
|
+3. Use a model with a larger context window
|
|
|
|
|
+4. If rate limited, reduce request frequency
|
|
|
|
|
+5. Check your AWS Bedrock quotas and limits`,
|
|
|
|
|
+ logLevel: "error",
|
|
|
|
|
+ },
|
|
|
|
|
+ ON_DEMAND_NOT_SUPPORTED: {
|
|
|
|
|
+ patterns: ["with on-demand throughput isn’t supported."],
|
|
|
|
|
+ messageTemplate: `
|
|
|
|
|
+1. Try enabling cross-region inference in settings.
|
|
|
|
|
+2. Or, create an inference profile and then leverage the "Use custom ARN..." option of the model selector in settings.`,
|
|
|
|
|
+ logLevel: "error",
|
|
|
|
|
+ },
|
|
|
|
|
+ ABORT: {
|
|
|
|
|
+ patterns: ["aborterror"], // This will match error.name.toLowerCase() for AbortError
|
|
|
|
|
+ messageTemplate: `Request was aborted: The operation timed out or was manually cancelled. Please try again or check your network connection.`,
|
|
|
|
|
+ logLevel: "info",
|
|
|
|
|
+ },
|
|
|
|
|
+ INVALID_ARN_FORMAT: {
|
|
|
|
|
+ patterns: ["invalid_arn_format:", "invalid arn format"],
|
|
|
|
|
+ messageTemplate: `Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name`,
|
|
|
|
|
+ logLevel: "error",
|
|
|
|
|
+ },
|
|
|
|
|
+ // Default/generic error
|
|
|
|
|
+ GENERIC: {
|
|
|
|
|
+ patterns: [], // Empty patterns array means this is the default
|
|
|
|
|
+ messageTemplate: `Unknown Error`,
|
|
|
|
|
+ logLevel: "error",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Validate ARN format and check region match
|
|
|
|
|
- const clientRegion = this.client.config.region as string
|
|
|
|
|
- const validation = validateBedrockArn(modelId, clientRegion)
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * Determines the error type based on the error message or name
|
|
|
|
|
+ */
|
|
|
|
|
+ private getErrorType(error: unknown): string {
|
|
|
|
|
+ if (!(error instanceof Error)) {
|
|
|
|
|
+ return "GENERIC"
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if (!validation.isValid) {
|
|
|
|
|
- logger.error("Invalid ARN format in completePrompt", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- modelId,
|
|
|
|
|
- errorMessage: validation.errorMessage,
|
|
|
|
|
- })
|
|
|
|
|
- throw new Error(
|
|
|
|
|
- validation.errorMessage ||
|
|
|
|
|
- "Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name",
|
|
|
|
|
- )
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ const errorMessage = error.message.toLowerCase()
|
|
|
|
|
+ const errorName = error.name.toLowerCase()
|
|
|
|
|
|
|
|
- // Extract region from ARN
|
|
|
|
|
- const arnRegion = validation.arnRegion!
|
|
|
|
|
|
|
+ // Check each error type's patterns
|
|
|
|
|
+ for (const [errorType, definition] of Object.entries(AwsBedrockHandler.ERROR_TYPES)) {
|
|
|
|
|
+ if (errorType === "GENERIC") continue // Skip the generic type
|
|
|
|
|
|
|
|
- // Log warning if there's a region mismatch
|
|
|
|
|
- if (validation.errorMessage) {
|
|
|
|
|
- logger.warn(validation.errorMessage, {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- arnRegion,
|
|
|
|
|
- clientRegion,
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
- } else if (this.options.awsUseCrossRegionInference) {
|
|
|
|
|
- let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
|
|
|
|
|
- switch (regionPrefix) {
|
|
|
|
|
- case "us-":
|
|
|
|
|
- modelId = `us.${modelConfig.id}`
|
|
|
|
|
- break
|
|
|
|
|
- case "eu-":
|
|
|
|
|
- modelId = `eu.${modelConfig.id}`
|
|
|
|
|
- break
|
|
|
|
|
- case "ap-":
|
|
|
|
|
- modelId = `apac.${modelConfig.id}`
|
|
|
|
|
- break
|
|
|
|
|
- default:
|
|
|
|
|
- modelId = modelConfig.id
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- modelId = modelConfig.id
|
|
|
|
|
|
|
+ // If any pattern matches in either message or name, return this error type
|
|
|
|
|
+ if (definition.patterns.some((pattern) => errorMessage.includes(pattern) || errorName.includes(pattern))) {
|
|
|
|
|
+ return errorType
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- const payload = {
|
|
|
|
|
- modelId,
|
|
|
|
|
- messages: convertToBedrockConverseMessages([
|
|
|
|
|
- {
|
|
|
|
|
- role: "user",
|
|
|
|
|
- content: prompt,
|
|
|
|
|
- },
|
|
|
|
|
- ]),
|
|
|
|
|
- inferenceConfig: {
|
|
|
|
|
- maxTokens: modelConfig.info.maxTokens || 4096,
|
|
|
|
|
- temperature: this.options.modelTemperature ?? BEDROCK_DEFAULT_TEMPERATURE,
|
|
|
|
|
- topP: 0.1,
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Default to generic error
|
|
|
|
|
+ return "GENERIC"
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Log the payload for debugging custom ARN issues
|
|
|
|
|
- if (this.options.awsCustomArn) {
|
|
|
|
|
- logger.debug("Bedrock completePrompt request details", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- clientRegion: this.client.config.region,
|
|
|
|
|
- payload: JSON.stringify(payload, null, 2),
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * Formats an error message based on the error type and context
|
|
|
|
|
+ */
|
|
|
|
|
+ private formatErrorMessage(error: unknown, errorType: string, isStreamContext: boolean): string {
|
|
|
|
|
+ const definition = AwsBedrockHandler.ERROR_TYPES[errorType] || AwsBedrockHandler.ERROR_TYPES.GENERIC
|
|
|
|
|
+ let template = definition.messageTemplate
|
|
|
|
|
|
|
|
- const command = new ConverseCommand(payload)
|
|
|
|
|
- const response = await this.client.send(command)
|
|
|
|
|
|
|
+ // Prepare template variables
|
|
|
|
|
+ const templateVars: Record<string, string> = {}
|
|
|
|
|
|
|
|
- if (
|
|
|
|
|
- response?.output?.message?.content &&
|
|
|
|
|
- response.output.message.content.length > 0 &&
|
|
|
|
|
- response.output.message.content[0].text &&
|
|
|
|
|
- response.output.message.content[0].text.trim().length > 0
|
|
|
|
|
- ) {
|
|
|
|
|
- try {
|
|
|
|
|
- return response.output.message.content[0].text
|
|
|
|
|
- } catch (parseError) {
|
|
|
|
|
- logger.error("Failed to parse Bedrock response", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- error: parseError instanceof Error ? parseError : String(parseError),
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ if (error instanceof Error) {
|
|
|
|
|
+ templateVars.errorMessage = error.message
|
|
|
|
|
+ templateVars.errorName = error.name
|
|
|
|
|
+
|
|
|
|
|
+ const modelConfig = this.getModel()
|
|
|
|
|
+ templateVars.modelId = modelConfig.id
|
|
|
|
|
+ templateVars.contextWindow = String(modelConfig.info.contextWindow || "unknown")
|
|
|
|
|
+
|
|
|
|
|
+ // Format error details
|
|
|
|
|
+ const errorDetails: Record<string, any> = {}
|
|
|
|
|
+ Object.getOwnPropertyNames(error).forEach((prop) => {
|
|
|
|
|
+ if (prop !== "stack") {
|
|
|
|
|
+ errorDetails[prop] = (error as any)[prop]
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
- return ""
|
|
|
|
|
- } catch (error) {
|
|
|
|
|
- // Enhanced error handling for custom ARN issues
|
|
|
|
|
- if (this.options.awsCustomArn) {
|
|
|
|
|
- logger.error("Error occurred with custom ARN in completePrompt", {
|
|
|
|
|
- ctx: "bedrock",
|
|
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
|
|
- error: error instanceof Error ? error : String(error),
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ })
|
|
|
|
|
|
|
|
- if (error instanceof Error) {
|
|
|
|
|
- const errorMessage = error.message.toLowerCase()
|
|
|
|
|
-
|
|
|
|
|
- // Access denied errors
|
|
|
|
|
- if (
|
|
|
|
|
- errorMessage.includes("access") &&
|
|
|
|
|
- (errorMessage.includes("model") || errorMessage.includes("denied"))
|
|
|
|
|
- ) {
|
|
|
|
|
- throw new Error(
|
|
|
|
|
- `Bedrock custom ARN error: You don't have access to the model with the specified ARN. Please verify:
|
|
|
|
|
-1. The ARN is correct and points to a valid model
|
|
|
|
|
-2. Your AWS credentials have permission to access this model (check IAM policies)
|
|
|
|
|
-3. The region in the ARN matches the region where the model is deployed
|
|
|
|
|
-4. If using a provisioned model, ensure it's active and not in a failed state`,
|
|
|
|
|
- )
|
|
|
|
|
- }
|
|
|
|
|
- // Model not found errors
|
|
|
|
|
- else if (errorMessage.includes("not found") || errorMessage.includes("does not exist")) {
|
|
|
|
|
- throw new Error(
|
|
|
|
|
- `Bedrock custom ARN error: The specified ARN does not exist or is invalid. Please check:
|
|
|
|
|
-1. The ARN format is correct (arn:aws:bedrock:region:account-id:resource-type/resource-name)
|
|
|
|
|
-2. The model exists in the specified region
|
|
|
|
|
-3. The account ID in the ARN is correct
|
|
|
|
|
-4. The resource type is one of: foundation-model, provisioned-model, or default-prompt-router`,
|
|
|
|
|
- )
|
|
|
|
|
- }
|
|
|
|
|
- // Throttling errors
|
|
|
|
|
- else if (
|
|
|
|
|
- errorMessage.includes("throttl") ||
|
|
|
|
|
- errorMessage.includes("rate") ||
|
|
|
|
|
- errorMessage.includes("limit")
|
|
|
|
|
- ) {
|
|
|
|
|
- throw new Error(
|
|
|
|
|
- `Bedrock custom ARN error: Request was throttled or rate limited. Please try:
|
|
|
|
|
-1. Reducing the frequency of requests
|
|
|
|
|
-2. If using a provisioned model, check its throughput settings
|
|
|
|
|
-3. Contact AWS support to request a quota increase if needed`,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ // Safely stringify error details to avoid circular references
|
|
|
|
|
+ templateVars.formattedErrorDetails = Object.entries(errorDetails)
|
|
|
|
|
+ .map(([key, value]) => {
|
|
|
|
|
+ let valueStr
|
|
|
|
|
+ if (typeof value === "object" && value !== null) {
|
|
|
|
|
+ try {
|
|
|
|
|
+ // Use a replacer function to handle circular references
|
|
|
|
|
+ valueStr = JSON.stringify(value, (k, v) => {
|
|
|
|
|
+ if (k && typeof v === "object" && v !== null) {
|
|
|
|
|
+ return "[Object]"
|
|
|
|
|
+ }
|
|
|
|
|
+ return v
|
|
|
|
|
+ })
|
|
|
|
|
+ } catch (e) {
|
|
|
|
|
+ valueStr = "[Complex Object]"
|
|
|
|
|
+ }
|
|
|
} else {
|
|
} else {
|
|
|
- throw new Error(`Bedrock custom ARN error: ${error.message}`)
|
|
|
|
|
|
|
+ valueStr = String(value)
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ return `- ${key}: ${valueStr}`
|
|
|
|
|
+ })
|
|
|
|
|
+ .join("\n")
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Standard error handling
|
|
|
|
|
- if (error instanceof Error) {
|
|
|
|
|
- throw new Error(`Bedrock completion error: ${error.message}`)
|
|
|
|
|
- }
|
|
|
|
|
- throw error
|
|
|
|
|
|
|
+ // Add context-specific template variables
|
|
|
|
|
+ const region =
|
|
|
|
|
+ typeof this?.client?.config?.region === "function"
|
|
|
|
|
+ ? this?.client?.config?.region()
|
|
|
|
|
+ : this?.client?.config?.region
|
|
|
|
|
+ templateVars.regionInfo = `(${region})`
|
|
|
|
|
+
|
|
|
|
|
+ // Replace template variables
|
|
|
|
|
+ for (const [key, value] of Object.entries(templateVars)) {
|
|
|
|
|
+ template = template.replace(new RegExp(`{${key}}`, "g"), value || "")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return template
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * Handles Bedrock API errors and generates appropriate error messages
|
|
|
|
|
+ * @param error The error that occurred
|
|
|
|
|
+ * @param isStreamContext Whether the error occurred in a streaming context (true) or not (false)
|
|
|
|
|
+ * @returns Error message string for non-streaming context or array of stream chunks for streaming context
|
|
|
|
|
+ */
|
|
|
|
|
+ private handleBedrockError(
|
|
|
|
|
+ error: unknown,
|
|
|
|
|
+ isStreamContext: boolean,
|
|
|
|
|
+ ): string | Array<{ type: string; text?: string; inputTokens?: number; outputTokens?: number }> {
|
|
|
|
|
+ // Determine error type
|
|
|
|
|
+ const errorType = this.getErrorType(error)
|
|
|
|
|
+
|
|
|
|
|
+ // Format error message
|
|
|
|
|
+ const errorMessage = this.formatErrorMessage(error, errorType, isStreamContext)
|
|
|
|
|
+
|
|
|
|
|
+ // Log the error
|
|
|
|
|
+ const definition = AwsBedrockHandler.ERROR_TYPES[errorType]
|
|
|
|
|
+ const logMethod = definition.logLevel
|
|
|
|
|
+ const contextName = isStreamContext ? "createMessage" : "completePrompt"
|
|
|
|
|
+ logger[logMethod](`${errorType} error in ${contextName}`, {
|
|
|
|
|
+ ctx: "bedrock",
|
|
|
|
|
+ customArn: this.options.awsCustomArn,
|
|
|
|
|
+ errorType,
|
|
|
|
|
+ errorMessage: error instanceof Error ? error.message : String(error),
|
|
|
|
|
+ ...(error instanceof Error && error.stack ? { errorStack: error.stack } : {}),
|
|
|
|
|
+ ...(this.client?.config?.region ? { clientRegion: this.client.config.region } : {}),
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ // Return appropriate response based on isStreamContext
|
|
|
|
|
+ if (isStreamContext) {
|
|
|
|
|
+ return [
|
|
|
|
|
+ { type: "text", text: `Error: ${errorMessage}` },
|
|
|
|
|
+ { type: "usage", inputTokens: 0, outputTokens: 0 },
|
|
|
|
|
+ ]
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // For non-streaming context, add the expected prefix
|
|
|
|
|
+ return `Bedrock completion error: ${errorMessage}`
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|