Browse Source

Merge pull request #320 from RooVetGit/o1_developer_role

Update openai package and use developer role message for o1
Matt Rubens 11 months ago
parent
commit
c3c2fd9061

+ 4 - 4
package-lock.json

@@ -31,7 +31,7 @@
         "isbinaryfile": "^5.0.2",
         "mammoth": "^1.8.0",
         "monaco-vscode-textmate-theme-converter": "^0.1.7",
-        "openai": "^4.73.1",
+        "openai": "^4.78.1",
         "os-name": "^6.0.0",
         "p-wait-for": "^5.0.2",
         "pdf-parse": "^1.1.1",
@@ -12546,9 +12546,9 @@
       }
     },
     "node_modules/openai": {
-      "version": "4.76.0",
-      "resolved": "https://registry.npmjs.org/openai/-/openai-4.76.0.tgz",
-      "integrity": "sha512-QBGIetjX1C9xDp5XGa/3mPnfKI9BgAe2xHQX6PmO98wuW9qQaurBaumcYptQWc9LHZZq7cH/Y1Rjnsr6uUDdVw==",
+      "version": "4.78.1",
+      "resolved": "https://registry.npmjs.org/openai/-/openai-4.78.1.tgz",
+      "integrity": "sha512-drt0lHZBd2lMyORckOXFPQTmnGLWSLt8VK0W9BhOKWpMFBEoHMoz5gxMPmVq5icp+sOrsbMnsmZTVHUlKvD1Ow==",
       "dependencies": {
         "@types/node": "^18.11.18",
         "@types/node-fetch": "^2.6.4",

+ 1 - 1
package.json

@@ -227,7 +227,7 @@
     "isbinaryfile": "^5.0.2",
     "mammoth": "^1.8.0",
     "monaco-vscode-textmate-theme-converter": "^0.1.7",
-    "openai": "^4.73.1",
+    "openai": "^4.78.1",
     "os-name": "^6.0.0",
     "p-wait-for": "^5.0.2",
     "pdf-parse": "^1.1.1",

+ 121 - 11
src/api/providers/__tests__/openai-native.test.ts

@@ -60,6 +60,13 @@ jest.mock('openai', () => {
 describe('OpenAiNativeHandler', () => {
     let handler: OpenAiNativeHandler;
     let mockOptions: ApiHandlerOptions;
+    const systemPrompt = 'You are a helpful assistant.';
+    const messages: Anthropic.Messages.MessageParam[] = [
+        {
+            role: 'user',
+            content: 'Hello!'
+        }
+    ];
 
     beforeEach(() => {
         mockOptions = {
@@ -86,14 +93,6 @@ describe('OpenAiNativeHandler', () => {
     });
 
     describe('createMessage', () => {
-        const systemPrompt = 'You are a helpful assistant.';
-        const messages: Anthropic.Messages.MessageParam[] = [
-            {
-                role: 'user',
-                content: 'Hello!'
-            }
-        ];
-
         it('should handle streaming responses', async () => {
             const stream = handler.createMessage(systemPrompt, messages);
             const chunks: any[] = [];
@@ -109,15 +108,126 @@ describe('OpenAiNativeHandler', () => {
 
         it('should handle API errors', async () => {
             mockCreate.mockRejectedValueOnce(new Error('API Error'));
-
             const stream = handler.createMessage(systemPrompt, messages);
-
             await expect(async () => {
                 for await (const chunk of stream) {
                     // Should not reach here
                 }
             }).rejects.toThrow('API Error');
         });
+
+        it('should handle missing content in response for o1 model', async () => {
+            // Use o1 model which supports developer role
+            handler = new OpenAiNativeHandler({
+                ...mockOptions,
+                apiModelId: 'o1'
+            });
+
+            mockCreate.mockResolvedValueOnce({
+                choices: [{ message: { content: null } }],
+                usage: {
+                    prompt_tokens: 0,
+                    completion_tokens: 0,
+                    total_tokens: 0
+                }
+            });
+
+            const generator = handler.createMessage(systemPrompt, messages);
+            const results = [];
+            for await (const result of generator) {
+                results.push(result);
+            }
+
+            expect(results).toEqual([
+                { type: 'text', text: '' },
+                { type: 'usage', inputTokens: 0, outputTokens: 0 }
+            ]);
+
+            // Verify developer role is used for system prompt with o1 model
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: 'o1',
+                messages: [
+                    { role: 'developer', content: systemPrompt },
+                    { role: 'user', content: 'Hello!' }
+                ]
+            });
+        });
+    });
+
+    describe('streaming models', () => {
+        beforeEach(() => {
+            handler = new OpenAiNativeHandler({
+                ...mockOptions,
+                apiModelId: 'gpt-4o',
+            });
+        });
+
+        it('should handle streaming response', async () => {
+            const mockStream = [
+                { choices: [{ delta: { content: 'Hello' } }], usage: null },
+                { choices: [{ delta: { content: ' there' } }], usage: null },
+                { choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
+            ];
+
+            mockCreate.mockResolvedValueOnce(
+                (async function* () {
+                    for (const chunk of mockStream) {
+                        yield chunk;
+                    }
+                })()
+            );
+
+            const generator = handler.createMessage(systemPrompt, messages);
+            const results = [];
+            for await (const result of generator) {
+                results.push(result);
+            }
+
+            expect(results).toEqual([
+                { type: 'text', text: 'Hello' },
+                { type: 'text', text: ' there' },
+                { type: 'text', text: '!' },
+                { type: 'usage', inputTokens: 10, outputTokens: 5 },
+            ]);
+
+            expect(mockCreate).toHaveBeenCalledWith({
+                model: 'gpt-4o',
+                temperature: 0,
+                messages: [
+                    { role: 'system', content: systemPrompt },
+                    { role: 'user', content: 'Hello!' },
+                ],
+                stream: true,
+                stream_options: { include_usage: true },
+            });
+        });
+
+        it('should handle empty delta content', async () => {
+            const mockStream = [
+                { choices: [{ delta: {} }], usage: null },
+                { choices: [{ delta: { content: null } }], usage: null },
+                { choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
+            ];
+
+            mockCreate.mockResolvedValueOnce(
+                (async function* () {
+                    for (const chunk of mockStream) {
+                        yield chunk;
+                    }
+                })()
+            );
+
+            const generator = handler.createMessage(systemPrompt, messages);
+            const results = [];
+            for await (const result of generator) {
+                results.push(result);
+            }
+
+            expect(results).toEqual([
+                { type: 'text', text: 'Hello' },
+                { type: 'usage', inputTokens: 10, outputTokens: 5 },
+            ]);
+        });
     });
 
     describe('completePrompt', () => {
@@ -206,4 +316,4 @@ describe('OpenAiNativeHandler', () => {
             expect(modelInfo.info).toBeDefined();
         });
     });
-});
+});

+ 7 - 5
src/api/providers/openai-native.ts

@@ -23,14 +23,16 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
 	}
 
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		switch (this.getModel().id) {
+		const modelId = this.getModel().id
+		switch (modelId) {
 			case "o1":
 			case "o1-preview":
 			case "o1-mini": {
-				// o1 doesnt support streaming, non-1 temp, or system prompt
+				// o1-preview and o1-mini don't support streaming, non-1 temp, or system prompt
+				// o1 doesnt support streaming or non-1 temp but does support a developer prompt
 				const response = await this.client.chat.completions.create({
-					model: this.getModel().id,
-					messages: [{ role: "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
+					model: modelId,
+					messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
 				})
 				yield {
 					type: "text",
@@ -93,7 +95,7 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
 				case "o1":
 				case "o1-preview":
 				case "o1-mini":
-					// o1 doesn't support non-1 temp or system prompt
+					// o1 doesn't support non-1 temp
 					requestOptions = {
 						model: modelId,
 						messages: [{ role: "user", content: prompt }]