Browse Source

Fixed the issue that Model ID cannot be saved

System233 10 months ago
parent
commit
f065f039be

+ 167 - 27
webview-ui/src/components/settings/ApiOptions.tsx

@@ -38,18 +38,14 @@ import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage"
 
 import { vscode } from "../../utils/vscode"
 import VSCodeButtonLink from "../common/VSCodeButtonLink"
-import { OpenRouterModelPicker } from "./OpenRouterModelPicker"
-import OpenAiModelPicker from "./OpenAiModelPicker"
-import { GlamaModelPicker } from "./GlamaModelPicker"
-import { UnboundModelPicker } from "./UnboundModelPicker"
 import { ModelInfoView } from "./ModelInfoView"
 import { DROPDOWN_Z_INDEX } from "./styles"
-import { RequestyModelPicker } from "./RequestyModelPicker"
+import { ModelPicker } from "./ModelPicker"
 import { TemperatureControl } from "./TemperatureControl"
 
 interface ApiOptionsProps {
 	uriScheme: string | undefined
-	apiConfiguration: ApiConfiguration | undefined
+	apiConfiguration: ApiConfiguration
 	setApiConfigurationField: <K extends keyof ApiConfiguration>(field: K, value: ApiConfiguration[K]) => void
 	apiErrorMessage?: string
 	modelIdErrorMessage?: string
@@ -67,6 +63,20 @@ const ApiOptions = ({
 	const [ollamaModels, setOllamaModels] = useState<string[]>([])
 	const [lmStudioModels, setLmStudioModels] = useState<string[]>([])
 	const [vsCodeLmModels, setVsCodeLmModels] = useState<vscodemodels.LanguageModelChatSelector[]>([])
+	const [openRouterModels, setOpenRouterModels] = useState<Record<string, ModelInfo>>({
+		[openRouterDefaultModelId]: openRouterDefaultModelInfo,
+	})
+	const [glamaModels, setGlamaModels] = useState<Record<string, ModelInfo>>({
+		[glamaDefaultModelId]: glamaDefaultModelInfo,
+	})
+	const [unboundModels, setUnboundModels] = useState<Record<string, ModelInfo>>({
+		[unboundDefaultModelId]: unboundDefaultModelInfo,
+	})
+	const [requestyModels, setRequestyModels] = useState<Record<string, ModelInfo>>({
+		[requestyDefaultModelId]: requestyDefaultModelInfo,
+	})
+	const [openAiModels, setOpenAiModels] = useState<Record<string, ModelInfo> | null>(null)
+
 	const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl)
 	const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion)
 	const [openRouterBaseUrlSelected, setOpenRouterBaseUrlSelected] = useState(!!apiConfiguration?.openRouterBaseUrl)
@@ -104,24 +114,93 @@ const ApiOptions = ({
 				vscode.postMessage({ type: "requestLmStudioModels", text: apiConfiguration?.lmStudioBaseUrl })
 			} else if (selectedProvider === "vscode-lm") {
 				vscode.postMessage({ type: "requestVsCodeLmModels" })
+			} else if (selectedProvider === "openai") {
+				vscode.postMessage({
+					type: "refreshOpenAiModels",
+					values: {
+						baseUrl: apiConfiguration?.openAiBaseUrl,
+						apiKey: apiConfiguration?.openAiApiKey,
+					},
+				})
+			} else if (selectedProvider === "openrouter") {
+				vscode.postMessage({ type: "refreshOpenRouterModels", values: {} })
+			} else if (selectedProvider === "glama") {
+				vscode.postMessage({ type: "refreshGlamaModels", values: {} })
+			} else if (selectedProvider === "requesty") {
+				vscode.postMessage({
+					type: "refreshRequestyModels",
+					values: {
+						apiKey: apiConfiguration?.requestyApiKey,
+					},
+				})
 			}
 		},
 		250,
-		[selectedProvider, apiConfiguration?.ollamaBaseUrl, apiConfiguration?.lmStudioBaseUrl],
+		[
+			selectedProvider,
+			apiConfiguration?.ollamaBaseUrl,
+			apiConfiguration?.lmStudioBaseUrl,
+			apiConfiguration?.openAiBaseUrl,
+			apiConfiguration?.openAiApiKey,
+			apiConfiguration?.requestyApiKey,
+		],
 	)
 
 	const handleMessage = useCallback((event: MessageEvent) => {
 		const message: ExtensionMessage = event.data
-
-		if (message.type === "ollamaModels" && Array.isArray(message.ollamaModels)) {
-			const newModels = message.ollamaModels
-			setOllamaModels(newModels)
-		} else if (message.type === "lmStudioModels" && Array.isArray(message.lmStudioModels)) {
-			const newModels = message.lmStudioModels
-			setLmStudioModels(newModels)
-		} else if (message.type === "vsCodeLmModels" && Array.isArray(message.vsCodeLmModels)) {
-			const newModels = message.vsCodeLmModels
-			setVsCodeLmModels(newModels)
+		switch (message.type) {
+			case "ollamaModels":
+				{
+					const newModels = message.ollamaModels ?? []
+					setOllamaModels(newModels)
+				}
+				break
+			case "lmStudioModels":
+				{
+					const newModels = message.lmStudioModels ?? []
+					setLmStudioModels(newModels)
+				}
+				break
+			case "vsCodeLmModels":
+				{
+					const newModels = message.vsCodeLmModels ?? []
+					setVsCodeLmModels(newModels)
+				}
+				break
+			case "glamaModels": {
+				const updatedModels = message.glamaModels ?? {}
+				setGlamaModels({
+					[glamaDefaultModelId]: glamaDefaultModelInfo, // in case the extension sent a model list without the default model
+					...updatedModels,
+				})
+				break
+			}
+			case "openRouterModels": {
+				const updatedModels = message.openRouterModels ?? {}
+				setOpenRouterModels({
+					[openRouterDefaultModelId]: openRouterDefaultModelInfo, // in case the extension sent a model list without the default model
+					...updatedModels,
+				})
+				break
+			}
+			case "openAiModels": {
+				const updatedModels = message.openAiModels ?? []
+				setOpenAiModels(Object.fromEntries(updatedModels.map((item) => [item, openAiModelInfoSaneDefaults])))
+				break
+			}
+			case "unboundModels": {
+				const updatedModels = message.unboundModels ?? {}
+				setUnboundModels(updatedModels)
+				break
+			}
+			case "requestyModels": {
+				const updatedModels = message.requestyModels ?? {}
+				setRequestyModels({
+					[requestyDefaultModelId]: requestyDefaultModelInfo, // in case the extension sent a model list without the default model
+					...updatedModels,
+				})
+				break
+			}
 		}
 	}, [])
 
@@ -616,7 +695,17 @@ const ApiOptions = ({
 						placeholder="Enter API Key...">
 						<span style={{ fontWeight: 500 }}>API Key</span>
 					</VSCodeTextField>
-					<OpenAiModelPicker />
+					<ModelPicker
+						apiConfiguration={apiConfiguration}
+						modelIdKey="openAiModelId"
+						modelInfoKey="openAiCustomModelInfo"
+						serviceName="OpenAI"
+						serviceUrl="https://platform.openai.com"
+						recommendedModel="gpt-4-turbo-preview"
+						models={openAiModels}
+						setApiConfigurationField={setApiConfigurationField}
+						defaultModelInfo={openAiModelInfoSaneDefaults}
+					/>
 					<div style={{ display: "flex", alignItems: "center" }}>
 						<Checkbox
 							checked={apiConfiguration?.openAiStreamingEnabled ?? true}
@@ -704,7 +793,7 @@ const ApiOptions = ({
 												})(),
 											}}
 											title="Maximum number of tokens the model can generate in a single response"
-											onChange={handleInputChange("openAiCustomModelInfo", (e) => {
+											onInput={handleInputChange("openAiCustomModelInfo", (e) => {
 												const value = parseInt((e.target as HTMLInputElement).value)
 												return {
 													...(apiConfiguration?.openAiCustomModelInfo ||
@@ -751,7 +840,7 @@ const ApiOptions = ({
 												})(),
 											}}
 											title="Total number of tokens (input + output) the model can process in a single request"
-											onChange={handleInputChange("openAiCustomModelInfo", (e) => {
+											onInput={handleInputChange("openAiCustomModelInfo", (e) => {
 												const value = (e.target as HTMLInputElement).value
 												const parsed = parseInt(value)
 												return {
@@ -897,7 +986,7 @@ const ApiOptions = ({
 														: "var(--vscode-errorForeground)"
 												})(),
 											}}
-											onChange={handleInputChange("openAiCustomModelInfo", (e) => {
+											onInput={handleInputChange("openAiCustomModelInfo", (e) => {
 												const value = (e.target as HTMLInputElement).value
 												const parsed = parseInt(value)
 												return {
@@ -942,7 +1031,7 @@ const ApiOptions = ({
 														: "var(--vscode-errorForeground)"
 												})(),
 											}}
-											onChange={handleInputChange("openAiCustomModelInfo", (e) => {
+											onInput={handleInputChange("openAiCustomModelInfo", (e) => {
 												const value = (e.target as HTMLInputElement).value
 												const parsed = parseInt(value)
 												return {
@@ -1011,6 +1100,7 @@ const ApiOptions = ({
 						placeholder={"e.g. meta-llama-3.1-8b-instruct"}>
 						<span style={{ fontWeight: 500 }}>Model ID</span>
 					</VSCodeTextField>
+
 					{lmStudioModels.length > 0 && (
 						<VSCodeRadioGroup
 							value={
@@ -1220,7 +1310,18 @@ const ApiOptions = ({
 						}}>
 						This key is stored locally and only used to make API requests from this extension.
 					</p>
-					<UnboundModelPicker />
+					<ModelPicker
+						apiConfiguration={apiConfiguration}
+						defaultModelId={unboundDefaultModelId}
+						defaultModelInfo={unboundDefaultModelInfo}
+						models={unboundModels}
+						modelInfoKey="unboundModelInfo"
+						modelIdKey="unboundModelId"
+						serviceName="Unbound"
+						serviceUrl="https://api.getunbound.ai/models"
+						recommendedModel={unboundDefaultModelId}
+						setApiConfigurationField={setApiConfigurationField}
+					/>
 				</div>
 			)}
 
@@ -1236,9 +1337,49 @@ const ApiOptions = ({
 				</p>
 			)}
 
-			{selectedProvider === "glama" && <GlamaModelPicker />}
-			{selectedProvider === "openrouter" && <OpenRouterModelPicker />}
-			{selectedProvider === "requesty" && <RequestyModelPicker />}
+			{selectedProvider === "glama" && (
+				<ModelPicker
+					apiConfiguration={apiConfiguration ?? {}}
+					defaultModelId={glamaDefaultModelId}
+					defaultModelInfo={glamaDefaultModelInfo}
+					models={glamaModels}
+					modelInfoKey="glamaModelInfo"
+					modelIdKey="glamaModelId"
+					serviceName="Glama"
+					serviceUrl="https://glama.ai/models"
+					recommendedModel="anthropic/claude-3-7-sonnet"
+					setApiConfigurationField={setApiConfigurationField}
+				/>
+			)}
+
+			{selectedProvider === "openrouter" && (
+				<ModelPicker
+					apiConfiguration={apiConfiguration}
+					setApiConfigurationField={setApiConfigurationField}
+					defaultModelId={openRouterDefaultModelId}
+					defaultModelInfo={openRouterDefaultModelInfo}
+					models={openRouterModels}
+					modelIdKey="openRouterModelId"
+					modelInfoKey="openRouterModelInfo"
+					serviceName="OpenRouter"
+					serviceUrl="https://openrouter.ai/models"
+					recommendedModel="anthropic/claude-3.7-sonnet"
+				/>
+			)}
+			{selectedProvider === "requesty" && (
+				<ModelPicker
+					apiConfiguration={apiConfiguration}
+					setApiConfigurationField={setApiConfigurationField}
+					defaultModelId={requestyDefaultModelId}
+					defaultModelInfo={requestyDefaultModelInfo}
+					models={requestyModels}
+					modelIdKey="requestyModelId"
+					modelInfoKey="requestyModelInfo"
+					serviceName="Requesty"
+					serviceUrl="https://requesty.ai"
+					recommendedModel="anthropic/claude-3-7-sonnet-latest"
+				/>
+			)}
 
 			{selectedProvider !== "glama" &&
 				selectedProvider !== "openrouter" &&
@@ -1260,7 +1401,6 @@ const ApiOptions = ({
 							{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
 							{selectedProvider === "mistral" && createDropdown(mistralModels)}
 						</div>
-
 						<ModelInfoView
 							selectedModelId={selectedModelId}
 							modelInfo={selectedModelInfo}

+ 0 - 15
webview-ui/src/components/settings/GlamaModelPicker.tsx

@@ -1,15 +0,0 @@
-import { ModelPicker } from "./ModelPicker"
-import { glamaDefaultModelId } from "../../../../src/shared/api"
-
-export const GlamaModelPicker = () => (
-	<ModelPicker
-		defaultModelId={glamaDefaultModelId}
-		modelsKey="glamaModels"
-		configKey="glamaModelId"
-		infoKey="glamaModelInfo"
-		refreshMessageType="refreshGlamaModels"
-		serviceName="Glama"
-		serviceUrl="https://glama.ai/models"
-		recommendedModel="anthropic/claude-3-7-sonnet"
-	/>
-)

+ 55 - 172
webview-ui/src/components/settings/ModelPicker.tsx

@@ -1,185 +1,90 @@
 import { VSCodeLink } from "@vscode/webview-ui-toolkit/react"
-import debounce from "debounce"
-import { useMemo, useState, useCallback, useEffect, useRef } from "react"
-import { useMount } from "react-use"
-import { CaretSortIcon, CheckIcon } from "@radix-ui/react-icons"
+import { useMemo, useState, useCallback, useEffect } from "react"
 
-import { cn } from "@/lib/utils"
-import {
-	Button,
-	Command,
-	CommandEmpty,
-	CommandGroup,
-	CommandInput,
-	CommandItem,
-	CommandList,
-	Popover,
-	PopoverContent,
-	PopoverTrigger,
-} from "@/components/ui"
-
-import { useExtensionState } from "../../context/ExtensionStateContext"
-import { vscode } from "../../utils/vscode"
 import { normalizeApiConfiguration } from "./ApiOptions"
 import { ModelInfoView } from "./ModelInfoView"
-
-type ModelProvider = "glama" | "openRouter" | "unbound" | "requesty" | "openAi"
-
-type ModelKeys<T extends ModelProvider> = `${T}Models`
-type ConfigKeys<T extends ModelProvider> = `${T}ModelId`
-type InfoKeys<T extends ModelProvider> = `${T}ModelInfo`
-type RefreshMessageType<T extends ModelProvider> = `refresh${Capitalize<T>}Models`
-
-interface ModelPickerProps<T extends ModelProvider = ModelProvider> {
-	defaultModelId: string
-	modelsKey: ModelKeys<T>
-	configKey: ConfigKeys<T>
-	infoKey: InfoKeys<T>
-	refreshMessageType: RefreshMessageType<T>
-	refreshValues?: Record<string, any>
+import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api"
+import { Combobox, ComboboxContent, ComboboxEmpty, ComboboxInput, ComboboxItem } from "../ui/combobox"
+
+type ExtractType<T> = NonNullable<
+	{ [K in keyof ApiConfiguration]: Required<ApiConfiguration>[K] extends T ? K : never }[keyof ApiConfiguration]
+>
+
+type ModelIdKeys = NonNullable<
+	{ [K in keyof ApiConfiguration]: K extends `${string}ModelId` ? K : never }[keyof ApiConfiguration]
+>
+declare module "react" {
+	interface CSSProperties {
+		// Allow CSS variables
+		[key: `--${string}`]: string | number
+	}
+}
+interface ModelPickerProps {
+	defaultModelId?: string
+	models: Record<string, ModelInfo> | null
+	modelIdKey: ModelIdKeys
+	modelInfoKey: ExtractType<ModelInfo>
 	serviceName: string
 	serviceUrl: string
 	recommendedModel: string
-	allowCustomModel?: boolean
+	apiConfiguration: ApiConfiguration
+	setApiConfigurationField: <K extends keyof ApiConfiguration>(field: K, value: ApiConfiguration[K]) => void
+	defaultModelInfo?: ModelInfo
 }
 
 export const ModelPicker = ({
 	defaultModelId,
-	modelsKey,
-	configKey,
-	infoKey,
-	refreshMessageType,
-	refreshValues,
+	models,
+	modelIdKey,
+	modelInfoKey,
 	serviceName,
 	serviceUrl,
 	recommendedModel,
-	allowCustomModel = false,
+	apiConfiguration,
+	setApiConfigurationField,
+	defaultModelInfo,
 }: ModelPickerProps) => {
-	const [customModelId, setCustomModelId] = useState("")
-	const [isCustomModel, setIsCustomModel] = useState(false)
-	const [open, setOpen] = useState(false)
-	const [value, setValue] = useState(defaultModelId)
 	const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)
-	const prevRefreshValuesRef = useRef<Record<string, any> | undefined>()
-
-	const { apiConfiguration, [modelsKey]: models, onUpdateApiConfig, setApiConfiguration } = useExtensionState()
 
-	const modelIds = useMemo(
-		() => (Array.isArray(models) ? models : Object.keys(models)).sort((a, b) => a.localeCompare(b)),
-		[models],
-	)
+	const modelIds = useMemo(() => Object.keys(models ?? {}).sort((a, b) => a.localeCompare(b)), [models])
 
 	const { selectedModelId, selectedModelInfo } = useMemo(
 		() => normalizeApiConfiguration(apiConfiguration),
 		[apiConfiguration],
 	)
-
-	const onSelectCustomModel = useCallback(
-		(modelId: string) => {
-			setCustomModelId(modelId)
-			const modelInfo = { id: modelId }
-			const apiConfig = { ...apiConfiguration, [configKey]: modelId, [infoKey]: modelInfo }
-			setApiConfiguration(apiConfig)
-			onUpdateApiConfig(apiConfig)
-			setValue(modelId)
-			setOpen(false)
-			setIsCustomModel(false)
-		},
-		[apiConfiguration, configKey, infoKey, onUpdateApiConfig, setApiConfiguration],
-	)
-
 	const onSelect = useCallback(
 		(modelId: string) => {
-			const modelInfo = Array.isArray(models)
-				? { id: modelId } // For OpenAI models which are just strings
-				: models[modelId] // For other models that have full info objects
-			const apiConfig = { ...apiConfiguration, [configKey]: modelId, [infoKey]: modelInfo }
-			setApiConfiguration(apiConfig)
-			onUpdateApiConfig(apiConfig)
-			setValue(modelId)
-			setOpen(false)
+			const modelInfo = models?.[modelId]
+			setApiConfigurationField(modelIdKey, modelId)
+			setApiConfigurationField(modelInfoKey, modelInfo ?? defaultModelInfo)
 		},
-		[apiConfiguration, configKey, infoKey, models, onUpdateApiConfig, setApiConfiguration],
+		[modelIdKey, modelInfoKey, models, setApiConfigurationField, defaultModelInfo],
 	)
-
-	const debouncedRefreshModels = useMemo(() => {
-		return debounce(() => {
-			const message = refreshValues
-				? { type: refreshMessageType, values: refreshValues }
-				: { type: refreshMessageType }
-			vscode.postMessage(message)
-		}, 100)
-	}, [refreshMessageType, refreshValues])
-
-	useMount(() => {
-		debouncedRefreshModels()
-		return () => debouncedRefreshModels.clear()
-	})
-
 	useEffect(() => {
-		if (!refreshValues) {
-			prevRefreshValuesRef.current = undefined
-			return
-		}
-
-		// Check if all values in refreshValues are truthy
-		if (Object.values(refreshValues).some((value) => !value)) {
-			prevRefreshValuesRef.current = undefined
-			return
-		}
-
-		// Compare with previous values
-		const prevValues = prevRefreshValuesRef.current
-		if (prevValues && JSON.stringify(prevValues) === JSON.stringify(refreshValues)) {
-			return
+		if (apiConfiguration[modelIdKey] == null && defaultModelId) {
+			onSelect(defaultModelId)
 		}
-
-		prevRefreshValuesRef.current = refreshValues
-		debouncedRefreshModels()
-	}, [debouncedRefreshModels, refreshValues])
-
-	useEffect(() => setValue(selectedModelId), [selectedModelId])
+	}, [apiConfiguration, defaultModelId, modelIdKey, onSelect])
 
 	return (
 		<>
 			<div className="font-semibold">Model</div>
-			<Popover open={open} onOpenChange={setOpen}>
-				<PopoverTrigger asChild>
-					<Button variant="combobox" role="combobox" aria-expanded={open} className="w-full justify-between">
-						{value ?? "Select model..."}
-						<CaretSortIcon className="opacity-50" />
-					</Button>
-				</PopoverTrigger>
-				<PopoverContent align="start" className="p-0">
-					<Command>
-						<CommandInput placeholder="Search model..." className="h-9" />
-						<CommandList>
-							<CommandEmpty>No model found.</CommandEmpty>
-							<CommandGroup>
-								{modelIds.map((model) => (
-									<CommandItem key={model} value={model} onSelect={onSelect}>
-										{model}
-										<CheckIcon
-											className={cn("ml-auto", value === model ? "opacity-100" : "opacity-0")}
-										/>
-									</CommandItem>
-								))}
-							</CommandGroup>
-							{allowCustomModel && (
-								<CommandGroup heading="Custom">
-									<CommandItem
-										onSelect={() => {
-											setIsCustomModel(true)
-											setOpen(false)
-										}}>
-										+ Add custom model
-									</CommandItem>
-								</CommandGroup>
-							)}
-						</CommandList>
-					</Command>
-				</PopoverContent>
-			</Popover>
+			<Combobox type="single" inputValue={apiConfiguration[modelIdKey]} onInputValueChange={onSelect}>
+				<ComboboxInput
+					className="border-vscode-errorForeground tefat"
+					placeholder="Search model..."
+					data-testid="model-input"
+				/>
+				<ComboboxContent>
+					<ComboboxEmpty>No model found.</ComboboxEmpty>
+					{modelIds.map((model) => (
+						<ComboboxItem key={model} value={model}>
+							{model}
+						</ComboboxItem>
+					))}
+				</ComboboxContent>
+			</Combobox>
+
 			{selectedModelId && selectedModelInfo && (
 				<ModelInfoView
 					selectedModelId={selectedModelId}
@@ -197,28 +102,6 @@ export const ModelPicker = ({
 				<VSCodeLink onClick={() => onSelect(recommendedModel)}>{recommendedModel}.</VSCodeLink>
 				You can also try searching "free" for no-cost options currently available.
 			</p>
-			{allowCustomModel && isCustomModel && (
-				<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
-					<div className="bg-[var(--vscode-editor-background)] p-6 rounded-lg w-96">
-						<h3 className="text-lg font-semibold mb-4">Add Custom Model</h3>
-						<input
-							type="text"
-							className="w-full p-2 mb-4 bg-[var(--vscode-input-background)] text-[var(--vscode-input-foreground)] border border-[var(--vscode-input-border)] rounded"
-							placeholder="Enter model ID"
-							value={customModelId}
-							onChange={(e) => setCustomModelId(e.target.value)}
-						/>
-						<div className="flex justify-end gap-2">
-							<Button variant="secondary" onClick={() => setIsCustomModel(false)}>
-								Cancel
-							</Button>
-							<Button onClick={() => onSelectCustomModel(customModelId)} disabled={!customModelId.trim()}>
-								Add
-							</Button>
-						</div>
-					</div>
-				</div>
-			)}
 		</>
 	)
 }

+ 0 - 27
webview-ui/src/components/settings/OpenAiModelPicker.tsx

@@ -1,27 +0,0 @@
-import React from "react"
-import { useExtensionState } from "../../context/ExtensionStateContext"
-import { ModelPicker } from "./ModelPicker"
-
-const OpenAiModelPicker: React.FC = () => {
-	const { apiConfiguration } = useExtensionState()
-
-	return (
-		<ModelPicker
-			defaultModelId={apiConfiguration?.openAiModelId || ""}
-			modelsKey="openAiModels"
-			configKey="openAiModelId"
-			infoKey="openAiModelInfo"
-			refreshMessageType="refreshOpenAiModels"
-			refreshValues={{
-				baseUrl: apiConfiguration?.openAiBaseUrl,
-				apiKey: apiConfiguration?.openAiApiKey,
-			}}
-			serviceName="OpenAI"
-			serviceUrl="https://platform.openai.com"
-			recommendedModel="gpt-4-turbo-preview"
-			allowCustomModel={true}
-		/>
-	)
-}
-
-export default OpenAiModelPicker

+ 0 - 15
webview-ui/src/components/settings/OpenRouterModelPicker.tsx

@@ -1,15 +0,0 @@
-import { ModelPicker } from "./ModelPicker"
-import { openRouterDefaultModelId } from "../../../../src/shared/api"
-
-export const OpenRouterModelPicker = () => (
-	<ModelPicker
-		defaultModelId={openRouterDefaultModelId}
-		modelsKey="openRouterModels"
-		configKey="openRouterModelId"
-		infoKey="openRouterModelInfo"
-		refreshMessageType="refreshOpenRouterModels"
-		serviceName="OpenRouter"
-		serviceUrl="https://openrouter.ai/models"
-		recommendedModel="anthropic/claude-3.7-sonnet"
-	/>
-)

+ 0 - 22
webview-ui/src/components/settings/RequestyModelPicker.tsx

@@ -1,22 +0,0 @@
-import { ModelPicker } from "./ModelPicker"
-import { requestyDefaultModelId } from "../../../../src/shared/api"
-import { useExtensionState } from "@/context/ExtensionStateContext"
-
-export const RequestyModelPicker = () => {
-	const { apiConfiguration } = useExtensionState()
-	return (
-		<ModelPicker
-			defaultModelId={requestyDefaultModelId}
-			modelsKey="requestyModels"
-			configKey="requestyModelId"
-			infoKey="requestyModelInfo"
-			refreshMessageType="refreshRequestyModels"
-			refreshValues={{
-				apiKey: apiConfiguration?.requestyApiKey,
-			}}
-			serviceName="Requesty"
-			serviceUrl="https://requesty.ai"
-			recommendedModel="anthropic/claude-3-7-sonnet-latest"
-		/>
-	)
-}

+ 4 - 2
webview-ui/src/components/settings/SettingsView.tsx

@@ -1,4 +1,4 @@
-import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react"
+import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from "react"
 import { VSCodeButton, VSCodeCheckbox, VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
 import { Dropdown, type DropdownOption } from "vscrui"
 
@@ -45,7 +45,6 @@ const SettingsView = forwardRef<SettingsViewRef, SettingsViewProps>(({ onDone },
 	// TODO: Reduce WebviewMessage/ExtensionState complexity
 	const { currentApiConfigName } = extensionState
 	const {
-		apiConfiguration,
 		alwaysAllowReadOnly,
 		allowedCommands,
 		alwaysAllowBrowser,
@@ -69,6 +68,9 @@ const SettingsView = forwardRef<SettingsViewRef, SettingsViewProps>(({ onDone },
 		terminalOutputLineLimit,
 		writeDelayMs,
 	} = cachedState
+	
+	//Make sure apiConfiguration is initialized and managed by SettingsView
+	const apiConfiguration = useMemo(() => cachedState.apiConfiguration ?? {}, [cachedState.apiConfiguration])
 
 	useEffect(() => {
 		// Update only when currentApiConfigName is changed

+ 0 - 15
webview-ui/src/components/settings/UnboundModelPicker.tsx

@@ -1,15 +0,0 @@
-import { ModelPicker } from "./ModelPicker"
-import { unboundDefaultModelId } from "../../../../src/shared/api"
-
-export const UnboundModelPicker = () => (
-	<ModelPicker
-		defaultModelId={unboundDefaultModelId}
-		modelsKey="unboundModels"
-		configKey="unboundModelId"
-		infoKey="unboundModelInfo"
-		refreshMessageType="refreshUnboundModels"
-		serviceName="Unbound"
-		serviceUrl="https://api.getunbound.ai/models"
-		recommendedModel={unboundDefaultModelId}
-	/>
-)

+ 27 - 32
webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx

@@ -3,7 +3,6 @@
 import { screen, fireEvent, render } from "@testing-library/react"
 import { act } from "react"
 import { ModelPicker } from "../ModelPicker"
-import { useExtensionState } from "../../../context/ExtensionStateContext"
 
 jest.mock("../../../context/ExtensionStateContext", () => ({
 	useExtensionState: jest.fn(),
@@ -20,36 +19,40 @@ global.ResizeObserver = MockResizeObserver
 Element.prototype.scrollIntoView = jest.fn()
 
 describe("ModelPicker", () => {
-	const mockOnUpdateApiConfig = jest.fn()
-	const mockSetApiConfiguration = jest.fn()
-
+	const mockSetApiConfigurationField = jest.fn()
+	const modelInfo = {
+		maxTokens: 8192,
+		contextWindow: 200_000,
+		supportsImages: true,
+		supportsComputerUse: true,
+		supportsPromptCache: true,
+		inputPrice: 3.0,
+		outputPrice: 15.0,
+		cacheWritesPrice: 3.75,
+		cacheReadsPrice: 0.3,
+	}
+	const mockModels = {
+		model1: { name: "Model 1", description: "Test model 1", ...modelInfo },
+		model2: { name: "Model 2", description: "Test model 2", ...modelInfo },
+	}
 	const defaultProps = {
+		apiConfiguration: {},
 		defaultModelId: "model1",
-		modelsKey: "glamaModels" as const,
-		configKey: "glamaModelId" as const,
-		infoKey: "glamaModelInfo" as const,
-		refreshMessageType: "refreshGlamaModels" as const,
+		defaultModelInfo: modelInfo,
+		modelIdKey: "glamaModelId" as const,
+		modelInfoKey: "glamaModelInfo" as const,
 		serviceName: "Test Service",
 		serviceUrl: "https://test.service",
 		recommendedModel: "recommended-model",
-	}
-
-	const mockModels = {
-		model1: { name: "Model 1", description: "Test model 1" },
-		model2: { name: "Model 2", description: "Test model 2" },
+		models: mockModels,
+		setApiConfigurationField: mockSetApiConfigurationField,
 	}
 
 	beforeEach(() => {
 		jest.clearAllMocks()
-		;(useExtensionState as jest.Mock).mockReturnValue({
-			apiConfiguration: {},
-			setApiConfiguration: mockSetApiConfiguration,
-			glamaModels: mockModels,
-			onUpdateApiConfig: mockOnUpdateApiConfig,
-		})
 	})
 
-	it("calls onUpdateApiConfig when a model is selected", async () => {
+	it("calls setApiConfigurationField when a model is selected", async () => {
 		await act(async () => {
 			render(<ModelPicker {...defaultProps} />)
 		})
@@ -67,20 +70,12 @@ describe("ModelPicker", () => {
 
 		await act(async () => {
 			// Find and click the model item by its value.
-			const modelItem = screen.getByRole("option", { name: "model2" })
-			fireEvent.click(modelItem)
+			const modelItem = screen.getByTestId("model-input")
+			fireEvent.input(modelItem, { target: { value: "model2" } })
 		})
 
 		// Verify the API config was updated.
-		expect(mockSetApiConfiguration).toHaveBeenCalledWith({
-			glamaModelId: "model2",
-			glamaModelInfo: mockModels["model2"],
-		})
-
-		// Verify onUpdateApiConfig was called with the new config.
-		expect(mockOnUpdateApiConfig).toHaveBeenCalledWith({
-			glamaModelId: "model2",
-			glamaModelInfo: mockModels["model2"],
-		})
+		expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, "model2")
+		expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelInfoKey, mockModels.model2)
 	})
 })

+ 4 - 9
webview-ui/src/utils/validate.ts

@@ -1,9 +1,4 @@
-import {
-	ApiConfiguration,
-	glamaDefaultModelId,
-	openRouterDefaultModelId,
-	unboundDefaultModelId,
-} from "../../../src/shared/api"
+import { ApiConfiguration } from "../../../src/shared/api"
 import { ModelInfo } from "../../../src/shared/api"
 export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): string | undefined {
 	if (apiConfiguration) {
@@ -86,7 +81,7 @@ export function validateModelId(
 	if (apiConfiguration) {
 		switch (apiConfiguration.apiProvider) {
 			case "glama":
-				const glamaModelId = apiConfiguration.glamaModelId || glamaDefaultModelId // in case the user hasn't changed the model id, it will be undefined by default
+				const glamaModelId = apiConfiguration.glamaModelId
 				if (!glamaModelId) {
 					return "You must provide a model ID."
 				}
@@ -96,7 +91,7 @@ export function validateModelId(
 				}
 				break
 			case "openrouter":
-				const modelId = apiConfiguration.openRouterModelId || openRouterDefaultModelId // in case the user hasn't changed the model id, it will be undefined by default
+				const modelId = apiConfiguration.openRouterModelId
 				if (!modelId) {
 					return "You must provide a model ID."
 				}
@@ -106,7 +101,7 @@ export function validateModelId(
 				}
 				break
 			case "unbound":
-				const unboundModelId = apiConfiguration.unboundModelId || unboundDefaultModelId
+				const unboundModelId = apiConfiguration.unboundModelId
 				if (!unboundModelId) {
 					return "You must provide a model ID."
 				}