2
0
Эх сурвалжийг харах

Merge pull request #977 from d-oit/mistral

Additional models for mistral api provider
Matt Rubens 10 сар өмнө
parent
commit
2928c802fc

+ 126 - 0
src/api/providers/__tests__/mistral.test.ts

@@ -0,0 +1,126 @@
+import { MistralHandler } from "../mistral"
+import { ApiHandlerOptions, mistralDefaultModelId } from "../../../shared/api"
+import { Anthropic } from "@anthropic-ai/sdk"
+import { ApiStreamTextChunk } from "../../transform/stream"
+
+// Mock Mistral client
+const mockCreate = jest.fn()
+jest.mock("@mistralai/mistralai", () => {
+	return {
+		Mistral: jest.fn().mockImplementation(() => ({
+			chat: {
+				stream: mockCreate.mockImplementation(async (options) => {
+					const stream = {
+						[Symbol.asyncIterator]: async function* () {
+							yield {
+								data: {
+									choices: [
+										{
+											delta: { content: "Test response" },
+											index: 0,
+										},
+									],
+								},
+							}
+						},
+					}
+					return stream
+				}),
+			},
+		})),
+	}
+})
+
+describe("MistralHandler", () => {
+	let handler: MistralHandler
+	let mockOptions: ApiHandlerOptions
+
+	beforeEach(() => {
+		mockOptions = {
+			apiModelId: "codestral-latest", // Update to match the actual model ID
+			mistralApiKey: "test-api-key",
+			includeMaxTokens: true,
+			modelTemperature: 0,
+		}
+		handler = new MistralHandler(mockOptions)
+		mockCreate.mockClear()
+	})
+
+	describe("constructor", () => {
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(MistralHandler)
+			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
+		})
+
+		it("should throw error if API key is missing", () => {
+			expect(() => {
+				new MistralHandler({
+					...mockOptions,
+					mistralApiKey: undefined,
+				})
+			}).toThrow("Mistral API key is required")
+		})
+
+		it("should use custom base URL if provided", () => {
+			const customBaseUrl = "https://custom.mistral.ai/v1"
+			const handlerWithCustomUrl = new MistralHandler({
+				...mockOptions,
+				mistralCodestralUrl: customBaseUrl,
+			})
+			expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler)
+		})
+	})
+
+	describe("getModel", () => {
+		it("should return correct model info", () => {
+			const model = handler.getModel()
+			expect(model.id).toBe(mockOptions.apiModelId)
+			expect(model.info).toBeDefined()
+			expect(model.info.supportsPromptCache).toBe(false)
+		})
+	})
+
+	describe("createMessage", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text", text: "Hello!" }],
+			},
+		]
+
+		it("should create message successfully", async () => {
+			const iterator = handler.createMessage(systemPrompt, messages)
+			const result = await iterator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith({
+				model: mockOptions.apiModelId,
+				messages: expect.any(Array),
+				maxTokens: expect.any(Number),
+				temperature: 0,
+			})
+
+			expect(result.value).toBeDefined()
+			expect(result.done).toBe(false)
+		})
+
+		it("should handle streaming response correctly", async () => {
+			const iterator = handler.createMessage(systemPrompt, messages)
+			const results: ApiStreamTextChunk[] = []
+
+			for await (const chunk of iterator) {
+				if ("text" in chunk) {
+					results.push(chunk as ApiStreamTextChunk)
+				}
+			}
+
+			expect(results.length).toBeGreaterThan(0)
+			expect(results[0].text).toBe("Test response")
+		})
+
+		it("should handle errors gracefully", async () => {
+			mockCreate.mockRejectedValueOnce(new Error("API Error"))
+			await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
+		})
+	})
+})

+ 20 - 7
src/api/providers/mistral.ts

