Browse Source

test(bedrock): add comprehensive test coverage for Bedrock integration

- Add tests for AWS Bedrock handler (stream handling, config, errors)
- Add tests for message format conversion (text, images, tools)
- Add tests for stream event parsing and transformation
- Add tests for cross-region inference and prompt cache
- Add tests for metadata and message lifecycle events
Cline 1 year ago
parent
commit
ca41c54cb5

+ 243 - 0
src/api/providers/__tests__/bedrock.test.ts

@@ -0,0 +1,243 @@
+import { AwsBedrockHandler } from '../bedrock';
+import { 
+    BedrockRuntimeClient, 
+    ConverseStreamCommand,
+    ConverseStreamCommandOutput
+} from '@aws-sdk/client-bedrock-runtime';
+import { ApiHandlerOptions } from '../../../shared/api';
+import { jest } from '@jest/globals';
+import { Readable } from 'stream';
+
+// Mock the BedrockRuntimeClient
+jest.mock('@aws-sdk/client-bedrock-runtime', () => ({
+    BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
+        send: jest.fn()
+    })),
+    ConverseStreamCommand: jest.fn()
+}));
+
+describe('AwsBedrockHandler', () => {
+    let handler: AwsBedrockHandler;
+    let mockClient: jest.Mocked<BedrockRuntimeClient>;
+
+    beforeEach(() => {
+        // Clear all mocks
+        jest.clearAllMocks();
+
+        // Create mock client with properly typed send method
+        mockClient = {
+            send: jest.fn().mockImplementation(() => Promise.resolve({
+                $metadata: {},
+                stream: new Readable({
+                    read() {
+                        this.push(null);
+                    }
+                })
+            }))
+        } as unknown as jest.Mocked<BedrockRuntimeClient>;
+
+        // Create handler with test options
+        const options: ApiHandlerOptions = {
+            awsRegion: 'us-west-2',
+            awsAccessKey: 'test-access-key',
+            awsSecretKey: 'test-secret-key',
+            apiModelId: 'test-model'
+        };
+        handler = new AwsBedrockHandler(options);
+        (handler as any).client = mockClient;
+    });
+
+    test('createMessage sends a streaming request correctly', async () => {
+        const mockStream = new Readable({
+            read() {
+                this.push(JSON.stringify({
+                    messageStart: { role: 'assistant' }
+                }));
+                this.push(JSON.stringify({
+                    contentBlockStart: {
+                        start: { text: 'Hello' }
+                    }
+                }));
+                this.push(JSON.stringify({
+                    contentBlockDelta: {
+                        delta: { text: ' world' }
+                    }
+                }));
+                this.push(JSON.stringify({
+                    messageStop: { stopReason: 'end_turn' }
+                }));
+                this.push(null);
+            }
+        });
+
+        mockClient.send.mockImplementation(() => 
+            Promise.resolve({
+                $metadata: {},
+                stream: mockStream
+            } as ConverseStreamCommandOutput)
+        );
+
+        const systemPrompt = 'Test system prompt';
+        const messages = [{ role: 'user' as const, content: 'Test message' }];
+
+        const stream = handler.createMessage(systemPrompt, messages);
+
+        // Collect all chunks
+        const chunks = [];
+        for await (const chunk of stream) {
+            chunks.push(chunk);
+        }
+
+        // Verify the command was sent correctly
+        expect(mockClient.send).toHaveBeenCalledWith(
+            expect.any(ConverseStreamCommand)
+        );
+
+        // Verify the stream chunks
+        expect(chunks).toEqual([
+            { type: 'text', text: 'Hello' },
+            { type: 'text', text: ' world' }
+        ]);
+    });
+
+    test('createMessage handles metadata events correctly', async () => {
+        const mockStream = new Readable({
+            read() {
+                this.push(JSON.stringify({
+                    metadata: {
+                        usage: {
+                            inputTokens: 10,
+                            outputTokens: 20,
+                            totalTokens: 30
+                        }
+                    }
+                }));
+                this.push(null);
+            }
+        });
+
+        mockClient.send.mockImplementation(() => 
+            Promise.resolve({
+                $metadata: {},
+                stream: mockStream
+            } as ConverseStreamCommandOutput)
+        );
+
+        const systemPrompt = 'Test system prompt';
+        const messages = [{ role: 'user' as const, content: 'Test message' }];
+
+        const stream = handler.createMessage(systemPrompt, messages);
+
+        const chunks = [];
+        for await (const chunk of stream) {
+            chunks.push(chunk);
+        }
+
+        expect(chunks).toEqual([
+            {
+                type: 'usage',
+                inputTokens: 10,
+                outputTokens: 20
+            }
+        ]);
+    });
+
+    test('createMessage handles errors during streaming', async () => {
+        mockClient.send.mockImplementation(() => 
+            Promise.reject(new Error('Test error'))
+        );
+
+        const systemPrompt = 'Test system prompt';
+        const messages = [{ role: 'user' as const, content: 'Test message' }];
+
+        await expect(handler.createMessage(systemPrompt, messages)).rejects.toThrow('Test error');
+    });
+
+    test('getModel returns correct model info', () => {
+        const modelInfo = handler.getModel();
+        expect(modelInfo).toEqual({
+            id: 'test-model',
+            info: expect.any(Object)
+        });
+    });
+
+    test('createMessage handles cross-region inference', async () => {
+        const options: ApiHandlerOptions = {
+            awsRegion: 'us-west-2',
+            awsAccessKey: 'test-access-key',
+            awsSecretKey: 'test-secret-key',
+            apiModelId: 'test-model',
+            awsUseCrossRegionInference: true
+        };
+        
+        handler = new AwsBedrockHandler(options);
+        (handler as any).client = mockClient;
+
+        const mockStream = new Readable({
+            read() {
+                this.push(JSON.stringify({
+                    contentBlockStart: {
+                        start: { text: 'Hello' }
+                    }
+                }));
+                this.push(null);
+            }
+        });
+
+        mockClient.send.mockImplementation(() => 
+            Promise.resolve({
+                $metadata: {},
+                stream: mockStream
+            } as ConverseStreamCommandOutput)
+        );
+
+        const systemPrompt = 'Test system prompt';
+        const messages = [{ role: 'user' as const, content: 'Test message' }];
+
+        await handler.createMessage(systemPrompt, messages);
+
+        expect(mockClient.send).toHaveBeenCalledWith(
+            expect.objectContaining({
+                input: expect.stringContaining('us.test-model')
+            })
+        );
+    });
+
+    test('createMessage includes prompt cache configuration when enabled', async () => {
+        const options: ApiHandlerOptions = {
+            awsRegion: 'us-west-2',
+            awsAccessKey: 'test-access-key',
+            awsSecretKey: 'test-secret-key',
+            apiModelId: 'test-model',
+            awsUsePromptCache: true,
+            awspromptCacheId: 'test-cache-id'
+        };
+        
+        handler = new AwsBedrockHandler(options);
+        (handler as any).client = mockClient;
+
+        const mockStream = new Readable({
+            read() {
+                this.push(null);
+            }
+        });
+
+        mockClient.send.mockImplementation(() => 
+            Promise.resolve({
+                $metadata: {},
+                stream: mockStream
+            } as ConverseStreamCommandOutput)
+        );
+
+        const systemPrompt = 'Test system prompt';
+        const messages = [{ role: 'user' as const, content: 'Test message' }];
+
+        await handler.createMessage(systemPrompt, messages);
+
+        expect(mockClient.send).toHaveBeenCalledWith(
+            expect.objectContaining({
+                input: expect.stringContaining('promptCacheId')
+            })
+        );
+    });
+});

