Prechádzať zdrojové kódy

Add non-streaming completePrompt to all providers

Matt Rubens 11 mesiacov pred
rodič
commit
4027e1c10c

+ 72 - 1
src/api/providers/__tests__/anthropic.test.ts

@@ -46,7 +46,42 @@ jest.mock('@anthropic-ai/sdk', () => {
                 }
             },
             messages: {
-                create: mockCreate
+                create: mockCreate.mockImplementation(async (options) => {
+                    if (!options.stream) {
+                        return {
+                            id: 'test-completion',
+                            content: [
+                                { type: 'text', text: 'Test response' }
+                            ],
+                            role: 'assistant',
+                            model: options.model,
+                            usage: {
+                                input_tokens: 10,
+                                output_tokens: 5
+                            }
+                        }
+                    }
+                    return {
+                        async *[Symbol.asyncIterator]() {
+                            yield {
+                                type: 'message_start',
+                                message: {
+                                    usage: {
+                                        input_tokens: 10,
+                                        output_tokens: 5
+                                    }
+                                }
+                            }
+                            yield {
+                                type: 'content_block_start',
+                                content_block: {
+                                    type: 'text',
+                                    text: 'Test response'
+                                }
+                            }
+                        }
+                    }
+                })
             }
         }))
     };
@@ -144,6 +179,42 @@ describe('AnthropicHandler', () => {
         });
     });
 
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: mockOptions.apiModelId,
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                max_tokens: 8192,
+                temperature: 0,
+                stream: false
+            });
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('Anthropic completion error: API Error');
+        });
+
+        it('should handle non-text content', async () => {
+            mockCreate.mockImplementationOnce(async () => ({
+                content: [{ type: 'image' }]
+            }));
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+
+        it('should handle empty response', async () => {
+            mockCreate.mockImplementationOnce(async () => ({
+                content: [{ type: 'text', text: '' }]
+            }));
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+    });
+
     describe('getModel', () => {
         it('should return default model if no model ID is provided', () => {
             const handlerWithoutModel = new AnthropicHandler({

+ 102 - 0
src/api/providers/__tests__/bedrock.test.ts

@@ -119,6 +119,108 @@ describe('AwsBedrockHandler', () => {
         });
     });
 
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const mockResponse = {
+                output: new TextEncoder().encode(JSON.stringify({
+                    content: 'Test response'
+                }))
+            };
+
+            const mockSend = jest.fn().mockResolvedValue(mockResponse);
+            handler['client'] = {
+                send: mockSend
+            } as unknown as BedrockRuntimeClient;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
+                input: expect.objectContaining({
+                    modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
+                    messages: expect.arrayContaining([
+                        expect.objectContaining({
+                            role: 'user',
+                            content: [{ text: 'Test prompt' }]
+                        })
+                    ]),
+                    inferenceConfig: expect.objectContaining({
+                        maxTokens: 5000,
+                        temperature: 0.3,
+                        topP: 0.1
+                    })
+                })
+            }));
+        });
+
+        it('should handle API errors', async () => {
+            const mockError = new Error('AWS Bedrock error');
+            const mockSend = jest.fn().mockRejectedValue(mockError);
+            handler['client'] = {
+                send: mockSend
+            } as unknown as BedrockRuntimeClient;
+
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('Bedrock completion error: AWS Bedrock error');
+        });
+
+        it('should handle invalid response format', async () => {
+            const mockResponse = {
+                output: new TextEncoder().encode('invalid json')
+            };
+
+            const mockSend = jest.fn().mockResolvedValue(mockResponse);
+            handler['client'] = {
+                send: mockSend
+            } as unknown as BedrockRuntimeClient;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+
+        it('should handle empty response', async () => {
+            const mockResponse = {
+                output: new TextEncoder().encode(JSON.stringify({}))
+            };
+
+            const mockSend = jest.fn().mockResolvedValue(mockResponse);
+            handler['client'] = {
+                send: mockSend
+            } as unknown as BedrockRuntimeClient;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+
+        it('should handle cross-region inference', async () => {
+            handler = new AwsBedrockHandler({
+                apiModelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0',
+                awsAccessKey: 'test-access-key',
+                awsSecretKey: 'test-secret-key',
+                awsRegion: 'us-east-1',
+                awsUseCrossRegionInference: true
+            });
+
+            const mockResponse = {
+                output: new TextEncoder().encode(JSON.stringify({
+                    content: 'Test response'
+                }))
+            };
+
+            const mockSend = jest.fn().mockResolvedValue(mockResponse);
+            handler['client'] = {
+                send: mockSend
+            } as unknown as BedrockRuntimeClient;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockSend).toHaveBeenCalledWith(expect.objectContaining({
+                input: expect.objectContaining({
+                    modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0'
+                })
+            }));
+        });
+    });
+
     describe('getModel', () => {
         it('should return correct model info in test environment', () => {
             const modelInfo = handler.getModel();

+ 59 - 1
src/api/providers/__tests__/gemini.test.ts

@@ -6,7 +6,12 @@ import { GoogleGenerativeAI } from '@google/generative-ai';
 jest.mock('@google/generative-ai', () => ({
     GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
         getGenerativeModel: jest.fn().mockReturnValue({
-            generateContentStream: jest.fn()
+            generateContentStream: jest.fn(),
+            generateContent: jest.fn().mockResolvedValue({
+                response: {
+                    text: () => 'Test response'
+                }
+            })
         })
     }))
 }));
@@ -133,6 +138,59 @@ describe('GeminiHandler', () => {
         });
     });
 
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const mockGenerateContent = jest.fn().mockResolvedValue({
+                response: {
+                    text: () => 'Test response'
+                }
+            });
+            const mockGetGenerativeModel = jest.fn().mockReturnValue({
+                generateContent: mockGenerateContent
+            });
+            (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockGetGenerativeModel).toHaveBeenCalledWith({
+                model: 'gemini-2.0-flash-thinking-exp-1219'
+            });
+            expect(mockGenerateContent).toHaveBeenCalledWith({
+                contents: [{ role: 'user', parts: [{ text: 'Test prompt' }] }],
+                generationConfig: {
+                    temperature: 0
+                }
+            });
+        });
+
+        it('should handle API errors', async () => {
+            const mockError = new Error('Gemini API error');
+            const mockGenerateContent = jest.fn().mockRejectedValue(mockError);
+            const mockGetGenerativeModel = jest.fn().mockReturnValue({
+                generateContent: mockGenerateContent
+            });
+            (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
+
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('Gemini completion error: Gemini API error');
+        });
+
+        it('should handle empty response', async () => {
+            const mockGenerateContent = jest.fn().mockResolvedValue({
+                response: {
+                    text: () => ''
+                }
+            });
+            const mockGetGenerativeModel = jest.fn().mockReturnValue({
+                generateContent: mockGenerateContent
+            });
+            (handler['client'] as any).getGenerativeModel = mockGetGenerativeModel;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+    });
+
     describe('getModel', () => {
         it('should return correct model info', () => {
             const modelInfo = handler.getModel();

+ 226 - 0
src/api/providers/__tests__/glama.test.ts

@@ -0,0 +1,226 @@
+import { GlamaHandler } from '../glama';
+import { ApiHandlerOptions } from '../../../shared/api';
+import OpenAI from 'openai';
+import { Anthropic } from '@anthropic-ai/sdk';
+import axios from 'axios';
+
+// Mock OpenAI client
+const mockCreate = jest.fn();
+const mockWithResponse = jest.fn();
+
+jest.mock('openai', () => {
+    return {
+        __esModule: true,
+        default: jest.fn().mockImplementation(() => ({
+            chat: {
+                completions: {
+                    create: (...args: any[]) => {
+                        const stream = {
+                            [Symbol.asyncIterator]: async function* () {
+                                yield {
+                                    choices: [{
+                                        delta: { content: 'Test response' },
+                                        index: 0
+                                    }],
+                                    usage: null
+                                };
+                                yield {
+                                    choices: [{
+                                        delta: {},
+                                        index: 0
+                                    }],
+                                    usage: {
+                                        prompt_tokens: 10,
+                                        completion_tokens: 5,
+                                        total_tokens: 15
+                                    }
+                                };
+                            }
+                        };
+
+                        const result = mockCreate(...args);
+                        if (args[0].stream) {
+                            mockWithResponse.mockReturnValue(Promise.resolve({
+                                data: stream,
+                                response: {
+                                    headers: {
+                                        get: (name: string) => name === 'x-completion-request-id' ? 'test-request-id' : null
+                                    }
+                                }
+                            }));
+                            result.withResponse = mockWithResponse;
+                        }
+                        return result;
+                    }
+                }
+            }
+        }))
+    };
+});
+
+describe('GlamaHandler', () => {
+    let handler: GlamaHandler;
+    let mockOptions: ApiHandlerOptions;
+
+    beforeEach(() => {
+        mockOptions = {
+            apiModelId: 'anthropic/claude-3-5-sonnet',
+            glamaModelId: 'anthropic/claude-3-5-sonnet',
+            glamaApiKey: 'test-api-key'
+        };
+        handler = new GlamaHandler(mockOptions);
+        mockCreate.mockClear();
+        mockWithResponse.mockClear();
+
+        // Default mock implementation for non-streaming responses
+        mockCreate.mockResolvedValue({
+            id: 'test-completion',
+            choices: [{
+                message: { role: 'assistant', content: 'Test response' },
+                finish_reason: 'stop',
+                index: 0
+            }],
+            usage: {
+                prompt_tokens: 10,
+                completion_tokens: 5,
+                total_tokens: 15
+            }
+        });
+    });
+
+    describe('constructor', () => {
+        it('should initialize with provided options', () => {
+            expect(handler).toBeInstanceOf(GlamaHandler);
+            expect(handler.getModel().id).toBe(mockOptions.apiModelId);
+        });
+    });
+
+    describe('createMessage', () => {
+        const systemPrompt = 'You are a helpful assistant.';
+        const messages: Anthropic.Messages.MessageParam[] = [
+            {
+                role: 'user',
+                content: 'Hello!'
+            }
+        ];
+
+        it('should handle streaming responses', async () => {
+            // Mock axios for token usage request
+            const mockAxios = jest.spyOn(axios, 'get').mockResolvedValueOnce({
+                data: {
+                    tokenUsage: {
+                        promptTokens: 10,
+                        completionTokens: 5,
+                        cacheCreationInputTokens: 0,
+                        cacheReadInputTokens: 0
+                    },
+                    totalCostUsd: "0.00"
+                }
+            });
+
+            const stream = handler.createMessage(systemPrompt, messages);
+            const chunks: any[] = [];
+            for await (const chunk of stream) {
+                chunks.push(chunk);
+            }
+
+            expect(chunks.length).toBe(2); // Text chunk and usage chunk
+            expect(chunks[0]).toEqual({
+                type: 'text',
+                text: 'Test response'
+            });
+            expect(chunks[1]).toEqual({
+                type: 'usage',
+                inputTokens: 10,
+                outputTokens: 5,
+                cacheWriteTokens: 0,
+                cacheReadTokens: 0,
+                totalCost: 0
+            });
+
+            mockAxios.mockRestore();
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockImplementationOnce(() => {
+                throw new Error('API Error');
+            });
+
+            const stream = handler.createMessage(systemPrompt, messages);
+            const chunks = [];
+
+            try {
+                for await (const chunk of stream) {
+                    chunks.push(chunk);
+                }
+                fail('Expected error to be thrown');
+            } catch (error) {
+                expect(error).toBeInstanceOf(Error);
+                expect(error.message).toBe('API Error');
+            }
+        });
+    });
+
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
+                model: mockOptions.apiModelId,
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                temperature: 0,
+                max_tokens: 8192
+            }));
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('Glama completion error: API Error');
+        });
+
+        it('should handle empty response', async () => {
+            mockCreate.mockResolvedValueOnce({
+                choices: [{ message: { content: '' } }]
+            });
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+
+        it('should not set max_tokens for non-Anthropic models', async () => {
+            // Reset mock to clear any previous calls
+            mockCreate.mockClear();
+            
+            const nonAnthropicOptions = {
+                apiModelId: 'openai/gpt-4',
+                glamaModelId: 'openai/gpt-4',
+                glamaApiKey: 'test-key',
+                glamaModelInfo: {
+                    maxTokens: 4096,
+                    contextWindow: 8192,
+                    supportsImages: true,
+                    supportsPromptCache: false
+                }
+            };
+            const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions);
+
+            await nonAnthropicHandler.completePrompt('Test prompt');
+            expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
+                model: 'openai/gpt-4',
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                temperature: 0
+            }));
+            expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_tokens');
+        });
+    });
+
+    describe('getModel', () => {
+        it('should return model info', () => {
+            const modelInfo = handler.getModel();
+            expect(modelInfo.id).toBe(mockOptions.apiModelId);
+            expect(modelInfo.info).toBeDefined();
+            expect(modelInfo.info.maxTokens).toBe(8192);
+            expect(modelInfo.info.contextWindow).toBe(200_000);
+        });
+    });
+});

