Browse Source

fixes image support in bedrock. regression from prompt cache implementation (#2723)

fixes image support in bedrock. regression created during prompt caching implementation
Smartsheet-JB-Brown 9 months ago
parent
commit
b077267b4a
2 changed files with 175 additions and 33 deletions
  1. 140 14
      src/api/providers/__tests__/bedrock.test.ts
  2. 35 19
      src/api/providers/bedrock.ts

+ 140 - 14
src/api/providers/__tests__/bedrock.test.ts

@@ -7,9 +7,23 @@ jest.mock("@aws-sdk/credential-providers", () => {
 	return { fromIni: mockFromIni }
 })
 
+// Mock BedrockRuntimeClient and ConverseStreamCommand
+const mockConverseStreamCommand = jest.fn()
+const mockSend = jest.fn().mockResolvedValue({
+	stream: [],
+})
+
+jest.mock("@aws-sdk/client-bedrock-runtime", () => ({
+	BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
+		send: mockSend,
+	})),
+	ConverseStreamCommand: mockConverseStreamCommand,
+	ConverseCommand: jest.fn(),
+}))
+
 import { AwsBedrockHandler } from "../bedrock"
 import { MessageContent } from "../../../shared/api"
-import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
+import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime"
 import { Anthropic } from "@anthropic-ai/sdk"
 const { fromIni } = require("@aws-sdk/credential-providers")
 import { logger } from "../../../utils/logging"
@@ -57,7 +71,6 @@ describe("AwsBedrockHandler", () => {
 		})
 
 		it("should handle inference-profile ARN with apne3 region prefix", () => {
-			// Mock the parseArn method before creating the handler
 			const originalParseArn = AwsBedrockHandler.prototype["parseArn"]
 			const parseArnMock = jest.fn().mockImplementation(function (this: any, arn: string, region?: string) {
 				return originalParseArn.call(this, arn, region)
@@ -65,12 +78,11 @@ describe("AwsBedrockHandler", () => {
 			AwsBedrockHandler.prototype["parseArn"] = parseArnMock
 
 			try {
-				// Create a handler with a custom ARN that includes the apne3. region prefix
 				const customArnHandler = new AwsBedrockHandler({
 					apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
 					awsAccessKey: "test-access-key",
 					awsSecretKey: "test-secret-key",
-					awsRegion: "ap-northeast-3", // Osaka region
+					awsRegion: "ap-northeast-3",
 					awsCustomArn:
 						"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
 				})
@@ -79,23 +91,17 @@ describe("AwsBedrockHandler", () => {
 
 				expect(modelInfo.id).toBe(
 					"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
-				),
-					// Verify the model info is defined
-					expect(modelInfo.info).toBeDefined()
+				)
+				expect(modelInfo.info).toBeDefined()
 
-				// Verify parseArn was called with the correct ARN
 				expect(parseArnMock).toHaveBeenCalledWith(
 					"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
 					"ap-northeast-3",
 				)
 
-				// Verify the model ID was correctly extracted from the ARN (without the region prefix)
 				expect((customArnHandler as any).arnInfo.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
-
-				// Verify cross-region inference flag is false since apne3 is a prefix for a single region
 				expect((customArnHandler as any).arnInfo.crossRegionInference).toBe(false)
 			} finally {
-				// Restore the original method
 				AwsBedrockHandler.prototype["parseArn"] = originalParseArn
 			}
 		})
@@ -109,12 +115,132 @@ describe("AwsBedrockHandler", () => {
 				awsRegion: "us-east-1",
 			})
 			const modelInfo = customArnHandler.getModel()
-			// Should fall back to default prompt router model
 			expect(modelInfo.id).toBe(
 				"arn:aws:bedrock:ap-northeast-3:123456789012:default-prompt-router/my_router_arn_no_model",
-			) // bedrockDefaultPromptRouterModelId
+			)
 			expect(modelInfo.info).toBeDefined()
 			expect(modelInfo.info.maxTokens).toBe(4096)
 		})
 	})
+
+	describe("image handling", () => {
+		const mockImageData = Buffer.from("test-image-data").toString("base64")
+
+		beforeEach(() => {
+			// Reset the mocks before each test
+			mockSend.mockReset()
+			mockConverseStreamCommand.mockReset()
+
+			mockSend.mockResolvedValue({
+				stream: [],
+			})
+		})
+
+		it("should properly convert image content to Bedrock format", async () => {
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user",
+					content: [
+						{
+							type: "image",
+							source: {
+								type: "base64",
+								data: mockImageData,
+								media_type: "image/jpeg",
+							},
+						},
+						{
+							type: "text",
+							text: "What's in this image?",
+						},
+					],
+				},
+			]
+
+			const generator = handler.createMessage("", messages)
+			await generator.next() // Start the generator
+
+			// Verify the command was created with the right payload
+			expect(mockConverseStreamCommand).toHaveBeenCalled()
+			const commandArg = mockConverseStreamCommand.mock.calls[0][0]
+
+			// Verify the image was properly formatted
+			const imageBlock = commandArg.messages[0].content[0]
+			expect(imageBlock).toHaveProperty("image")
+			expect(imageBlock.image).toHaveProperty("format", "jpeg")
+			expect(imageBlock.image.source).toHaveProperty("bytes")
+			expect(imageBlock.image.source.bytes).toBeInstanceOf(Uint8Array)
+		})
+
+		it("should reject unsupported image formats", async () => {
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user",
+					content: [
+						{
+							type: "image",
+							source: {
+								type: "base64",
+								data: mockImageData,
+								media_type: "image/tiff" as "image/jpeg", // Type assertion to bypass TS
+							},
+						},
+					],
+				},
+			]
+
+			const generator = handler.createMessage("", messages)
+			await expect(generator.next()).rejects.toThrow("Unsupported image format: tiff")
+		})
+
+		it("should handle multiple images in a single message", async () => {
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user",
+					content: [
+						{
+							type: "image",
+							source: {
+								type: "base64",
+								data: mockImageData,
+								media_type: "image/jpeg",
+							},
+						},
+						{
+							type: "text",
+							text: "First image",
+						},
+						{
+							type: "image",
+							source: {
+								type: "base64",
+								data: mockImageData,
+								media_type: "image/png",
+							},
+						},
+						{
+							type: "text",
+							text: "Second image",
+						},
+					],
+				},
+			]
+
+			const generator = handler.createMessage("", messages)
+			await generator.next() // Start the generator
+
+			// Verify the command was created with the right payload
+			expect(mockConverseStreamCommand).toHaveBeenCalled()
+			const commandArg = mockConverseStreamCommand.mock.calls[0][0]
+
+			// Verify both images were properly formatted
+			const firstImage = commandArg.messages[0].content[0]
+			const secondImage = commandArg.messages[0].content[2]
+
+			expect(firstImage).toHaveProperty("image")
+			expect(firstImage.image).toHaveProperty("format", "jpeg")
+			expect(secondImage).toHaveProperty("image")
+			expect(secondImage.image).toHaveProperty("format", "png")
+		})
+	})
 })

