Explorar o código

feat: Add support for Azure AI Inference Service with DeepSeek-V3 model (#2241)

* feat: Add support for Azure AI Inference Service with DeepSeek-V3 model

* refactor: extract Azure AI inference path to constant to avoid duplication

* fix(tests): update RequestyHandler tests to properly handle Azure inference and streaming

* fix(api): remove duplicate constant and update requesty tests

* refactor: remove unused isAzure property from OpenAiHandler

* refactor(openai): remove unused isAzure and extract Azure check
Thomas Jeung hai 9 meses
pai
achega
e9bcee5e2a

+ 115 - 4
src/api/providers/__tests__/openai.test.ts

@@ -1,6 +1,7 @@
 import { OpenAiHandler } from "../openai"
 import { ApiHandlerOptions } from "../../../shared/api"
 import { Anthropic } from "@anthropic-ai/sdk"
+import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "../constants"
 
 // Mock OpenAI client
 const mockCreate = jest.fn()
@@ -202,10 +203,13 @@ describe("OpenAiHandler", () => {
 		it("should complete prompt successfully", async () => {
 			const result = await handler.completePrompt("Test prompt")
 			expect(result).toBe("Test response")
-			expect(mockCreate).toHaveBeenCalledWith({
-				model: mockOptions.openAiModelId,
-				messages: [{ role: "user", content: "Test prompt" }],
-			})
+			expect(mockCreate).toHaveBeenCalledWith(
+				{
+					model: mockOptions.openAiModelId,
+					messages: [{ role: "user", content: "Test prompt" }],
+				},
+				{},
+			)
 		})
 
 		it("should handle API errors", async () => {
@@ -241,4 +245,111 @@ describe("OpenAiHandler", () => {
 			expect(model.info).toBeDefined()
 		})
 	})
+
+	describe("Azure AI Inference Service", () => {
+		const azureOptions = {
+			...mockOptions,
+			openAiBaseUrl: "https://test.services.ai.azure.com",
+			openAiModelId: "deepseek-v3",
+			azureApiVersion: "2024-05-01-preview",
+		}
+
+		it("should initialize with Azure AI Inference Service configuration", () => {
+			const azureHandler = new OpenAiHandler(azureOptions)
+			expect(azureHandler).toBeInstanceOf(OpenAiHandler)
+			expect(azureHandler.getModel().id).toBe(azureOptions.openAiModelId)
+		})
+
+		it("should handle streaming responses with Azure AI Inference Service", async () => {
+			const azureHandler = new OpenAiHandler(azureOptions)
+			const systemPrompt = "You are a helpful assistant."
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user",
+					content: "Hello!",
+				},
+			]
+
+			const stream = azureHandler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response")
+
+			// Verify the API call was made with correct Azure AI Inference Service path
+			expect(mockCreate).toHaveBeenCalledWith(
+				{
+					model: azureOptions.openAiModelId,
+					messages: [
+						{ role: "system", content: systemPrompt },
+						{ role: "user", content: "Hello!" },
+					],
+					stream: true,
+					stream_options: { include_usage: true },
+					temperature: 0,
+				},
+				{ path: "/models/chat/completions" },
+			)
+		})
+
+		it("should handle non-streaming responses with Azure AI Inference Service", async () => {
+			const azureHandler = new OpenAiHandler({
+				...azureOptions,
+				openAiStreamingEnabled: false,
+			})
+			const systemPrompt = "You are a helpful assistant."
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user",
+					content: "Hello!",
+				},
+			]
+
+			const stream = azureHandler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunk = chunks.find((chunk) => chunk.type === "text")
+			const usageChunk = chunks.find((chunk) => chunk.type === "usage")
+
+			expect(textChunk).toBeDefined()
+			expect(textChunk?.text).toBe("Test response")
+			expect(usageChunk).toBeDefined()
+			expect(usageChunk?.inputTokens).toBe(10)
+			expect(usageChunk?.outputTokens).toBe(5)
+
+			// Verify the API call was made with correct Azure AI Inference Service path
+			expect(mockCreate).toHaveBeenCalledWith(
+				{
+					model: azureOptions.openAiModelId,
+					messages: [
+						{ role: "user", content: systemPrompt },
+						{ role: "user", content: "Hello!" },
+					],
+				},
+				{ path: "/models/chat/completions" },
+			)
+		})
+
+		it("should handle completePrompt with Azure AI Inference Service", async () => {
+			const azureHandler = new OpenAiHandler(azureOptions)
+			const result = await azureHandler.completePrompt("Test prompt")
+			expect(result).toBe("Test response")
+			expect(mockCreate).toHaveBeenCalledWith(
+				{
+					model: azureOptions.openAiModelId,
+					messages: [{ role: "user", content: "Test prompt" }],
+				},
+				{ path: "/models/chat/completions" },
+			)
+		})
+	})
 })

