Browse Source

Add LM Studio provider

Saoud Rizwan 1 year ago
parent
commit
39bc35eec1

+ 1 - 0
.github/ISSUE_TEMPLATE/bug_report.yml

@@ -15,6 +15,7 @@ body:
               - AWS Bedrock
               - OpenAI
               - OpenAI Compatible
+              - LM Studio
               - Ollama
       validations:
           required: true

+ 2 - 2
README.md

@@ -44,7 +44,7 @@ Thanks to [Claude 3.5 Sonnet's agentic coding capabilities](https://www-cdn.ant
 
 ### Use any API and Model
 
-Cline supports API providers like OpenRouter, Anthropic, OpenAI, Google Gemini, AWS Bedrock, Azure, and GCP Vertex. You can also configure any OpenAI compatible API, or use a local model through Ollama. If you're using OpenRouter, the extension fetches their latest model list, allowing you to use the newest models as soon as they're available.
+Cline supports API providers like OpenRouter, Anthropic, OpenAI, Google Gemini, AWS Bedrock, Azure, and GCP Vertex. You can also configure any OpenAI compatible API, or use a local model through LM Studio/Ollama. If you're using OpenRouter, the extension fetches their latest model list, allowing you to use the newest models as soon as they're available.
 
 The extension also keeps track of total tokens and API usage cost for the entire task loop and individual requests, keeping you informed of spend every step of the way.
 
@@ -104,7 +104,7 @@ To contribute to the project, start by exploring [open issues](https://github.co
 <details>
 <summary>Local Development Instructions</summary>
 
-1. Clone the repository *(Requires [git-lfs](https://git-lfs.com/))*:
+1. Clone the repository _(Requires [git-lfs](https://git-lfs.com/))_:
     ```bash
     git clone https://github.com/cline/cline.git
     ```

+ 3 - 0
src/api/index.ts

@@ -6,6 +6,7 @@ import { OpenRouterHandler } from "./providers/openrouter"
 import { VertexHandler } from "./providers/vertex"
 import { OpenAiHandler } from "./providers/openai"
 import { OllamaHandler } from "./providers/ollama"
+import { LmStudioHandler } from "./providers/lmstudio"
 import { GeminiHandler } from "./providers/gemini"
 import { OpenAiNativeHandler } from "./providers/openai-native"
 import { ApiStream } from "./transform/stream"
@@ -30,6 +31,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 			return new OpenAiHandler(options)
 		case "ollama":
 			return new OllamaHandler(options)
+		case "lmstudio":
+			return new LmStudioHandler(options)
 		case "gemini":
 			return new GeminiHandler(options)
 		case "openai-native":

+ 56 - 0
src/api/providers/lmstudio.ts

@@ -0,0 +1,56 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
+import { ApiHandler } from "../"
+import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiStream } from "../transform/stream"
+
+export class LmStudioHandler implements ApiHandler {
+	private options: ApiHandlerOptions
+	private client: OpenAI
+
+	constructor(options: ApiHandlerOptions) {
+		this.options = options
+		this.client = new OpenAI({
+			baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1",
+			apiKey: "noop",
+		})
+	}
+
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+			{ role: "system", content: systemPrompt },
+			...convertToOpenAiMessages(messages),
+		]
+
+		try {
+			const stream = await this.client.chat.completions.create({
+				model: this.getModel().id,
+				messages: openAiMessages,
+				temperature: 0,
+				stream: true,
+			})
+			for await (const chunk of stream) {
+				const delta = chunk.choices[0]?.delta
+				if (delta?.content) {
+					yield {
+						type: "text",
+						text: delta.content,
+					}
+				}
+			}
+		} catch (error) {
+			// LM Studio doesn't return an error code/body for now
+			throw new Error(
+				"Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Cline's prompts."
+			)
+		}
+	}
+
+	getModel(): { id: string; info: ModelInfo } {
+		return {
+			id: this.options.lmStudioModelId || "",
+			info: openAiModelInfoSaneDefaults,
+		}
+	}
+}

+ 35 - 0
src/core/webview/ClineProvider.ts

@@ -51,6 +51,8 @@ type GlobalStateKey =
 	| "openAiModelId"
 	| "ollamaModelId"
 	| "ollamaBaseUrl"
+	| "lmStudioModelId"
+	| "lmStudioBaseUrl"
 	| "anthropicBaseUrl"
 	| "azureApiVersion"
 	| "openRouterModelId"
@@ -359,6 +361,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 								openAiModelId,
 								ollamaModelId,
 								ollamaBaseUrl,
+								lmStudioModelId,
+								lmStudioBaseUrl,
 								anthropicBaseUrl,
 								geminiApiKey,
 								openAiNativeApiKey,
@@ -382,6 +386,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 							await this.updateGlobalState("openAiModelId", openAiModelId)
 							await this.updateGlobalState("ollamaModelId", ollamaModelId)
 							await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl)
+							await this.updateGlobalState("lmStudioModelId", lmStudioModelId)
+							await this.updateGlobalState("lmStudioBaseUrl", lmStudioBaseUrl)
 							await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
 							await this.storeSecret("geminiApiKey", geminiApiKey)
 							await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
@@ -442,6 +448,10 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 						const ollamaModels = await this.getOllamaModels(message.text)
 						this.postMessageToWebview({ type: "ollamaModels", ollamaModels })
 						break
+					case "requestLmStudioModels":
+						const lmStudioModels = await this.getLmStudioModels(message.text)
+						this.postMessageToWebview({ type: "lmStudioModels", lmStudioModels })
+						break
 					case "refreshOpenRouterModels":
 						await this.refreshOpenRouterModels()
 						break
@@ -509,6 +519,25 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		}
 	}
 