+ 35 - 19
src/api/providers/bedrock.ts

@@ -3,6 +3,7 @@ import {
 	ConverseStreamCommand,
 	ConverseCommand,
 	BedrockRuntimeClientConfig,
+	ContentBlock,
 } from "@aws-sdk/client-bedrock-runtime"
 import { fromIni } from "@aws-sdk/credential-providers"
 import { Anthropic } from "@anthropic-ai/sdk"
@@ -23,6 +24,7 @@ import { Message, SystemContentBlock } from "@aws-sdk/client-bedrock-runtime"
 import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy"
 import { ModelInfo as CacheModelInfo } from "../transform/cache-strategy/types"
 import { AMAZON_BEDROCK_REGION_INFO } from "../../shared/aws_regions"
+import { convertToBedrockConverseMessages as sharedConverter } from "../transform/bedrock-converse-format"
 
 const BEDROCK_DEFAULT_TEMPERATURE = 0.3
 const BEDROCK_MAX_TOKENS = 4096
@@ -434,7 +436,18 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 		modelInfo?: any,
 		conversationId?: string, // Optional conversation ID to track cache points across messages
 	): { system: SystemContentBlock[]; messages: Message[] } {
-		// Convert model info to expected format
+		// First convert messages using shared converter for proper image handling
+		const convertedMessages = sharedConverter(anthropicMessages as Anthropic.Messages.MessageParam[])
+
+		// If prompt caching is disabled, return the converted messages directly
+		if (!usePromptCache) {
+			return {
+				system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [],
+				messages: convertedMessages,
+			}
+		}
+
+		// Convert model info to expected format for cache strategy
 		const cacheModelInfo: CacheModelInfo = {
 			maxTokens: modelInfo?.maxTokens || 8192,
 			contextWindow: modelInfo?.contextWindow || 200_000,
@@ -444,18 +457,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 			cachableFields: modelInfo?.cachableFields || [],
 		}
 
-		// Clean messages by removing any existing cache points
-		const cleanedMessages = anthropicMessages.map((msg) => {
-			if (typeof msg.content === "string") {
-				return msg
-			}
-			const cleaned = {
-				...msg,
-				content: this.removeCachePoints(msg.content),
-			}
-			return cleaned
-		})
-
 		// Get previous cache point placements for this conversation if available
 		const previousPlacements =
 			conversationId && this.previousCachePointPlacements[conversationId]
@@ -466,21 +467,36 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
 		const config = {
 			modelInfo: cacheModelInfo,
 			systemPrompt: systemMessage,
-			messages: cleanedMessages as Anthropic.Messages.MessageParam[],
+			messages: anthropicMessages as Anthropic.Messages.MessageParam[],
 			usePromptCache,
 			previousCachePointPlacements: previousPlacements,
 		}
 
-		// Determine optimal cache points
+		// Get cache point placements
 		let strategy = new MultiPointStrategy(config)
-		const result = strategy.determineOptimalCachePoints()
+		const cacheResult = strategy.determineOptimalCachePoints()
 
 		// Store cache point placements for future use if conversation ID is provided
-		if (conversationId && result.messageCachePointPlacements) {
-			this.previousCachePointPlacements[conversationId] = result.messageCachePointPlacements
+		if (conversationId && cacheResult.messageCachePointPlacements) {
+			this.previousCachePointPlacements[conversationId] = cacheResult.messageCachePointPlacements
 		}
 
-		return result
+		// Apply cache points to the properly converted messages
+		const messagesWithCache = convertedMessages.map((msg, index) => {
+			const placement = cacheResult.messageCachePointPlacements?.find((p) => p.index === index)
+			if (placement) {
+				return {
+					...msg,
+					content: [...(msg.content || []), { cachePoint: { type: "default" } } as ContentBlock],
+				}
+			}
+			return msg
+		})
+
+		return {
+			system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [],
+			messages: messagesWithCache,
+		}
 	}
 
 	/************************************************************************************