|
|
@@ -11,9 +11,16 @@ import {
|
|
|
import { convertToOpenAiMessages } from "../transform/openai-format"
|
|
|
import { ApiStream } from "../transform/stream"
|
|
|
import { BaseProvider } from "./base-provider"
|
|
|
+import { calculateApiCostOpenAI } from "../../utils/cost"
|
|
|
|
|
|
const OPENAI_NATIVE_DEFAULT_TEMPERATURE = 0
|
|
|
|
|
|
+// Define a type for the model object returned by getModel
|
|
|
+export type OpenAiNativeModel = {
|
|
|
+ id: OpenAiNativeModelId
|
|
|
+ info: ModelInfo
|
|
|
+}
|
|
|
+
|
|
|
export class OpenAiNativeHandler extends BaseProvider implements SingleCompletionHandler {
|
|
|
protected options: ApiHandlerOptions
|
|
|
private client: OpenAI
|
|
|
@@ -26,31 +33,31 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
}
|
|
|
|
|
|
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
|
|
|
- const modelId = this.getModel().id
|
|
|
+ const model = this.getModel()
|
|
|
|
|
|
- if (modelId.startsWith("o1")) {
|
|
|
- yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
|
|
|
+ if (model.id.startsWith("o1")) {
|
|
|
+ yield* this.handleO1FamilyMessage(model, systemPrompt, messages)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if (modelId.startsWith("o3-mini")) {
|
|
|
- yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
|
|
|
+ if (model.id.startsWith("o3-mini")) {
|
|
|
+ yield* this.handleO3FamilyMessage(model, systemPrompt, messages)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
|
|
|
+ yield* this.handleDefaultModelMessage(model, systemPrompt, messages)
|
|
|
}
|
|
|
|
|
|
private async *handleO1FamilyMessage(
|
|
|
- modelId: string,
|
|
|
+ model: OpenAiNativeModel,
|
|
|
systemPrompt: string,
|
|
|
messages: Anthropic.Messages.MessageParam[],
|
|
|
): ApiStream {
|
|
|
// o1 supports developer prompt with formatting
|
|
|
// o1-preview and o1-mini only support user messages
|
|
|
- const isOriginalO1 = modelId === "o1"
|
|
|
+ const isOriginalO1 = model.id === "o1"
|
|
|
const response = await this.client.chat.completions.create({
|
|
|
- model: modelId,
|
|
|
+ model: model.id,
|
|
|
messages: [
|
|
|
{
|
|
|
role: isOriginalO1 ? "developer" : "user",
|
|
|
@@ -62,11 +69,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
stream_options: { include_usage: true },
|
|
|
})
|
|
|
|
|
|
- yield* this.handleStreamResponse(response)
|
|
|
+ yield* this.handleStreamResponse(response, model)
|
|
|
}
|
|
|
|
|
|
private async *handleO3FamilyMessage(
|
|
|
- modelId: string,
|
|
|
+ model: OpenAiNativeModel,
|
|
|
systemPrompt: string,
|
|
|
messages: Anthropic.Messages.MessageParam[],
|
|
|
): ApiStream {
|
|
|
@@ -84,23 +91,23 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
reasoning_effort: this.getModel().info.reasoningEffort,
|
|
|
})
|
|
|
|
|
|
- yield* this.handleStreamResponse(stream)
|
|
|
+ yield* this.handleStreamResponse(stream, model)
|
|
|
}
|
|
|
|
|
|
private async *handleDefaultModelMessage(
|
|
|
- modelId: string,
|
|
|
+ model: OpenAiNativeModel,
|
|
|
systemPrompt: string,
|
|
|
messages: Anthropic.Messages.MessageParam[],
|
|
|
): ApiStream {
|
|
|
const stream = await this.client.chat.completions.create({
|
|
|
- model: modelId,
|
|
|
+ model: model.id,
|
|
|
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
|
|
|
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
|
|
|
stream: true,
|
|
|
stream_options: { include_usage: true },
|
|
|
})
|
|
|
|
|
|
- yield* this.handleStreamResponse(stream)
|
|
|
+ yield* this.handleStreamResponse(stream, model)
|
|
|
}
|
|
|
|
|
|
private async *yieldResponseData(response: OpenAI.Chat.Completions.ChatCompletion): ApiStream {
|
|
|
@@ -115,7 +122,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
|
|
|
+ private async *handleStreamResponse(
|
|
|
+ stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>,
|
|
|
+ model: OpenAiNativeModel,
|
|
|
+ ): ApiStream {
|
|
|
for await (const chunk of stream) {
|
|
|
const delta = chunk.choices[0]?.delta
|
|
|
if (delta?.content) {
|
|
|
@@ -126,16 +136,29 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
}
|
|
|
|
|
|
if (chunk.usage) {
|
|
|
- yield {
|
|
|
- type: "usage",
|
|
|
- inputTokens: chunk.usage.prompt_tokens || 0,
|
|
|
- outputTokens: chunk.usage.completion_tokens || 0,
|
|
|
- }
|
|
|
+ yield* this.yieldUsage(model.info, chunk.usage)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- override getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
|
|
|
+ private async *yieldUsage(info: ModelInfo, usage: OpenAI.Completions.CompletionUsage | undefined): ApiStream {
|
|
|
+ const inputTokens = usage?.prompt_tokens || 0 // sum of cache hits and misses
|
|
|
+ const outputTokens = usage?.completion_tokens || 0
|
|
|
+ const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
|
|
|
+ const cacheWriteTokens = 0
|
|
|
+ const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
|
|
|
+ const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)
|
|
|
+ yield {
|
|
|
+ type: "usage",
|
|
|
+ inputTokens: nonCachedInputTokens,
|
|
|
+ outputTokens: outputTokens,
|
|
|
+ cacheWriteTokens: cacheWriteTokens,
|
|
|
+ cacheReadTokens: cacheReadTokens,
|
|
|
+ totalCost: totalCost,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ override getModel(): OpenAiNativeModel {
|
|
|
const modelId = this.options.apiModelId
|
|
|
if (modelId && modelId in openAiNativeModels) {
|
|
|
const id = modelId as OpenAiNativeModelId
|
|
|
@@ -146,15 +169,15 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
|
|
|
async completePrompt(prompt: string): Promise<string> {
|
|
|
try {
|
|
|
- const modelId = this.getModel().id
|
|
|
+ const model = this.getModel()
|
|
|
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
|
|
|
|
|
|
- if (modelId.startsWith("o1")) {
|
|
|
- requestOptions = this.getO1CompletionOptions(modelId, prompt)
|
|
|
- } else if (modelId.startsWith("o3-mini")) {
|
|
|
- requestOptions = this.getO3CompletionOptions(modelId, prompt)
|
|
|
+ if (model.id.startsWith("o1")) {
|
|
|
+ requestOptions = this.getO1CompletionOptions(model, prompt)
|
|
|
+ } else if (model.id.startsWith("o3-mini")) {
|
|
|
+ requestOptions = this.getO3CompletionOptions(model, prompt)
|
|
|
} else {
|
|
|
- requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
|
|
|
+ requestOptions = this.getDefaultCompletionOptions(model, prompt)
|
|
|
}
|
|
|
|
|
|
const response = await this.client.chat.completions.create(requestOptions)
|
|
|
@@ -168,17 +191,17 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
}
|
|
|
|
|
|
private getO1CompletionOptions(
|
|
|
- modelId: string,
|
|
|
+ model: OpenAiNativeModel,
|
|
|
prompt: string,
|
|
|
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
|
|
|
return {
|
|
|
- model: modelId,
|
|
|
+ model: model.id,
|
|
|
messages: [{ role: "user", content: prompt }],
|
|
|
}
|
|
|
}
|
|
|
|
|
|
private getO3CompletionOptions(
|
|
|
- modelId: string,
|
|
|
+ model: OpenAiNativeModel,
|
|
|
prompt: string,
|
|
|
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
|
|
|
return {
|
|
|
@@ -189,11 +212,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
|
|
|
}
|
|
|
|
|
|
private getDefaultCompletionOptions(
|
|
|
- modelId: string,
|
|
|
+ model: OpenAiNativeModel,
|
|
|
prompt: string,
|
|
|
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
|
|
|
return {
|
|
|
- model: modelId,
|
|
|
+ model: model.id,
|
|
|
messages: [{ role: "user", content: prompt }],
|
|
|
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
|
|
|
}
|