+	// LM Studio
+
+	async getLmStudioModels(baseUrl?: string) {
+		try {
+			if (!baseUrl) {
+				baseUrl = "http://localhost:1234"
+			}
+			if (!URL.canParse(baseUrl)) {
+				return []
+			}
+			const response = await axios.get(`${baseUrl}/v1/models`)
+			const modelsArray = response.data?.data?.map((model: any) => model.id) || []
+			const models = [...new Set<string>(modelsArray)]
+			return models
+		} catch (error) {
+			return []
+		}
+	}
+
 	// OpenRouter
 
 	async handleOpenRouterCallback(code: string) {
@@ -835,6 +864,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			openAiModelId,
 			ollamaModelId,
 			ollamaBaseUrl,
+			lmStudioModelId,
+			lmStudioBaseUrl,
 			anthropicBaseUrl,
 			geminiApiKey,
 			openAiNativeApiKey,
@@ -862,6 +893,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("openAiModelId") as Promise<string | undefined>,
 			this.getGlobalState("ollamaModelId") as Promise<string | undefined>,
 			this.getGlobalState("ollamaBaseUrl") as Promise<string | undefined>,
+			this.getGlobalState("lmStudioModelId") as Promise<string | undefined>,
+			this.getGlobalState("lmStudioBaseUrl") as Promise<string | undefined>,
 			this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
 			this.getSecret("geminiApiKey") as Promise<string | undefined>,
 			this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
@@ -906,6 +939,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				openAiModelId,
 				ollamaModelId,
 				ollamaBaseUrl,
+				lmStudioModelId,
+				lmStudioBaseUrl,
 				anthropicBaseUrl,
 				geminiApiKey,
 				openAiNativeApiKey,

+ 2 - 0
src/shared/ExtensionMessage.ts

@@ -10,6 +10,7 @@ export interface ExtensionMessage {
 		| "state"
 		| "selectedImages"
 		| "ollamaModels"
+		| "lmStudioModels"
 		| "theme"
 		| "workspaceUpdated"
 		| "invoke"
@@ -21,6 +22,7 @@ export interface ExtensionMessage {
 	state?: ExtensionState
 	images?: string[]
 	ollamaModels?: string[]
+	lmStudioModels?: string[]
 	filePaths?: string[]
 	partialMessage?: ClineMessage
 	openRouterModels?: Record<string, ModelInfo>

+ 1 - 0
src/shared/WebviewMessage.ts

@@ -17,6 +17,7 @@ export interface WebviewMessage {
 		| "exportTaskWithId"
 		| "resetState"
 		| "requestOllamaModels"
+		| "requestLmStudioModels"
 		| "openImage"
 		| "openFile"
 		| "openMention"

+ 3 - 0
src/shared/api.ts

@@ -5,6 +5,7 @@ export type ApiProvider =
 	| "vertex"
 	| "openai"
 	| "ollama"
+	| "lmstudio"
 	| "gemini"
 	| "openai-native"
 
@@ -27,6 +28,8 @@ export interface ApiHandlerOptions {
 	openAiModelId?: string
 	ollamaModelId?: string
 	ollamaBaseUrl?: string
+	lmStudioModelId?: string
+	lmStudioBaseUrl?: string
 	geminiApiKey?: string
 	openAiNativeApiKey?: string
 	azureApiVersion?: string

+ 1 - 0
webview-ui/src/components/chat/TaskHeader.tsx

@@ -96,6 +96,7 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 		return (
 			apiConfiguration?.apiProvider !== "openai" &&
 			apiConfiguration?.apiProvider !== "ollama" &&
+			apiConfiguration?.apiProvider !== "lmstudio" &&
 			apiConfiguration?.apiProvider !== "gemini"
 		)
 	}, [apiConfiguration?.apiProvider])

+ 89 - 7
webview-ui/src/components/settings/ApiOptions.tsx

@@ -45,6 +45,7 @@ interface ApiOptionsProps {
 const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) => {
 	const { apiConfiguration, setApiConfiguration, uriScheme } = useExtensionState()
 	const [ollamaModels, setOllamaModels] = useState<string[]>([])
+	const [lmStudioModels, setLmStudioModels] = useState<string[]>([])
 	const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl)
 	const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion)
 	const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)
@@ -57,23 +58,27 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 		return normalizeApiConfiguration(apiConfiguration)
 	}, [apiConfiguration])
 
