Просмотр исходного кода

first pass on speculative decoding with LMStudio

Adam Larson 10 месяцев назад
Родитель
Сommit
1c8f9ed683

+ 24 - 5
src/api/providers/lmstudio.ts

@@ -30,13 +30,24 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
 		]
 
 		try {
-			const stream = await this.client.chat.completions.create({
+			// Create params object with optional draft model
+			const params: any = {
 				model: this.getModel().id,
 				messages: openAiMessages,
 				temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
 				stream: true,
-			})
-			for await (const chunk of stream) {
+			}
+
+			// Add draft model if speculative decoding is enabled and a draft model is specified
+			if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
+				params.draft_model = this.options.lmStudioDraftModelId
+			}
+
+			const results = await this.client.chat.completions.create(params)
+
+			// Stream handling
+			// @ts-ignore
+			for await (const chunk of results) {
 				const delta = chunk.choices[0]?.delta
 				if (delta?.content) {
 					yield {
@@ -62,12 +73,20 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
 
 	async completePrompt(prompt: string): Promise<string> {
 		try {
-			const response = await this.client.chat.completions.create({
+			// Create params object with optional draft model
+			const params: any = {
 				model: this.getModel().id,
 				messages: [{ role: "user", content: prompt }],
 				temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
 				stream: false,
-			})
+			}
+
+			// Add draft model if speculative decoding is enabled and a draft model is specified
+			if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
+				params.draft_model = this.options.lmStudioDraftModelId
+			}
+
+			const response = await this.client.chat.completions.create(params)
 			return response.choices[0]?.message.content || ""
 		} catch (error) {
 			throw new Error(

+ 10 - 0
src/core/webview/ClineProvider.ts

@@ -1676,6 +1676,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			modelTemperature,
 			modelMaxTokens,
 			modelMaxThinkingTokens,
+			lmStudioDraftModelId,
+			lmStudioSpeculativeDecodingEnabled,
 		} = apiConfiguration
 		await Promise.all([
 			this.updateGlobalState("apiProvider", apiProvider),
@@ -1725,6 +1727,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.updateGlobalState("modelTemperature", modelTemperature),
 			this.updateGlobalState("modelMaxTokens", modelMaxTokens),
 			this.updateGlobalState("anthropicThinking", modelMaxThinkingTokens),
+			this.updateGlobalState("lmStudioDraftModelId", lmStudioDraftModelId),
+			this.updateGlobalState("lmStudioSpeculativeDecodingEnabled", lmStudioSpeculativeDecodingEnabled),
 		])
 		if (this.cline) {
 			this.cline.api = buildApiHandler(apiConfiguration)
@@ -2221,6 +2225,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			modelMaxThinkingTokens,
 			maxOpenTabsContext,
 			browserToolEnabled,
+			lmStudioSpeculativeDecodingEnabled,
+			lmStudioDraftModelId,
 		] = await Promise.all([
 			this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
 			this.getGlobalState("apiModelId") as Promise<string | undefined>,
@@ -2306,6 +2312,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("anthropicThinking") as Promise<number | undefined>,
 			this.getGlobalState("maxOpenTabsContext") as Promise<number | undefined>,
 			this.getGlobalState("browserToolEnabled") as Promise<boolean | undefined>,
+			this.getGlobalState("lmStudioSpeculativeDecodingEnabled") as Promise<boolean | undefined>,
+			this.getGlobalState("lmStudioDraftModelId") as Promise<string | undefined>,
 		])
 
 		let apiProvider: ApiProvider
@@ -2371,6 +2379,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				modelTemperature,
 				modelMaxTokens,
 				modelMaxThinkingTokens,
+				lmStudioSpeculativeDecodingEnabled,
+				lmStudioDraftModelId,
 			},
 			lastShownAnnouncementId,
 			customInstructions,

+ 2 - 0
src/shared/api.ts

@@ -49,6 +49,8 @@ export interface ApiHandlerOptions {
 	ollamaBaseUrl?: string
 	lmStudioModelId?: string
 	lmStudioBaseUrl?: string
+	lmStudioDraftModelId?: string
+	lmStudioSpeculativeDecodingEnabled?: boolean
 	geminiApiKey?: string
 	openAiNativeApiKey?: string
 	mistralApiKey?: string

+ 2 - 0
src/shared/globalState.ts

@@ -41,6 +41,8 @@ export type GlobalStateKey =
 	| "ollamaBaseUrl"
 	| "lmStudioModelId"
 	| "lmStudioBaseUrl"
+	| "lmStudioDraftModelId"
+	| "lmStudioSpeculativeDecodingEnabled"
 	| "anthropicBaseUrl"
 	| "azureApiVersion"
 	| "openAiStreamingEnabled"

+ 73 - 0
webview-ui/src/components/settings/ApiOptions.tsx

@@ -1107,6 +1107,79 @@ const ApiOptions = ({
 							))}
 						</VSCodeRadioGroup>
 					)}
+					<div style={{ display: "flex", alignItems: "center", marginTop: "16px", marginBottom: "8px" }}>
+						<Checkbox
+							checked={apiConfiguration?.lmStudioSpeculativeDecodingEnabled === true}
+							onChange={(checked) => {
+								// Explicitly set the boolean value using direct method
+								setApiConfigurationField("lmStudioSpeculativeDecodingEnabled", checked)
+							}}>
+							Enable Speculative Decoding
+						</Checkbox>
+					</div>
+					{apiConfiguration?.lmStudioSpeculativeDecodingEnabled && (
+						<>
+							<VSCodeTextField
+								value={apiConfiguration?.lmStudioDraftModelId || ""}
+								style={{ width: "100%" }}
+								onInput={handleInputChange("lmStudioDraftModelId")}
+								placeholder={"e.g. lmstudio-community/llama-3.2-1b-instruct"}>
+								<span className="font-medium">Draft Model ID</span>
+							</VSCodeTextField>
+							<div
+								style={{
+									fontSize: "11px",
+									color: "var(--vscode-descriptionForeground)",
+									marginTop: 4,
+									display: "flex",
+									alignItems: "center",
+									gap: 4,
+								}}>
+								<i className="codicon codicon-info" style={{ fontSize: "12px" }}></i>
+								<span>
+									Draft model must be from the same model family for speculative decoding to work
+									correctly.
+								</span>
+							</div>
+							{lmStudioModels.length > 0 && (
+								<>
+									<div style={{ marginTop: "8px" }}>
+										<span className="font-medium">Select Draft Model</span>
+									</div>
+									<VSCodeRadioGroup
+										value={
+											lmStudioModels.includes(apiConfiguration?.lmStudioDraftModelId || "")
+												? apiConfiguration?.lmStudioDraftModelId
+												: ""
+										}
+										onChange={handleInputChange("lmStudioDraftModelId")}>
+										{lmStudioModels.map((model) => (
+											<VSCodeRadio key={`draft-${model}`} value={model}>
+												{model}
+											</VSCodeRadio>
+										))}
+									</VSCodeRadioGroup>
+									{lmStudioModels.length === 0 && (
+										<div
+											style={{
+												fontSize: "12px",
+												marginTop: "8px",
+												padding: "6px",
+												backgroundColor: "var(--vscode-inputValidation-infoBackground)",
+												border: "1px solid var(--vscode-inputValidation-infoBorder)",
+												borderRadius: "3px",
+												color: "var(--vscode-inputValidation-infoForeground)",
+											}}>
+											<i className="codicon codicon-info" style={{ marginRight: "5px" }}></i>
+											No draft models found. Please ensure LM Studio is running with Server Mode
+											enabled.
+										</div>
+									)}
+								</>
+							)}
+						</>
+					)}
+
 					<p
 						style={{
 							fontSize: "12px",