Просмотр исходного кода

Refactor openai-native, prepend string to developer instructions so that o1/o3 will use md

Nissa Seru 11 месяцев назад
Родитель
Сommit
df932935cf
2 измененных файлов с 163 добавлено и 63 удалено
  1. 25 1
      src/api/providers/__tests__/openai-native.test.ts
  2. 138 62
      src/api/providers/openai-native.ts

+ 25 - 1
src/api/providers/__tests__/openai-native.test.ts

@@ -153,11 +153,35 @@ describe("OpenAiNativeHandler", () => {
 			expect(mockCreate).toHaveBeenCalledWith({
 			expect(mockCreate).toHaveBeenCalledWith({
 				model: "o1",
 				model: "o1",
 				messages: [
 				messages: [
-					{ role: "developer", content: systemPrompt },
+					{ role: "developer", content: "Formatting re-enabled\n" + systemPrompt },
 					{ role: "user", content: "Hello!" },
 					{ role: "user", content: "Hello!" },
 				],
 				],
 			})
 			})
 		})
 		})
+
+		it("should handle o3-mini model family correctly", async () => {
+			handler = new OpenAiNativeHandler({
+				...mockOptions,
+				apiModelId: "o3-mini",
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith({
+				model: "o3-mini",
+				messages: [
+					{ role: "developer", content: "Formatting re-enabled\n" + systemPrompt },
+					{ role: "user", content: "Hello!" },
+				],
+				stream: true,
+				stream_options: { include_usage: true },
+				reasoning_effort: "medium",
+			})
+		})
 	})
 	})
 
 
 	describe("streaming models", () => {
 	describe("streaming models", () => {

+ 138 - 62
src/api/providers/openai-native.ts

@@ -24,57 +24,111 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
 
 
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		const modelId = this.getModel().id
 		const modelId = this.getModel().id
-		switch (modelId) {
-			case "o1":
-			case "o1-preview":
-			case "o1-mini": {
-				// o1-preview and o1-mini don't support streaming, non-1 temp, or system prompt
-				// o1 doesnt support streaming or non-1 temp but does support a developer prompt
-				const response = await this.client.chat.completions.create({
-					model: modelId,
-					messages: [
-						{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
-						...convertToOpenAiMessages(messages),
-					],
-				})
+
+		if (modelId.startsWith("o1")) {
+			yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
+			return
+		}
+
+		if (modelId.startsWith("o3-mini")) {
+			yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
+			return
+		}
+
+		yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
+	}
+
+	private async *handleO1FamilyMessage(
+		modelId: string,
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[]
+	): ApiStream {
+		// o1 supports developer prompt with formatting
+		// o1-preview and o1-mini only support user messages
+		const isOriginalO1 = modelId === "o1"
+		const response = await this.client.chat.completions.create({
+			model: modelId,
+			messages: [
+				{
+					role: isOriginalO1 ? "developer" : "user",
+					content: isOriginalO1 ? `Formatting re-enabled\n${systemPrompt}` : systemPrompt,
+				},
+				...convertToOpenAiMessages(messages),
+			],
+		})
+
+		yield* this.yieldResponseData(response)
+	}
+
+	private async *handleO3FamilyMessage(
+		modelId: string,
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[]
+	): ApiStream {
+		const stream = await this.client.chat.completions.create({
+			model: "o3-mini",
+			messages: [
+				{
+					role: "developer",
+					content: `Formatting re-enabled\n${systemPrompt}`,
+				},
+				...convertToOpenAiMessages(messages),
+			],
+			stream: true,
+			stream_options: { include_usage: true },
+			reasoning_effort: this.getModel().info.reasoningEffort,
+		})
+
+		yield* this.handleStreamResponse(stream)
+	}
+
+	private async *handleDefaultModelMessage(
+		modelId: string,
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[]
+	): ApiStream {
+		const stream = await this.client.chat.completions.create({
+			model: modelId,
+			temperature: 0,
+			messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
+			stream: true,
+			stream_options: { include_usage: true },
+		})
+
+		yield* this.handleStreamResponse(stream)
+	}
+
+	private async *yieldResponseData(
+		response: OpenAI.Chat.Completions.ChatCompletion
+	): ApiStream {
+		yield {
+			type: "text",
+			text: response.choices[0]?.message.content || "",
+		}
+		yield {
+			type: "usage",
+			inputTokens: response.usage?.prompt_tokens || 0,
+			outputTokens: response.usage?.completion_tokens || 0,
+		}
+	}
+
+	private async *handleStreamResponse(
+		stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
+	): ApiStream {
+		for await (const chunk of stream) {
+			const delta = chunk.choices[0]?.delta
+			if (delta?.content) {
 				yield {
 				yield {
 					type: "text",
 					type: "text",
-					text: response.choices[0]?.message.content || "",
+					text: delta.content,
 				}
 				}
+			}
+
+			if (chunk.usage) {
 				yield {
 				yield {
 					type: "usage",
 					type: "usage",
-					inputTokens: response.usage?.prompt_tokens || 0,
-					outputTokens: response.usage?.completion_tokens || 0,
-				}
-				break
-			}
-			default: {
-				const stream = await this.client.chat.completions.create({
-					model: this.getModel().id,
-					// max_completion_tokens: this.getModel().info.maxTokens,
-					temperature: 0,
-					messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
-					stream: true,
-					stream_options: { include_usage: true },
-				})
-
-				for await (const chunk of stream) {
-					const delta = chunk.choices[0]?.delta
-					if (delta?.content) {
-						yield {
-							type: "text",
-							text: delta.content,
-						}
-					}
-
-					// contains a null value except for the last chunk which contains the token usage statistics for the entire request
-					if (chunk.usage) {
-						yield {
-							type: "usage",
-							inputTokens: chunk.usage.prompt_tokens || 0,
-							outputTokens: chunk.usage.completion_tokens || 0,
-						}
-					}
+					inputTokens: chunk.usage.prompt_tokens || 0,
+					outputTokens: chunk.usage.completion_tokens || 0,
 				}
 				}
 			}
 			}
 		}
 		}
@@ -94,22 +148,12 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
 			const modelId = this.getModel().id
 			const modelId = this.getModel().id
 			let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
 			let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
 
 
-			switch (modelId) {
-				case "o1":
-				case "o1-preview":
-				case "o1-mini":
-					// o1 doesn't support non-1 temp
-					requestOptions = {
-						model: modelId,
-						messages: [{ role: "user", content: prompt }],
-					}
-					break
-				default:
-					requestOptions = {
-						model: modelId,
-						messages: [{ role: "user", content: prompt }],
-						temperature: 0,
-					}
+			if (modelId.startsWith("o1")) {
+				requestOptions = this.getO1CompletionOptions(modelId, prompt)
+			} else if (modelId.startsWith("o3-mini")) {
+				requestOptions = this.getO3CompletionOptions(modelId, prompt)
+			} else {
+				requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
 			}
 			}
 
 
 			const response = await this.client.chat.completions.create(requestOptions)
 			const response = await this.client.chat.completions.create(requestOptions)
@@ -121,4 +165,36 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
 			throw error
 			throw error
 		}
 		}
 	}
 	}
+
+	private getO1CompletionOptions(
+		modelId: string,
+		prompt: string
+	): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
+		return {
+			model: modelId,
+			messages: [{ role: "user", content: prompt }],
+		}
+	}
+
+	private getO3CompletionOptions(
+		modelId: string,
+		prompt: string
+	): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
+		return {
+			model: "o3-mini",
+			messages: [{ role: "user", content: prompt }],
+			reasoning_effort: this.getModel().info.reasoningEffort,
+		}
+	}
+
+	private getDefaultCompletionOptions(
+		modelId: string,
+		prompt: string
+	): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
+		return {
+			model: modelId,
+			messages: [{ role: "user", content: prompt }],
+			temperature: 0,
+		}
+	}
 }
 }