Browse Source

Update src/api/providers/bedrock.ts

agree, sorry old Javascript habits

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
Smartsheet-JB-Brown 11 months ago
parent
commit
8d30b6f44a

+ 0 - 151
src/api/providers/__tests__/bedrock-createMessage.test.ts

@@ -1,151 +0,0 @@
-// Mock AWS SDK credential providers
-jest.mock("@aws-sdk/credential-providers", () => ({
-	fromIni: jest.fn().mockReturnValue({
-		accessKeyId: "profile-access-key",
-		secretAccessKey: "profile-secret-key",
-	}),
-}))
-
-import { AwsBedrockHandler, StreamEvent } from "../bedrock"
-import { ApiHandlerOptions } from "../../../shared/api"
-import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
-import { logger } from "../../../utils/logging"
-
-describe("AwsBedrockHandler createMessage", () => {
-	let mockSend: jest.SpyInstance
-
-	beforeEach(() => {
-		// Mock the BedrockRuntimeClient.prototype.send method
-		mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => {
-			return {
-				stream: createMockStream([]),
-			}
-		})
-	})
-
-	afterEach(() => {
-		mockSend.mockRestore()
-	})
-
-	// Helper function to create a mock async iterable stream
-	function createMockStream(events: StreamEvent[]) {
-		return {
-			[Symbol.asyncIterator]: async function* () {
-				for (const event of events) {
-					yield event
-				}
-				// Always yield a metadata event at the end
-				yield {
-					metadata: {
-						usage: {
-							inputTokens: 100,
-							outputTokens: 200,
-						},
-					},
-				}
-			},
-		}
-	}
-
-	it("should log debug information during createMessage with custom ARN", async () => {
-		// Create a handler with a custom ARN
-		const mockOptions: ApiHandlerOptions = {
-			apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
-			awsAccessKey: "test-access-key",
-			awsSecretKey: "test-secret-key",
-			awsRegion: "us-east-1",
-			awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model",
-		}
-
-		const handler = new AwsBedrockHandler(mockOptions)
-
-		// Mock the stream to include various events that trigger debug logs
-		mockSend.mockImplementationOnce(async () => {
-			return {
-				stream: createMockStream([
-					// Event with invokedModelId
-					{
-						trace: {
-							promptRouter: {
-								invokedModelId:
-									"arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
-							},
-						},
-					},
-					// Content events
-					{
-						contentBlockStart: {
-							start: {
-								text: "Hello",
-							},
-							contentBlockIndex: 0,
-						},
-					},
-					{
-						contentBlockDelta: {
-							delta: {
-								text: ", world!",
-							},
-							contentBlockIndex: 0,
-						},
-					},
-				]),
-			}
-		})
-
-		// Create a message generator
-		const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
-
-		// Collect all yielded events
-		const events = []
-		for await (const event of messageGenerator) {
-			events.push(event)
-		}
-
-		// Verify that events were yielded
-		expect(events.length).toBeGreaterThan(0)
-
-		// Verify that debug logs were called
-		expect(logger.debug).toHaveBeenCalledWith(
-			"Using custom ARN for Bedrock request",
-			expect.objectContaining({
-				ctx: "bedrock",
-				customArn: mockOptions.awsCustomArn,
-			}),
-		)
-
-		expect(logger.debug).toHaveBeenCalledWith(
-			"Bedrock invokedModelId detected",
-			expect.objectContaining({
-				ctx: "bedrock",
-				invokedModelId:
-					"arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
-			}),
-		)
-	})
-
-	it("should log debug information during createMessage with cross-region inference", async () => {
-		// Create a handler with cross-region inference
-		const mockOptions: ApiHandlerOptions = {
-			apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
-			awsAccessKey: "test-access-key",
-			awsSecretKey: "test-secret-key",
-			awsRegion: "us-east-1",
-			awsUseCrossRegionInference: true,
-		}
-
-		const handler = new AwsBedrockHandler(mockOptions)
-
-		// Create a message generator
-		const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
-
-		// Collect all yielded events
-		const events = []
-		for await (const event of messageGenerator) {
-			events.push(event)
-		}
-
-		// Verify that events were yielded
-		expect(events.length).toBeGreaterThan(0)
-	})
-})

+ 27 - 1
src/api/providers/__tests__/bedrock.test.ts

@@ -326,11 +326,37 @@ describe("AwsBedrockHandler", () => {
 			})
 			const modelInfo = customArnHandler.getModel()
 			expect(modelInfo.id).toBe("arn:aws:bedrock:us-east-1::foundation-model/custom-model")
