Browse Source

Merge pull request #848 from websentry-ai/pm/unbound-fetch-models

Fetches Unbound models from api
Matt Rubens 1 year ago
parent
commit
fffc9f7311

+ 21 - 2
src/api/providers/__tests__/unbound.test.ts

@@ -73,6 +73,15 @@ describe("UnboundHandler", () => {
 		mockOptions = {
 			apiModelId: "anthropic/claude-3-5-sonnet-20241022",
 			unboundApiKey: "test-api-key",
+			unboundModelId: "anthropic/claude-3-5-sonnet-20241022",
+			unboundModelInfo: {
+				description: "Anthropic's Claude 3 Sonnet model",
+				maxTokens: 8192,
+				contextWindow: 200000,
+				supportsPromptCache: true,
+				inputPrice: 0.01,
+				outputPrice: 0.02,
+			},
 		}
 		handler = new UnboundHandler(mockOptions)
 		mockCreate.mockClear()
@@ -205,6 +214,15 @@ describe("UnboundHandler", () => {
 			const nonAnthropicOptions = {
 				apiModelId: "openai/gpt-4o",
 				unboundApiKey: "test-key",
+				unboundModelId: "openai/gpt-4o",
+				unboundModelInfo: {
+					description: "OpenAI's GPT-4",
+					maxTokens: undefined,
+					contextWindow: 128000,
+					supportsPromptCache: true,
+					inputPrice: 0.01,
+					outputPrice: 0.03,
+				},
 			}
 			const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions)
 
@@ -230,10 +248,11 @@ describe("UnboundHandler", () => {
 		it("should return default model when invalid model provided", () => {
 			const handlerWithInvalidModel = new UnboundHandler({
 				...mockOptions,
-				apiModelId: "invalid/model",
+				unboundModelId: "invalid/model",
+				unboundModelInfo: undefined,
 			})
 			const modelInfo = handlerWithInvalidModel.getModel()
-			expect(modelInfo.id).toBe("openai/gpt-4o") // Default model
+			expect(modelInfo.id).toBe("anthropic/claude-3-5-sonnet-20241022") // Default model
 			expect(modelInfo.info).toBeDefined()
 		})
 	})

+ 7 - 7
src/api/providers/unbound.ts

@@ -1,7 +1,7 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
 import { ApiHandler, SingleCompletionHandler } from "../"
-import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api"
+import { ApiHandlerOptions, ModelInfo, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
 
@@ -129,15 +129,15 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler {
 		}
 	}
 
-	getModel(): { id: UnboundModelId; info: ModelInfo } {
-		const modelId = this.options.apiModelId
-		if (modelId && modelId in unboundModels) {
-			const id = modelId as UnboundModelId
-			return { id, info: unboundModels[id] }
+	getModel(): { id: string; info: ModelInfo } {
+		const modelId = this.options.unboundModelId
+		const modelInfo = this.options.unboundModelInfo
+		if (modelId && modelInfo) {
+			return { id: modelId, info: modelInfo }
 		}
 		return {
 			id: unboundDefaultModelId,
-			info: unboundModels[unboundDefaultModelId],
+			info: unboundDefaultModelInfo,
 		}
 	}
 

+ 77 - 14
src/core/webview/ClineProvider.ts

@@ -122,6 +122,7 @@ type GlobalStateKey =
 	| "autoApprovalEnabled"
 	| "customModes" // Array of custom modes
 	| "unboundModelId"
+	| "unboundModelInfo"
 
 export const GlobalFileNames = {
 	apiConversationHistory: "api_conversation_history.json",
@@ -129,6 +130,7 @@ export const GlobalFileNames = {
 	glamaModels: "glama_models.json",
 	openRouterModels: "openrouter_models.json",
 	mcpSettings: "cline_mcp_settings.json",
+	unboundModels: "unbound_models.json",
 }
 
 export class ClineProvider implements vscode.WebviewViewProvider {
@@ -665,6 +667,24 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 							}
 						})
 
+						this.readUnboundModels().then((unboundModels) => {
+							if (unboundModels) {
+								this.postMessageToWebview({ type: "unboundModels", unboundModels })
+							}
+						})
+						this.refreshUnboundModels().then(async (unboundModels) => {
+							if (unboundModels) {
+								const { apiConfiguration } = await this.getState()
+								if (apiConfiguration?.unboundModelId) {
+									await this.updateGlobalState(
+										"unboundModelInfo",
+										unboundModels[apiConfiguration.unboundModelId],
+									)
+									await this.postStateToWebview()
+								}
+							}
+						})
+
 						this.configManager
 							.listConfig()
 							.then(async (listApiConfig) => {
@@ -824,6 +844,9 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 							this.postMessageToWebview({ type: "openAiModels", openAiModels })
 						}
 						break
+					case "refreshUnboundModels":
+						await this.refreshUnboundModels()
+						break
 					case "openImage":
 						openImage(message.text!)
 						break
@@ -1563,6 +1586,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			mistralApiKey,
 			unboundApiKey,
 			unboundModelId,
+			unboundModelInfo,
 		} = apiConfiguration
 		await this.updateGlobalState("apiProvider", apiProvider)
 		await this.updateGlobalState("apiModelId", apiModelId)
@@ -1603,6 +1627,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		await this.storeSecret("mistralApiKey", mistralApiKey)
 		await this.storeSecret("unboundApiKey", unboundApiKey)
 		await this.updateGlobalState("unboundModelId", unboundModelId)
+		await this.updateGlobalState("unboundModelInfo", unboundModelInfo)
 		if (this.cline) {
 			this.cline.api = buildApiHandler(apiConfiguration)
 		}
@@ -1808,16 +1833,20 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		// await this.postMessageToWebview({ type: "action", action: "settingsButtonClicked" }) // bad ux if user is on welcome
 	}
 
-	async readGlamaModels(): Promise<Record<string, ModelInfo> | undefined> {
-		const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels)
-		const fileExists = await fileExistsAtPath(glamaModelsFilePath)
+	private async readModelsFromCache(filename: string): Promise<Record<string, ModelInfo> | undefined> {
+		const filePath = path.join(await this.ensureCacheDirectoryExists(), filename)
+		const fileExists = await fileExistsAtPath(filePath)
 		if (fileExists) {
-			const fileContents = await fs.readFile(glamaModelsFilePath, "utf8")
+			const fileContents = await fs.readFile(filePath, "utf8")
 			return JSON.parse(fileContents)
 		}
 		return undefined
 	}
 
+	async readGlamaModels(): Promise<Record<string, ModelInfo> | undefined> {
+		return this.readModelsFromCache(GlobalFileNames.glamaModels)
+	}
+
 	async refreshGlamaModels() {
 		const glamaModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.glamaModels)
 
@@ -1893,16 +1922,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 	}
 
 	async readOpenRouterModels(): Promise<Record<string, ModelInfo> | undefined> {
-		const openRouterModelsFilePath = path.join(
-			await this.ensureCacheDirectoryExists(),
-			GlobalFileNames.openRouterModels,
-		)
-		const fileExists = await fileExistsAtPath(openRouterModelsFilePath)
-		if (fileExists) {
-			const fileContents = await fs.readFile(openRouterModelsFilePath, "utf8")
-			return JSON.parse(fileContents)
-		}
-		return undefined
+		return this.readModelsFromCache(GlobalFileNames.openRouterModels)
 	}
 
 	async refreshOpenRouterModels() {
@@ -2017,6 +2037,46 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		return models
 	}
 
+	async readUnboundModels(): Promise<Record<string, ModelInfo> | undefined> {
+		return this.readModelsFromCache(GlobalFileNames.unboundModels)
+	}
+
+	async refreshUnboundModels() {
+		const unboundModelsFilePath = path.join(await this.ensureCacheDirectoryExists(), GlobalFileNames.unboundModels)
+
+		const models: Record<string, ModelInfo> = {}
+		try {
+			const response = await axios.get("https://api.getunbound.ai/models")
+
+			if (response.data) {
+				const rawModels: Record<string, any> = response.data
+
+				for (const [modelId, model] of Object.entries(rawModels)) {
+					models[modelId] = {
+						maxTokens: model.maxTokens ? parseInt(model.maxTokens) : undefined,
+						contextWindow: model.contextWindow ? parseInt(model.contextWindow) : 0,
+						supportsImages: model.supportsImages ?? false,
+						supportsPromptCache: model.supportsPromptCaching ?? false,
+						supportsComputerUse: model.supportsComputerUse ?? false,
+						inputPrice: model.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined,
+						outputPrice: model.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined,
+						cacheWritesPrice: model.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined,
+						cacheReadsPrice: model.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined,
+					}
+				}
+			}
+			await fs.writeFile(unboundModelsFilePath, JSON.stringify(models))
+			this.outputChannel.appendLine(`Unbound models fetched and saved: ${JSON.stringify(models, null, 2)}`)
+		} catch (error) {
+			this.outputChannel.appendLine(
+				`Error fetching Unbound models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`,
+			)
+		}
+
+		await this.postMessageToWebview({ type: "unboundModels", unboundModels: models })
+		return models
+	}
+
 	// Task history
 
 	async getTaskWithId(id: string): Promise<{
@@ -2330,6 +2390,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			experiments,
 			unboundApiKey,
 			unboundModelId,
+			unboundModelInfo,
 		] = await Promise.all([
 			this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
 			this.getGlobalState("apiModelId") as Promise<string | undefined>,
@@ -2405,6 +2466,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("experiments") as Promise<Record<ExperimentId, boolean> | undefined>,
 			this.getSecret("unboundApiKey") as Promise<string | undefined>,
 			this.getGlobalState("unboundModelId") as Promise<string | undefined>,
+			this.getGlobalState("unboundModelInfo") as Promise<ModelInfo | undefined>,
 		])
 
 		let apiProvider: ApiProvider
@@ -2462,6 +2524,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				vsCodeLmModelSelector,
 				unboundApiKey,
 				unboundModelId,
+				unboundModelInfo,
 			},
 			lastShownAnnouncementId,
 			customInstructions,

+ 3 - 0
src/shared/ExtensionMessage.ts

@@ -42,6 +42,8 @@ export interface ExtensionMessage {
 		| "autoApprovalEnabled"
 		| "updateCustomMode"
 		| "deleteCustomMode"
+		| "unboundModels"
+		| "refreshUnboundModels"
 		| "currentCheckpointUpdated"
 	text?: string
 	action?:
@@ -67,6 +69,7 @@ export interface ExtensionMessage {
 	glamaModels?: Record<string, ModelInfo>
 	openRouterModels?: Record<string, ModelInfo>
 	openAiModels?: string[]
+	unboundModels?: Record<string, ModelInfo>
 	mcpServers?: McpServer[]
 	commits?: GitCommit[]
 	listApiConfig?: ApiConfigMeta[]

+ 1 - 0
src/shared/WebviewMessage.ts

@@ -42,6 +42,7 @@ export interface WebviewMessage {
 		| "refreshGlamaModels"
 		| "refreshOpenRouterModels"
 		| "refreshOpenAiModels"
+		| "refreshUnboundModels"
 		| "alwaysAllowBrowser"
 		| "alwaysAllowMcp"
 		| "alwaysAllowModeSwitch"

+ 12 - 9
src/shared/api.ts

@@ -60,6 +60,7 @@ export interface ApiHandlerOptions {
 	includeMaxTokens?: boolean
 	unboundApiKey?: string
 	unboundModelId?: string
+	unboundModelInfo?: ModelInfo
 }
 
 export type ApiConfiguration = ApiHandlerOptions & {
@@ -650,12 +651,14 @@ export const mistralModels = {
 } as const satisfies Record<string, ModelInfo>
 
 // Unbound Security
-export type UnboundModelId = keyof typeof unboundModels
-export const unboundDefaultModelId = "openai/gpt-4o"
-export const unboundModels = {
-	"anthropic/claude-3-5-sonnet-20241022": anthropicModels["claude-3-5-sonnet-20241022"],
-	"openai/gpt-4o": openAiNativeModels["gpt-4o"],
-	"deepseek/deepseek-chat": deepSeekModels["deepseek-chat"],
-	"deepseek/deepseek-reasoner": deepSeekModels["deepseek-reasoner"],
-	"mistral/codestral-latest": mistralModels["codestral-latest"],
-} as const satisfies Record<string, ModelInfo>
+export const unboundDefaultModelId = "anthropic/claude-3-5-sonnet-20241022"
+export const unboundDefaultModelInfo: ModelInfo = {
+	maxTokens: 8192,
+	contextWindow: 200_000,
+	supportsImages: true,
+	supportsPromptCache: true,
+	inputPrice: 3.0,
+	outputPrice: 15.0,
+	cacheWritesPrice: 3.75,
+	cacheReadsPrice: 0.3,
+}

+ 11 - 4
webview-ui/src/components/settings/ApiOptions.tsx

@@ -28,7 +28,7 @@ import {
 	vertexDefaultModelId,
 	vertexModels,
 	unboundDefaultModelId,
-	unboundModels,
+	unboundDefaultModelInfo,
 } from "../../../../src/shared/api"
 import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage"
 import { useExtensionState } from "../../context/ExtensionStateContext"
@@ -37,9 +37,11 @@ import VSCodeButtonLink from "../common/VSCodeButtonLink"
 import { OpenRouterModelPicker } from "./OpenRouterModelPicker"
 import OpenAiModelPicker from "./OpenAiModelPicker"
 import { GlamaModelPicker } from "./GlamaModelPicker"
+import { UnboundModelPicker } from "./UnboundModelPicker"
 import { ModelInfoView } from "./ModelInfoView"
 import { DROPDOWN_Z_INDEX } from "./styles"
 
+
 interface ApiOptionsProps {
 	apiErrorMessage?: string
 	modelIdErrorMessage?: string
@@ -1314,6 +1316,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
 						}}>
 						This key is stored locally and only used to make API requests from this extension.
 					</p>
+					<UnboundModelPicker />
 				</div>
 			)}
 
@@ -1336,7 +1339,8 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
 				selectedProvider !== "openrouter" &&
 				selectedProvider !== "openai" &&
 				selectedProvider !== "ollama" &&
-				selectedProvider !== "lmstudio" && (
+				selectedProvider !== "lmstudio" &&
+				selectedProvider !== "unbound" && (
 					<>
 						<div className="dropdown-container">
 							<label htmlFor="model-id">
@@ -1349,7 +1353,6 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage }: ApiOptionsProps) =
 							{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
 							{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
 							{selectedProvider === "mistral" && createDropdown(mistralModels)}
-							{selectedProvider === "unbound" && createDropdown(unboundModels)}
 						</div>
 
 						<ModelInfoView
@@ -1458,7 +1461,11 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 				},
 			}
 		case "unbound":
-			return getProviderData(unboundModels, unboundDefaultModelId)
+			return {
+				selectedProvider: provider,
+				selectedModelId: apiConfiguration?.unboundModelId || unboundDefaultModelId,
+				selectedModelInfo: apiConfiguration?.unboundModelInfo || unboundDefaultModelInfo,
+			}
 		default:
 			return getProviderData(anthropicModels, anthropicDefaultModelId)
 	}

+ 4 - 4
webview-ui/src/components/settings/ModelPicker.tsx

@@ -25,10 +25,10 @@ import { ModelInfoView } from "./ModelInfoView"
 
 interface ModelPickerProps {
 	defaultModelId: string
-	modelsKey: "glamaModels" | "openRouterModels"
-	configKey: "glamaModelId" | "openRouterModelId"
-	infoKey: "glamaModelInfo" | "openRouterModelInfo"
-	refreshMessageType: "refreshGlamaModels" | "refreshOpenRouterModels"
+	modelsKey: "glamaModels" | "openRouterModels" | "unboundModels"
+	configKey: "glamaModelId" | "openRouterModelId" | "unboundModelId"
+	infoKey: "glamaModelInfo" | "openRouterModelInfo" | "unboundModelInfo"
+	refreshMessageType: "refreshGlamaModels" | "refreshOpenRouterModels" | "refreshUnboundModels"
 	serviceName: string
 	serviceUrl: string
 	recommendedModel: string

+ 15 - 0
webview-ui/src/components/settings/UnboundModelPicker.tsx

@@ -0,0 +1,15 @@
+import { ModelPicker } from "./ModelPicker"
+import { unboundDefaultModelId } from "../../../../src/shared/api"
+
+export const UnboundModelPicker = () => (
+	<ModelPicker
+		defaultModelId={unboundDefaultModelId}
+		modelsKey="unboundModels"
+		configKey="unboundModelId"
+		infoKey="unboundModelInfo"
+		refreshMessageType="refreshUnboundModels"
+		serviceName="Unbound"
+		serviceUrl="https://api.getunbound.ai/models"
+		recommendedModel={unboundDefaultModelId}
+	/>
+)

+ 12 - 0
webview-ui/src/context/ExtensionStateContext.tsx

@@ -8,6 +8,8 @@ import {
 	glamaDefaultModelInfo,
 	openRouterDefaultModelId,
 	openRouterDefaultModelInfo,
+	unboundDefaultModelId,
+	unboundDefaultModelInfo,
 } from "../../../src/shared/api"
 import { vscode } from "../utils/vscode"
 import { convertTextMateToHljs } from "../utils/textMateToHljs"
@@ -24,6 +26,7 @@ export interface ExtensionStateContextType extends ExtensionState {
 	theme: any
 	glamaModels: Record<string, ModelInfo>
 	openRouterModels: Record<string, ModelInfo>
+	unboundModels: Record<string, ModelInfo>
 	openAiModels: string[]
 	mcpServers: McpServer[]
 	currentCheckpoint?: string
@@ -124,6 +127,9 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
 	const [openRouterModels, setOpenRouterModels] = useState<Record<string, ModelInfo>>({
 		[openRouterDefaultModelId]: openRouterDefaultModelInfo,
 	})
+	const [unboundModels, setUnboundModels] = useState<Record<string, ModelInfo>>({
+		[unboundDefaultModelId]: unboundDefaultModelInfo,
+	})
 
 	const [openAiModels, setOpenAiModels] = useState<string[]>([])
 	const [mcpServers, setMcpServers] = useState<McpServer[]>([])
@@ -239,6 +245,11 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
 					setOpenAiModels(updatedModels)
 					break
 				}
+				case "unboundModels": {
+					const updatedModels = message.unboundModels ?? {}
+					setUnboundModels(updatedModels)
+					break
+				}
 				case "mcpServers": {
 					setMcpServers(message.mcpServers ?? [])
 					break
@@ -270,6 +281,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
 		glamaModels,
 		openRouterModels,
 		openAiModels,
+		unboundModels,
 		mcpServers,
 		currentCheckpoint,
 		filePaths,

+ 12 - 1
webview-ui/src/utils/validate.ts

@@ -1,4 +1,4 @@
-import { ApiConfiguration, glamaDefaultModelId, openRouterDefaultModelId } from "../../../src/shared/api"
+import { ApiConfiguration, glamaDefaultModelId, openRouterDefaultModelId, unboundDefaultModelId } from "../../../src/shared/api"
 import { ModelInfo } from "../../../src/shared/api"
 export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): string | undefined {
 	if (apiConfiguration) {
@@ -76,6 +76,7 @@ export function validateModelId(
 	apiConfiguration?: ApiConfiguration,
 	glamaModels?: Record<string, ModelInfo>,
 	openRouterModels?: Record<string, ModelInfo>,
+	unboundModels?: Record<string, ModelInfo>,
 ): string | undefined {
 	if (apiConfiguration) {
 		switch (apiConfiguration.apiProvider) {
@@ -99,6 +100,16 @@ export function validateModelId(
 					return "The model ID you provided is not available. Please choose a different model."
 				}
 				break
+			case "unbound":
+				const unboundModelId = apiConfiguration.unboundModelId || unboundDefaultModelId
+				if (!unboundModelId) {
+					return "You must provide a model ID."
+				}
+				if (unboundModels && !Object.keys(unboundModels).includes(unboundModelId)) {
+					// even if the model list endpoint failed, extensionstatecontext will always have the default model info
+					return "The model ID you provided is not available. Please choose a different model."
+				}
+				break
 		}
 	}
 	return undefined