Просмотр исходного кода

fix: improve LM Studio model detection to show all downloaded models (#5047)

Daniel 6 месяцев назад
Родитель
Сommit
041c28d8e5

+ 40 - 2
src/api/providers/fetchers/__tests__/lmstudio.test.ts

@@ -1,6 +1,6 @@
 import axios from "axios"
 import { vi, describe, it, expect, beforeEach } from "vitest"
-import { LMStudioClient, LLM, LLMInstanceInfo } from "@lmstudio/sdk" // LLMInfo is a type
+import { LMStudioClient, LLM, LLMInstanceInfo, LLMInfo } from "@lmstudio/sdk"
 import { getLMStudioModels, parseLMStudioModel } from "../lmstudio"
 import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types" // ModelInfo is a type
 
@@ -11,12 +11,16 @@ const mockedAxios = axios as any
 // Mock @lmstudio/sdk
 const mockGetModelInfo = vi.fn()
 const mockListLoaded = vi.fn()
+const mockListDownloadedModels = vi.fn()
 vi.mock("@lmstudio/sdk", () => {
 	return {
 		LMStudioClient: vi.fn().mockImplementation(() => ({
 			llm: {
 				listLoaded: mockListLoaded,
 			},
+			system: {
+				listDownloadedModels: mockListDownloadedModels,
+			},
 		})),
 	}
 })
@@ -28,6 +32,7 @@ describe("LMStudio Fetcher", () => {
 		MockedLMStudioClientConstructor.mockClear()
 		mockListLoaded.mockClear()
 		mockGetModelInfo.mockClear()
+		mockListDownloadedModels.mockClear()
 	})
 
 	describe("parseLMStudioModel", () => {
@@ -88,8 +93,40 @@ describe("LMStudio Fetcher", () => {
 			trainedForToolUse: false, // Added
 		}
 
-		it("should fetch and parse models successfully", async () => {
+		it("should fetch downloaded models using system.listDownloadedModels", async () => {
+			const mockLLMInfo: LLMInfo = {
+				type: "llm" as const,
+				modelKey: "mistralai/devstral-small-2505",
+				format: "safetensors",
+				displayName: "Devstral Small 2505",
+				path: "mistralai/devstral-small-2505",
+				sizeBytes: 13277565112,
+				architecture: "mistral",
+				vision: false,
+				trainedForToolUse: false,
+				maxContextLength: 131072,
+			}
+
+			mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
+			mockListDownloadedModels.mockResolvedValueOnce([mockLLMInfo])
+
+			const result = await getLMStudioModels(baseUrl)
+
+			expect(mockedAxios.get).toHaveBeenCalledTimes(1)
+			expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
+			expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
+			expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
+			expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
+			expect(mockListDownloadedModels).toHaveBeenCalledWith("llm")
+			expect(mockListLoaded).not.toHaveBeenCalled()
+
+			const expectedParsedModel = parseLMStudioModel(mockLLMInfo)
+			expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel })
+		})
+
+		it("should fall back to listLoaded when listDownloadedModels fails", async () => {
 			mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
+			mockListDownloadedModels.mockRejectedValueOnce(new Error("Method not available"))
 			mockListLoaded.mockResolvedValueOnce([{ getModelInfo: mockGetModelInfo }])
 			mockGetModelInfo.mockResolvedValueOnce(mockRawModel)
 
@@ -99,6 +136,7 @@ describe("LMStudio Fetcher", () => {
 			expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
 			expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
 			expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
+			expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
 			expect(mockListLoaded).toHaveBeenCalledTimes(1)
 
 			const expectedParsedModel = parseLMStudioModel(mockRawModel)

+ 24 - 8
src/api/providers/fetchers/lmstudio.ts

@@ -2,14 +2,17 @@ import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types"
 import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk"
 import axios from "axios"
 
-export const parseLMStudioModel = (rawModel: LLMInstanceInfo): ModelInfo => {
+export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => {
+	// Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models)
+	const contextLength = "contextLength" in rawModel ? rawModel.contextLength : rawModel.maxContextLength
+
 	const modelInfo: ModelInfo = Object.assign({}, lMStudioDefaultModelInfo, {
 		description: `${rawModel.displayName} - ${rawModel.path}`,
-		contextWindow: rawModel.contextLength,
+		contextWindow: contextLength,
 		supportsPromptCache: true,
 		supportsImages: rawModel.vision,
 		supportsComputerUse: false,
-		maxTokens: rawModel.contextLength,
+		maxTokens: contextLength,
 	})
 
 	return modelInfo
@@ -33,12 +36,25 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom
 		await axios.get(`${baseUrl}/v1/models`)
 
 		const client = new LMStudioClient({ baseUrl: lmsUrl })
-		const response = (await client.llm.listLoaded().then((models: LLM[]) => {
-			return Promise.all(models.map((m) => m.getModelInfo()))
-		})) as Array<LLMInstanceInfo>
 
-		for (const lmstudioModel of response) {
-			models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
+		// First, try to get all downloaded models
+		try {
+			const downloadedModels = await client.system.listDownloadedModels("llm")
+			for (const model of downloadedModels) {
+				// Use the model path as the key since that's what users select
+				models[model.path] = parseLMStudioModel(model)
+			}
+		} catch (error) {
+			console.warn("Failed to list downloaded models, falling back to loaded models only")
+
+			// Fall back to listing only loaded models
+			const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
+				return Promise.all(models.map((m) => m.getModelInfo()))
+			})) as Array<LLMInstanceInfo>
+
+			for (const lmstudioModel of loadedModels) {
+				models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
+			}
 		}
 	} catch (error) {
 		if (error.code === "ECONNREFUSED") {

+ 6 - 0
src/core/webview/webviewMessageHandler.ts

@@ -448,6 +448,9 @@ export const webviewMessageHandler = async (
 			// Specific handler for Ollama models only
 			const { apiConfiguration: ollamaApiConfig } = await provider.getState()
 			try {
+				// Flush cache first to ensure fresh models
+				await flushModels("ollama")
+
 				const ollamaModels = await getModels({
 					provider: "ollama",
 					baseUrl: ollamaApiConfig.ollamaBaseUrl,
@@ -469,6 +472,9 @@ export const webviewMessageHandler = async (
 			// Specific handler for LM Studio models only
 			const { apiConfiguration: lmStudioApiConfig } = await provider.getState()
 			try {
+				// Flush cache first to ensure fresh models
+				await flushModels("lmstudio")
+
 				const lmStudioModels = await getModels({
 					provider: "lmstudio",
 					baseUrl: lmStudioApiConfig.lmStudioBaseUrl,

+ 8 - 1
webview-ui/src/components/settings/providers/LMStudio.tsx

@@ -1,4 +1,4 @@
-import { useCallback, useState, useMemo } from "react"
+import { useCallback, useState, useMemo, useEffect } from "react"
 import { useEvent } from "react-use"
 import { Trans } from "react-i18next"
 import { Checkbox } from "vscrui"
@@ -9,6 +9,7 @@ import type { ProviderSettings } from "@roo-code/types"
 import { useAppTranslation } from "@src/i18n/TranslationContext"
 import { ExtensionMessage } from "@roo/ExtensionMessage"
 import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
+import { vscode } from "@src/utils/vscode"
 
 import { inputEventTransform } from "../transforms"
 
@@ -49,6 +50,12 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
 
 	useEvent("message", onMessage)
 
+	// Refresh models on mount
+	useEffect(() => {
+		// Request fresh models - the handler now flushes cache automatically
+		vscode.postMessage({ type: "requestLmStudioModels" })
+	}, [])
+
 	// Check if the selected model exists in the fetched models
 	const modelNotAvailable = useMemo(() => {
 		const selectedModel = apiConfiguration?.lmStudioModelId

+ 8 - 1
webview-ui/src/components/settings/providers/Ollama.tsx

@@ -1,4 +1,4 @@
-import { useState, useCallback, useMemo } from "react"
+import { useState, useCallback, useMemo, useEffect } from "react"
 import { useEvent } from "react-use"
 import { VSCodeTextField, VSCodeRadioGroup, VSCodeRadio } from "@vscode/webview-ui-toolkit/react"
 
@@ -8,6 +8,7 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"
 
 import { useAppTranslation } from "@src/i18n/TranslationContext"
 import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
+import { vscode } from "@src/utils/vscode"
 
 import { inputEventTransform } from "../transforms"
 
@@ -48,6 +49,12 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
 
 	useEvent("message", onMessage)
 
+	// Refresh models on mount
+	useEffect(() => {
+		// Request fresh models - the handler now flushes cache automatically
+		vscode.postMessage({ type: "requestOllamaModels" })
+	}, [])
+
 	// Check if the selected model exists in the fetched models
 	const modelNotAvailable = useMemo(() => {
 		const selectedModel = apiConfiguration?.ollamaModelId