Просмотр исходного кода

feat: optimize router model fetching with single-provider filtering (#8956)

Daniel 2 месяцев назад
Родитель
Сommit
31de103e7d

+ 167 - 0
src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts

@@ -0,0 +1,167 @@
+import { describe, it, expect, vi, beforeEach } from "vitest"
+import { webviewMessageHandler } from "../webviewMessageHandler"
+import type { ClineProvider } from "../ClineProvider"
+
+// Mock vscode (minimal)
+vi.mock("vscode", () => ({
+	window: {
+		showErrorMessage: vi.fn(),
+		showWarningMessage: vi.fn(),
+		showInformationMessage: vi.fn(),
+	},
+	workspace: {
+		workspaceFolders: undefined,
+		getConfiguration: vi.fn(() => ({
+			get: vi.fn(),
+			update: vi.fn(),
+		})),
+	},
+	env: {
+		clipboard: { writeText: vi.fn() },
+		openExternal: vi.fn(),
+	},
+	commands: {
+		executeCommand: vi.fn(),
+	},
+	Uri: {
+		parse: vi.fn((s: string) => ({ toString: () => s })),
+		file: vi.fn((p: string) => ({ fsPath: p })),
+	},
+	ConfigurationTarget: {
+		Global: 1,
+		Workspace: 2,
+		WorkspaceFolder: 3,
+	},
+}))
+
+// Mock modelCache getModels/flushModels used by the handler
+const getModelsMock = vi.fn()
+vi.mock("../../../api/providers/fetchers/modelCache", () => ({
+	getModels: (...args: any[]) => getModelsMock(...args),
+	flushModels: vi.fn(),
+}))
+
+describe("webviewMessageHandler - requestRouterModels provider filter", () => {
+	let mockProvider: ClineProvider & {
+		postMessageToWebview: ReturnType<typeof vi.fn>
+		getState: ReturnType<typeof vi.fn>
+		contextProxy: any
+		log: ReturnType<typeof vi.fn>
+	}
+
+	beforeEach(() => {
+		vi.clearAllMocks()
+
+		mockProvider = {
+			// Only methods used by this code path
+			postMessageToWebview: vi.fn(),
+			getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }),
+			contextProxy: {
+				getValue: vi.fn(),
+				setValue: vi.fn(),
+				globalStorageUri: { fsPath: "/mock/storage" },
+			},
+			log: vi.fn(),
+		} as any
+
+		// Default mock: return distinct model maps per provider so we can verify keys
+		getModelsMock.mockImplementation(async (options: any) => {
+			switch (options?.provider) {
+				case "roo":
+					return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } }
+				case "openrouter":
+					return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } }
+				case "requesty":
+					return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } }
+				case "deepinfra":
+					return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } }
+				case "glama":
+					return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } }
+				case "unbound":
+					return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } }
+				case "vercel-ai-gateway":
+					return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } }
+				case "io-intelligence":
+					return { "io/model": { contextWindow: 8192, supportsPromptCache: false } }
+				case "litellm":
+					return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } }
+				default:
+					return {}
+			}
+		})
+	})
+
+	it("fetches only requested provider when values.provider is present ('roo')", async () => {
+		await webviewMessageHandler(
+			mockProvider as any,
+			{
+				type: "requestRouterModels",
+				values: { provider: "roo" },
+			} as any,
+		)
+
+		// Should post a single routerModels message
+		expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
+			expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }),
+		)
+
+		const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
+			(c: any[]) => c[0]?.type === "routerModels",
+		)
+		expect(call).toBeTruthy()
+		const payload = call[0]
+		const routerModels = payload.routerModels as Record<string, Record<string, any>>
+
+		// Only "roo" key should be present
+		const keys = Object.keys(routerModels)
+		expect(keys).toEqual(["roo"])
+		expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet")
+
+		// getModels should have been called exactly once for roo
+		const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
+		expect(providersCalled).toEqual(["roo"])
+	})
+
+	it("defaults to aggregate fetching when no provider filter is sent", async () => {
+		await webviewMessageHandler(
+			mockProvider as any,
+			{
+				type: "requestRouterModels",
+			} as any,
+		)
+
+		const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
+			(c: any[]) => c[0]?.type === "routerModels",
+		)
+		expect(call).toBeTruthy()
+		const routerModels = call[0].routerModels as Record<string, Record<string, any>>
+
+		// Aggregate handler initializes many known routers - ensure a few expected keys exist
+		expect(routerModels).toHaveProperty("openrouter")
+		expect(routerModels).toHaveProperty("roo")
+		expect(routerModels).toHaveProperty("requesty")
+	})
+
+	it("supports filtering another single provider ('openrouter')", async () => {
+		await webviewMessageHandler(
+			mockProvider as any,
+			{
+				type: "requestRouterModels",
+				values: { provider: "openrouter" },
+			} as any,
+		)
+
+		const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
+			(c: any[]) => c[0]?.type === "routerModels",
+		)
+		expect(call).toBeTruthy()
+		const routerModels = call[0].routerModels as Record<string, Record<string, any>>
+		const keys = Object.keys(routerModels)
+
+		expect(keys).toEqual(["openrouter"])
+		expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5")
+
+		const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
+		expect(providersCalled).toEqual(["openrouter"])
+	})
+})

