Răsfoiți Sursa

Add Mistral API provider

Saoud Rizwan 11 luni în urmă
părinte
comite
077fa84374

+ 9 - 0
package-lock.json

@@ -13,6 +13,7 @@
         "@anthropic-ai/vertex-sdk": "^0.4.1",
         "@aws-sdk/client-bedrock-runtime": "^3.706.0",
         "@google/generative-ai": "^0.18.0",
+        "@mistralai/mistralai": "^1.3.6",
         "@modelcontextprotocol/sdk": "^1.0.1",
         "@types/clone-deep": "^4.0.4",
         "@types/pdf-parse": "^1.1.4",
@@ -4254,6 +4255,14 @@
         "node": ">=8"
       }
     },
+    "node_modules/@mistralai/mistralai": {
+      "version": "1.3.6",
+      "resolved": "https://registry.npmjs.org/@mistralai/mistralai/-/mistralai-1.3.6.tgz",
+      "integrity": "sha512-2y7U5riZq+cIjKpxGO9y417XuZv9CpBXEAvbjRMzWPGhXY7U1ZXj4VO4H9riS2kFZqTR2yLEKSE6/pGWVVIqgQ==",
+      "peerDependencies": {
+        "zod": ">= 3"
+      }
+    },
     "node_modules/@mixmark-io/domino": {
       "version": "2.2.0",
       "resolved": "https://registry.npmjs.org/@mixmark-io/domino/-/domino-2.2.0.tgz",

+ 1 - 0
package.json

@@ -226,6 +226,7 @@
     "@anthropic-ai/vertex-sdk": "^0.4.1",
     "@aws-sdk/client-bedrock-runtime": "^3.706.0",
     "@google/generative-ai": "^0.18.0",
+    "@mistralai/mistralai": "^1.3.6",
     "@modelcontextprotocol/sdk": "^1.0.1",
     "@types/clone-deep": "^4.0.4",
     "@types/pdf-parse": "^1.1.4",

+ 3 - 0
src/api/index.ts

@@ -11,6 +11,7 @@ import { LmStudioHandler } from "./providers/lmstudio"
 import { GeminiHandler } from "./providers/gemini"
 import { OpenAiNativeHandler } from "./providers/openai-native"
 import { DeepSeekHandler } from "./providers/deepseek"
+import { MistralHandler } from "./providers/mistral"
 import { VsCodeLmHandler } from "./providers/vscode-lm"
 import { ApiStream } from "./transform/stream"
 
@@ -50,6 +51,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 			return new DeepSeekHandler(options)
 		case "vscode-lm":
 			return new VsCodeLmHandler(options)
+		case "mistral":
+			return new MistralHandler(options)
 		default:
 			return new AnthropicHandler(options)
 	}

+ 74 - 0
src/api/providers/mistral.ts

@@ -0,0 +1,74 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { Mistral } from "@mistralai/mistralai"
+import { ApiHandler } from "../"
+import {
+	ApiHandlerOptions,
+	mistralDefaultModelId,
+	MistralModelId,
+	mistralModels,
+	ModelInfo,
+	openAiNativeDefaultModelId,
+	OpenAiNativeModelId,
+	openAiNativeModels,
+} from "../../shared/api"
+import { convertToMistralMessages } from "../transform/mistral-format"
+import { ApiStream } from "../transform/stream"
+
+export class MistralHandler implements ApiHandler {
+	private options: ApiHandlerOptions
+	private client: Mistral
+
+	constructor(options: ApiHandlerOptions) {
+		this.options = options
+		this.client = new Mistral({
+			serverURL: "https://codestral.mistral.ai",
+			apiKey: this.options.mistralApiKey,
+		})
+	}
+
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		const stream = await this.client.chat.stream({
+			model: this.getModel().id,
+			// max_completion_tokens: this.getModel().info.maxTokens,
+			temperature: 0,
+			messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
+			stream: true,
+		})
+
+		for await (const chunk of stream) {
+			const delta = chunk.data.choices[0]?.delta
+			if (delta?.content) {
+				let content: string = ""
+				if (typeof delta.content === "string") {
+					content = delta.content
+				} else if (Array.isArray(delta.content)) {
+					content = delta.content.map((c) => (c.type === "text" ? c.text : "")).join("")
+				}
+				yield {
+					type: "text",
+					text: content,
+				}
+			}
+
+			if (chunk.data.usage) {
+				yield {
+					type: "usage",
+					inputTokens: chunk.data.usage.promptTokens || 0,
+					outputTokens: chunk.data.usage.completionTokens || 0,
+				}
+			}
+		}
+	}
+
+	getModel(): { id: MistralModelId; info: ModelInfo } {
+		const modelId = this.options.apiModelId
+		if (modelId && modelId in mistralModels) {
+			const id = modelId as MistralModelId
+			return { id, info: mistralModels[id] }
+		}
+		return {
+			id: mistralDefaultModelId,
+			info: mistralModels[mistralDefaultModelId],
+		}
+	}
+}

+ 92 - 0
src/api/transform/mistral-format.ts

