|
|
@@ -15,8 +15,7 @@ import { convertToSimpleMessages } from "../transform/simple-format"
|
|
|
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
|
|
|
import { BaseProvider } from "./base-provider"
|
|
|
import { XmlMatcher } from "../../utils/xml-matcher"
|
|
|
-
|
|
|
-const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6
|
|
|
+import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
|
|
|
|
|
|
export const defaultHeaders = {
|
|
|
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
|
|
|
@@ -25,6 +24,8 @@ export const defaultHeaders = {
|
|
|
|
|
|
export interface OpenAiHandlerOptions extends ApiHandlerOptions {}
|
|
|
|
|
|
+const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"
|
|
|
+
|
|
|
export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
|
|
|
protected options: OpenAiHandlerOptions
|
|
|
private client: OpenAI
|
|
|
@@ -35,17 +36,19 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
|
|
|
const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
|
|
|
const apiKey = this.options.openAiApiKey ?? "not-provided"
|
|
|
- let urlHost: string
|
|
|
-
|
|
|
- try {
|
|
|
- urlHost = new URL(this.options.openAiBaseUrl ?? "").host
|
|
|
- } catch (error) {
|
|
|
- // Likely an invalid `openAiBaseUrl`; we're still working on
|
|
|
- // proper settings validation.
|
|
|
- urlHost = ""
|
|
|
- }
|
|
|
+ const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
|
|
|
+ const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
|
|
|
+ const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure
|
|
|
|
|
|
- if (urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure) {
|
|
|
+ if (isAzureAiInference) {
|
|
|
+ // Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
|
|
|
+ this.client = new OpenAI({
|
|
|
+ baseURL,
|
|
|
+ apiKey,
|
|
|
+ defaultHeaders,
|
|
|
+ defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
|
|
|
+ })
|
|
|
+ } else if (isAzureOpenAi) {
|
|
|
// Azure API shape slightly differs from the core API shape:
|
|
|
// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
|
|
|
this.client = new AzureOpenAI({
|
|
|
@@ -64,6 +67,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
const modelUrl = this.options.openAiBaseUrl ?? ""
|
|
|
const modelId = this.options.openAiModelId ?? ""
|
|
|
const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
|
|
|
+ const isAzureAiInference = this._isAzureAiInference(modelUrl)
|
|
|
+ const urlHost = this._getUrlHost(modelUrl)
|
|
|
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
|
|
|
const ark = modelUrl.includes(".volces.com")
|
|
|
if (modelId.startsWith("o3-mini")) {
|
|
|
@@ -132,7 +137,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
requestOptions.max_tokens = modelInfo.maxTokens
|
|
|
}
|
|
|
|
|
|
- const stream = await this.client.chat.completions.create(requestOptions)
|
|
|
+ const stream = await this.client.chat.completions.create(
|
|
|
+ requestOptions,
|
|
|
+ isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
|
|
|
+ )
|
|
|
|
|
|
const matcher = new XmlMatcher(
|
|
|
"think",
|
|
|
@@ -185,7 +193,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
: [systemMessage, ...convertToOpenAiMessages(messages)],
|
|
|
}
|
|
|
|
|
|
- const response = await this.client.chat.completions.create(requestOptions)
|
|
|
+ const response = await this.client.chat.completions.create(
|
|
|
+ requestOptions,
|
|
|
+ this._isAzureAiInference(modelUrl) ? { path: AZURE_AI_INFERENCE_PATH } : {},
|
|
|
+ )
|
|
|
|
|
|
yield {
|
|
|
type: "text",
|
|
|
@@ -212,12 +223,16 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
|
|
|
async completePrompt(prompt: string): Promise<string> {
|
|
|
try {
|
|
|
+ const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
|
|
|
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
|
|
|
model: this.getModel().id,
|
|
|
messages: [{ role: "user", content: prompt }],
|
|
|
}
|
|
|
|
|
|
- const response = await this.client.chat.completions.create(requestOptions)
|
|
|
+ const response = await this.client.chat.completions.create(
|
|
|
+ requestOptions,
|
|
|
+ isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
|
|
|
+ )
|
|
|
return response.choices[0]?.message.content || ""
|
|
|
} catch (error) {
|
|
|
if (error instanceof Error) {
|
|
|
@@ -233,19 +248,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
messages: Anthropic.Messages.MessageParam[],
|
|
|
): ApiStream {
|
|
|
if (this.options.openAiStreamingEnabled ?? true) {
|
|
|
- const stream = await this.client.chat.completions.create({
|
|
|
- model: modelId,
|
|
|
- messages: [
|
|
|
- {
|
|
|
- role: "developer",
|
|
|
- content: `Formatting re-enabled\n${systemPrompt}`,
|
|
|
- },
|
|
|
- ...convertToOpenAiMessages(messages),
|
|
|
- ],
|
|
|
- stream: true,
|
|
|
- stream_options: { include_usage: true },
|
|
|
- reasoning_effort: this.getModel().info.reasoningEffort,
|
|
|
- })
|
|
|
+ const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
|
|
|
+
|
|
|
+ const stream = await this.client.chat.completions.create(
|
|
|
+ {
|
|
|
+ model: modelId,
|
|
|
+ messages: [
|
|
|
+ {
|
|
|
+ role: "developer",
|
|
|
+ content: `Formatting re-enabled\n${systemPrompt}`,
|
|
|
+ },
|
|
|
+ ...convertToOpenAiMessages(messages),
|
|
|
+ ],
|
|
|
+ stream: true,
|
|
|
+ stream_options: { include_usage: true },
|
|
|
+ reasoning_effort: this.getModel().info.reasoningEffort,
|
|
|
+ },
|
|
|
+ methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
|
|
|
+ )
|
|
|
|
|
|
yield* this.handleStreamResponse(stream)
|
|
|
} else {
|
|
|
@@ -260,7 +280,12 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
],
|
|
|
}
|
|
|
|
|
|
- const response = await this.client.chat.completions.create(requestOptions)
|
|
|
+ const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
|
|
|
+
|
|
|
+ const response = await this.client.chat.completions.create(
|
|
|
+ requestOptions,
|
|
|
+ methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
|
|
|
+ )
|
|
|
|
|
|
yield {
|
|
|
type: "text",
|
|
|
@@ -289,6 +314,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+ private _getUrlHost(baseUrl?: string): string {
|
|
|
+ try {
|
|
|
+ return new URL(baseUrl ?? "").host
|
|
|
+ } catch (error) {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private _isAzureAiInference(baseUrl?: string): boolean {
|
|
|
+ const urlHost = this._getUrlHost(baseUrl)
|
|
|
+ return urlHost.endsWith(".services.ai.azure.com")
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
export async function getOpenAiModels(baseUrl?: string, apiKey?: string) {
|