+ 39 - 38
src/core/webview/webviewMessageHandler.ts

@@ -757,20 +757,26 @@ export const webviewMessageHandler = async (
 		case "requestRouterModels":
 			const { apiConfiguration } = await provider.getState()
 
-			const routerModels: Record<RouterName, ModelRecord> = {
-				openrouter: {},
-				"vercel-ai-gateway": {},
-				huggingface: {},
-				litellm: {},
-				deepinfra: {},
-				"io-intelligence": {},
-				requesty: {},
-				unbound: {},
-				glama: {},
-				ollama: {},
-				lmstudio: {},
-				roo: {},
-			}
+			// Optional single provider filter from webview
+			const requestedProvider = message?.values?.provider
+			const providerFilter = requestedProvider ? toRouterName(requestedProvider) : undefined
+
+			const routerModels: Record<RouterName, ModelRecord> = providerFilter
+				? ({} as Record<RouterName, ModelRecord>)
+				: {
+						openrouter: {},
+						"vercel-ai-gateway": {},
+						huggingface: {},
+						litellm: {},
+						deepinfra: {},
+						"io-intelligence": {},
+						requesty: {},
+						unbound: {},
+						glama: {},
+						ollama: {},
+						lmstudio: {},
+						roo: {},
+					}
 
 			const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
 				try {
@@ -785,7 +791,8 @@ export const webviewMessageHandler = async (
 				}
 			}
 
-			const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [
+			// Base candidates (only those handled by this aggregate fetcher)
+			const candidates: { key: RouterName; options: GetModelsOptions }[] = [
 				{ key: "openrouter", options: { provider: "openrouter" } },
 				{
 					key: "requesty",
@@ -818,29 +825,30 @@ export const webviewMessageHandler = async (
 				},
 			]
 
-			// Add IO Intelligence if API key is provided.
-			const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey
-
-			if (ioIntelligenceApiKey) {
-				modelFetchPromises.push({
+			// IO Intelligence is conditional on api key
+			if (apiConfiguration.ioIntelligenceApiKey) {
+				candidates.push({
 					key: "io-intelligence",
-					options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey },
+					options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey },
 				})
 			}
 
-			// Don't fetch Ollama and LM Studio models by default anymore.
-			// They have their own specific handlers: requestOllamaModels and requestLmStudioModels.
-
+			// LiteLLM is conditional on baseUrl+apiKey
 			const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey
 			const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl
 
 			if (litellmApiKey && litellmBaseUrl) {
-				modelFetchPromises.push({
+				candidates.push({
 					key: "litellm",
 					options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },
 				})
 			}
 
+			// Apply single provider filter if specified
+			const modelFetchPromises = providerFilter
+				? candidates.filter(({ key }) => key === providerFilter)
+				: candidates
+
 			const results = await Promise.allSettled(
 				modelFetchPromises.map(async ({ key, options }) => {
 					const models = await safeGetModels(options)
@@ -854,18 +862,7 @@ export const webviewMessageHandler = async (
 				if (result.status === "fulfilled") {
 					routerModels[routerName] = result.value.models
 
-					// Ollama and LM Studio settings pages still need these events.
-					if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
-						provider.postMessageToWebview({
-							type: "ollamaModels",
-							ollamaModels: result.value.models,
-						})
-					} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
-						provider.postMessageToWebview({
-							type: "lmStudioModels",
-							lmStudioModels: result.value.models,
-						})
-					}
+					// Ollama and LM Studio settings pages still need these events. They are not fetched here.
 				} else {
 					// Handle rejection: Post a specific error message for this provider.
 					const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)
@@ -882,7 +879,11 @@ export const webviewMessageHandler = async (
 				}
 			})
 
-			provider.postMessageToWebview({ type: "routerModels", routerModels })
+			provider.postMessageToWebview({
+				type: "routerModels",
+				routerModels,
+				values: providerFilter ? { provider: requestedProvider } : undefined,
+			})
 			break
 		case "requestOllamaModels": {
 			// Specific handler for Ollama models only.

+ 9 - 6
webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts

@@ -291,7 +291,7 @@ describe("useSelectedModel", () => {
 	})
 
 	describe("loading and error states", () => {
-		it("should return loading state when router models are loading", () => {
+		it("should NOT set loading when router models are loading but provider is static (anthropic)", () => {
 			mockUseRouterModels.mockReturnValue({
 				data: undefined,
 				isLoading: true,
@@ -307,10 +307,11 @@ describe("useSelectedModel", () => {
 			const wrapper = createWrapper()
 			const { result } = renderHook(() => useSelectedModel(), { wrapper })
 
-			expect(result.current.isLoading).toBe(true)
+			// With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false
+			expect(result.current.isLoading).toBe(false)
 		})
 
-		it("should return loading state when open router model providers are loading", () => {
+		it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => {
 			mockUseRouterModels.mockReturnValue({
 				data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} },
 				isLoading: false,
@@ -326,10 +327,11 @@ describe("useSelectedModel", () => {
 			const wrapper = createWrapper()
 			const { result } = renderHook(() => useSelectedModel(), { wrapper })
 
-			expect(result.current.isLoading).toBe(true)
+			// With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false
+			expect(result.current.isLoading).toBe(false)
 		})
 
-		it("should return error state when either hook has an error", () => {
+		it("should NOT set error when hooks error but provider is static (anthropic)", () => {
 			mockUseRouterModels.mockReturnValue({
 				data: undefined,
 				isLoading: false,
@@ -345,7 +347,8 @@ describe("useSelectedModel", () => {
 			const wrapper = createWrapper()
 			const { result } = renderHook(() => useSelectedModel(), { wrapper })
 
-			expect(result.current.isError).toBe(true)
+			// Error from gated routerModels should not bubble for static provider default
+			expect(result.current.isError).toBe(false)
 		})
 	})
 

+ 27 - 3
webview-ui/src/components/ui/hooks/useRouterModels.ts

@@ -5,7 +5,12 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"
 
 import { vscode } from "@src/utils/vscode"
 
-const getRouterModels = async () =>
+type UseRouterModelsOptions = {
+	provider?: string // single provider filter (e.g. "roo")
+	enabled?: boolean // gate fetching entirely
+}
+
+const getRouterModels = async (provider?: string) =>
 	new Promise<RouterModels>((resolve, reject) => {
 		const cleanup = () => {
 			window.removeEventListener("message", handler)
@@ -20,6 +25,14 @@ const getRouterModels = async () =>
 			const message: ExtensionMessage = event.data
 
 			if (message.type === "routerModels") {
+				const msgProvider = message?.values?.provider as string | undefined
+
+				// Verify response matches request
+				if (provider !== msgProvider) {
+					// Not our response; ignore and wait for the matching one
+					return
+				}
+
 				clearTimeout(timeout)
 				cleanup()
 
@@ -32,7 +45,18 @@ const getRouterModels = async () =>
 		}
 
 		window.addEventListener("message", handler)
-		vscode.postMessage({ type: "requestRouterModels" })
+		if (provider) {
+			vscode.postMessage({ type: "requestRouterModels", values: { provider } })
+		} else {
+			vscode.postMessage({ type: "requestRouterModels" })
+		}
 	})
 
-export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels })
+export const useRouterModels = (opts: UseRouterModelsOptions = {}) => {
+	const provider = opts.provider || undefined
+	return useQuery({
+		queryKey: ["routerModels", provider || "all"],
+		queryFn: () => getRouterModels(provider),
+		enabled: opts.enabled !== false,
+	})
+}

+ 33 - 16
webview-ui/src/components/ui/hooks/useSelectedModel.ts

@@ -58,6 +58,7 @@ import {
 	vercelAiGatewayDefaultModelId,
 	BEDROCK_1M_CONTEXT_MODEL_IDS,
 	deepInfraDefaultModelId,
+	isDynamicProvider,
 } from "@roo-code/types"
 
 import type { ModelRecord, RouterModels } from "@roo/api"
@@ -73,24 +74,38 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
 	const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined
 	const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined
 
-	const routerModels = useRouterModels()
+	// Only fetch router models for dynamic providers
+	const shouldFetchRouterModels = isDynamicProvider(provider)
+	const routerModels = useRouterModels({
+		provider: shouldFetchRouterModels ? provider : undefined,
+		enabled: shouldFetchRouterModels,
+	})
+
 	const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
 	const lmStudioModels = useLmStudioModels(lmStudioModelId)
 	const ollamaModels = useOllamaModels(ollamaModelId)
 
+	// Compute readiness only for the data actually needed for the selected provider
+	const needRouterModels = shouldFetchRouterModels
+	const needOpenRouterProviders = provider === "openrouter"
+	const needLmStudio = typeof lmStudioModelId !== "undefined"
+	const needOllama = typeof ollamaModelId !== "undefined"
+
+	const isReady =
+		(!needLmStudio || typeof lmStudioModels.data !== "undefined") &&
+		(!needOllama || typeof ollamaModels.data !== "undefined") &&
+		(!needRouterModels || typeof routerModels.data !== "undefined") &&
+		(!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined")
+
 	const { id, info } =
-		apiConfiguration &&
-		(typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") &&
-		(typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") &&
-		typeof routerModels.data !== "undefined" &&
-		typeof openRouterModelProviders.data !== "undefined"
+		apiConfiguration && isReady
 			? getSelectedModel({
 					provider,
 					apiConfiguration,
-					routerModels: routerModels.data,
-					openRouterModelProviders: openRouterModelProviders.data,
-					lmStudioModels: lmStudioModels.data,
-					ollamaModels: ollamaModels.data,
+					routerModels: (routerModels.data || {}) as RouterModels,
+					openRouterModelProviders: (openRouterModelProviders.data || {}) as Record<string, ModelInfo>,
+					lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined,
+					ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined,
 				})
 			: { id: anthropicDefaultModelId, info: undefined }
 
@@ -99,13 +114,15 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
 		id,
 		info,
 		isLoading:
-			routerModels.isLoading ||
-			openRouterModelProviders.isLoading ||
-			(apiConfiguration?.lmStudioModelId && lmStudioModels!.isLoading),
+			(needRouterModels && routerModels.isLoading) ||
+			(needOpenRouterProviders && openRouterModelProviders.isLoading) ||
+			(needLmStudio && lmStudioModels!.isLoading) ||
+			(needOllama && ollamaModels!.isLoading),
 		isError:
-			routerModels.isError ||
-			openRouterModelProviders.isError ||
-			(apiConfiguration?.lmStudioModelId && lmStudioModels!.isError),
+			(needRouterModels && routerModels.isError) ||
+			(needOpenRouterProviders && openRouterModelProviders.isError) ||
+			(needLmStudio && lmStudioModels!.isError) ||
+			(needOllama && ollamaModels!.isError),
 	}
 }