-	// Poll ollama models
-	const requestOllamaModels = useCallback(() => {
+	// Poll ollama/lmstudio models
+	const requestLocalModels = useCallback(() => {
 		if (selectedProvider === "ollama") {
 			vscode.postMessage({ type: "requestOllamaModels", text: apiConfiguration?.ollamaBaseUrl })
+		} else if (selectedProvider === "lmstudio") {
+			vscode.postMessage({ type: "requestLmStudioModels", text: apiConfiguration?.lmStudioBaseUrl })
 		}
-	}, [selectedProvider, apiConfiguration?.ollamaBaseUrl])
+	}, [selectedProvider, apiConfiguration?.ollamaBaseUrl, apiConfiguration?.lmStudioBaseUrl])
 	useEffect(() => {
-		if (selectedProvider === "ollama") {
-			requestOllamaModels()
+		if (selectedProvider === "ollama" || selectedProvider === "lmstudio") {
+			requestLocalModels()
 		}
-	}, [selectedProvider, requestOllamaModels])
-	useInterval(requestOllamaModels, selectedProvider === "ollama" ? 2000 : null)
+	}, [selectedProvider, requestLocalModels])
+	useInterval(requestLocalModels, selectedProvider === "ollama" || selectedProvider === "lmstudio" ? 2000 : null)
 
 	const handleMessage = useCallback((event: MessageEvent) => {
 		const message: ExtensionMessage = event.data
 		if (message.type === "ollamaModels" && message.ollamaModels) {
 			setOllamaModels(message.ollamaModels)
+		} else if (message.type === "lmStudioModels" && message.lmStudioModels) {
+			setLmStudioModels(message.lmStudioModels)
 		}
 	}, [])
 	useEvent("message", handleMessage)
@@ -128,6 +133,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 					<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
 					<VSCodeOption value="openai-native">OpenAI</VSCodeOption>
 					<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
+					<VSCodeOption value="lmstudio">LM Studio</VSCodeOption>
 					<VSCodeOption value="ollama">Ollama</VSCodeOption>
 				</VSCodeDropdown>
 			</div>
@@ -463,6 +469,75 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 				</div>
 			)}
 
