Browse Source

fix: resolve LM Studio context length detection (#5075) (#5076)

Daniel 8 months ago
parent
commit
5bf7d006a2
2 changed files with 104 additions and 4 deletions
  1. 85 2
      src/api/providers/__tests__/lmstudio.spec.ts
  2. 19 2
      src/api/providers/lm-studio.ts

+ 85 - 2
src/api/providers/__tests__/lmstudio.spec.ts

@@ -58,23 +58,47 @@ vi.mock("openai", () => {
 	}
 	}
 })
 })
 
 
+// Mock LM Studio fetcher
+vi.mock("../fetchers/lmstudio", () => ({
+	getLMStudioModels: vi.fn(),
+}))
+
 import type { Anthropic } from "@anthropic-ai/sdk"
 import type { Anthropic } from "@anthropic-ai/sdk"
+import type { ModelInfo } from "@roo-code/types"
 
 
 import { LmStudioHandler } from "../lm-studio"
 import { LmStudioHandler } from "../lm-studio"
 import type { ApiHandlerOptions } from "../../../shared/api"
 import type { ApiHandlerOptions } from "../../../shared/api"
+import { getLMStudioModels } from "../fetchers/lmstudio"
+
+// Get the mocked function
+const mockGetLMStudioModels = vi.mocked(getLMStudioModels)
 
 
 describe("LmStudioHandler", () => {
 describe("LmStudioHandler", () => {
 	let handler: LmStudioHandler
 	let handler: LmStudioHandler
 	let mockOptions: ApiHandlerOptions
 	let mockOptions: ApiHandlerOptions
 
 
+	const mockModelInfo: ModelInfo = {
+		maxTokens: 8192,
+		contextWindow: 32768,
+		supportsImages: false,
+		supportsComputerUse: false,
+		supportsPromptCache: true,
+		inputPrice: 0,
+		outputPrice: 0,
+		cacheWritesPrice: 0,
+		cacheReadsPrice: 0,
+		description: "Test Model - local-model",
+	}
+
 	beforeEach(() => {
 	beforeEach(() => {
 		mockOptions = {
 		mockOptions = {
 			apiModelId: "local-model",
 			apiModelId: "local-model",
 			lmStudioModelId: "local-model",
 			lmStudioModelId: "local-model",
-			lmStudioBaseUrl: "http://localhost:1234/v1",
+			lmStudioBaseUrl: "http://localhost:1234",
 		}
 		}
 		handler = new LmStudioHandler(mockOptions)
 		handler = new LmStudioHandler(mockOptions)
 		mockCreate.mockClear()
 		mockCreate.mockClear()
+		mockGetLMStudioModels.mockClear()
 	})
 	})
 
 
 	describe("constructor", () => {
 	describe("constructor", () => {
@@ -156,12 +180,71 @@ describe("LmStudioHandler", () => {
 	})
 	})
 
 
 	describe("getModel", () => {
 	describe("getModel", () => {
-		it("should return model info", () => {
+		it("should return default model info when no models fetched", () => {
 			const modelInfo = handler.getModel()
 			const modelInfo = handler.getModel()
 			expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
 			expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
 			expect(modelInfo.info).toBeDefined()
 			expect(modelInfo.info).toBeDefined()
 			expect(modelInfo.info.maxTokens).toBe(-1)
 			expect(modelInfo.info.maxTokens).toBe(-1)
 			expect(modelInfo.info.contextWindow).toBe(128_000)
 			expect(modelInfo.info.contextWindow).toBe(128_000)
 		})
 		})
+
+		it("should return fetched model info when available", async () => {
+			// Mock the fetched models
+			mockGetLMStudioModels.mockResolvedValueOnce({
+				"local-model": mockModelInfo,
+			})
+
+			await handler.fetchModel()
+			const modelInfo = handler.getModel()
+
+			expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
+			expect(modelInfo.info).toEqual(mockModelInfo)
+			expect(modelInfo.info.contextWindow).toBe(32768)
+		})
+
+		it("should fallback to default when model not found in fetched models", async () => {
+			// Mock fetched models without our target model
+			mockGetLMStudioModels.mockResolvedValueOnce({
+				"other-model": mockModelInfo,
+			})
+
+			await handler.fetchModel()
+			const modelInfo = handler.getModel()
+
+			expect(modelInfo.id).toBe(mockOptions.lmStudioModelId)
+			expect(modelInfo.info.maxTokens).toBe(-1)
+			expect(modelInfo.info.contextWindow).toBe(128_000)
+		})
+	})
+
+	describe("fetchModel", () => {
+		it("should fetch models successfully", async () => {
+			mockGetLMStudioModels.mockResolvedValueOnce({
+				"local-model": mockModelInfo,
+			})
+
+			const result = await handler.fetchModel()
+
+			expect(mockGetLMStudioModels).toHaveBeenCalledWith(mockOptions.lmStudioBaseUrl)
+			expect(result.id).toBe(mockOptions.lmStudioModelId)
+			expect(result.info).toEqual(mockModelInfo)
+		})
+
+		it("should handle fetch errors gracefully", async () => {
+			const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {})
+			mockGetLMStudioModels.mockRejectedValueOnce(new Error("Connection failed"))
+
+			const result = await handler.fetchModel()
+
+			expect(consoleSpy).toHaveBeenCalledWith(
+				"Failed to fetch LM Studio models, using defaults:",
+				expect.any(Error),
+			)
+			expect(result.id).toBe(mockOptions.lmStudioModelId)
+			expect(result.info.maxTokens).toBe(-1)
+			expect(result.info.contextWindow).toBe(128_000)
+
+			consoleSpy.mockRestore()
+		})
 	})
 	})
 })
 })

+ 19 - 2
src/api/providers/lm-studio.ts

@@ -13,10 +13,12 @@ import { ApiStream } from "../transform/stream"
 
 
 import { BaseProvider } from "./base-provider"
 import { BaseProvider } from "./base-provider"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+import { getLMStudioModels } from "./fetchers/lmstudio"
 
 
 export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
 export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: ApiHandlerOptions
 	protected options: ApiHandlerOptions
 	private client: OpenAI
 	private client: OpenAI
+	private models: Record<string, ModelInfo> = {}
 
 
 	constructor(options: ApiHandlerOptions) {
 	constructor(options: ApiHandlerOptions) {
 		super()
 		super()
@@ -130,10 +132,25 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
 		}
 		}
 	}
 	}
 
 
+	public async fetchModel() {
+		try {
+			this.models = await getLMStudioModels(this.options.lmStudioBaseUrl)
+		} catch (error) {
+			console.warn("Failed to fetch LM Studio models, using defaults:", error)
+			this.models = {}
+		}
+		return this.getModel()
+	}
+
 	override getModel(): { id: string; info: ModelInfo } {
 	override getModel(): { id: string; info: ModelInfo } {
+		const id = this.options.lmStudioModelId || ""
+
+		// Try to get the actual model info from fetched models
+		const info = this.models[id] || openAiModelInfoSaneDefaults
+
 		return {
 		return {
-			id: this.options.lmStudioModelId || "",
-			info: openAiModelInfoSaneDefaults,
+			id,
+			info,
 		}
 		}
 	}
 	}