+ 36 - 4
src/api/providers/__tests__/requesty.test.ts

@@ -38,8 +38,29 @@ describe("RequestyHandler", () => {
 		// Clear mocks
 		jest.clearAllMocks()
 
-		// Setup mock create function
-		mockCreate = jest.fn()
+		// Setup mock create function that preserves params
+		let lastParams: any
+		mockCreate = jest.fn().mockImplementation((params) => {
+			lastParams = params
+			return {
+				[Symbol.asyncIterator]: async function* () {
+					yield {
+						choices: [{ delta: { content: "Hello" } }],
+					}
+					yield {
+						choices: [{ delta: { content: " world" } }],
+						usage: {
+							prompt_tokens: 30,
+							completion_tokens: 10,
+							prompt_tokens_details: {
+								cached_tokens: 15,
+								caching_tokens: 5,
+							},
+						},
+					}
+				},
+			}
+		})
 
 		// Mock OpenAI constructor
 		;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(
@@ -47,7 +68,13 @@ describe("RequestyHandler", () => {
 				({
 					chat: {
 						completions: {
-							create: mockCreate,
+							create: (params: any) => {
+								// Store params for verification
+								const result = mockCreate(params)
+								// Make params available for test assertions
+								;(result as any).params = params
+								return result
+							},
 						},
 					},
 				}) as unknown as OpenAI,
@@ -122,7 +149,12 @@ describe("RequestyHandler", () => {
 					},
 				])
 
-				expect(mockCreate).toHaveBeenCalledWith({
+				// Get the actual params that were passed
+				const calls = mockCreate.mock.calls
+				expect(calls.length).toBe(1)
+				const actualParams = calls[0][0]
+
+				expect(actualParams).toEqual({
 					model: defaultOptions.requestyModelId,
 					temperature: 0,
 					messages: [

+ 66 - 29
src/api/providers/openai.ts

@@ -15,8 +15,7 @@ import { convertToSimpleMessages } from "../transform/simple-format"
 import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
 import { BaseProvider } from "./base-provider"
 import { XmlMatcher } from "../../utils/xml-matcher"
-
-const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6
+import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
 
 export const defaultHeaders = {
 	"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
@@ -25,6 +24,8 @@ export const defaultHeaders = {
 
 export interface OpenAiHandlerOptions extends ApiHandlerOptions {}
 
+const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"
+
 export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: OpenAiHandlerOptions
 	private client: OpenAI
@@ -35,17 +36,19 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 
 		const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
 		const apiKey = this.options.openAiApiKey ?? "not-provided"
-		let urlHost: string
-
-		try {
-			urlHost = new URL(this.options.openAiBaseUrl ?? "").host
-		} catch (error) {
-			// Likely an invalid `openAiBaseUrl`; we're still working on
-			// proper settings validation.
-			urlHost = ""
-		}
+		const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
+		const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
+		const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure
 
-		if (urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure) {
+		if (isAzureAiInference) {
+			// Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
+			this.client = new OpenAI({
+				baseURL,
+				apiKey,
+				defaultHeaders,
+				defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
+			})
+		} else if (isAzureOpenAi) {
 			// Azure API shape slightly differs from the core API shape:
 			// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
 			this.client = new AzureOpenAI({
@@ -64,6 +67,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 		const modelUrl = this.options.openAiBaseUrl ?? ""
 		const modelId = this.options.openAiModelId ?? ""
 		const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
+		const isAzureAiInference = this._isAzureAiInference(modelUrl)
+		const urlHost = this._getUrlHost(modelUrl)
 		const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
 		const ark = modelUrl.includes(".volces.com")
 		if (modelId.startsWith("o3-mini")) {
@@ -132,7 +137,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				requestOptions.max_tokens = modelInfo.maxTokens
 			}
 
-			const stream = await this.client.chat.completions.create(requestOptions)
+			const stream = await this.client.chat.completions.create(
+				requestOptions,
+				isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
+			)
 
 			const matcher = new XmlMatcher(
 				"think",
@@ -185,7 +193,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 					: [systemMessage, ...convertToOpenAiMessages(messages)],
 			}
 
-			const response = await this.client.chat.completions.create(requestOptions)
+			const response = await this.client.chat.completions.create(
+				requestOptions,
+				this._isAzureAiInference(modelUrl) ? { path: AZURE_AI_INFERENCE_PATH } : {},
+			)
 
 			yield {
 				type: "text",
@@ -212,12 +223,16 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 
 	async completePrompt(prompt: string): Promise<string> {
 		try {
+			const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
 			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
 				model: this.getModel().id,
 				messages: [{ role: "user", content: prompt }],
 			}
 
-			const response = await this.client.chat.completions.create(requestOptions)
+			const response = await this.client.chat.completions.create(
+				requestOptions,
+				isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
+			)
 			return response.choices[0]?.message.content || ""
 		} catch (error) {
 			if (error instanceof Error) {
@@ -233,19 +248,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 		messages: Anthropic.Messages.MessageParam[],
 	): ApiStream {
 		if (this.options.openAiStreamingEnabled ?? true) {
-			const stream = await this.client.chat.completions.create({
-				model: modelId,
-				messages: [
-					{
-						role: "developer",
-						content: `Formatting re-enabled\n${systemPrompt}`,
-					},
-					...convertToOpenAiMessages(messages),
-				],
-				stream: true,
-				stream_options: { include_usage: true },
-				reasoning_effort: this.getModel().info.reasoningEffort,
-			})
+			const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
+
+			const stream = await this.client.chat.completions.create(
+				{
+					model: modelId,
+					messages: [
+						{
+							role: "developer",
+							content: `Formatting re-enabled\n${systemPrompt}`,
+						},
+						...convertToOpenAiMessages(messages),
+					],
+					stream: true,
+					stream_options: { include_usage: true },
+					reasoning_effort: this.getModel().info.reasoningEffort,
+				},
+				methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
+			)
 
 			yield* this.handleStreamResponse(stream)
 		} else {
@@ -260,7 +280,12 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				],
 			}
 
-			const response = await this.client.chat.completions.create(requestOptions)
+			const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
+
+			const response = await this.client.chat.completions.create(
+				requestOptions,
+				methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
+			)
 
 			yield {
 				type: "text",
@@ -289,6 +314,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 			}
 		}
 	}
+	private _getUrlHost(baseUrl?: string): string {
+		try {
+			return new URL(baseUrl ?? "").host
+		} catch (error) {
+			return ""
+		}
+	}
+
+	private _isAzureAiInference(baseUrl?: string): boolean {
+		const urlHost = this._getUrlHost(baseUrl)
+		return urlHost.endsWith(".services.ai.azure.com")
+	}
 }
 
 export async function getOpenAiModels(baseUrl?: string, apiKey?: string) {