openai.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import { Anthropic } from "@anthropic-ai/sdk"
  2. import OpenAI, { AzureOpenAI } from "openai"
  3. import axios from "axios"
  4. import {
  5. type ModelInfo,
  6. azureOpenAiDefaultApiVersion,
  7. openAiModelInfoSaneDefaults,
  8. DEEP_SEEK_DEFAULT_TEMPERATURE,
  9. OPENAI_AZURE_AI_INFERENCE_PATH,
  10. } from "@roo-code/types"
  11. import type { ApiHandlerOptions } from "../../shared/api"
  12. import { XmlMatcher } from "../../utils/xml-matcher"
  13. import { convertToOpenAiMessages } from "../transform/openai-format"
  14. import { convertToR1Format } from "../transform/r1-format"
  15. import { convertToSimpleMessages } from "../transform/simple-format"
  16. import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
  17. import { getModelParams } from "../transform/model-params"
  18. import { DEFAULT_HEADERS } from "./constants"
  19. import { BaseProvider } from "./base-provider"
  20. import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
  21. // TODO: Rename this to OpenAICompatibleHandler. Also, I think the
  22. // `OpenAINativeHandler` can subclass from this, since it's obviously
  23. // compatible with the OpenAI API. We can also rename it to `OpenAIHandler`.
  24. export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
  25. protected options: ApiHandlerOptions
  26. private client: OpenAI
  27. constructor(options: ApiHandlerOptions) {
  28. super()
  29. this.options = options
  30. const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
  31. const apiKey = this.options.openAiApiKey ?? "not-provided"
  32. const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
  33. const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
  34. const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure
  35. const headers = {
  36. ...DEFAULT_HEADERS,
  37. ...(this.options.openAiHeaders || {}),
  38. }
  39. if (isAzureAiInference) {
  40. // Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
  41. this.client = new OpenAI({
  42. baseURL,
  43. apiKey,
  44. defaultHeaders: headers,
  45. defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
  46. })
  47. } else if (isAzureOpenAi) {
  48. // Azure API shape slightly differs from the core API shape:
  49. // https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
  50. this.client = new AzureOpenAI({
  51. baseURL,
  52. apiKey,
  53. apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
  54. defaultHeaders: headers,
  55. })
  56. } else {
  57. this.client = new OpenAI({
  58. baseURL,
  59. apiKey,
  60. defaultHeaders: headers,
  61. })
  62. }
  63. }
  64. override async *createMessage(
  65. systemPrompt: string,
  66. messages: Anthropic.Messages.MessageParam[],
  67. metadata?: ApiHandlerCreateMessageMetadata,
  68. ): ApiStream {
  69. const { info: modelInfo, reasoning } = this.getModel()
  70. const modelUrl = this.options.openAiBaseUrl ?? ""
  71. const modelId = this.options.openAiModelId ?? ""
  72. const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
  73. const enabledLegacyFormat = this.options.openAiLegacyFormat ?? false
  74. const isAzureAiInference = this._isAzureAiInference(modelUrl)
  75. const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
  76. const ark = modelUrl.includes(".volces.com")
  77. if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
  78. yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
  79. return
  80. }
  81. if (this.options.openAiStreamingEnabled ?? true) {
  82. let systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = {
  83. role: "system",
  84. content: systemPrompt,
  85. }
  86. let convertedMessages
  87. if (deepseekReasoner) {
  88. convertedMessages = convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
  89. } else if (ark || enabledLegacyFormat) {
  90. convertedMessages = [systemMessage, ...convertToSimpleMessages(messages)]
  91. } else {
  92. if (modelInfo.supportsPromptCache) {
  93. systemMessage = {
  94. role: "system",
  95. content: [
  96. {
  97. type: "text",
  98. text: systemPrompt,
  99. // @ts-ignore-next-line
  100. cache_control: { type: "ephemeral" },
  101. },
  102. ],
  103. }
  104. }
  105. convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)]
  106. if (modelInfo.supportsPromptCache) {
  107. // Note: the following logic is copied from openrouter:
  108. // Add cache_control to the last two user messages
  109. // (note: this works because we only ever add one user message at a time, but if we added multiple we'd need to mark the user message before the last assistant message)
  110. const lastTwoUserMessages = convertedMessages.filter((msg) => msg.role === "user").slice(-2)
  111. lastTwoUserMessages.forEach((msg) => {
  112. if (typeof msg.content === "string") {
  113. msg.content = [{ type: "text", text: msg.content }]
  114. }
  115. if (Array.isArray(msg.content)) {
  116. // NOTE: this is fine since env details will always be added at the end. but if it weren't there, and the user added a image_url type message, it would pop a text part before it and then move it after to the end.
  117. let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
  118. if (!lastTextPart) {
  119. lastTextPart = { type: "text", text: "..." }
  120. msg.content.push(lastTextPart)
  121. }
  122. // @ts-ignore-next-line
  123. lastTextPart["cache_control"] = { type: "ephemeral" }
  124. }
  125. })
  126. }
  127. }
  128. const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
  129. const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
  130. model: modelId,
  131. temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
  132. messages: convertedMessages,
  133. stream: true as const,
  134. ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
  135. ...(reasoning && reasoning),
  136. }
  137. // Add max_tokens if needed
  138. this.addMaxTokensIfNeeded(requestOptions, modelInfo)
  139. const stream = await this.client.chat.completions.create(
  140. requestOptions,
  141. isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
  142. )
  143. const matcher = new XmlMatcher(
  144. "think",
  145. (chunk) =>
  146. ({
  147. type: chunk.matched ? "reasoning" : "text",
  148. text: chunk.data,
  149. }) as const,
  150. )
  151. let lastUsage
  152. for await (const chunk of stream) {
  153. const delta = chunk.choices[0]?.delta ?? {}
  154. if (delta.content) {
  155. for (const chunk of matcher.update(delta.content)) {
  156. yield chunk
  157. }
  158. }
  159. if ("reasoning_content" in delta && delta.reasoning_content) {
  160. yield {
  161. type: "reasoning",
  162. text: (delta.reasoning_content as string | undefined) || "",
  163. }
  164. }
  165. if (chunk.usage) {
  166. lastUsage = chunk.usage
  167. }
  168. }
  169. for (const chunk of matcher.final()) {
  170. yield chunk
  171. }
  172. if (lastUsage) {
  173. yield this.processUsageMetrics(lastUsage, modelInfo)
  174. }
  175. } else {
  176. // o1 for instance doesnt support streaming, non-1 temp, or system prompt
  177. const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {
  178. role: "user",
  179. content: systemPrompt,
  180. }
  181. const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
  182. model: modelId,
  183. messages: deepseekReasoner
  184. ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
  185. : enabledLegacyFormat
  186. ? [systemMessage, ...convertToSimpleMessages(messages)]
  187. : [systemMessage, ...convertToOpenAiMessages(messages)],
  188. }
  189. // Add max_tokens if needed
  190. this.addMaxTokensIfNeeded(requestOptions, modelInfo)
  191. const response = await this.client.chat.completions.create(
  192. requestOptions,
  193. this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
  194. )
  195. yield {
  196. type: "text",
  197. text: response.choices[0]?.message.content || "",
  198. }
  199. yield this.processUsageMetrics(response.usage, modelInfo)
  200. }
  201. }
  202. protected processUsageMetrics(usage: any, _modelInfo?: ModelInfo): ApiStreamUsageChunk {
  203. return {
  204. type: "usage",
  205. inputTokens: usage?.prompt_tokens || 0,
  206. outputTokens: usage?.completion_tokens || 0,
  207. cacheWriteTokens: usage?.cache_creation_input_tokens || undefined,
  208. cacheReadTokens: usage?.cache_read_input_tokens || undefined,
  209. }
  210. }
  211. override getModel() {
  212. const id = this.options.openAiModelId ?? ""
  213. const info = this.options.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults
  214. const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options })
  215. return { id, info, ...params }
  216. }
  217. async completePrompt(prompt: string): Promise<string> {
  218. try {
  219. const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
  220. const model = this.getModel()
  221. const modelInfo = model.info
  222. const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
  223. model: model.id,
  224. messages: [{ role: "user", content: prompt }],
  225. }
  226. // Add max_tokens if needed
  227. this.addMaxTokensIfNeeded(requestOptions, modelInfo)
  228. const response = await this.client.chat.completions.create(
  229. requestOptions,
  230. isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
  231. )
  232. return response.choices[0]?.message.content || ""
  233. } catch (error) {
  234. if (error instanceof Error) {
  235. throw new Error(`OpenAI completion error: ${error.message}`)
  236. }
  237. throw error
  238. }
  239. }
  240. private async *handleO3FamilyMessage(
  241. modelId: string,
  242. systemPrompt: string,
  243. messages: Anthropic.Messages.MessageParam[],
  244. ): ApiStream {
  245. const modelInfo = this.getModel().info
  246. const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
  247. if (this.options.openAiStreamingEnabled ?? true) {
  248. const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
  249. const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
  250. model: modelId,
  251. messages: [
  252. {
  253. role: "developer",
  254. content: `Formatting re-enabled\n${systemPrompt}`,
  255. },
  256. ...convertToOpenAiMessages(messages),
  257. ],
  258. stream: true,
  259. ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
  260. reasoning_effort: modelInfo.reasoningEffort,
  261. temperature: undefined,
  262. }
  263. // O3 family models do not support the deprecated max_tokens parameter
  264. // but they do support max_completion_tokens (the modern OpenAI parameter)
  265. // This allows O3 models to limit response length when includeMaxTokens is enabled
  266. this.addMaxTokensIfNeeded(requestOptions, modelInfo)
  267. const stream = await this.client.chat.completions.create(
  268. requestOptions,
  269. methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
  270. )
  271. yield* this.handleStreamResponse(stream)
  272. } else {
  273. const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
  274. model: modelId,
  275. messages: [
  276. {
  277. role: "developer",
  278. content: `Formatting re-enabled\n${systemPrompt}`,
  279. },
  280. ...convertToOpenAiMessages(messages),
  281. ],
  282. reasoning_effort: modelInfo.reasoningEffort,
  283. temperature: undefined,
  284. }
  285. // O3 family models do not support the deprecated max_tokens parameter
  286. // but they do support max_completion_tokens (the modern OpenAI parameter)
  287. // This allows O3 models to limit response length when includeMaxTokens is enabled
  288. this.addMaxTokensIfNeeded(requestOptions, modelInfo)
  289. const response = await this.client.chat.completions.create(
  290. requestOptions,
  291. methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
  292. )
  293. yield {
  294. type: "text",
  295. text: response.choices[0]?.message.content || "",
  296. }
  297. yield this.processUsageMetrics(response.usage)
  298. }
  299. }
  300. private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
  301. for await (const chunk of stream) {
  302. const delta = chunk.choices[0]?.delta
  303. if (delta?.content) {
  304. yield {
  305. type: "text",
  306. text: delta.content,
  307. }
  308. }
  309. if (chunk.usage) {
  310. yield {
  311. type: "usage",
  312. inputTokens: chunk.usage.prompt_tokens || 0,
  313. outputTokens: chunk.usage.completion_tokens || 0,
  314. }
  315. }
  316. }
  317. }
  318. private _getUrlHost(baseUrl?: string): string {
  319. try {
  320. return new URL(baseUrl ?? "").host
  321. } catch (error) {
  322. return ""
  323. }
  324. }
  325. private _isGrokXAI(baseUrl?: string): boolean {
  326. const urlHost = this._getUrlHost(baseUrl)
  327. return urlHost.includes("x.ai")
  328. }
  329. private _isAzureAiInference(baseUrl?: string): boolean {
  330. const urlHost = this._getUrlHost(baseUrl)
  331. return urlHost.endsWith(".services.ai.azure.com")
  332. }
  333. /**
  334. * Adds max_completion_tokens to the request body if needed based on provider configuration
  335. * Note: max_tokens is deprecated in favor of max_completion_tokens as per OpenAI documentation
  336. * O3 family models handle max_tokens separately in handleO3FamilyMessage
  337. */
  338. private addMaxTokensIfNeeded(
  339. requestOptions:
  340. | OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
  341. | OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
  342. modelInfo: ModelInfo,
  343. ): void {
  344. // Only add max_completion_tokens if includeMaxTokens is true
  345. if (this.options.includeMaxTokens === true) {
  346. // Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
  347. // Using max_completion_tokens as max_tokens is deprecated
  348. requestOptions.max_completion_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
  349. }
  350. }
  351. }
  352. export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiHeaders?: Record<string, string>) {
  353. try {
  354. if (!baseUrl) {
  355. return []
  356. }
  357. // Trim whitespace from baseUrl to handle cases where users accidentally include spaces
  358. const trimmedBaseUrl = baseUrl.trim()
  359. if (!URL.canParse(trimmedBaseUrl)) {
  360. return []
  361. }
  362. const config: Record<string, any> = {}
  363. const headers: Record<string, string> = {
  364. ...DEFAULT_HEADERS,
  365. ...(openAiHeaders || {}),
  366. }
  367. if (apiKey) {
  368. headers["Authorization"] = `Bearer ${apiKey}`
  369. }
  370. if (Object.keys(headers).length > 0) {
  371. config["headers"] = headers
  372. }
  373. const response = await axios.get(`${trimmedBaseUrl}/models`, config)
  374. const modelsArray = response.data?.data?.map((model: any) => model.id) || []
  375. return [...new Set<string>(modelsArray)]
  376. } catch (error) {
  377. return []
  378. }
  379. }