@@ -0,0 +1,92 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { Mistral } from "@mistralai/mistralai"
+import { AssistantMessage } from "@mistralai/mistralai/models/components/assistantmessage"
+import { SystemMessage } from "@mistralai/mistralai/models/components/systemmessage"
+import { ToolMessage } from "@mistralai/mistralai/models/components/toolmessage"
+import { UserMessage } from "@mistralai/mistralai/models/components/usermessage"
+
+export type MistralMessage =
+	| (SystemMessage & { role: "system" })
+	| (UserMessage & { role: "user" })
+	| (AssistantMessage & { role: "assistant" })
+	| (ToolMessage & { role: "tool" })
+
+export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): MistralMessage[] {
+	const mistralMessages: MistralMessage[] = []
+	for (const anthropicMessage of anthropicMessages) {
+		if (typeof anthropicMessage.content === "string") {
+			mistralMessages.push({
+				role: anthropicMessage.role,
+				content: anthropicMessage.content,
+			})
+		} else {
+			if (anthropicMessage.role === "user") {
+				const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
+					nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
+					toolMessages: Anthropic.ToolResultBlockParam[]
+				}>(
+					(acc, part) => {
+						if (part.type === "tool_result") {
+							acc.toolMessages.push(part)
+						} else if (part.type === "text" || part.type === "image") {
+							acc.nonToolMessages.push(part)
+						} // user cannot send tool_use messages
+						return acc
+					},
+					{ nonToolMessages: [], toolMessages: [] },
+				)
+
+				if (nonToolMessages.length > 0) {
+					mistralMessages.push({
+						role: "user",
+						content: nonToolMessages.map((part) => {
+							if (part.type === "image") {
+								return {
+									type: "image_url",
+									imageUrl: {
+										url: `data:${part.source.media_type};base64,${part.source.data}`,
+									},
+								}
+							}
+							return { type: "text", text: part.text }
+						}),
+					})
+				}
+			} else if (anthropicMessage.role === "assistant") {
+				const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
+					nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
+					toolMessages: Anthropic.ToolUseBlockParam[]
+				}>(
+					(acc, part) => {
+						if (part.type === "tool_use") {
+							acc.toolMessages.push(part)
+						} else if (part.type === "text" || part.type === "image") {
+							acc.nonToolMessages.push(part)
+						} // assistant cannot send tool_result messages
+						return acc
+					},
+					{ nonToolMessages: [], toolMessages: [] },
+				)
+
+				let content: string | undefined
+				if (nonToolMessages.length > 0) {
+					content = nonToolMessages
+						.map((part) => {
+							if (part.type === "image") {
+								return "" // impossible as the assistant cannot send images
+							}
+							return part.text
+						})
+						.join("\n")
+				}
+
+				mistralMessages.push({
+					role: "assistant",
+					content,
+				})
+			}
+		}
+	}
+
+	return mistralMessages
+}

+ 7 - 0
src/core/webview/ClineProvider.ts

@@ -49,6 +49,7 @@ type SecretKey =
 	| "geminiApiKey"
 	| "openAiNativeApiKey"
 	| "deepSeekApiKey"
+	| "mistralApiKey"
 type GlobalStateKey =
 	| "apiProvider"
 	| "apiModelId"
@@ -1120,6 +1121,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			openRouterModelInfo,
 			openRouterUseMiddleOutTransform,
 			vsCodeLmModelSelector,
+			mistralApiKey,
 		} = apiConfiguration
 		await this.updateGlobalState("apiProvider", apiProvider)
 		await this.updateGlobalState("apiModelId", apiModelId)
@@ -1152,6 +1154,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
 		await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform)
 		await this.updateGlobalState("vsCodeLmModelSelector", vsCodeLmModelSelector)
+		await this.storeSecret("mistralApiKey", mistralApiKey)
 		if (this.cline) {
 			this.cline.api = buildApiHandler(apiConfiguration)
 		} 
@@ -1766,6 +1769,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			geminiApiKey,
 			openAiNativeApiKey,
 			deepSeekApiKey,
+			mistralApiKey,
 			azureApiVersion,
 			openAiStreamingEnabled,
 			openRouterModelId,
@@ -1826,6 +1830,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getSecret("geminiApiKey") as Promise<string | undefined>,
 			this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
 			this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
+			this.getSecret("mistralApiKey") as Promise<string | undefined>,
 			this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
 			this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
 			this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
@@ -1903,6 +1908,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				geminiApiKey,
 				openAiNativeApiKey,
 				deepSeekApiKey,
+				mistralApiKey,
 				azureApiVersion,
 				openAiStreamingEnabled,
 				openRouterModelId,
@@ -2041,6 +2047,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			"geminiApiKey",
 			"openAiNativeApiKey",
 			"deepSeekApiKey",