-			expect(modelInfo.info.maxTokens).toBe(8192)
+			expect(modelInfo.info.maxTokens).toBe(4096)
 			expect(modelInfo.info.contextWindow).toBe(200_000)
 			expect(modelInfo.info.supportsPromptCache).toBe(false)
 		})
 
+		it("should correctly identify model info from inference profile ARN", () => {
+			//this test intentionally uses a model that has different maxTokens, contextWindow and other values than the fall back option in the code
+			const customArnHandler = new AwsBedrockHandler({
+				apiModelId: "meta.llama3-8b-instruct-v1:0", // This will be ignored when awsCustomArn is provided
+				awsAccessKey: "test-access-key",
+				awsSecretKey: "test-secret-key",
+				awsRegion: "us-west-2",
+				awsCustomArn:
+					"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.meta.llama3-8b-instruct-v1:0",
+			})
+			const modelInfo = customArnHandler.getModel()
+
+			// Verify the ARN is used as the model ID
+			expect(modelInfo.id).toBe(
+				"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.meta.llama3-8b-instruct-v1:0",
+			)
+
+			//
+			expect(modelInfo.info.maxTokens).toBe(2048)
+			expect(modelInfo.info.contextWindow).toBe(4_000)
+			expect(modelInfo.info.supportsImages).toBe(false)
+			expect(modelInfo.info.supportsPromptCache).toBe(false)
+
+			// This test highlights that the regex in getModel needs to be updated to handle inference-profile ARNs
+		})
+
 		it("should use default model when custom-arn is selected but no ARN is provided", () => {
 			const customArnHandler = new AwsBedrockHandler({
 				apiModelId: "custom-arn",

+ 19 - 111
src/api/providers/bedrock.ts

@@ -29,7 +29,8 @@ import { logger } from "../../utils/logging"
  */
 function validateBedrockArn(arn: string, region?: string) {
 	// Validate ARN format
-	const arnRegex = /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router)\/(.+)$/
+	const arnRegex =
+		/^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router|prompt-router)\/(.+)$/
 	const match = arn.match(arnRegex)
 
 	if (!match) {
@@ -164,8 +165,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 	}
 
 	override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		var modelConfig = this.getModel()
-
+		let modelConfig = this.getModel()
 		// Handle cross-region inference
 		let modelId: string
 
@@ -290,16 +290,12 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 						if (modelMatch && modelMatch[1]) {
 							let modelName = modelMatch[1]
 
-							logger.debug("Bedrock invokedModelId detected", { ctx: "bedrock", invokedModelId })
-
 							// Get a new modelConfig from getModel() using invokedModelId.. remove the region first
 							let region = modelName.slice(0, 3)
 
-							logger.debug("region", { region })
-
 							if (region === "us." || region === "eu.") modelName = modelName.slice(3)
 							this.costModelConfig = this.getModelByName(modelName)
-							logger.debug("Updated modelConfig using invokedModelId", {
+							logger.debug("Updated modelConfig using invokedModelId from a prompt router response", {
 								ctx: "bedrock",
 								modelConfig: this.costModelConfig,
 							})
@@ -489,93 +485,30 @@ Please check:
 		}
 	}
 
-	//Theory: Prompt Router responses seem to come back in a different sequence and the yield calls are not resulting in costs getting updated
-
-	//Sample response
-	/*
-	{"$metadata":
-		{	
-			"httpStatusCode":200,
-			"requestId":"96b8aeff-225b-470e-9901-7554c6ee15b3",
-			"attempts":1,
-			"totalRetryDelay":0
-		},
-		"metrics":
-		{
-			"latencyMs":4588
-		},
-		"output":
-		{
-			"message":
-			{
-				"content":[
-					{
-						"text":"I apologize, but I don't have access to any specific AWS Bedrock Intelligent Prompt Routing system or ARN (Amazon Resource Name). I'm Claude, an AI assistant created by Anthropic to be helpful, harmless, and honest. I don't have direct access to AWS services or the ability to verify their functionality.\n\nIf you're testing an AWS Bedrock prompt router, you would need to check within your AWS console or use AWS CLI tools to verify if it's working correctly. I can't confirm the status or functionality of any specific AWS resources.\n\nIs there anything else I can assist you with regarding AI, language models, or general information about prompt routing concepts?"
-					}]
-					,
-				"role":"assistant"
-			}
-		},
-		"stopReason":"end_turn",
-		"trace":
-		{
-			"promptRouter":
-			{
-				"invokedModelId":"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
-			},
-			"usage":
-			{
-				"inputTokens":38,
-				"outputTokens":147,
-				"totalTokens":185
-			}
-		}
-*/
-
+	//Prompt Router responses come back in a different sequence and the yield calls are not resulting in costs getting updated
 	getModelByName(modelName: string): { id: BedrockModelId | string; info: ModelInfo } {
-		logger.debug("Getting model info for specific name", {
-			ctx: "bedrock",
-			modelName,
-			awsCustomArn: this.options.awsCustomArn,
-		})
-
 		// Try to find the model in bedrockModels
 		if (modelName in bedrockModels) {
 			const id = modelName as BedrockModelId
-			logger.debug("Found model name", {
-				ctx: "bedrock",
-				modelName,
-				id: id,
-				info: bedrockModels[id],
-				awsCustomArn: this.options.awsCustomArn,
-			})
 
-			let modelInfo = JSON.parse(JSON.stringify(bedrockModels[id]))
+			//Do a deep copy of the model info so that later in the code the model id and maxTokens can be set.
+			// The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value
+			// in a prompt router response isn't possible on the constant.
+			let model = JSON.parse(JSON.stringify(bedrockModels[id]))
 
 			// If modelMaxTokens is explicitly set in options, override the default
 			if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
-				modelInfo.maxTokens = this.options.modelMaxTokens
+				model.maxTokens = this.options.modelMaxTokens
 			}
 
-			return { id, info: modelInfo }
+			return { id, info: model }
 		}
 
-		// A specific name was asked for but not found, use default values
-		logger.debug("Return defaults 1", {
-			ctx: "bedrock",
-			bedrockDefaultModelId,
-			customArn: this.options.awsCustomArn,
-		})
-
 		return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
 	}
 
 	override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
 		if (this.costModelConfig.id.trim().length > 0) {
-			logger.debug("Returning cost previously set model config from a prompt router response", {
-				ctx: "bedrock",
-				model: this.costModelConfig,
-			})
 			return this.costModelConfig
 		}
 
@@ -583,21 +516,19 @@ Please check:
 		if (this.options.awsCustomArn) {
 			// Extract the model name from the ARN
 			const arnMatch = this.options.awsCustomArn.match(
-				/^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model)\/(.+)$/,
+				/^arn:aws:bedrock:([^:]+):(\d+):(inference-profile|foundation-model|provisioned-model)\/(.+)$/,
 			)
 
-			const extractedModelName = arnMatch ? arnMatch[2] : ""
+			let modelName = arnMatch ? arnMatch[4] : ""
+			if (modelName) {
+				let region = modelName.slice(0, 3)
+				if (region === "us." || region === "eu.") modelName = modelName.slice(3)
 
-			logger.debug(`Regex match to foundation-model model:`, {
-				extractedModelName: extractedModelName,
-				arnMatch: arnMatch,
-			})
-
-			if (extractedModelName) {
-				const modelData = this.getModelByName(extractedModelName)
+				let modelData = this.getModelByName(modelName)
+				modelData.id = this.options.awsCustomArn
 
 				if (modelData) {
-					logger.debug(`Matched custom ARN to model: ${extractedModelName}`, {
+					logger.debug(`Matched custom ARN to model: ${modelName}`, {
 						ctx: "bedrock",
 						modelData,
 					})
@@ -606,12 +537,6 @@ Please check:
 			}
 
 			// An ARN was used, but no model info match found, use default values based on common patterns
-			logger.debug("Return defaults for custom ARN", {
-				ctx: "bedrock",
-				bedrockDefaultPromptRouterModelId,
-				customArn: this.options.awsCustomArn,
-			})
-
 			let modelInfo = this.getModelByName(bedrockDefaultPromptRouterModelId)
 
 			// For custom ARNs, always return the specific values expected by tests
@@ -626,13 +551,6 @@ Please check:
 			if (this.options.apiModelId === "custom-arn") {
 				// This should not happen as we should have awsCustomArn set
 				// but just in case, return a default model
-
-				logger.debug("Return defaults 3", {
-					ctx: "bedrock",
-					name: this.options.apiModelId,
-					customArn: this.options.awsCustomArn,
-				})
-
 				return this.getModelByName(bedrockDefaultModelId)
 			}
 
@@ -655,12 +573,6 @@ Please check:
 			// For production, validate against known models
 			return this.getModelByName(this.options.apiModelId)
 		}
-
-		logger.debug("Return defaults for no matching model info", {
-			ctx: "bedrock",
-			customArn: this.options.awsCustomArn,
-		})
-
 		return this.getModelByName(bedrockDefaultModelId)
 	}
 
@@ -674,10 +586,6 @@ Please check:
 			// For custom ARNs, use the ARN directly without modification
 			if (this.options.awsCustomArn) {
 				modelId = modelConfig.id
-				logger.debug("Using custom ARN in completePrompt", {
-					ctx: "bedrock",
-					customArn: this.options.awsCustomArn,
-				})
 
 				// Validate ARN format and check region match
 				const clientRegion = this.client.config.region as string