+ 252 - 0
src/api/transform/__tests__/bedrock-converse-format.test.ts

@@ -0,0 +1,252 @@
+import { convertToBedrockConverseMessages, convertToAnthropicMessage } from '../bedrock-converse-format'
+import { Anthropic } from '@anthropic-ai/sdk'
+import { ContentBlock, ToolResultContentBlock } from '@aws-sdk/client-bedrock-runtime'
+import { StreamEvent } from '../../providers/bedrock'
+
+describe('bedrock-converse-format', () => {
+    describe('convertToBedrockConverseMessages', () => {
+        test('converts simple text messages correctly', () => {
+            const messages: Anthropic.Messages.MessageParam[] = [
+                { role: 'user', content: 'Hello' },
+                { role: 'assistant', content: 'Hi there' }
+            ]
+
+            const result = convertToBedrockConverseMessages(messages)
+
+            expect(result).toEqual([
+                {
+                    role: 'user',
+                    content: [{ text: 'Hello' }]
+                },
+                {
+                    role: 'assistant',
+                    content: [{ text: 'Hi there' }]
+                }
+            ])
+        })
+
+        test('converts messages with images correctly', () => {
+            const messages: Anthropic.Messages.MessageParam[] = [
+                {
+                    role: 'user',
+                    content: [
+                        {
+                            type: 'text',
+                            text: 'Look at this image:'
+                        },
+                        {
+                            type: 'image',
+                            source: {
+                                type: 'base64',
+                                data: 'SGVsbG8=', // "Hello" in base64
+                                media_type: 'image/jpeg' as const
+                            }
+                        }
+                    ]
+                }
+            ]
+
+            const result = convertToBedrockConverseMessages(messages)
+
+            if (!result[0] || !result[0].content) {
+                fail('Expected result to have content')
+                return
+            }
+
+            expect(result[0].role).toBe('user')
+            expect(result[0].content).toHaveLength(2)
+            expect(result[0].content[0]).toEqual({ text: 'Look at this image:' })
+            
+            const imageBlock = result[0].content[1] as ContentBlock
+            if ('image' in imageBlock && imageBlock.image && imageBlock.image.source) {
+                expect(imageBlock.image.format).toBe('jpeg')
+                expect(imageBlock.image.source).toBeDefined()
+                expect(imageBlock.image.source.bytes).toBeDefined()
+            } else {
+                fail('Expected image block not found')
+            }
+        })
+
+        test('converts tool use messages correctly', () => {
+            const messages: Anthropic.Messages.MessageParam[] = [
+                {
+                    role: 'assistant',
+                    content: [
+                        {
+                            type: 'tool_use',
+                            id: 'test-id',
+                            name: 'read_file',
+                            input: {
+                                path: 'test.txt'
+                            }
+                        }
+                    ]
+                }
+            ]
+
+            const result = convertToBedrockConverseMessages(messages)
+
+            if (!result[0] || !result[0].content) {
+                fail('Expected result to have content')
+                return
+            }
+
+            expect(result[0].role).toBe('assistant')
+            const toolBlock = result[0].content[0] as ContentBlock
+            if ('toolUse' in toolBlock && toolBlock.toolUse) {
+                expect(toolBlock.toolUse).toEqual({
+                    toolUseId: 'test-id',
+                    name: 'read_file',
+                    input: '<read_file>\n<path>\ntest.txt\n</path>\n</read_file>'
+                })
+            } else {
+                fail('Expected tool use block not found')
+            }
+        })
+
+        test('converts tool result messages correctly', () => {
+            const messages: Anthropic.Messages.MessageParam[] = [
+                {
+                    role: 'assistant',
+                    content: [
+                        {
+                            type: 'tool_result',
+                            tool_use_id: 'test-id',
+                            content: [{ type: 'text', text: 'File contents here' }]
+                        }
+                    ]
+                }
+            ]
+
+            const result = convertToBedrockConverseMessages(messages)
+
+            if (!result[0] || !result[0].content) {
+                fail('Expected result to have content')
+                return
+            }
+
+            expect(result[0].role).toBe('assistant')
+            const resultBlock = result[0].content[0] as ContentBlock
+            if ('toolResult' in resultBlock && resultBlock.toolResult) {
+                const expectedContent: ToolResultContentBlock[] = [
+                    { text: 'File contents here' }
+                ]
+                expect(resultBlock.toolResult).toEqual({
+                    toolUseId: 'test-id',
+                    content: expectedContent,
+                    status: 'success'
+                })
+            } else {
+                fail('Expected tool result block not found')
+            }
+        })
+
+        test('handles text content correctly', () => {
+            const messages: Anthropic.Messages.MessageParam[] = [
+                {
+                    role: 'user',
+                    content: [
+                        {
+                            type: 'text',
+                            text: 'Hello world'
+                        }
+                    ]
+                }
+            ]
+
+            const result = convertToBedrockConverseMessages(messages)
+
+            if (!result[0] || !result[0].content) {
+                fail('Expected result to have content')
+                return
+            }
+
+            expect(result[0].role).toBe('user')
+            expect(result[0].content).toHaveLength(1)
+            const textBlock = result[0].content[0] as ContentBlock
+            expect(textBlock).toEqual({ text: 'Hello world' })
+        })
+    })
+
+    describe('convertToAnthropicMessage', () => {
+        test('converts metadata events correctly', () => {
+            const event: StreamEvent = {
+                metadata: {
+                    usage: {
+                        inputTokens: 10,
+                        outputTokens: 20
+                    }
+                }
+            }
+
+            const result = convertToAnthropicMessage(event, 'test-model')
+
+            expect(result).toEqual({
+                id: '',
+                type: 'message',
+                role: 'assistant',
+                model: 'test-model',
+                usage: {
+                    input_tokens: 10,
+                    output_tokens: 20
+                }
+            })
+        })
+
+        test('converts content block start events correctly', () => {
+            const event: StreamEvent = {
+                contentBlockStart: {
+                    start: {
+                        text: 'Hello'
+                    }
+                }
+            }
+
+            const result = convertToAnthropicMessage(event, 'test-model')
+
+            expect(result).toEqual({
+                type: 'message',
+                role: 'assistant',
+                content: [{ type: 'text', text: 'Hello' }],
+                model: 'test-model'
+            })
+        })
+
+        test('converts content block delta events correctly', () => {
+            const event: StreamEvent = {
+                contentBlockDelta: {
+                    delta: {
+                        text: ' world'
+                    }
+                }
+            }
+
+            const result = convertToAnthropicMessage(event, 'test-model')
+
+            expect(result).toEqual({
+                type: 'message',
+                role: 'assistant',
+                content: [{ type: 'text', text: ' world' }],
+                model: 'test-model'
+            })
+        })
+
+        test('converts message stop events correctly', () => {
+            const event: StreamEvent = {
+                messageStop: {
+                    stopReason: 'end_turn' as const
+                }
+            }
+
+            const result = convertToAnthropicMessage(event, 'test-model')
+
+            expect(result).toEqual({
+                type: 'message',
+                role: 'assistant',
+                stop_reason: 'end_turn',
+                stop_sequence: null,
+                model: 'test-model'
+            })
+        })
+    })
+})