Просмотр исходного кода

Make ollama models info transport work like lmstudio (#7679)

ItsOnlyBinary 3 месяцев назад
Родитель
Сommit
76c6745649

+ 42 - 0
src/core/webview/__tests__/webviewMessageHandler.spec.ts

@@ -136,6 +136,48 @@ describe("webviewMessageHandler - requestLmStudioModels", () => {
 	})
 })
 
+describe("webviewMessageHandler - requestOllamaModels", () => {
+	beforeEach(() => {
+		vi.clearAllMocks()
+		mockClineProvider.getState = vi.fn().mockResolvedValue({
+			apiConfiguration: {
+				ollamaModelId: "model-1",
+				ollamaBaseUrl: "http://localhost:1234",
+			},
+		})
+	})
+
+	it("successfully fetches models from Ollama", async () => {
+		const mockModels: ModelRecord = {
+			"model-1": {
+				maxTokens: 4096,
+				contextWindow: 8192,
+				supportsPromptCache: false,
+				description: "Test model 1",
+			},
+			"model-2": {
+				maxTokens: 8192,
+				contextWindow: 16384,
+				supportsPromptCache: false,
+				description: "Test model 2",
+			},
+		}
+
+		mockGetModels.mockResolvedValue(mockModels)
+
+		await webviewMessageHandler(mockClineProvider, {
+			type: "requestOllamaModels",
+		})
+
+		expect(mockGetModels).toHaveBeenCalledWith({ provider: "ollama", baseUrl: "http://localhost:1234" })
+
+		expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
+			type: "ollamaModels",
+			ollamaModels: mockModels,
+		})
+	})
+})
+
 describe("webviewMessageHandler - requestRouterModels", () => {
 	beforeEach(() => {
 		vi.clearAllMocks()

+ 2 - 2
src/core/webview/webviewMessageHandler.ts

@@ -797,7 +797,7 @@ export const webviewMessageHandler = async (
 					if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
 						provider.postMessageToWebview({
 							type: "ollamaModels",
-							ollamaModels: Object.keys(result.value.models),
+							ollamaModels: result.value.models,
 						})
 					} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
 						provider.postMessageToWebview({
@@ -842,7 +842,7 @@ export const webviewMessageHandler = async (
 				if (Object.keys(ollamaModels).length > 0) {
 					provider.postMessageToWebview({
 						type: "ollamaModels",
-						ollamaModels: Object.keys(ollamaModels),
+						ollamaModels: ollamaModels,
 					})
 				}
 			} catch (error) {

+ 1 - 1
src/shared/ExtensionMessage.ts

@@ -148,7 +148,7 @@ export interface ExtensionMessage {
 	clineMessage?: ClineMessage
 	routerModels?: RouterModels
 	openAiModels?: string[]
-	ollamaModels?: string[]
+	ollamaModels?: ModelRecord
 	lmStudioModels?: ModelRecord
 	vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
 	huggingFaceModels?: Array<{

+ 7 - 8
webview-ui/src/components/settings/providers/Ollama.tsx

@@ -11,6 +11,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
 import { vscode } from "@src/utils/vscode"
 
 import { inputEventTransform } from "../transforms"
+import { ModelRecord } from "@roo/api"
 
 type OllamaProps = {
 	apiConfiguration: ProviderSettings
@@ -20,7 +21,7 @@ type OllamaProps = {
 export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaProps) => {
 	const { t } = useAppTranslation()
 
-	const [ollamaModels, setOllamaModels] = useState<string[]>([])
+	const [ollamaModels, setOllamaModels] = useState<ModelRecord>({})
 	const routerModels = useRouterModels()
 
 	const handleInputChange = useCallback(
@@ -40,7 +41,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
 		switch (message.type) {
 			case "ollamaModels":
 				{
-					const newModels = message.ollamaModels ?? []
+					const newModels = message.ollamaModels ?? {}
 					setOllamaModels(newModels)
 				}
 				break
@@ -61,7 +62,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
 		if (!selectedModel) return false
 
 		// Check if model exists in local ollama models
-		if (ollamaModels.length > 0 && ollamaModels.includes(selectedModel)) {
+		if (Object.keys(ollamaModels).length > 0 && selectedModel in ollamaModels) {
 			return false // Model is available locally
 		}
 
@@ -116,15 +117,13 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
 					</div>
 				</div>
 			)}
-			{ollamaModels.length > 0 && (
+			{Object.keys(ollamaModels).length > 0 && (
 				<VSCodeRadioGroup
 					value={
-						ollamaModels.includes(apiConfiguration?.ollamaModelId || "")
-							? apiConfiguration?.ollamaModelId
-							: ""
+						(apiConfiguration?.ollamaModelId || "") in ollamaModels ? apiConfiguration?.ollamaModelId : ""
 					}
 					onChange={handleInputChange("ollamaModelId")}>
-					{ollamaModels.map((model) => (
+					{Object.keys(ollamaModels).map((model) => (
 						<VSCodeRadio key={model} value={model} checked={apiConfiguration?.ollamaModelId === model}>
 							{model}
 						</VSCodeRadio>

+ 39 - 0
webview-ui/src/components/ui/hooks/useOllamaModels.ts

@@ -0,0 +1,39 @@
+import { useQuery } from "@tanstack/react-query"
+
+import { ModelRecord } from "@roo/api"
+import { ExtensionMessage } from "@roo/ExtensionMessage"
+
+import { vscode } from "@src/utils/vscode"
+
+const getOllamaModels = async () =>
+	new Promise<ModelRecord>((resolve, reject) => {
+		const cleanup = () => {
+			window.removeEventListener("message", handler)
+		}
+
+		const timeout = setTimeout(() => {
+			cleanup()
+			reject(new Error("Ollama models request timed out"))
+		}, 10000)
+
+		const handler = (event: MessageEvent) => {
+			const message: ExtensionMessage = event.data
+
+			if (message.type === "ollamaModels") {
+				clearTimeout(timeout)
+				cleanup()
+
+				if (message.ollamaModels) {
+					resolve(message.ollamaModels)
+				} else {
+					reject(new Error("No Ollama models in response"))
+				}
+			}
+		}
+
+		window.addEventListener("message", handler)
+		vscode.postMessage({ type: "requestOllamaModels" })
+	})
+
+export const useOllamaModels = (modelId?: string) =>
+	useQuery({ queryKey: ["ollamaModels"], queryFn: () => (modelId ? getOllamaModels() : {}) })

+ 8 - 1
webview-ui/src/components/ui/hooks/useSelectedModel.ts

@@ -64,19 +64,23 @@ import type { ModelRecord, RouterModels } from "@roo/api"
 import { useRouterModels } from "./useRouterModels"
 import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
 import { useLmStudioModels } from "./useLmStudioModels"
+import { useOllamaModels } from "./useOllamaModels"
 
 export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
 	const provider = apiConfiguration?.apiProvider || "anthropic"
 	const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined
 	const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined
+	const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined
 
 	const routerModels = useRouterModels()
 	const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
 	const lmStudioModels = useLmStudioModels(lmStudioModelId)
+	const ollamaModels = useOllamaModels(ollamaModelId)
 
 	const { id, info } =
 		apiConfiguration &&
 		(typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") &&
+		(typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") &&
 		typeof routerModels.data !== "undefined" &&
 		typeof openRouterModelProviders.data !== "undefined"
 			? getSelectedModel({
@@ -85,6 +89,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
 					routerModels: routerModels.data,
 					openRouterModelProviders: openRouterModelProviders.data,
 					lmStudioModels: lmStudioModels.data,
+					ollamaModels: ollamaModels.data,
 				})
 			: { id: anthropicDefaultModelId, info: undefined }
 
@@ -109,12 +114,14 @@ function getSelectedModel({
 	routerModels,
 	openRouterModelProviders,
 	lmStudioModels,
+	ollamaModels,
 }: {
 	provider: ProviderName
 	apiConfiguration: ProviderSettings
 	routerModels: RouterModels
 	openRouterModelProviders: Record<string, ModelInfo>
 	lmStudioModels: ModelRecord | undefined
+	ollamaModels: ModelRecord | undefined
 }): { id: string; info: ModelInfo | undefined } {
 	// the `undefined` case are used to show the invalid selection to prevent
 	// users from seeing the default model if their selection is invalid
@@ -255,7 +262,7 @@ function getSelectedModel({
 		}
 		case "ollama": {
 			const id = apiConfiguration.ollamaModelId ?? ""
-			const info = routerModels.ollama && routerModels.ollama[id]
+			const info = ollamaModels && ollamaModels[apiConfiguration.ollamaModelId!]
 			return {
 				id,
 				info: info || undefined,