|
|
@@ -7,9 +7,23 @@ jest.mock("@aws-sdk/credential-providers", () => {
|
|
|
return { fromIni: mockFromIni }
|
|
|
})
|
|
|
|
|
|
+// Mock BedrockRuntimeClient and ConverseStreamCommand
|
|
|
+const mockConverseStreamCommand = jest.fn()
|
|
|
+const mockSend = jest.fn().mockResolvedValue({
|
|
|
+ stream: [],
|
|
|
+})
|
|
|
+
|
|
|
+jest.mock("@aws-sdk/client-bedrock-runtime", () => ({
|
|
|
+ BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
|
|
|
+ send: mockSend,
|
|
|
+ })),
|
|
|
+ ConverseStreamCommand: mockConverseStreamCommand,
|
|
|
+ ConverseCommand: jest.fn(),
|
|
|
+}))
|
|
|
+
|
|
|
import { AwsBedrockHandler } from "../bedrock"
|
|
|
import { MessageContent } from "../../../shared/api"
|
|
|
-import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
|
|
|
+import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime"
|
|
|
import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
const { fromIni } = require("@aws-sdk/credential-providers")
|
|
|
import { logger } from "../../../utils/logging"
|
|
|
@@ -57,7 +71,6 @@ describe("AwsBedrockHandler", () => {
|
|
|
})
|
|
|
|
|
|
it("should handle inference-profile ARN with apne3 region prefix", () => {
|
|
|
- // Mock the parseArn method before creating the handler
|
|
|
const originalParseArn = AwsBedrockHandler.prototype["parseArn"]
|
|
|
const parseArnMock = jest.fn().mockImplementation(function (this: any, arn: string, region?: string) {
|
|
|
return originalParseArn.call(this, arn, region)
|
|
|
@@ -65,12 +78,11 @@ describe("AwsBedrockHandler", () => {
|
|
|
AwsBedrockHandler.prototype["parseArn"] = parseArnMock
|
|
|
|
|
|
try {
|
|
|
- // Create a handler with a custom ARN that includes the apne3. region prefix
|
|
|
const customArnHandler = new AwsBedrockHandler({
|
|
|
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
|
awsAccessKey: "test-access-key",
|
|
|
awsSecretKey: "test-secret-key",
|
|
|
- awsRegion: "ap-northeast-3", // Osaka region
|
|
|
+ awsRegion: "ap-northeast-3",
|
|
|
awsCustomArn:
|
|
|
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
|
})
|
|
|
@@ -79,23 +91,17 @@ describe("AwsBedrockHandler", () => {
|
|
|
|
|
|
expect(modelInfo.id).toBe(
|
|
|
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
|
- ),
|
|
|
- // Verify the model info is defined
|
|
|
- expect(modelInfo.info).toBeDefined()
|
|
|
+ )
|
|
|
+ expect(modelInfo.info).toBeDefined()
|
|
|
|
|
|
- // Verify parseArn was called with the correct ARN
|
|
|
expect(parseArnMock).toHaveBeenCalledWith(
|
|
|
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
|
"ap-northeast-3",
|
|
|
)
|
|
|
|
|
|
- // Verify the model ID was correctly extracted from the ARN (without the region prefix)
|
|
|
expect((customArnHandler as any).arnInfo.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
|
|
|
-
|
|
|
- // Verify cross-region inference flag is false since apne3 is a prefix for a single region
|
|
|
expect((customArnHandler as any).arnInfo.crossRegionInference).toBe(false)
|
|
|
} finally {
|
|
|
- // Restore the original method
|
|
|
AwsBedrockHandler.prototype["parseArn"] = originalParseArn
|
|
|
}
|
|
|
})
|
|
|
@@ -109,12 +115,132 @@ describe("AwsBedrockHandler", () => {
|
|
|
awsRegion: "us-east-1",
|
|
|
})
|
|
|
const modelInfo = customArnHandler.getModel()
|
|
|
- // Should fall back to default prompt router model
|
|
|
expect(modelInfo.id).toBe(
|
|
|
"arn:aws:bedrock:ap-northeast-3:123456789012:default-prompt-router/my_router_arn_no_model",
|
|
|
- ) // bedrockDefaultPromptRouterModelId
|
|
|
+ )
|
|
|
expect(modelInfo.info).toBeDefined()
|
|
|
expect(modelInfo.info.maxTokens).toBe(4096)
|
|
|
})
|
|
|
})
|
|
|
+
|
|
|
+ describe("image handling", () => {
|
|
|
+ const mockImageData = Buffer.from("test-image-data").toString("base64")
|
|
|
+
|
|
|
+ beforeEach(() => {
|
|
|
+ // Reset the mocks before each test
|
|
|
+ mockSend.mockReset()
|
|
|
+ mockConverseStreamCommand.mockReset()
|
|
|
+
|
|
|
+ mockSend.mockResolvedValue({
|
|
|
+ stream: [],
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should properly convert image content to Bedrock format", async () => {
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [
|
|
|
+ {
|
|
|
+ type: "image",
|
|
|
+ source: {
|
|
|
+ type: "base64",
|
|
|
+ data: mockImageData,
|
|
|
+ media_type: "image/jpeg",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "text",
|
|
|
+ text: "What's in this image?",
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const generator = handler.createMessage("", messages)
|
|
|
+ await generator.next() // Start the generator
|
|
|
+
|
|
|
+ // Verify the command was created with the right payload
|
|
|
+ expect(mockConverseStreamCommand).toHaveBeenCalled()
|
|
|
+ const commandArg = mockConverseStreamCommand.mock.calls[0][0]
|
|
|
+
|
|
|
+ // Verify the image was properly formatted
|
|
|
+ const imageBlock = commandArg.messages[0].content[0]
|
|
|
+ expect(imageBlock).toHaveProperty("image")
|
|
|
+ expect(imageBlock.image).toHaveProperty("format", "jpeg")
|
|
|
+ expect(imageBlock.image.source).toHaveProperty("bytes")
|
|
|
+ expect(imageBlock.image.source.bytes).toBeInstanceOf(Uint8Array)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should reject unsupported image formats", async () => {
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [
|
|
|
+ {
|
|
|
+ type: "image",
|
|
|
+ source: {
|
|
|
+ type: "base64",
|
|
|
+ data: mockImageData,
|
|
|
+ media_type: "image/tiff" as "image/jpeg", // Type assertion to bypass TS
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const generator = handler.createMessage("", messages)
|
|
|
+ await expect(generator.next()).rejects.toThrow("Unsupported image format: tiff")
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle multiple images in a single message", async () => {
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [
|
|
|
+ {
|
|
|
+ type: "image",
|
|
|
+ source: {
|
|
|
+ type: "base64",
|
|
|
+ data: mockImageData,
|
|
|
+ media_type: "image/jpeg",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "text",
|
|
|
+ text: "First image",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "image",
|
|
|
+ source: {
|
|
|
+ type: "base64",
|
|
|
+ data: mockImageData,
|
|
|
+ media_type: "image/png",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ type: "text",
|
|
|
+ text: "Second image",
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const generator = handler.createMessage("", messages)
|
|
|
+ await generator.next() // Start the generator
|
|
|
+
|
|
|
+ // Verify the command was created with the right payload
|
|
|
+ expect(mockConverseStreamCommand).toHaveBeenCalled()
|
|
|
+ const commandArg = mockConverseStreamCommand.mock.calls[0][0]
|
|
|
+
|
|
|
+ // Verify both images were properly formatted
|
|
|
+ const firstImage = commandArg.messages[0].content[0]
|
|
|
+ const secondImage = commandArg.messages[0].content[2]
|
|
|
+
|
|
|
+ expect(firstImage).toHaveProperty("image")
|
|
|
+ expect(firstImage.image).toHaveProperty("format", "jpeg")
|
|
|
+ expect(secondImage).toHaveProperty("image")
|
|
|
+ expect(secondImage.image).toHaveProperty("format", "png")
|
|
|
+ })
|
|
|
+ })
|
|
|
})
|