+ 112 - 100
src/api/providers/__tests__/lmstudio.test.ts

@@ -1,148 +1,160 @@
 import { LmStudioHandler } from '../lmstudio';
-import { Anthropic } from '@anthropic-ai/sdk';
+import { ApiHandlerOptions } from '../../../shared/api';
 import OpenAI from 'openai';
+import { Anthropic } from '@anthropic-ai/sdk';
 
-// Mock OpenAI SDK
-jest.mock('openai', () => ({
-    __esModule: true,
-    default: jest.fn().mockImplementation(() => ({
-        chat: {
-            completions: {
-                create: jest.fn()
+// Mock OpenAI client
+const mockCreate = jest.fn();
+jest.mock('openai', () => {
+    return {
+        __esModule: true,
+        default: jest.fn().mockImplementation(() => ({
+            chat: {
+                completions: {
+                    create: mockCreate.mockImplementation(async (options) => {
+                        if (!options.stream) {
+                            return {
+                                id: 'test-completion',
+                                choices: [{
+                                    message: { role: 'assistant', content: 'Test response' },
+                                    finish_reason: 'stop',
+                                    index: 0
+                                }],
+                                usage: {
+                                    prompt_tokens: 10,
+                                    completion_tokens: 5,
+                                    total_tokens: 15
+                                }
+                            };
+                        }
+                        
+                        return {
+                            [Symbol.asyncIterator]: async function* () {
+                                yield {
+                                    choices: [{
+                                        delta: { content: 'Test response' },
+                                        index: 0
+                                    }],
+                                    usage: null
+                                };
+                                yield {
+                                    choices: [{
+                                        delta: {},
+                                        index: 0
+                                    }],
+                                    usage: {
+                                        prompt_tokens: 10,
+                                        completion_tokens: 5,
+                                        total_tokens: 15
+                                    }
+                                };
+                            }
+                        };
+                    })
+                }
             }
-        }
-    }))
-}));
+        }))
+    };
+});
 
 describe('LmStudioHandler', () => {
     let handler: LmStudioHandler;
+    let mockOptions: ApiHandlerOptions;
 
     beforeEach(() => {
-        handler = new LmStudioHandler({
-            lmStudioModelId: 'mistral-7b',
-            lmStudioBaseUrl: 'http://localhost:1234'
-        });
+        mockOptions = {
+            apiModelId: 'local-model',
+            lmStudioModelId: 'local-model',
+            lmStudioBaseUrl: 'http://localhost:1234/v1'
+        };
+        handler = new LmStudioHandler(mockOptions);
+        mockCreate.mockClear();
     });
 
     describe('constructor', () => {
-        it('should initialize with provided config', () => {
-            expect(OpenAI).toHaveBeenCalledWith({
-                baseURL: 'http://localhost:1234/v1',
-                apiKey: 'noop'
-            });
+        it('should initialize with provided options', () => {
+            expect(handler).toBeInstanceOf(LmStudioHandler);
+            expect(handler.getModel().id).toBe(mockOptions.lmStudioModelId);
         });
 
         it('should use default base URL if not provided', () => {
-            const defaultHandler = new LmStudioHandler({
-                lmStudioModelId: 'mistral-7b'
-            });
-
-            expect(OpenAI).toHaveBeenCalledWith({
-                baseURL: 'http://localhost:1234/v1',
-                apiKey: 'noop'
+            const handlerWithoutUrl = new LmStudioHandler({
+                apiModelId: 'local-model',
+                lmStudioModelId: 'local-model'
             });
+            expect(handlerWithoutUrl).toBeInstanceOf(LmStudioHandler);
         });
     });
 
     describe('createMessage', () => {
-        const mockMessages: Anthropic.Messages.MessageParam[] = [
+        const systemPrompt = 'You are a helpful assistant.';
+        const messages: Anthropic.Messages.MessageParam[] = [
             {
                 role: 'user',
-                content: 'Hello'
-            },
-            {
-                role: 'assistant',
-                content: 'Hi there!'
+                content: 'Hello!'
             }
         ];
 
-        const systemPrompt = 'You are a helpful assistant';
-
-        it('should handle streaming responses correctly', async () => {
-            const mockStream = [
-                {
-                    choices: [{
-                        delta: { content: 'Hello' }
-                    }]
-                },
-                {
-                    choices: [{
-                        delta: { content: ' world!' }
-                    }]
-                }
-            ];
-
-            // Setup async iterator for mock stream
-            const asyncIterator = {
-                async *[Symbol.asyncIterator]() {
-                    for (const chunk of mockStream) {
-                        yield chunk;
-                    }
-                }
-            };
-
-            const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
-            (handler['client'].chat.completions as any).create = mockCreate;
-
-            const stream = handler.createMessage(systemPrompt, mockMessages);
-            const chunks = [];
-            
+        it('should handle streaming responses', async () => {
+            const stream = handler.createMessage(systemPrompt, messages);
+            const chunks: any[] = [];
             for await (const chunk of stream) {
                 chunks.push(chunk);
             }
 
-            expect(chunks.length).toBe(2);
-            expect(chunks[0]).toEqual({
-                type: 'text',
-                text: 'Hello'
-            });
-            expect(chunks[1]).toEqual({
-                type: 'text',
-                text: ' world!'
-            });
-
-            expect(mockCreate).toHaveBeenCalledWith({
-                model: 'mistral-7b',
-                messages: expect.arrayContaining([
-                    {
-                        role: 'system',
-                        content: systemPrompt
-                    }
-                ]),
-                temperature: 0,
-                stream: true
-            });
+            expect(chunks.length).toBeGreaterThan(0);
+            const textChunks = chunks.filter(chunk => chunk.type === 'text');
+            expect(textChunks).toHaveLength(1);
+            expect(textChunks[0].text).toBe('Test response');
         });
 
-        it('should handle API errors with custom message', async () => {
-            const mockError = new Error('LM Studio API error');
-            const mockCreate = jest.fn().mockRejectedValue(mockError);
-            (handler['client'].chat.completions as any).create = mockCreate;
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
 
-            const stream = handler.createMessage(systemPrompt, mockMessages);
+            const stream = handler.createMessage(systemPrompt, messages);
 
             await expect(async () => {
                 for await (const chunk of stream) {
-                    // Should throw before yielding any chunks
+                    // Should not reach here
                 }
             }).rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
         });
     });
 
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: mockOptions.lmStudioModelId,
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                temperature: 0,
+                stream: false
+            });
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('Please check the LM Studio developer logs to debug what went wrong');
+        });
+
+        it('should handle empty response', async () => {
+            mockCreate.mockResolvedValueOnce({
+                choices: [{ message: { content: '' } }]
+            });
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+    });
+
     describe('getModel', () => {
-        it('should return model info with sane defaults', () => {
+        it('should return model info', () => {
             const modelInfo = handler.getModel();
-            expect(modelInfo.id).toBe('mistral-7b');
+            expect(modelInfo.id).toBe(mockOptions.lmStudioModelId);
             expect(modelInfo.info).toBeDefined();
             expect(modelInfo.info.maxTokens).toBe(-1);
             expect(modelInfo.info.contextWindow).toBe(128_000);
         });
-
-        it('should return empty string as model ID if not provided', () => {
-            const noModelHandler = new LmStudioHandler({});
-            const modelInfo = noModelHandler.getModel();
-            expect(modelInfo.id).toBe('');
-            expect(modelInfo.info).toBeDefined();
-        });
     });
 });

+ 110 - 98
src/api/providers/__tests__/ollama.test.ts

@@ -1,148 +1,160 @@
 import { OllamaHandler } from '../ollama';
-import { Anthropic } from '@anthropic-ai/sdk';
+import { ApiHandlerOptions } from '../../../shared/api';
 import OpenAI from 'openai';
+import { Anthropic } from '@anthropic-ai/sdk';
 
-// Mock OpenAI SDK
-jest.mock('openai', () => ({
-    __esModule: true,
-    default: jest.fn().mockImplementation(() => ({
-        chat: {
-            completions: {
-                create: jest.fn()
+// Mock OpenAI client
+const mockCreate = jest.fn();
+jest.mock('openai', () => {
+    return {
+        __esModule: true,
+        default: jest.fn().mockImplementation(() => ({
+            chat: {
+                completions: {
+                    create: mockCreate.mockImplementation(async (options) => {
+                        if (!options.stream) {
+                            return {
+                                id: 'test-completion',
+                                choices: [{
+                                    message: { role: 'assistant', content: 'Test response' },
+                                    finish_reason: 'stop',
+                                    index: 0
+                                }],
+                                usage: {
+                                    prompt_tokens: 10,
+                                    completion_tokens: 5,
+                                    total_tokens: 15
+                                }
+                            };
+                        }
+                        
+                        return {
+                            [Symbol.asyncIterator]: async function* () {
+                                yield {
+                                    choices: [{
+                                        delta: { content: 'Test response' },
+                                        index: 0
+                                    }],
+                                    usage: null
+                                };
+                                yield {
+                                    choices: [{
+                                        delta: {},
+                                        index: 0
+                                    }],
+                                    usage: {
+                                        prompt_tokens: 10,
+                                        completion_tokens: 5,
+                                        total_tokens: 15
+                                    }
+                                };
+                            }
+                        };
+                    })
+                }
             }
-        }
-    }))
-}));
+        }))
+    };
+});
 
 describe('OllamaHandler', () => {
     let handler: OllamaHandler;
+    let mockOptions: ApiHandlerOptions;
 
     beforeEach(() => {
-        handler = new OllamaHandler({
+        mockOptions = {
+            apiModelId: 'llama2',
             ollamaModelId: 'llama2',
-            ollamaBaseUrl: 'http://localhost:11434'
-        });
+            ollamaBaseUrl: 'http://localhost:11434/v1'
+        };
+        handler = new OllamaHandler(mockOptions);
+        mockCreate.mockClear();
     });
 
     describe('constructor', () => {
-        it('should initialize with provided config', () => {
-            expect(OpenAI).toHaveBeenCalledWith({
-                baseURL: 'http://localhost:11434/v1',
-                apiKey: 'ollama'
-            });
+        it('should initialize with provided options', () => {
+            expect(handler).toBeInstanceOf(OllamaHandler);
+            expect(handler.getModel().id).toBe(mockOptions.ollamaModelId);
         });
 
         it('should use default base URL if not provided', () => {
-            const defaultHandler = new OllamaHandler({
+            const handlerWithoutUrl = new OllamaHandler({
+                apiModelId: 'llama2',
                 ollamaModelId: 'llama2'
             });
-
-            expect(OpenAI).toHaveBeenCalledWith({
-                baseURL: 'http://localhost:11434/v1',
-                apiKey: 'ollama'
-            });
+            expect(handlerWithoutUrl).toBeInstanceOf(OllamaHandler);
         });
     });
 
     describe('createMessage', () => {
-        const mockMessages: Anthropic.Messages.MessageParam[] = [
+        const systemPrompt = 'You are a helpful assistant.';
+        const messages: Anthropic.Messages.MessageParam[] = [
             {
                 role: 'user',
-                content: 'Hello'
-            },
-            {
-                role: 'assistant',
-                content: 'Hi there!'
+                content: 'Hello!'
             }
         ];
 
-        const systemPrompt = 'You are a helpful assistant';
-
-        it('should handle streaming responses correctly', async () => {
-            const mockStream = [
-                {
-                    choices: [{
-                        delta: { content: 'Hello' }
-                    }]
-                },
-                {
-                    choices: [{
-                        delta: { content: ' world!' }
-                    }]
-                }
-            ];
-
-            // Setup async iterator for mock stream
-            const asyncIterator = {
-                async *[Symbol.asyncIterator]() {
-                    for (const chunk of mockStream) {
-                        yield chunk;
-                    }
-                }
-            };
-
-            const mockCreate = jest.fn().mockResolvedValue(asyncIterator);
-            (handler['client'].chat.completions as any).create = mockCreate;
-
-            const stream = handler.createMessage(systemPrompt, mockMessages);
-            const chunks = [];
-            
+        it('should handle streaming responses', async () => {
+            const stream = handler.createMessage(systemPrompt, messages);
+            const chunks: any[] = [];
             for await (const chunk of stream) {
                 chunks.push(chunk);
             }
 
-            expect(chunks.length).toBe(2);
-            expect(chunks[0]).toEqual({
-                type: 'text',
-                text: 'Hello'
-            });
-            expect(chunks[1]).toEqual({
-                type: 'text',
-                text: ' world!'
-            });
-
-            expect(mockCreate).toHaveBeenCalledWith({
-                model: 'llama2',
-                messages: expect.arrayContaining([
-                    {
-                        role: 'system',
-                        content: systemPrompt
-                    }
-                ]),
-                temperature: 0,
-                stream: true
-            });
+            expect(chunks.length).toBeGreaterThan(0);
+            const textChunks = chunks.filter(chunk => chunk.type === 'text');
+            expect(textChunks).toHaveLength(1);
+            expect(textChunks[0].text).toBe('Test response');
         });
 
         it('should handle API errors', async () => {
-            const mockError = new Error('Ollama API error');
-            const mockCreate = jest.fn().mockRejectedValue(mockError);
-            (handler['client'].chat.completions as any).create = mockCreate;
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
 
-            const stream = handler.createMessage(systemPrompt, mockMessages);
+            const stream = handler.createMessage(systemPrompt, messages);
 
             await expect(async () => {
                 for await (const chunk of stream) {
-                    // Should throw before yielding any chunks
+                    // Should not reach here
                 }
-            }).rejects.toThrow('Ollama API error');
+            }).rejects.toThrow('API Error');
+        });
+    });
+
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: mockOptions.ollamaModelId,
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                temperature: 0,
+                stream: false
+            });
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('Ollama completion error: API Error');
+        });
+
+        it('should handle empty response', async () => {
+            mockCreate.mockResolvedValueOnce({
+                choices: [{ message: { content: '' } }]
+            });
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
         });
     });
 
     describe('getModel', () => {
-        it('should return model info with sane defaults', () => {
+        it('should return model info', () => {
             const modelInfo = handler.getModel();
-            expect(modelInfo.id).toBe('llama2');
+            expect(modelInfo.id).toBe(mockOptions.ollamaModelId);
             expect(modelInfo.info).toBeDefined();
             expect(modelInfo.info.maxTokens).toBe(-1);
             expect(modelInfo.info.contextWindow).toBe(128_000);
         });
-
-        it('should return empty string as model ID if not provided', () => {
-            const noModelHandler = new OllamaHandler({});
-            const modelInfo = noModelHandler.getModel();
-            expect(modelInfo.id).toBe('');
-            expect(modelInfo.info).toBeDefined();
-        });
     });
 });

+ 200 - 221
src/api/providers/__tests__/openai-native.test.ts

@@ -1,230 +1,209 @@
-import { OpenAiNativeHandler } from "../openai-native"
-import OpenAI from "openai"
-import { ApiHandlerOptions, openAiNativeDefaultModelId } from "../../../shared/api"
-import { Anthropic } from "@anthropic-ai/sdk"
-
-// Mock OpenAI
-jest.mock("openai")
-
-describe("OpenAiNativeHandler", () => {
-    let handler: OpenAiNativeHandler
-    let mockOptions: ApiHandlerOptions
-    let mockOpenAIClient: jest.Mocked<OpenAI>
-    let mockCreate: jest.Mock
-
-    beforeEach(() => {
-        // Reset mocks
-        jest.clearAllMocks()
-
-        // Setup mock options
-        mockOptions = {
-            openAiNativeApiKey: "test-api-key",
-            apiModelId: "gpt-4o", // Use the correct model ID from shared/api.ts
-        }
-
-        // Setup mock create function
-        mockCreate = jest.fn()
-
-        // Setup mock OpenAI client
-        mockOpenAIClient = {
+import { OpenAiNativeHandler } from '../openai-native';
+import { ApiHandlerOptions } from '../../../shared/api';
+import OpenAI from 'openai';
+import { Anthropic } from '@anthropic-ai/sdk';
+
+// Mock OpenAI client
+const mockCreate = jest.fn();
+jest.mock('openai', () => {
+    return {
+        __esModule: true,
+        default: jest.fn().mockImplementation(() => ({
             chat: {
                 completions: {
-                    create: mockCreate,
-                },
-            },
-        } as unknown as jest.Mocked<OpenAI>
-
-        // Mock OpenAI constructor
-        ;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIClient)
-
-        // Create handler instance
-        handler = new OpenAiNativeHandler(mockOptions)
-    })
-
-    describe("constructor", () => {
-        it("should initialize with provided options", () => {
-            expect(OpenAI).toHaveBeenCalledWith({
-                apiKey: mockOptions.openAiNativeApiKey,
-            })
-        })
-    })
-
-    describe("getModel", () => {
-        it("should return specified model when valid", () => {
-            const result = handler.getModel()
-            expect(result.id).toBe("gpt-4o") // Use the correct model ID
-        })
-
-        it("should return default model when model ID is invalid", () => {
-            handler = new OpenAiNativeHandler({
-                ...mockOptions,
-                apiModelId: "invalid-model" as any,
-            })
-            const result = handler.getModel()
-            expect(result.id).toBe(openAiNativeDefaultModelId)
-        })
-
-        it("should return default model when model ID is not provided", () => {
-            handler = new OpenAiNativeHandler({
-                ...mockOptions,
-                apiModelId: undefined,
-            })
-            const result = handler.getModel()
-            expect(result.id).toBe(openAiNativeDefaultModelId)
-        })
-    })
-
-    describe("createMessage", () => {
-        const systemPrompt = "You are a helpful assistant"
-        const messages: Anthropic.Messages.MessageParam[] = [
-            { role: "user", content: "Hello" },
-        ]
-
-        describe("o1 models", () => {
-            beforeEach(() => {
-                handler = new OpenAiNativeHandler({
-                    ...mockOptions,
-                    apiModelId: "o1-preview",
-                })
-            })
-
-            it("should handle non-streaming response for o1 models", async () => {
-                const mockResponse = {
-                    choices: [{ message: { content: "Hello there!" } }],
-                    usage: {
-                        prompt_tokens: 10,
-                        completion_tokens: 5,
-                    },
-                }
-
-                mockCreate.mockResolvedValueOnce(mockResponse)
-
-                const generator = handler.createMessage(systemPrompt, messages)
-                const results = []
-                for await (const result of generator) {
-                    results.push(result)
-                }
-
-                expect(results).toEqual([
-                    { type: "text", text: "Hello there!" },
-                    { type: "usage", inputTokens: 10, outputTokens: 5 },
-                ])
-
-                expect(mockCreate).toHaveBeenCalledWith({
-                    model: "o1-preview",
-                    messages: [
-                        { role: "user", content: systemPrompt },
-                        { role: "user", content: "Hello" },
-                    ],
-                })
-            })
-
-            it("should handle missing content in response", async () => {
-                const mockResponse = {
-                    choices: [{ message: { content: null } }],
-                    usage: null,
-                }
-
-                mockCreate.mockResolvedValueOnce(mockResponse)
-
-                const generator = handler.createMessage(systemPrompt, messages)
-                const results = []
-                for await (const result of generator) {
-                    results.push(result)
-                }
-
-                expect(results).toEqual([
-                    { type: "text", text: "" },
-                    { type: "usage", inputTokens: 0, outputTokens: 0 },
-                ])
-            })
-        })
-
-        describe("streaming models", () => {
-            beforeEach(() => {
-                handler = new OpenAiNativeHandler({
-                    ...mockOptions,
-                    apiModelId: "gpt-4o",
-                })
-            })
-
-            it("should handle streaming response", async () => {
-                const mockStream = [
-                    { choices: [{ delta: { content: "Hello" } }], usage: null },
-                    { choices: [{ delta: { content: " there" } }], usage: null },
-                    { choices: [{ delta: { content: "!" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
-                ]
-
-                mockCreate.mockResolvedValueOnce(
-                    (async function* () {
-                        for (const chunk of mockStream) {
-                            yield chunk
-                        }
-                    })()
-                )
-
-                const generator = handler.createMessage(systemPrompt, messages)
-                const results = []
-                for await (const result of generator) {
-                    results.push(result)
-                }
-
-                expect(results).toEqual([
-                    { type: "text", text: "Hello" },
-                    { type: "text", text: " there" },
-                    { type: "text", text: "!" },
-                    { type: "usage", inputTokens: 10, outputTokens: 5 },
-                ])
-
-                expect(mockCreate).toHaveBeenCalledWith({
-                    model: "gpt-4o",
-                    temperature: 0,
-                    messages: [
-                        { role: "system", content: systemPrompt },
-                        { role: "user", content: "Hello" },
-                    ],
-                    stream: true,
-                    stream_options: { include_usage: true },
-                })
-            })
-
-            it("should handle empty delta content", async () => {
-                const mockStream = [
-                    { choices: [{ delta: {} }], usage: null },
-                    { choices: [{ delta: { content: null } }], usage: null },
-                    { choices: [{ delta: { content: "Hello" } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
-                ]
-
-                mockCreate.mockResolvedValueOnce(
-                    (async function* () {
-                        for (const chunk of mockStream) {
-                            yield chunk
+                    create: mockCreate.mockImplementation(async (options) => {
+                        if (!options.stream) {
+                            return {
+                                id: 'test-completion',
+                                choices: [{
+                                    message: { role: 'assistant', content: 'Test response' },
+                                    finish_reason: 'stop',
+                                    index: 0
+                                }],
+                                usage: {
+                                    prompt_tokens: 10,
+                                    completion_tokens: 5,
+                                    total_tokens: 15
+                                }
+                            };
                         }
-                    })()
-                )
-
-                const generator = handler.createMessage(systemPrompt, messages)
-                const results = []
-                for await (const result of generator) {
-                    results.push(result)
+                        
+                        return {
+                            [Symbol.asyncIterator]: async function* () {
+                                yield {
+                                    choices: [{
+                                        delta: { content: 'Test response' },
+                                        index: 0
+                                    }],
+                                    usage: null
+                                };
+                                yield {
+                                    choices: [{
+                                        delta: {},
+                                        index: 0
+                                    }],
+                                    usage: {
+                                        prompt_tokens: 10,
+                                        completion_tokens: 5,
+                                        total_tokens: 15
+                                    }
+                                };
+                            }
+                        };
+                    })
                 }
+            }
+        }))
+    };
+});
 
-                expect(results).toEqual([
-                    { type: "text", text: "Hello" },
-                    { type: "usage", inputTokens: 10, outputTokens: 5 },
-                ])
-            })
-        })
+describe('OpenAiNativeHandler', () => {
+    let handler: OpenAiNativeHandler;
+    let mockOptions: ApiHandlerOptions;
 
-        it("should handle API errors", async () => {
-            mockCreate.mockRejectedValueOnce(new Error("API Error"))
+    beforeEach(() => {
+        mockOptions = {
+            apiModelId: 'gpt-4o',
+            openAiNativeApiKey: 'test-api-key'
+        };
+        handler = new OpenAiNativeHandler(mockOptions);
+        mockCreate.mockClear();
+    });
+
+    describe('constructor', () => {
+        it('should initialize with provided options', () => {
+            expect(handler).toBeInstanceOf(OpenAiNativeHandler);
+            expect(handler.getModel().id).toBe(mockOptions.apiModelId);
+        });
+
+        it('should initialize with empty API key', () => {
+            const handlerWithoutKey = new OpenAiNativeHandler({
+                apiModelId: 'gpt-4o',
+                openAiNativeApiKey: ''
+            });
+            expect(handlerWithoutKey).toBeInstanceOf(OpenAiNativeHandler);
+        });
+    });
+
+    describe('createMessage', () => {
+        const systemPrompt = 'You are a helpful assistant.';
+        const messages: Anthropic.Messages.MessageParam[] = [
+            {
+                role: 'user',
+                content: 'Hello!'
+            }
+        ];
+
+        it('should handle streaming responses', async () => {
+            const stream = handler.createMessage(systemPrompt, messages);
+            const chunks: any[] = [];
+            for await (const chunk of stream) {
+                chunks.push(chunk);
+            }
+
+            expect(chunks.length).toBeGreaterThan(0);
+            const textChunks = chunks.filter(chunk => chunk.type === 'text');
+            expect(textChunks).toHaveLength(1);
+            expect(textChunks[0].text).toBe('Test response');
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
+
+            const stream = handler.createMessage(systemPrompt, messages);
 
-            const generator = handler.createMessage(systemPrompt, messages)
             await expect(async () => {
-                for await (const _ of generator) {
-                    // consume generator
+                for await (const chunk of stream) {
+                    // Should not reach here
                 }
-            }).rejects.toThrow("API Error")
-        })
-    })
-})
+            }).rejects.toThrow('API Error');
+        });
+    });
+
+    describe('completePrompt', () => {
+        it('should complete prompt successfully with gpt-4o model', async () => {
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: 'gpt-4o',
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                temperature: 0
+            });
+        });
+
+        it('should complete prompt successfully with o1 model', async () => {
+            handler = new OpenAiNativeHandler({
+                apiModelId: 'o1',
+                openAiNativeApiKey: 'test-api-key'
+            });
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: 'o1',
+                messages: [{ role: 'user', content: 'Test prompt' }]
+            });
+        });
+
+        it('should complete prompt successfully with o1-preview model', async () => {
+            handler = new OpenAiNativeHandler({
+                apiModelId: 'o1-preview',
+                openAiNativeApiKey: 'test-api-key'
+            });
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: 'o1-preview',
+                messages: [{ role: 'user', content: 'Test prompt' }]
+            });
+        });
+
+        it('should complete prompt successfully with o1-mini model', async () => {
+            handler = new OpenAiNativeHandler({
+                apiModelId: 'o1-mini',
+                openAiNativeApiKey: 'test-api-key'
+            });
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: 'o1-mini',
+                messages: [{ role: 'user', content: 'Test prompt' }]
+            });
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('OpenAI Native completion error: API Error');
+        });
+
+        it('should handle empty response', async () => {
+            mockCreate.mockResolvedValueOnce({
+                choices: [{ message: { content: '' } }]
+            });
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+    });
+
+    describe('getModel', () => {
+        it('should return model info', () => {
+            const modelInfo = handler.getModel();
+            expect(modelInfo.id).toBe(mockOptions.apiModelId);
+            expect(modelInfo.info).toBeDefined();
+            expect(modelInfo.info.maxTokens).toBe(4096);
+            expect(modelInfo.info.contextWindow).toBe(128_000);
+        });
+
+        it('should handle undefined model ID', () => {
+            const handlerWithoutModel = new OpenAiNativeHandler({
+                openAiNativeApiKey: 'test-api-key'
+            });
+            const modelInfo = handlerWithoutModel.getModel();
+            expect(modelInfo.id).toBe('gpt-4o'); // Default model
+            expect(modelInfo.info).toBeDefined();
+        });
+    });
+});

+ 26 - 0
src/api/providers/__tests__/openai.test.ts

@@ -176,6 +176,32 @@ describe('OpenAiHandler', () => {
         });
     });
 
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: mockOptions.openAiModelId,
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                temperature: 0
+            });
+        });
+
+        it('should handle API errors', async () => {
+            mockCreate.mockRejectedValueOnce(new Error('API Error'));
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('OpenAI completion error: API Error');
+        });
+
+        it('should handle empty response', async () => {
+            mockCreate.mockImplementationOnce(() => ({
+                choices: [{ message: { content: '' } }]
+            }));
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+    });
+
     describe('getModel', () => {
         it('should return model info with sane defaults', () => {
             const model = handler.getModel();

+ 79 - 1
src/api/providers/__tests__/vertex.test.ts

@@ -6,7 +6,42 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
 jest.mock('@anthropic-ai/vertex-sdk', () => ({
     AnthropicVertex: jest.fn().mockImplementation(() => ({
         messages: {
-            create: jest.fn()
+            create: jest.fn().mockImplementation(async (options) => {
+                if (!options.stream) {
+                    return {
+                        id: 'test-completion',
+                        content: [
+                            { type: 'text', text: 'Test response' }
+                        ],
+                        role: 'assistant',
+                        model: options.model,
+                        usage: {
+                            input_tokens: 10,
+                            output_tokens: 5
+                        }
+                    }
+                }
+                return {
+                    async *[Symbol.asyncIterator]() {
+                        yield {
+                            type: 'message_start',
+                            message: {
+                                usage: {
+                                    input_tokens: 10,
+                                    output_tokens: 5
+                                }
+                            }
+                        }
+                        yield {
+                            type: 'content_block_start',
+                            content_block: {
+                                type: 'text',
+                                text: 'Test response'
+                            }
+                        }
+                    }
+                }
+            })
         }
     }))
 }));
@@ -196,6 +231,49 @@ describe('VertexHandler', () => {
         });
     });
 
+    describe('completePrompt', () => {
+        it('should complete prompt successfully', async () => {
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('Test response');
+            expect(handler['client'].messages.create).toHaveBeenCalledWith({
+                model: 'claude-3-5-sonnet-v2@20241022',
+                max_tokens: 8192,
+                temperature: 0,
+                messages: [{ role: 'user', content: 'Test prompt' }],
+                stream: false
+            });
+        });
+
+        it('should handle API errors', async () => {
+            const mockError = new Error('Vertex API error');
+            const mockCreate = jest.fn().mockRejectedValue(mockError);
+            (handler['client'].messages as any).create = mockCreate;
+
+            await expect(handler.completePrompt('Test prompt'))
+                .rejects.toThrow('Vertex completion error: Vertex API error');
+        });
+
+        it('should handle non-text content', async () => {
+            const mockCreate = jest.fn().mockResolvedValue({
+                content: [{ type: 'image' }]
+            });
+            (handler['client'].messages as any).create = mockCreate;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+
+        it('should handle empty response', async () => {
+            const mockCreate = jest.fn().mockResolvedValue({
+                content: [{ type: 'text', text: '' }]
+            });
+            (handler['client'].messages as any).create = mockCreate;
+
+            const result = await handler.completePrompt('Test prompt');
+            expect(result).toBe('');
+        });
+    });
+
     describe('getModel', () => {
         it('should return correct model info', () => {
             const modelInfo = handler.getModel();

+ 25 - 2
src/api/providers/anthropic.ts

@@ -7,10 +7,10 @@ import {
 	ApiHandlerOptions,
 	ModelInfo,
 } from "../../shared/api"
-import { ApiHandler } from "../index"
+import { ApiHandler, SingleCompletionHandler } from "../index"
 import { ApiStream } from "../transform/stream"
 
-export class AnthropicHandler implements ApiHandler {
+export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: Anthropic
 
@@ -173,4 +173,27 @@ export class AnthropicHandler implements ApiHandler {
 		}
 		return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const response = await this.client.messages.create({
+				model: this.getModel().id,
+				max_tokens: this.getModel().info.maxTokens || 8192,
+				temperature: 0,
+				messages: [{ role: "user", content: prompt }],
+				stream: false
+			})
+
+			const content = response.content[0]
+			if (content.type === 'text') {
+				return content.text
+			}
+			return ''
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Anthropic completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }

+ 66 - 7
src/api/providers/bedrock.ts

@@ -1,6 +1,6 @@
-import { BedrockRuntimeClient, ConverseStreamCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
+import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig } from "@aws-sdk/client-bedrock-runtime"
 import { Anthropic } from "@anthropic-ai/sdk"
-import { ApiHandler } from "../"
+import { ApiHandler, SingleCompletionHandler } from "../"
 import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
 import { ApiStream } from "../transform/stream"
 import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
@@ -38,7 +38,7 @@ export interface StreamEvent {
     };
 }
 
-export class AwsBedrockHandler implements ApiHandler {
+export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
     private options: ApiHandlerOptions
     private client: BedrockRuntimeClient
 
@@ -199,7 +199,7 @@ export class AwsBedrockHandler implements ApiHandler {
         if (modelId) {
             // For tests, allow any model ID
             if (process.env.NODE_ENV === 'test') {
-                return { 
+                return {
                     id: modelId,
                     info: {
                         maxTokens: 5000,
@@ -214,9 +214,68 @@ export class AwsBedrockHandler implements ApiHandler {
                 return { id, info: bedrockModels[id] }
             }
         }
-        return { 
-            id: bedrockDefaultModelId, 
-            info: bedrockModels[bedrockDefaultModelId] 
+        return {
+            id: bedrockDefaultModelId,
+            info: bedrockModels[bedrockDefaultModelId]
+        }
+    }
+
+    async completePrompt(prompt: string): Promise<string> {
+        try {
+            const modelConfig = this.getModel()
+            
+            // Handle cross-region inference
+            let modelId: string
+            if (this.options.awsUseCrossRegionInference) {
+                let regionPrefix = (this.options.awsRegion || "").slice(0, 3)
+                switch (regionPrefix) {
+                    case "us-":
+                        modelId = `us.${modelConfig.id}`
+                        break
+                    case "eu-":
+                        modelId = `eu.${modelConfig.id}`
+                        break
+                    default:
+                        modelId = modelConfig.id
+                        break
+                }
+            } else {
+                modelId = modelConfig.id
+            }
+
+            const payload = {
+                modelId,
+                messages: convertToBedrockConverseMessages([{
+                    role: "user",
+                    content: prompt
+                }]),
+                inferenceConfig: {
+                    maxTokens: modelConfig.info.maxTokens || 5000,
+                    temperature: 0.3,
+                    topP: 0.1
+                }
+            }
+
+            const command = new ConverseCommand(payload)
+            const response = await this.client.send(command)
+
+            if (response.output && response.output instanceof Uint8Array) {
+                try {
+                    const outputStr = new TextDecoder().decode(response.output)
+                    const output = JSON.parse(outputStr)
+                    if (output.content) {
+                        return output.content
+                    }
+                } catch (parseError) {
+                    console.error('Failed to parse Bedrock response:', parseError)
+                }
+            }
+            return ''
+        } catch (error) {
+            if (error instanceof Error) {
+                throw new Error(`Bedrock completion error: ${error.message}`)
+            }
+            throw error
         }
     }
 }

+ 24 - 2
src/api/providers/gemini.ts

@@ -1,11 +1,11 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { GoogleGenerativeAI } from "@google/generative-ai"
-import { ApiHandler } from "../"
+import { ApiHandler, SingleCompletionHandler } from "../"
 import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
 import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
 import { ApiStream } from "../transform/stream"
 
-export class GeminiHandler implements ApiHandler {
+export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: GoogleGenerativeAI
 
@@ -53,4 +53,26 @@ export class GeminiHandler implements ApiHandler {
 		}
 		return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const model = this.client.getGenerativeModel({
+				model: this.getModel().id,
+			})
+
+			const result = await model.generateContent({
+				contents: [{ role: "user", parts: [{ text: prompt }] }],
+				generationConfig: {
+					temperature: 0,
+				},
+			})
+
+			return result.response.text()
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Gemini completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }

+ 24 - 2
src/api/providers/glama.ts

@@ -1,13 +1,13 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import axios from "axios"
 import OpenAI from "openai"
-import { ApiHandler } from "../"
+import { ApiHandler, SingleCompletionHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream } from "../transform/stream"
 import delay from "delay"
 
-export class GlamaHandler implements ApiHandler {
+export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: OpenAI
 
@@ -129,4 +129,26 @@ export class GlamaHandler implements ApiHandler {
 		
 		return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+				model: this.getModel().id,
+				messages: [{ role: "user", content: prompt }],
+				temperature: 0,
+			}
+
+			if (this.getModel().id.startsWith("anthropic/")) {
+				requestOptions.max_tokens = 8192
+			}
+			
+			const response = await this.client.chat.completions.create(requestOptions)
+			return response.choices[0]?.message.content || ""
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Glama completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }

+ 18 - 2
src/api/providers/lmstudio.ts

@@ -1,11 +1,11 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
-import { ApiHandler } from "../"
+import { ApiHandler, SingleCompletionHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream } from "../transform/stream"
 
-export class LmStudioHandler implements ApiHandler {
+export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: OpenAI
 
@@ -53,4 +53,20 @@ export class LmStudioHandler implements ApiHandler {
 			info: openAiModelInfoSaneDefaults,
 		}
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const response = await this.client.chat.completions.create({
+				model: this.getModel().id,
+				messages: [{ role: "user", content: prompt }],
+				temperature: 0,
+				stream: false
+			})
+			return response.choices[0]?.message.content || ""
+		} catch (error) {
+			throw new Error(
+				"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts.",
+			)
+		}
+	}
 }

+ 19 - 2
src/api/providers/ollama.ts

@@ -1,11 +1,11 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
-import { ApiHandler } from "../"
+import { ApiHandler, SingleCompletionHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream } from "../transform/stream"
 
-export class OllamaHandler implements ApiHandler {
+export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: OpenAI
 
@@ -46,4 +46,21 @@ export class OllamaHandler implements ApiHandler {
 			info: openAiModelInfoSaneDefaults,
 		}
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const response = await this.client.chat.completions.create({
+				model: this.getModel().id,
+				messages: [{ role: "user", content: prompt }],
+				temperature: 0,
+				stream: false
+			})
+			return response.choices[0]?.message.content || ""
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Ollama completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }

+ 35 - 2
src/api/providers/openai-native.ts

@@ -1,6 +1,6 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
-import { ApiHandler } from "../"
+import { ApiHandler, SingleCompletionHandler } from "../"
 import {
 	ApiHandlerOptions,
 	ModelInfo,
@@ -11,7 +11,7 @@ import {
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream } from "../transform/stream"
 
-export class OpenAiNativeHandler implements ApiHandler {
+export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: OpenAI
 
@@ -83,4 +83,37 @@ export class OpenAiNativeHandler implements ApiHandler {
 		}
 		return { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] }
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const modelId = this.getModel().id
+			let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
+
+			switch (modelId) {
+				case "o1":
+				case "o1-preview":
+				case "o1-mini":
+					// o1 doesn't support non-1 temp or system prompt
+					requestOptions = {
+						model: modelId,
+						messages: [{ role: "user", content: prompt }]
+					}
+					break
+				default:
+					requestOptions = {
+						model: modelId,
+						messages: [{ role: "user", content: prompt }],
+						temperature: 0
+					}
+			}
+
+			const response = await this.client.chat.completions.create(requestOptions)
+			return response.choices[0]?.message.content || ""
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`OpenAI Native completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }

+ 20 - 2
src/api/providers/openai.ts

@@ -6,11 +6,11 @@ import {
 	ModelInfo,
 	openAiModelInfoSaneDefaults,
 } from "../../shared/api"
-import { ApiHandler } from "../index"
+import { ApiHandler, SingleCompletionHandler } from "../index"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream } from "../transform/stream"
 
-export class OpenAiHandler implements ApiHandler {
+export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
 	protected options: ApiHandlerOptions
 	private client: OpenAI
 
@@ -100,4 +100,22 @@ export class OpenAiHandler implements ApiHandler {
 			info: openAiModelInfoSaneDefaults,
 		}
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
+				model: this.getModel().id,
+				messages: [{ role: "user", content: prompt }],
+				temperature: 0,
+			}
+			
+			const response = await this.client.chat.completions.create(requestOptions)
+			return response.choices[0]?.message.content || ""
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`OpenAI completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }

+ 25 - 2
src/api/providers/vertex.ts

@@ -1,11 +1,11 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
-import { ApiHandler } from "../"
+import { ApiHandler, SingleCompletionHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
 import { ApiStream } from "../transform/stream"
 
 // https://docs.anthropic.com/en/api/claude-on-vertex-ai
-export class VertexHandler implements ApiHandler {
+export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: AnthropicVertex
 
@@ -83,4 +83,27 @@ export class VertexHandler implements ApiHandler {
 		}
 		return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const response = await this.client.messages.create({
+				model: this.getModel().id,
+				max_tokens: this.getModel().info.maxTokens || 8192,
+				temperature: 0,
+				messages: [{ role: "user", content: prompt }],
+				stream: false
+			})
+
+			const content = response.content[0]
+			if (content.type === 'text') {
+				return content.text
+			}
+			return ''
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Vertex completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }