|
@@ -58,23 +58,47 @@ vi.mock("openai", () => {
|
|
|
}
|
|
}
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
|
|
+// Mock LM Studio fetcher
|
|
|
|
|
+vi.mock("../fetchers/lmstudio", () => ({
|
|
|
|
|
+ getLMStudioModels: vi.fn(),
|
|
|
|
|
+}))
|
|
|
|
|
+
|
|
|
import type { Anthropic } from "@anthropic-ai/sdk"
|
|
import type { Anthropic } from "@anthropic-ai/sdk"
|
|
|
|
|
+import type { ModelInfo } from "@roo-code/types"
|
|
|
|
|
|
|
|
import { LmStudioHandler } from "../lm-studio"
|
|
import { LmStudioHandler } from "../lm-studio"
|
|
|
import type { ApiHandlerOptions } from "../../../shared/api"
|
|
import type { ApiHandlerOptions } from "../../../shared/api"
|
|
|
|
|
+import { getLMStudioModels } from "../fetchers/lmstudio"
|
|
|
|
|
+
|
|
|
|
|
+// Get the mocked function
|
|
|
|
|
+const mockGetLMStudioModels = vi.mocked(getLMStudioModels)
|
|
|
|
|
|
|
|
describe("LmStudioHandler", () => {
|
|
describe("LmStudioHandler", () => {
|
|
|
let handler: LmStudioHandler
|
|
let handler: LmStudioHandler
|
|
|
let mockOptions: ApiHandlerOptions
|
|
let mockOptions: ApiHandlerOptions
|
|
|
|
|
|
|
|
|
|
+ const mockModelInfo: ModelInfo = {
|
|
|
|
|
+ maxTokens: 8192,
|
|
|
|
|
+ contextWindow: 32768,
|
|
|
|
|
+ supportsImages: false,
|
|
|
|
|
+ supportsComputerUse: false,
|
|
|
|
|
+ supportsPromptCache: true,
|
|
|
|
|
+ inputPrice: 0,
|
|
|
|
|
+ outputPrice: 0,
|
|
|
|
|
+ cacheWritesPrice: 0,
|
|
|
|
|
+ cacheReadsPrice: 0,
|
|
|
|
|
+ description: "Test Model - local-model",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
beforeEach(() => {
|
|
beforeEach(() => {
|
|
|
mockOptions = {
|
|
mockOptions = {
|
|
|
apiModelId: "local-model",
|
|
apiModelId: "local-model",
|
|
|
lmStudioModelId: "local-model",
|
|
lmStudioModelId: "local-model",
|
|
|
- lmStudioBaseUrl: "http://localhost:1234/v1",
|
|
|
|
|
|
|
+ lmStudioBaseUrl: "http://localhost:1234",
|
|
|
}
|
|
}
|
|
|
handler = new LmStudioHandler(mockOptions)
|
|
handler = new LmStudioHandler(mockOptions)
|
|
|
mockCreate.mockClear()
|
|
mockCreate.mockClear()
|
|
|
|
|
+ mockGetLMStudioModels.mockClear()
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
describe("constructor", () => {
|
|
describe("constructor", () => {
|
|
@@ -156,12 +180,71 @@ describe("LmStudioHandler", () => {
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
describe("getModel", () => {
|
|
describe("getModel", () => {
|
|
|
- it("should return model info", () => {
|
|
|
|
|
|
|
+ it("should return default model info when no models fetched", () => {
|
|
|
const modelInfo = handler.getModel()
|
|
const modelInfo = handler.getModel()
|
|
|
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
|
expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
|
|
expect(modelInfo.info).toBeDefined()
|
|
expect(modelInfo.info).toBeDefined()
|
|
|
expect(modelInfo.info.maxTokens).toBe(-1)
|
|
expect(modelInfo.info.maxTokens).toBe(-1)
|
|
|
expect(modelInfo.info.contextWindow).toBe(128_000)
|
|
expect(modelInfo.info.contextWindow).toBe(128_000)
|
|
|
})
|
|
})
|
|
|
|
|
+
|
|
|
|
|
+ it("should return fetched model info when available", async () => {
|
|
|
|
|
+ // Mock the fetched models
|
|
|
|
|
+ mockGetLMStudioModels.mockResolvedValueOnce({
|
|
|
|
|
+ "local-model": mockModelInfo,
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ await handler.fetchModel()
|
|
|
|
|
+ const modelInfo = handler.getModel()
|
|
|
|
|
+
|
|
|
|
|
+ expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
|
|
|
|
+ expect(modelInfo.info).toEqual(mockModelInfo)
|
|
|
|
|
+ expect(modelInfo.info.contextWindow).toBe(32768)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ it("should fallback to default when model not found in fetched models", async () => {
|
|
|
|
|
+ // Mock fetched models without our target model
|
|
|
|
|
+ mockGetLMStudioModels.mockResolvedValueOnce({
|
|
|
|
|
+ "other-model": mockModelInfo,
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ await handler.fetchModel()
|
|
|
|
|
+ const modelInfo = handler.getModel()
|
|
|
|
|
+
|
|
|
|
|
+ expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
|
|
|
|
|
+ expect(modelInfo.info.maxTokens).toBe(-1)
|
|
|
|
|
+ expect(modelInfo.info.contextWindow).toBe(128_000)
|
|
|
|
|
+ })
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ describe("fetchModel", () => {
|
|
|
|
|
+ it("should fetch models successfully", async () => {
|
|
|
|
|
+ mockGetLMStudioModels.mockResolvedValueOnce({
|
|
|
|
|
+ "local-model": mockModelInfo,
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ const result = await handler.fetchModel()
|
|
|
|
|
+
|
|
|
|
|
+ expect(mockGetLMStudioModels).toHaveBeenCalledWith(mockOptions.lmStudioBaseUrl)
|
|
|
|
|
+ expect(result.id).toBe(mockOptions.lmStudioModelId)
|
|
|
|
|
+ expect(result.info).toEqual(mockModelInfo)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ it("should handle fetch errors gracefully", async () => {
|
|
|
|
|
+ const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {})
|
|
|
|
|
+ mockGetLMStudioModels.mockRejectedValueOnce(new Error("Connection failed"))
|
|
|
|
|
+
|
|
|
|
|
+ const result = await handler.fetchModel()
|
|
|
|
|
+
|
|
|
|
|
+ expect(consoleSpy).toHaveBeenCalledWith(
|
|
|
|
|
+ "Failed to fetch LM Studio models, using defaults:",
|
|
|
|
|
+ expect.any(Error),
|
|
|
|
|
+ )
|
|
|
|
|
+ expect(result.id).toBe(mockOptions.lmStudioModelId)
|
|
|
|
|
+ expect(result.info.maxTokens).toBe(-1)
|
|
|
|
|
+ expect(result.info.contextWindow).toBe(128_000)
|
|
|
|
|
+
|
|
|
|
|
+ consoleSpy.mockRestore()
|
|
|
|
|
+ })
|
|
|
})
|
|
})
|
|
|
})
|
|
})
|