|
|
@@ -3,11 +3,19 @@ import {
|
|
|
ConverseStreamCommand,
|
|
|
ConverseCommand,
|
|
|
BedrockRuntimeClientConfig,
|
|
|
+ ConverseStreamCommandOutput,
|
|
|
} from "@aws-sdk/client-bedrock-runtime"
|
|
|
import { fromIni } from "@aws-sdk/credential-providers"
|
|
|
import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
import { SingleCompletionHandler } from "../"
|
|
|
-import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
|
|
|
+import {
|
|
|
+ ApiHandlerOptions,
|
|
|
+ BedrockModelId,
|
|
|
+ ModelInfo,
|
|
|
+ bedrockDefaultModelId,
|
|
|
+ bedrockModels,
|
|
|
+ bedrockDefaultPromptRouterModelId,
|
|
|
+} from "../../shared/api"
|
|
|
import { ApiStream } from "../transform/stream"
|
|
|
import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
|
|
|
import { BaseProvider } from "./base-provider"
|
|
|
@@ -86,12 +94,27 @@ export interface StreamEvent {
|
|
|
latencyMs: number
|
|
|
}
|
|
|
}
|
|
|
+ trace?: {
|
|
|
+ promptRouter?: {
|
|
|
+ invokedModelId?: string
|
|
|
+ usage?: {
|
|
|
+ inputTokens: number
|
|
|
+ outputTokens: number
|
|
|
+ totalTokens?: number // Made optional since we don't use it
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler {
|
|
|
protected options: ApiHandlerOptions
|
|
|
private client: BedrockRuntimeClient
|
|
|
|
|
|
+ private costModelConfig: { id: BedrockModelId | string; info: ModelInfo } = {
|
|
|
+ id: "",
|
|
|
+ info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false, supportsImages: false },
|
|
|
+ }
|
|
|
+
|
|
|
constructor(options: ApiHandlerOptions) {
|
|
|
super()
|
|
|
this.options = options
|
|
|
@@ -141,7 +164,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
}
|
|
|
|
|
|
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
|
|
- const modelConfig = this.getModel()
|
|
|
+ var modelConfig = this.getModel()
|
|
|
|
|
|
// Handle cross-region inference
|
|
|
let modelId: string
|
|
|
@@ -250,8 +273,8 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- // Handle metadata events first
|
|
|
- if (streamEvent.metadata?.usage) {
|
|
|
+ // Handle metadata events first.
|
|
|
+ if (streamEvent?.metadata?.usage) {
|
|
|
yield {
|
|
|
type: "usage",
|
|
|
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
|
|
|
@@ -260,6 +283,45 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ if (streamEvent?.trace?.promptRouter?.invokedModelId) {
|
|
|
+ try {
|
|
|
+ const invokedModelId = streamEvent.trace.promptRouter.invokedModelId
|
|
|
+ const modelMatch = invokedModelId.match(/\/([^\/]+)(?::|$)/)
|
|
|
+ if (modelMatch && modelMatch[1]) {
|
|
|
+ let modelName = modelMatch[1]
|
|
|
+
|
|
|
+ logger.debug("Bedrock invokedModelId detected", { ctx: "bedrock", invokedModelId })
|
|
|
+
|
|
|
+ // Get a new modelConfig from getModel() using invokedModelId.. remove the region first
|
|
|
+ let region = modelName.slice(0, 3)
|
|
|
+
|
|
|
+ logger.debug("region", { region })
|
|
|
+
|
|
|
+ if (region === "us." || region === "eu.") modelName = modelName.slice(3)
|
|
|
+ this.costModelConfig = this.getModelByName(modelName)
|
|
|
+ logger.debug("Updated modelConfig using invokedModelId", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ modelConfig: this.costModelConfig,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ // Handle metadata events for the promptRouter.
|
|
|
+ if (streamEvent?.trace?.promptRouter?.usage) {
|
|
|
+ yield {
|
|
|
+ type: "usage",
|
|
|
+ inputTokens: streamEvent?.trace?.promptRouter?.usage?.inputTokens || 0,
|
|
|
+ outputTokens: streamEvent?.trace?.promptRouter?.usage?.outputTokens || 0,
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ } catch (error) {
|
|
|
+ logger.error("Error handling Bedrock invokedModelId", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ error: error instanceof Error ? error : String(error),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// Handle message start
|
|
|
if (streamEvent.messageStart) {
|
|
|
continue
|
|
|
@@ -282,7 +344,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
|
|
|
}
|
|
|
continue
|
|
|
}
|
|
|
-
|
|
|
// Handle message stop
|
|
|
if (streamEvent.messageStop) {
|
|
|
continue
|
|
|
@@ -428,122 +489,162 @@ Please check:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
|
|
- // If custom ARN is provided, use it
|
|
|
- if (this.options.awsCustomArn) {
|
|
|
- // Custom ARNs should not be modified with region prefixes
|
|
|
- // as they already contain the full resource path
|
|
|
-
|
|
|
- // Check if the ARN contains information about the model type
|
|
|
- // This helps set appropriate token limits for models behind prompt routers
|
|
|
- const arnLower = this.options.awsCustomArn.toLowerCase()
|
|
|
-
|
|
|
- // Determine model info based on ARN content
|
|
|
- let modelInfo: ModelInfo
|
|
|
-
|
|
|
- if (arnLower.includes("claude-3-7-sonnet") || arnLower.includes("claude-3.7-sonnet")) {
|
|
|
- // Claude 3.7 Sonnet has 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- supportsComputerUse: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-5-sonnet") || arnLower.includes("claude-3.5-sonnet")) {
|
|
|
- // Claude 3.5 Sonnet has 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- supportsComputerUse: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-opus") || arnLower.includes("claude-3.0-opus")) {
|
|
|
- // Claude 3 Opus has 4096 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-haiku") || arnLower.includes("claude-3.0-haiku")) {
|
|
|
- // Claude 3 Haiku has 4096 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude-3-5-haiku") || arnLower.includes("claude-3.5-haiku")) {
|
|
|
- // Claude 3.5 Haiku has 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 200_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: false,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("claude")) {
|
|
|
- // Generic Claude model with conservative token limit
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 128_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else if (arnLower.includes("llama3") || arnLower.includes("llama-3")) {
|
|
|
- // Llama 3 models typically have 8192 tokens in Bedrock
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 8192,
|
|
|
- contextWindow: 128_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: arnLower.includes("90b") || arnLower.includes("11b"),
|
|
|
- }
|
|
|
- } else if (arnLower.includes("nova-pro")) {
|
|
|
- // Amazon Nova Pro
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 5000,
|
|
|
- contextWindow: 300_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Default for unknown models or prompt routers
|
|
|
- modelInfo = {
|
|
|
- maxTokens: 4096,
|
|
|
- contextWindow: 128_000,
|
|
|
- supportsPromptCache: false,
|
|
|
- supportsImages: true,
|
|
|
- }
|
|
|
+ //Theory: Prompt Router responses seem to come back in a different sequence and the yield calls are not resulting in costs getting updated
|
|
|
+
|
|
|
+ //Sample response
|
|
|
+ /*
|
|
|
+ {"$metadata":
|
|
|
+ {
|
|
|
+ "httpStatusCode":200,
|
|
|
+ "requestId":"96b8aeff-225b-470e-9901-7554c6ee15b3",
|
|
|
+ "attempts":1,
|
|
|
+ "totalRetryDelay":0
|
|
|
+ },
|
|
|
+ "metrics":
|
|
|
+ {
|
|
|
+ "latencyMs":4588
|
|
|
+ },
|
|
|
+ "output":
|
|
|
+ {
|
|
|
+ "message":
|
|
|
+ {
|
|
|
+ "content":[
|
|
|
+ {
|
|
|
+ "text":"I apologize, but I don't have access to any specific AWS Bedrock Intelligent Prompt Routing system or ARN (Amazon Resource Name). I'm Claude, an AI assistant created by Anthropic to be helpful, harmless, and honest. I don't have direct access to AWS services or the ability to verify their functionality.\n\nIf you're testing an AWS Bedrock prompt router, you would need to check within your AWS console or use AWS CLI tools to verify if it's working correctly. I can't confirm the status or functionality of any specific AWS resources.\n\nIs there anything else I can assist you with regarding AI, language models, or general information about prompt routing concepts?"
|
|
|
+ }]
|
|
|
+ ,
|
|
|
+ "role":"assistant"
|
|
|
}
|
|
|
+ },
|
|
|
+ "stopReason":"end_turn",
|
|
|
+ "trace":
|
|
|
+ {
|
|
|
+ "promptRouter":
|
|
|
+ {
|
|
|
+ "invokedModelId":"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
|
|
+ },
|
|
|
+ "usage":
|
|
|
+ {
|
|
|
+ "inputTokens":38,
|
|
|
+ "outputTokens":147,
|
|
|
+ "totalTokens":185
|
|
|
+ }
|
|
|
+ }
|
|
|
+*/
|
|
|
+
|
|
|
+ getModelByName(modelName: string): { id: BedrockModelId | string; info: ModelInfo } {
|
|
|
+ logger.debug("Getting model info for specific name", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ modelName,
|
|
|
+ awsCustomArn: this.options.awsCustomArn,
|
|
|
+ })
|
|
|
+
|
|
|
+ // Try to find the model in bedrockModels
|
|
|
+ if (modelName in bedrockModels) {
|
|
|
+ const id = modelName as BedrockModelId
|
|
|
+ logger.debug("Found model name", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ modelName,
|
|
|
+ id: id,
|
|
|
+ info: bedrockModels[id],
|
|
|
+ awsCustomArn: this.options.awsCustomArn,
|
|
|
+ })
|
|
|
+
|
|
|
+ let modelInfo = JSON.parse(JSON.stringify(bedrockModels[id]))
|
|
|
|
|
|
// If modelMaxTokens is explicitly set in options, override the default
|
|
|
if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
|
|
|
modelInfo.maxTokens = this.options.modelMaxTokens
|
|
|
}
|
|
|
|
|
|
+ return { id, info: modelInfo }
|
|
|
+ }
|
|
|
+
|
|
|
+ // A specific name was asked for but not found, use default values
|
|
|
+ logger.debug("Return defaults 1", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ bedrockDefaultModelId,
|
|
|
+ customArn: this.options.awsCustomArn,
|
|
|
+ })
|
|
|
+
|
|
|
+ return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
|
|
|
+ }
|
|
|
+
|
|
|
+ override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
|
|
|
+ if (this.costModelConfig.id.trim().length > 0) {
|
|
|
+ logger.debug("Returning cost previously set model config from a prompt router response", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ model: this.costModelConfig,
|
|
|
+ })
|
|
|
+ return this.costModelConfig
|
|
|
+ }
|
|
|
+
|
|
|
+ // If custom ARN is provided, use it
|
|
|
+ if (this.options.awsCustomArn) {
|
|
|
+ // Extract the model name from the ARN
|
|
|
+ const arnMatch = this.options.awsCustomArn.match(
|
|
|
+ /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model)\/(.+)$/,
|
|
|
+ )
|
|
|
+
|
|
|
+ const extractedModelName = arnMatch ? arnMatch[2] : ""
|
|
|
+
|
|
|
+ logger.debug(`Regex match to foundation-model model:`, {
|
|
|
+ extractedModelName: extractedModelName,
|
|
|
+ arnMatch: arnMatch,
|
|
|
+ })
|
|
|
+
|
|
|
+ if (extractedModelName) {
|
|
|
+ const modelData = this.getModelByName(extractedModelName)
|
|
|
+
|
|
|
+ if (modelData) {
|
|
|
+ logger.debug(`Matched custom ARN to model: ${extractedModelName}`, {
|
|
|
+ ctx: "bedrock",
|
|
|
+ modelData,
|
|
|
+ })
|
|
|
+ return modelData
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // An ARN was used, but no model info match found, use default values based on common patterns
|
|
|
+ logger.debug("Return defaults for custom ARN", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ bedrockDefaultPromptRouterModelId,
|
|
|
+ customArn: this.options.awsCustomArn,
|
|
|
+ })
|
|
|
+
|
|
|
+ let modelInfo = this.getModelByName(bedrockDefaultPromptRouterModelId)
|
|
|
+
|
|
|
+ // For custom ARNs, always return the specific values expected by tests
|
|
|
return {
|
|
|
id: this.options.awsCustomArn,
|
|
|
- info: modelInfo,
|
|
|
+ info: modelInfo.info,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- const modelId = this.options.apiModelId
|
|
|
- if (modelId) {
|
|
|
+ if (this.options.apiModelId) {
|
|
|
// Special case for custom ARN option
|
|
|
- if (modelId === "custom-arn") {
|
|
|
+ if (this.options.apiModelId === "custom-arn") {
|
|
|
// This should not happen as we should have awsCustomArn set
|
|
|
// but just in case, return a default model
|
|
|
- return {
|
|
|
- id: bedrockDefaultModelId,
|
|
|
- info: bedrockModels[bedrockDefaultModelId],
|
|
|
- }
|
|
|
+
|
|
|
+ logger.debug("Return defaults 3", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ name: this.options.apiModelId,
|
|
|
+ customArn: this.options.awsCustomArn,
|
|
|
+ })
|
|
|
+
|
|
|
+ return this.getModelByName(bedrockDefaultModelId)
|
|
|
}
|
|
|
|
|
|
- // For tests, allow any model ID
|
|
|
+ // For tests, allow any model ID (but not custom ARNs, which are handled above)
|
|
|
if (process.env.NODE_ENV === "test") {
|
|
|
+ logger.debug("Return defaults 4", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ customArn: this.options.awsCustomArn,
|
|
|
+ })
|
|
|
+
|
|
|
return {
|
|
|
- id: modelId,
|
|
|
+ id: this.options.apiModelId,
|
|
|
info: {
|
|
|
maxTokens: 5000,
|
|
|
contextWindow: 128_000,
|
|
|
@@ -552,20 +653,21 @@ Please check:
|
|
|
}
|
|
|
}
|
|
|
// For production, validate against known models
|
|
|
- if (modelId in bedrockModels) {
|
|
|
- const id = modelId as BedrockModelId
|
|
|
- return { id, info: bedrockModels[id] }
|
|
|
- }
|
|
|
- }
|
|
|
- return {
|
|
|
- id: bedrockDefaultModelId,
|
|
|
- info: bedrockModels[bedrockDefaultModelId],
|
|
|
+ return this.getModelByName(this.options.apiModelId)
|
|
|
}
|
|
|
+
|
|
|
+ logger.debug("Return defaults for no matching model info", {
|
|
|
+ ctx: "bedrock",
|
|
|
+ customArn: this.options.awsCustomArn,
|
|
|
+ })
|
|
|
+
|
|
|
+ return this.getModelByName(bedrockDefaultModelId)
|
|
|
}
|
|
|
|
|
|
async completePrompt(prompt: string): Promise<string> {
|
|
|
try {
|
|
|
const modelConfig = this.getModel()
|
|
|
+ //this.costModelConfig = modelConfig;
|
|
|
|
|
|
// Handle cross-region inference
|
|
|
let modelId: string
|
|
|
@@ -653,6 +755,7 @@ Please check:
|
|
|
try {
|
|
|
const outputStr = new TextDecoder().decode(response.output)
|
|
|
const output = JSON.parse(outputStr)
|
|
|
+ logger.debug("Bedrock response", { ctx: "bedrock", output: output })
|
|
|
if (output.content) {
|
|
|
return output.content
|
|
|
}
|