+			{selectedProvider === "lmstudio" && (
+				<div>
+					<VSCodeTextField
+						value={apiConfiguration?.lmStudioBaseUrl || ""}
+						style={{ width: "100%" }}
+						type="url"
+						onInput={handleInputChange("lmStudioBaseUrl")}
+						placeholder={"Default: http://localhost:1234"}>
+						<span style={{ fontWeight: 500 }}>Base URL (optional)</span>
+					</VSCodeTextField>
+					<VSCodeTextField
+						value={apiConfiguration?.lmStudioModelId || ""}
+						style={{ width: "100%" }}
+						onInput={handleInputChange("lmStudioModelId")}
+						placeholder={"e.g. meta-llama-3.1-8b-instruct"}>
+						<span style={{ fontWeight: 500 }}>Model ID</span>
+					</VSCodeTextField>
+					{lmStudioModels.length > 0 && (
+						<VSCodeRadioGroup
+							value={
+								lmStudioModels.includes(apiConfiguration?.lmStudioModelId || "")
+									? apiConfiguration?.lmStudioModelId
+									: ""
+							}
+							onChange={(e) => {
+								const value = (e.target as HTMLInputElement)?.value
+								// need to check value first since radio group returns empty string sometimes
+								if (value) {
+									handleInputChange("lmStudioModelId")({
+										target: { value },
+									})
+								}
+							}}>
+							{lmStudioModels.map((model) => (
+								<VSCodeRadio
+									key={model}
+									value={model}
+									checked={apiConfiguration?.lmStudioModelId === model}>
+									{model}
+								</VSCodeRadio>
+							))}
+						</VSCodeRadioGroup>
+					)}
+					<p
+						style={{
+							fontSize: "12px",
+							marginTop: "5px",
+							color: "var(--vscode-descriptionForeground)",
+						}}>
+						LM Studio allows you to run models locally on your computer. For instructions on how to get
+						started, see their
+						<VSCodeLink href="https://lmstudio.ai/docs" style={{ display: "inline", fontSize: "inherit" }}>
+							quickstart guide.
+						</VSCodeLink>
+						You will also need to start LM Studio's{" "}
+						<VSCodeLink
+							href="https://lmstudio.ai/docs/basics/server"
+							style={{ display: "inline", fontSize: "inherit" }}>
+							local server
+						</VSCodeLink>{" "}
+						feature to use it with this extension.{" "}
+						<span style={{ color: "var(--vscode-errorForeground)" }}>
+							(<span style={{ fontWeight: 500 }}>Note:</span> Cline uses complex prompts and works best
+							with Claude models. Less capable models may not work as expected.)
+						</span>
+					</p>
+				</div>
+			)}
+
 			{selectedProvider === "ollama" && (
 				<div>
 					<VSCodeTextField
@@ -543,6 +618,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
 			{selectedProvider !== "openrouter" &&
 				selectedProvider !== "openai" &&
 				selectedProvider !== "ollama" &&
+				selectedProvider !== "lmstudio" &&
 				showModelOptions && (
 					<>
 						<div className="dropdown-container">
@@ -758,6 +834,12 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 				selectedModelId: apiConfiguration?.ollamaModelId || "",
 				selectedModelInfo: openAiModelInfoSaneDefaults,
 			}
+		case "lmstudio":
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.lmStudioModelId || "",
+				selectedModelInfo: openAiModelInfoSaneDefaults,
+			}
 		default:
 			return getProviderData(anthropicModels, anthropicDefaultModelId)
 	}

+ 1 - 0
webview-ui/src/context/ExtensionStateContext.tsx

@@ -54,6 +54,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
 							config.vertexProjectId,
 							config.openAiApiKey,
 							config.ollamaModelId,
+							config.lmStudioModelId,
 							config.geminiApiKey,
 							config.openAiNativeApiKey,
 					  ].some((key) => key !== undefined)

+ 5 - 0
webview-ui/src/utils/validate.ts

@@ -47,6 +47,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
 					return "You must provide a valid model ID."
 				}
 				break
+			case "lmstudio":
+				if (!apiConfiguration.lmStudioModelId) {
+					return "You must provide a valid model ID."
+				}
+				break
 		}
 	}
 	return undefined