| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446 |
- import { Anthropic } from "@anthropic-ai/sdk"
- import OpenAI, { AzureOpenAI } from "openai"
- import axios from "axios"
- import {
- type ModelInfo,
- azureOpenAiDefaultApiVersion,
- openAiModelInfoSaneDefaults,
- DEEP_SEEK_DEFAULT_TEMPERATURE,
- OPENAI_AZURE_AI_INFERENCE_PATH,
- } from "@roo-code/types"
- import type { ApiHandlerOptions } from "../../shared/api"
- import { XmlMatcher } from "../../utils/xml-matcher"
- import { convertToOpenAiMessages } from "../transform/openai-format"
- import { convertToR1Format } from "../transform/r1-format"
- import { convertToSimpleMessages } from "../transform/simple-format"
- import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
- import { getModelParams } from "../transform/model-params"
- import { DEFAULT_HEADERS } from "./constants"
- import { BaseProvider } from "./base-provider"
- import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
- // TODO: Rename this to OpenAICompatibleHandler. Also, I think the
- // `OpenAINativeHandler` can subclass from this, since it's obviously
- // compatible with the OpenAI API. We can also rename it to `OpenAIHandler`.
- export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
- protected options: ApiHandlerOptions
- private client: OpenAI
- constructor(options: ApiHandlerOptions) {
- super()
- this.options = options
- const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
- const apiKey = this.options.openAiApiKey ?? "not-provided"
- const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
- const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
- const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure
- const headers = {
- ...DEFAULT_HEADERS,
- ...(this.options.openAiHeaders || {}),
- }
- if (isAzureAiInference) {
- // Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
- this.client = new OpenAI({
- baseURL,
- apiKey,
- defaultHeaders: headers,
- defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
- })
- } else if (isAzureOpenAi) {
- // Azure API shape slightly differs from the core API shape:
- // https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
- this.client = new AzureOpenAI({
- baseURL,
- apiKey,
- apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
- defaultHeaders: headers,
- })
- } else {
- this.client = new OpenAI({
- baseURL,
- apiKey,
- defaultHeaders: headers,
- })
- }
- }
- override async *createMessage(
- systemPrompt: string,
- messages: Anthropic.Messages.MessageParam[],
- metadata?: ApiHandlerCreateMessageMetadata,
- ): ApiStream {
- const { info: modelInfo, reasoning } = this.getModel()
- const modelUrl = this.options.openAiBaseUrl ?? ""
- const modelId = this.options.openAiModelId ?? ""
- const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
- const enabledLegacyFormat = this.options.openAiLegacyFormat ?? false
- const isAzureAiInference = this._isAzureAiInference(modelUrl)
- const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
- const ark = modelUrl.includes(".volces.com")
- if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
- yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
- return
- }
- if (this.options.openAiStreamingEnabled ?? true) {
- let systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
- role: "system",
- content: systemPrompt,
- }
- let convertedMessages
- if (deepseekReasoner) {
- convertedMessages = convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
- } else if (ark || enabledLegacyFormat) {
- convertedMessages = [systemMessage, ...convertToSimpleMessages(messages)]
- } else {
- if (modelInfo.supportsPromptCache) {
- systemMessage = {
- role: "system",
- content: [
- {
- type: "text",
- text: systemPrompt,
- // @ts-ignore-next-line
- cache_control: { type: "ephemeral" },
- },
- ],
- }
- }
- convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)]
- if (modelInfo.supportsPromptCache) {
- // Note: the following logic is copied from openrouter:
- // Add cache_control to the last two user messages
- // (note: this works because we only ever add one user message at a time, but if we added multiple we'd need to mark the user message before the last assistant message)
- const lastTwoUserMessages = convertedMessages.filter((msg) => msg.role === "user").slice(-2)
- lastTwoUserMessages.forEach((msg) => {
- if (typeof msg.content === "string") {
- msg.content = [{ type: "text", text: msg.content }]
- }
- if (Array.isArray(msg.content)) {
- // NOTE: this is fine since env details will always be added at the end. but if it weren't there, and the user added a image_url type message, it would pop a text part before it and then move it after to the end.
- let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
- if (!lastTextPart) {
- lastTextPart = { type: "text", text: "..." }
- msg.content.push(lastTextPart)
- }
- // @ts-ignore-next-line
- lastTextPart["cache_control"] = { type: "ephemeral" }
- }
- })
- }
- }
- const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
- const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
- model: modelId,
- temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
- messages: convertedMessages,
- stream: true as const,
- ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
- ...(reasoning && reasoning),
- }
- // Add max_tokens if needed
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
- const stream = await this.client.chat.completions.create(
- requestOptions,
- isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
- )
- const matcher = new XmlMatcher(
- "think",
- (chunk) =>
- ({
- type: chunk.matched ? "reasoning" : "text",
- text: chunk.data,
- }) as const,
- )
- let lastUsage
- for await (const chunk of stream) {
- const delta = chunk.choices[0]?.delta ?? {}
- if (delta.content) {
- for (const chunk of matcher.update(delta.content)) {
- yield chunk
- }
- }
- if ("reasoning_content" in delta && delta.reasoning_content) {
- yield {
- type: "reasoning",
- text: (delta.reasoning_content as string | undefined) || "",
- }
- }
- if (chunk.usage) {
- lastUsage = chunk.usage
- }
- }
- for (const chunk of matcher.final()) {
- yield chunk
- }
- if (lastUsage) {
- yield this.processUsageMetrics(lastUsage, modelInfo)
- }
- } else {
- // o1 for instance doesnt support streaming, non-1 temp, or system prompt
- const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
- role: "user",
- content: systemPrompt,
- }
- const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
- model: modelId,
- messages: deepseekReasoner
- ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
- : enabledLegacyFormat
- ? [systemMessage, ...convertToSimpleMessages(messages)]
- : [systemMessage, ...convertToOpenAiMessages(messages)],
- }
- // Add max_tokens if needed
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
- const response = await this.client.chat.completions.create(
- requestOptions,
- this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
- )
- yield {
- type: "text",
- text: response.choices[0]?.message.content || "",
- }
- yield this.processUsageMetrics(response.usage, modelInfo)
- }
- }
- protected processUsageMetrics(usage: any, _modelInfo?: ModelInfo): ApiStreamUsageChunk {
- return {
- type: "usage",
- inputTokens: usage?.prompt_tokens || 0,
- outputTokens: usage?.completion_tokens || 0,
- cacheWriteTokens: usage?.cache_creation_input_tokens || undefined,
- cacheReadTokens: usage?.cache_read_input_tokens || undefined,
- }
- }
- override getModel() {
- const id = this.options.openAiModelId ?? ""
- const info = this.options.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults
- const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options })
- return { id, info, ...params }
- }
- async completePrompt(prompt: string): Promise<string> {
- try {
- const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
- const model = this.getModel()
- const modelInfo = model.info
- const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
- model: model.id,
- messages: [{ role: "user", content: prompt }],
- }
- // Add max_tokens if needed
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
- const response = await this.client.chat.completions.create(
- requestOptions,
- isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
- )
- return response.choices[0]?.message.content || ""
- } catch (error) {
- if (error instanceof Error) {
- throw new Error(`OpenAI completion error: ${error.message}`)
- }
- throw error
- }
- }
- private async *handleO3FamilyMessage(
- modelId: string,
- systemPrompt: string,
- messages: Anthropic.Messages.MessageParam[],
- ): ApiStream {
- const modelInfo = this.getModel().info
- const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
- if (this.options.openAiStreamingEnabled ?? true) {
- const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
- const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
- model: modelId,
- messages: [
- {
- role: "developer",
- content: `Formatting re-enabled\n${systemPrompt}`,
- },
- ...convertToOpenAiMessages(messages),
- ],
- stream: true,
- ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
- reasoning_effort: modelInfo.reasoningEffort,
- temperature: undefined,
- }
- // O3 family models do not support the deprecated max_tokens parameter
- // but they do support max_completion_tokens (the modern OpenAI parameter)
- // This allows O3 models to limit response length when includeMaxTokens is enabled
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
- const stream = await this.client.chat.completions.create(
- requestOptions,
- methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
- )
- yield* this.handleStreamResponse(stream)
- } else {
- const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
- model: modelId,
- messages: [
- {
- role: "developer",
- content: `Formatting re-enabled\n${systemPrompt}`,
- },
- ...convertToOpenAiMessages(messages),
- ],
- reasoning_effort: modelInfo.reasoningEffort,
- temperature: undefined,
- }
- // O3 family models do not support the deprecated max_tokens parameter
- // but they do support max_completion_tokens (the modern OpenAI parameter)
- // This allows O3 models to limit response length when includeMaxTokens is enabled
- this.addMaxTokensIfNeeded(requestOptions, modelInfo)
- const response = await this.client.chat.completions.create(
- requestOptions,
- methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
- )
- yield {
- type: "text",
- text: response.choices[0]?.message.content || "",
- }
- yield this.processUsageMetrics(response.usage)
- }
- }
- private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
- for await (const chunk of stream) {
- const delta = chunk.choices[0]?.delta
- if (delta?.content) {
- yield {
- type: "text",
- text: delta.content,
- }
- }
- if (chunk.usage) {
- yield {
- type: "usage",
- inputTokens: chunk.usage.prompt_tokens || 0,
- outputTokens: chunk.usage.completion_tokens || 0,
- }
- }
- }
- }
- private _getUrlHost(baseUrl?: string): string {
- try {
- return new URL(baseUrl ?? "").host
- } catch (error) {
- return ""
- }
- }
- private _isGrokXAI(baseUrl?: string): boolean {
- const urlHost = this._getUrlHost(baseUrl)
- return urlHost.includes("x.ai")
- }
- private _isAzureAiInference(baseUrl?: string): boolean {
- const urlHost = this._getUrlHost(baseUrl)
- return urlHost.endsWith(".services.ai.azure.com")
- }
- /**
- * Adds max_completion_tokens to the request body if needed based on provider configuration
- * Note: max_tokens is deprecated in favor of max_completion_tokens as per OpenAI documentation
- * O3 family models handle max_tokens separately in handleO3FamilyMessage
- */
- private addMaxTokensIfNeeded(
- requestOptions:
- | OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
- | OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
- modelInfo: ModelInfo,
- ): void {
- // Only add max_completion_tokens if includeMaxTokens is true
- if (this.options.includeMaxTokens === true) {
- // Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
- // Using max_completion_tokens as max_tokens is deprecated
- requestOptions.max_completion_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
- }
- }
- }
- export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiHeaders?: Record<string, string>) {
- try {
- if (!baseUrl) {
- return []
- }
- // Trim whitespace from baseUrl to handle cases where users accidentally include spaces
- const trimmedBaseUrl = baseUrl.trim()
- if (!URL.canParse(trimmedBaseUrl)) {
- return []
- }
- const config: Record<string, any> = {}
- const headers: Record<string, string> = {
- ...DEFAULT_HEADERS,
- ...(openAiHeaders || {}),
- }
- if (apiKey) {
- headers["Authorization"] = `Bearer ${apiKey}`
- }
- if (Object.keys(headers).length > 0) {
- config["headers"] = headers
- }
- const response = await axios.get(`${trimmedBaseUrl}/models`, config)
- const modelsArray = response.data?.data?.map((model: any) => model.id) || []
- return [...new Set<string>(modelsArray)]
- } catch (error) {
- return []
- }
- }
|