+			"mistralApiKey",
 		]
 		for (const key of secretKeys) {
 			await this.storeSecret(key, undefined)

+ 1 - 0
src/shared/__tests__/checkExistApiConfig.test.ts

@@ -49,6 +49,7 @@ describe('checkExistKey', () => {
       geminiApiKey: undefined,
       openAiNativeApiKey: undefined,
       deepSeekApiKey: undefined,
+      mistralApiKey: undefined,
       vsCodeLmModelSelector: undefined
     };
     expect(checkExistKey(config)).toBe(false);

+ 17 - 0
src/shared/api.ts

@@ -13,6 +13,7 @@ export type ApiProvider =
 	| "openai-native"
 	| "deepseek"
 	| "vscode-lm"
+	| "mistral"
 
 export interface ApiHandlerOptions {
 	apiModelId?: string
@@ -43,6 +44,7 @@ export interface ApiHandlerOptions {
 	lmStudioBaseUrl?: string
 	geminiApiKey?: string
 	openAiNativeApiKey?: string
+	mistralApiKey?: string
 	azureApiVersion?: string
 	openRouterUseMiddleOutTransform?: boolean
 	openAiStreamingEnabled?: boolean
@@ -549,3 +551,18 @@ export const deepSeekModels = {
 // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs
 export const azureOpenAiDefaultApiVersion = "2024-08-01-preview"
 
+
+// Mistral
+// https://docs.mistral.ai/getting-started/models/models_overview/
+export type MistralModelId = keyof typeof mistralModels
+export const mistralDefaultModelId: MistralModelId = "codestral-latest"
+export const mistralModels = {
+	"codestral-latest": {
+		maxTokens: 32_768,
+		contextWindow: 256_000,
+		supportsImages: false,
+		supportsPromptCache: false,
+		inputPrice: 0.3,
+		outputPrice: 0.9,
+	},
+} as const satisfies Record<string, ModelInfo>

+ 1 - 0
src/shared/checkExistApiConfig.ts

@@ -14,6 +14,7 @@ export function checkExistKey(config: ApiConfiguration | undefined) {
 			config.geminiApiKey,
 			config.openAiNativeApiKey,
 			config.deepSeekApiKey,
+			config.mistralApiKey,
 			config.vsCodeLmModelSelector,
 		].some((key) => key !== undefined)
 		: false;

+ 37 - 0
webview-ui/src/components/settings/ApiOptions.tsx

@@ -22,6 +22,8 @@ import {
 	geminiModels,
 	glamaDefaultModelId,
 	glamaDefaultModelInfo,
+	mistralDefaultModelId,
+	mistralModels,
 	openAiModelInfoSaneDefaults,
 	openAiNativeDefaultModelId,
 	openAiNativeModels,
@@ -145,6 +147,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
 						{ value: "bedrock", label: "AWS Bedrock" },
 						{ value: "glama", label: "Glama" },
 						{ value: "vscode-lm", label: "VS Code LM API" },
+						{ value: "mistral", label: "Mistral" },
 						{ value: "lmstudio", label: "LM Studio" },
 						{ value: "ollama", label: "Ollama" }
 					]}
@@ -258,6 +261,37 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
 				</div>
 			)}
 
+			{selectedProvider === "mistral" && (
+				<div>
+					<VSCodeTextField
+						value={apiConfiguration?.mistralApiKey || ""}
+						style={{ width: "100%" }}
+						type="password"
+						onInput={handleInputChange("mistralApiKey")}
+						placeholder="Enter API Key...">
+						<span style={{ fontWeight: 500 }}>Mistral API Key</span>
+					</VSCodeTextField>
+					<p
+						style={{
+							fontSize: "12px",
+							marginTop: 3,
+							color: "var(--vscode-descriptionForeground)",
+						}}>
+						This key is stored locally and only used to make API requests from this extension.
+						{!apiConfiguration?.mistralApiKey && (
+							<VSCodeLink
+								href="https://console.mistral.ai/codestral/"
+								style={{
+									display: "inline",
+									fontSize: "inherit",
+								}}>
+								You can get a Mistral API key by signing up here.
+							</VSCodeLink>
+						)}
+					</p>
+				</div>
+			)}
+
 			{selectedProvider === "openrouter" && (
 				<div>
 					<VSCodeTextField
@@ -778,6 +812,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
 							{selectedProvider === "gemini" && createDropdown(geminiModels)}
 							{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
 							{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
+							{selectedProvider === "mistral" && createDropdown(mistralModels)}
 						</div>
 
 						<ModelInfoView
@@ -978,6 +1013,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 				selectedModelId: apiConfiguration?.glamaModelId || glamaDefaultModelId,
 				selectedModelInfo: apiConfiguration?.glamaModelInfo || glamaDefaultModelInfo,
 			}
+		case "mistral":
+			return getProviderData(mistralModels, mistralDefaultModelId)
 		case "openrouter":
 			return {
 				selectedProvider: provider,

+ 5 - 0
webview-ui/src/utils/validate.ts

@@ -38,6 +38,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
 					return "You must provide a valid API key or choose a different provider."
 				}
 				break
+			case "mistral":
+				if (!apiConfiguration.mistralApiKey) {
+					return "You must provide a valid API key or choose a different provider."
+				}
+				break
 			case "openai":
 				if (
 					!apiConfiguration.openAiBaseUrl ||