@@ -21,23 +21,36 @@ export class MistralHandler implements ApiHandler {
 	private client: Mistral
 
 	constructor(options: ApiHandlerOptions) {
+		if (!options.mistralApiKey) {
+			throw new Error("Mistral API key is required")
+		}
+
 		this.options = options
+		const baseUrl = this.getBaseUrl()
+		console.debug(`[Roo Code] MistralHandler using baseUrl: ${baseUrl}`)
 		this.client = new Mistral({
-			serverURL: "https://codestral.mistral.ai",
+			serverURL: baseUrl,
 			apiKey: this.options.mistralApiKey,
 		})
 	}
 
+	private getBaseUrl(): string {
+		const modelId = this.options.apiModelId
+		if (modelId?.startsWith("codestral-")) {
+			return this.options.mistralCodestralUrl || "https://codestral.mistral.ai"
+		}
+		return "https://api.mistral.ai"
+	}
+
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		const stream = await this.client.chat.stream({
-			model: this.getModel().id,
-			// max_completion_tokens: this.getModel().info.maxTokens,
+		const response = await this.client.chat.stream({
+			model: this.options.apiModelId || mistralDefaultModelId,
+			messages: convertToMistralMessages(messages),
+			maxTokens: this.options.includeMaxTokens ? this.getModel().info.maxTokens : undefined,
 			temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE,
-			messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
-			stream: true,
 		})
 
-		for await (const chunk of stream) {
+		for await (const chunk of response) {
 			const delta = chunk.data.choices[0]?.delta
 			if (delta?.content) {
 				let content: string = ""

+ 6 - 0
src/core/webview/ClineProvider.ts

@@ -127,6 +127,7 @@ type GlobalStateKey =
 	| "requestyModelInfo"
 	| "unboundModelInfo"
 	| "modelTemperature"
+	| "mistralCodestralUrl"
 	| "maxOpenTabsContext"
 
 export const GlobalFileNames = {
@@ -1637,6 +1638,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			openRouterUseMiddleOutTransform,
 			vsCodeLmModelSelector,
 			mistralApiKey,
+			mistralCodestralUrl,
 			unboundApiKey,
 			unboundModelId,
 			unboundModelInfo,
@@ -1682,6 +1684,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform)
 		await this.updateGlobalState("vsCodeLmModelSelector", vsCodeLmModelSelector)
 		await this.storeSecret("mistralApiKey", mistralApiKey)
+		await this.updateGlobalState("mistralCodestralUrl", mistralCodestralUrl)
 		await this.storeSecret("unboundApiKey", unboundApiKey)
 		await this.updateGlobalState("unboundModelId", unboundModelId)
 		await this.updateGlobalState("unboundModelInfo", unboundModelInfo)
@@ -2521,6 +2524,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			openAiNativeApiKey,
 			deepSeekApiKey,
 			mistralApiKey,
+			mistralCodestralUrl,
 			azureApiVersion,
 			openAiStreamingEnabled,
 			openRouterModelId,
@@ -2602,6 +2606,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
 			this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
 			this.getSecret("mistralApiKey") as Promise<string | undefined>,
+			this.getGlobalState("mistralCodestralUrl") as Promise<string | undefined>,
 			this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
 			this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
 			this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
@@ -2700,6 +2705,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				openAiNativeApiKey,
 				deepSeekApiKey,
 				mistralApiKey,
+				mistralCodestralUrl,
 				azureApiVersion,
 				openAiStreamingEnabled,
 				openRouterModelId,

+ 42 - 1
src/shared/api.ts

@@ -52,6 +52,7 @@ export interface ApiHandlerOptions {
 	geminiApiKey?: string
 	openAiNativeApiKey?: string
 	mistralApiKey?: string
+	mistralCodestralUrl?: string // New option for Codestral URL
 	azureApiVersion?: string
 	openRouterUseMiddleOutTransform?: boolean
 	openAiStreamingEnabled?: boolean
@@ -670,13 +671,53 @@ export type MistralModelId = keyof typeof mistralModels
 export const mistralDefaultModelId: MistralModelId = "codestral-latest"
 export const mistralModels = {
 	"codestral-latest": {
-		maxTokens: 32_768,
+		maxTokens: 256_000,
 		contextWindow: 256_000,
 		supportsImages: false,
 		supportsPromptCache: false,
 		inputPrice: 0.3,
 		outputPrice: 0.9,
 	},
+	"mistral-large-latest": {
+		maxTokens: 131_000,
+		contextWindow: 131_000,
+		supportsImages: false,
+		supportsPromptCache: false,
+		inputPrice: 2.0,
+		outputPrice: 6.0,
+	},
+	"ministral-8b-latest": {
+		maxTokens: 131_000,
+		contextWindow: 131_000,
+		supportsImages: false,
+		supportsPromptCache: false,
+		inputPrice: 0.1,
+		outputPrice: 0.1,
+	},
+	"ministral-3b-latest": {
+		maxTokens: 131_000,
+		contextWindow: 131_000,
+		supportsImages: false,
+		supportsPromptCache: false,
+		inputPrice: 0.04,
+		outputPrice: 0.04,
+	},
+	"mistral-small-latest": {
+		maxTokens: 32_000,
+		contextWindow: 32_000,
+		supportsImages: false,
+		supportsPromptCache: false,
+		inputPrice: 0.2,
+		outputPrice: 0.6,
+	},
+	"pixtral-large-latest": {
+		maxTokens: 131_000,
+		contextWindow: 131_000,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 2.0,
+		outputPrice: 6.0,
+	},
 } as const satisfies Record<string, ModelInfo>
 
 // Unbound Security

+ 25 - 2
webview-ui/src/components/settings/ApiOptions.tsx

@@ -314,6 +314,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage, fromWelcomeView }: A
 						placeholder="Enter API Key...">
 						<span style={{ fontWeight: 500 }}>Mistral API Key</span>
 					</VSCodeTextField>
+
 					<p
 						style={{
 							fontSize: "12px",
@@ -323,15 +324,37 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage, fromWelcomeView }: A
 						This key is stored locally and only used to make API requests from this extension.
 						{!apiConfiguration?.mistralApiKey && (
 							<VSCodeLink
-								href="https://console.mistral.ai/codestral/"
+								href="https://console.mistral.ai/"
 								style={{
 									display: "inline",
 									fontSize: "inherit",
 								}}>
-								You can get a Mistral API key by signing up here.
+								You can get a La Plateforme (api.mistral.ai) / Codestral (codestral.mistral.ai) API key
+								by signing up here.
 							</VSCodeLink>
 						)}
 					</p>
+
+					{apiConfiguration?.apiModelId?.startsWith("codestral-") && (
+						<div>
+							<VSCodeTextField
+								value={apiConfiguration?.mistralCodestralUrl || ""}
+								style={{ width: "100%", marginTop: "10px" }}
+								type="url"
+								onBlur={handleInputChange("mistralCodestralUrl")}
+								placeholder="Default: https://codestral.mistral.ai">
+								<span style={{ fontWeight: 500 }}>Codestral Base URL (Optional)</span>
+							</VSCodeTextField>
+							<p
+								style={{
+									fontSize: "12px",
+									marginTop: 3,
+									color: "var(--vscode-descriptionForeground)",
+								}}>
+								Set alternative URL for Codestral model: https://api.mistral.ai
+							</p>
+						</div>
+					)}
 				</div>
 			)}