|
|
@@ -3,11 +3,19 @@ import {
|
|
|
ConverseStreamCommand,
|
|
|
ConverseCommand,
|
|
|
BedrockRuntimeClientConfig,
|
|
|
+ ConverseStreamCommandOutput,
|
|
|
} from "@aws-sdk/client-bedrock-runtime"
|
|
|
import { fromIni } from "@aws-sdk/credential-providers"
|
|
|
import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
import { SingleCompletionHandler } from "../"
|
|
|
-import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
|
|
+import {
|
|
|
+ ApiHandlerOptions,
|
|
|
+ BedrockModelId,
|
|
|
+ ModelInfo,
|
|
|
+ bedrockDefaultModelId,
|
|
|
+ bedrockModels,
|
|
|
+ bedrockDefaultPromptRouterModelId,
|
|
|
+} from "../../shared/api"
|
|
|
import { ApiStream } from "../transform/stream"
|
|
|
import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
|
|
|
import { BaseProvider } from "./base-provider"
|
|
|
@@ -21,7 +29,8 @@ import { logger } from "../../utils/logging"
|
|
|
*/
|
|
|
function validateBedrockArn(arn: string, region?: string) {
|
|
|
// Validate ARN format
|
|
|
- const arnRegex = /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router)\/(.+)$/
|
|
|
+ const arnRegex =
|
|
|
+ /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router|prompt-router)\/(.+)$/
|
|
|
const match = arn.match(arnRegex)
|
|
|
|
|
|
if (!match) {
|
|
|
@@ -86,12 +95,27 @@ export interface StreamEvent {
|
|
|
latencyMs: number
|
|
|
}
|
|
|
}
|
|
|
+ trace?: {
|
|
|
+ promptRouter?: {
|
|
|
+ invokedModelId?: string
|
|
|
+ usage?: {
|
|
|
+ inputTokens: number
|
|
|
+ outputTokens: number
|
|
|
+ totalTokens?: number // Made optional since we don't use it
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler {
|
|
|
protected options: ApiHandlerOptions
|
|
|
private client: BedrockRuntimeClient
|
|
|
|
|
|
+ private costModelConfig: { id: BedrockModelId | string; info: ModelInfo } = {
|
|
|
+ id: "",
|
|
|
+ info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false, supportsImages: false },
|
|
|
+ }
|
|
|
+
|
|
|
constructor(options: ApiHandlerOptions) {
|
|
|
super()
|
|
|
this.options = options
|
|
|
@@ -141,8 +165,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
}
|
|
|
|
|
|
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
|
|
- const modelConfig = this.getModel()
|
|
|
-
|
|
|
+ let modelConfig = this.getModel()
|
|
|
// Handle cross-region inference
|
|
|
let modelId: string
|
|
|
|
|
|
@@ -250,8 +273,8 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- // Handle metadata events first
|
|
|
- if (streamEvent.metadata?.usage) {
|
|
|
+ // Handle metadata events first.
|
|
|
+ if (streamEvent?.metadata?.usage) {
|
|
|
yield {
|
|
|
type: "usage",
|
|
|
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
|
|
@@ -260,6 +283,37 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ if (streamEvent?.trace?.promptRouter?.invokedModelId) {
|
|
|
+ try {
|
|
|
+ const invokedModelId = streamEvent.trace.promptRouter.invokedModelId
|
|
|
+ const modelMatch = invokedModelId.match(/\/([^\/]+)(?::|$)/)
|
|
|
+ if (modelMatch && modelMatch[1]) {
|
|
|
+ let modelName = modelMatch[1]
|
|
|
+
|
|
|
+ // Get a new modelConfig from getModel() using invokedModelId.. remove the region first
|
|
|
+ let region = modelName.slice(0, 3)
|
|
|
+
|
|
|
+ if (region === "us." || region === "eu.") modelName = modelName.slice(3)
|
|
|
+ this.costModelConfig = this.getModelByName(modelName)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Handle metadata events for the promptRouter.
|
|
|
+ if (streamEvent?.trace?.promptRouter?.usage) {
|
|
|
+ yield {
|
|
|
+ type: "usage",
|
|
|
+ inputTokens: streamEvent?.trace?.promptRouter?.usage?.inputTokens || 0,
|
|
|
+ outputTokens: streamEvent?.trace?.promptRouter?.usage?.outputTokens || 0,
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ } catch (error) {
|
|
|
+ logger.error("Error handling Bedrock invokedModelId", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ error: error instanceof Error ? error : String(error),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// Handle message start
|
|
|
if (streamEvent.messageStart) {
|
|
|
continue
|
|
|
@@ -282,7 +336,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
}
|
|
|
continue
|
|
|
}
|
|
|
-
|
|
|
// Handle message stop
|
|
|
if (streamEvent.messageStop) {
|
|
|
continue
|
|
|
@@ -428,122 +481,75 @@ Please check:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ //Prompt Router responses come back in a different sequence and the yield calls are not resulting in costs getting updated
|
|
|
+ getModelByName(modelName: string): { id: BedrockModelId | string; info: ModelInfo } {
|
|
|
+ // Try to find the model in bedrockModels
|
|
|
+ if (modelName in bedrockModels) {
|
|
|
+ const id = modelName as BedrockModelId
|
|
|
+
|
|
|
+ //Do a deep copy of the model info so that later in the code the model id and maxTokens can be set.
|
|
|
+ // The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value
|
|
|
+ // in a prompt router response isn't possible on the constant.
|
|
|
+ let model = JSON.parse(JSON.stringify(bedrockModels[id]))
|
|
|
+
|
|
|
+ // If modelMaxTokens is explicitly set in options, override the default
|
|
|
+ if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
|
|
|
+ model.maxTokens = this.options.modelMaxTokens
|
|
|
+ }
|
|
|
+
|
|
|
+ return { id, info: model }
|
|
|
+ }
|
|
|
+
|
|
|
+ return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
|
|
|
+ }
|
|
|
+
|
|
|
override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
|
|
+ if (this.costModelConfig.id.trim().length > 0) {
|
|
|
+ return this.costModelConfig
|
|
|
+ }
|
|
|
+
|
|
|
// If custom ARN is provided, use it
|
|
|
if (this.options.awsCustomArn) {
|
|
|
- // Custom ARNs should not be modified with region prefixes
|
|
|
- // as they already contain the full resource path
|
|
|
-
|
|
|
- // Check if the ARN contains information about the model type
|
|
|
- // This helps set appropriate token limits for models behind prompt routers
|
|
|
- const arnLower = this.options.awsCustomArn.toLowerCase()
|
|
|
-
|
|
|
- // Determine model info based on ARN content
|
|
|
- let modelInfo: ModelInfo
|
|
|
-
|
|
|
- if (arnLower.includes("claude-3-7-sonnet") || arnLower.includes("claude-3.7-sonnet")) {
|
|
|
- // Claude 3.7 Sonnet has 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- supportsComputerUse: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-5-sonnet") || arnLower.includes("claude-3.5-sonnet")) {
|
|
|
- // Claude 3.5 Sonnet has 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- supportsComputerUse: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-opus") || arnLower.includes("claude-3.0-opus")) {
|
|
|
- // Claude 3 Opus has 4096 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-haiku") || arnLower.includes("claude-3.0-haiku")) {
|
|
|
- // Claude 3 Haiku has 4096 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-5-haiku") || arnLower.includes("claude-3.5-haiku")) {
|
|
|
- // Claude 3.5 Haiku has 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: false,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude")) {
|
|
|
- // Generic Claude model with conservative token limit
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 128_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("llama3") || arnLower.includes("llama-3")) {
|
|
|
- // Llama 3 models typically have 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 128_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: arnLower.includes("90b") || arnLower.includes("11b"),
|
|
|
- }
|
|
|
- } else if (arnLower.includes("nova-pro")) {
|
|
|
- // Amazon Nova Pro
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 5000,
|
|
|
- contextWindow: 300_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Default for unknown models or prompt routers
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 128_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
+ // Extract the model name from the ARN
|
|
|
+ const arnMatch = this.options.awsCustomArn.match(
|
|
|
+ /^arn:aws:bedrock:([^:]+):(\d+):(inference-profile|foundation-model|provisioned-model)\/(.+)$/,
|
|
|
+ )
|
|
|
+
|
|
|
+ let modelName = arnMatch ? arnMatch[4] : ""
|
|
|
+ if (modelName) {
|
|
|
+ let region = modelName.slice(0, 3)
|
|
|
+ if (region === "us." || region === "eu.") modelName = modelName.slice(3)
|
|
|
+
|
|
|
+ let modelData = this.getModelByName(modelName)
|
|
|
+ modelData.id = this.options.awsCustomArn
|
|
|
+
|
|
|
+ if (modelData) {
|
|
|
+ return modelData
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // If modelMaxTokens is explicitly set in options, override the default
|
|
|
- if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
|
|
|
- modelInfo.maxTokens = this.options.modelMaxTokens
|
|
|
- }
|
|
|
+ // An ARN was used, but no model info match found, use default values based on common patterns
|
|
|
+ let model = this.getModelByName(bedrockDefaultPromptRouterModelId)
|
|
|
|
|
|
+ // For custom ARNs, always return the specific values expected by tests
|
|
|
return {
|
|
|
id: this.options.awsCustomArn,
|
|
|
- info: modelInfo,
|
|
|
+ info: model.info,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- const modelId = this.options.apiModelId
|
|
|
- if (modelId) {
|
|
|
+ if (this.options.apiModelId) {
|
|
|
// Special case for custom ARN option
|
|
|
- if (modelId === "custom-arn") {
|
|
|
+ if (this.options.apiModelId === "custom-arn") {
|
|
|
// This should not happen as we should have awsCustomArn set
|
|
|
// but just in case, return a default model
|
|
|
- return {
|
|
|
- id: bedrockDefaultModelId,
|
|
|
- info: bedrockModels[bedrockDefaultModelId],
|
|
|
- }
|
|
|
+ return this.getModelByName(bedrockDefaultModelId)
|
|
|
}
|
|
|
|
|
|
- // For tests, allow any model ID
|
|
|
+ // For tests, allow any model ID (but not custom ARNs, which are handled above)
|
|
|
if (process.env.NODE_ENV === "test") {
|
|
|
return {
|
|
|
- id: modelId,
|
|
|
+ id: this.options.apiModelId,
|
|
|
info: {
|
|
|
maxTokens: 5000,
|
|
|
contextWindow: 128_000,
|
|
|
@@ -552,15 +558,9 @@ Please check:
|
|
|
}
|
|
|
}
|
|
|
// For production, validate against known models
|
|
|
- if (modelId in bedrockModels) {
|
|
|
- const id = modelId as BedrockModelId
|
|
|
- return { id, info: bedrockModels[id] }
|
|
|
- }
|
|
|
- }
|
|
|
- return {
|
|
|
- id: bedrockDefaultModelId,
|
|
|
- info: bedrockModels[bedrockDefaultModelId],
|
|
|
+ return this.getModelByName(this.options.apiModelId)
|
|
|
}
|
|
|
+ return this.getModelByName(bedrockDefaultModelId)
|
|
|
}
|
|
|
|
|
|
async completePrompt(prompt: string): Promise<string> {
|
|
|
@@ -573,10 +573,6 @@ Please check:
|
|
|
// For custom ARNs, use the ARN directly without modification
|
|
|
if (this.options.awsCustomArn) {
|
|
|
modelId = modelConfig.id
|
|
|
- logger.debug("Using custom ARN in completePrompt", {
|
|
|
- ctx: "bedrock",
|
|
|
- customArn: this.options.awsCustomArn,
|
|
|
- })
|
|
|
|
|
|
// Validate ARN format and check region match
|
|
|
const clientRegion = this.client.config.region as string
|