|
|
@@ -1,6 +1,6 @@
|
|
|
import axios from "axios"
|
|
|
import { vi, describe, it, expect, beforeEach } from "vitest"
|
|
|
-import { LMStudioClient, LLM, LLMInstanceInfo } from "@lmstudio/sdk" // LLMInfo is a type
|
|
|
+import { LMStudioClient, LLM, LLMInstanceInfo, LLMInfo } from "@lmstudio/sdk"
|
|
|
import { getLMStudioModels, parseLMStudioModel } from "../lmstudio"
|
|
|
import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types" // ModelInfo is a type
|
|
|
|
|
|
@@ -11,12 +11,16 @@ const mockedAxios = axios as any
|
|
|
// Mock @lmstudio/sdk
|
|
|
const mockGetModelInfo = vi.fn()
|
|
|
const mockListLoaded = vi.fn()
|
|
|
+const mockListDownloadedModels = vi.fn()
|
|
|
vi.mock("@lmstudio/sdk", () => {
|
|
|
return {
|
|
|
LMStudioClient: vi.fn().mockImplementation(() => ({
|
|
|
llm: {
|
|
|
listLoaded: mockListLoaded,
|
|
|
},
|
|
|
+ system: {
|
|
|
+ listDownloadedModels: mockListDownloadedModels,
|
|
|
+ },
|
|
|
})),
|
|
|
}
|
|
|
})
|
|
|
@@ -28,6 +32,7 @@ describe("LMStudio Fetcher", () => {
|
|
|
MockedLMStudioClientConstructor.mockClear()
|
|
|
mockListLoaded.mockClear()
|
|
|
mockGetModelInfo.mockClear()
|
|
|
+ mockListDownloadedModels.mockClear()
|
|
|
})
|
|
|
|
|
|
describe("parseLMStudioModel", () => {
|
|
|
@@ -88,8 +93,40 @@ describe("LMStudio Fetcher", () => {
|
|
|
trainedForToolUse: false, // Added
|
|
|
}
|
|
|
|
|
|
- it("should fetch and parse models successfully", async () => {
|
|
|
+ it("should fetch downloaded models using system.listDownloadedModels", async () => {
|
|
|
+ const mockLLMInfo: LLMInfo = {
|
|
|
+ type: "llm" as const,
|
|
|
+ modelKey: "mistralai/devstral-small-2505",
|
|
|
+ format: "safetensors",
|
|
|
+ displayName: "Devstral Small 2505",
|
|
|
+ path: "mistralai/devstral-small-2505",
|
|
|
+ sizeBytes: 13277565112,
|
|
|
+ architecture: "mistral",
|
|
|
+ vision: false,
|
|
|
+ trainedForToolUse: false,
|
|
|
+ maxContextLength: 131072,
|
|
|
+ }
|
|
|
+
|
|
|
+ mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
|
|
|
+ mockListDownloadedModels.mockResolvedValueOnce([mockLLMInfo])
|
|
|
+
|
|
|
+ const result = await getLMStudioModels(baseUrl)
|
|
|
+
|
|
|
+ expect(mockedAxios.get).toHaveBeenCalledTimes(1)
|
|
|
+ expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
|
|
|
+ expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
|
|
|
+ expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
|
|
|
+ expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
|
|
|
+ expect(mockListDownloadedModels).toHaveBeenCalledWith("llm")
|
|
|
+ expect(mockListLoaded).not.toHaveBeenCalled()
|
|
|
+
|
|
|
+ const expectedParsedModel = parseLMStudioModel(mockLLMInfo)
|
|
|
+ expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should fall back to listLoaded when listDownloadedModels fails", async () => {
|
|
|
mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
|
|
|
+ mockListDownloadedModels.mockRejectedValueOnce(new Error("Method not available"))
|
|
|
mockListLoaded.mockResolvedValueOnce([{ getModelInfo: mockGetModelInfo }])
|
|
|
mockGetModelInfo.mockResolvedValueOnce(mockRawModel)
|
|
|
|
|
|
@@ -99,6 +136,7 @@ describe("LMStudio Fetcher", () => {
|
|
|
expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
|
|
|
expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
|
|
|
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
|
|
|
+ expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
|
|
|
expect(mockListLoaded).toHaveBeenCalledTimes(1)
|
|
|
|
|
|
const expectedParsedModel = parseLMStudioModel(mockRawModel)
|