import z from "zod" import { App } from "../app/app" import { Config } from "../config/config" import { PROVIDER_DATABASE } from "./database" import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai" import { Log } from "../util/log" import path from "path" import { Global } from "../global" import { BunProc } from "../bun" import { BashTool } from "../tool/bash" import { EditTool } from "../tool/edit" import { FetchTool } from "../tool/fetch" import { GlobTool } from "../tool/glob" import { GrepTool } from "../tool/grep" import { ListTool } from "../tool/ls" import { LspDiagnosticTool } from "../tool/lsp-diagnostics" import { LspHoverTool } from "../tool/lsp-hover" import { PatchTool } from "../tool/patch" import { ViewTool } from "../tool/view" import type { Tool } from "../tool/tool" export namespace Provider { const log = Log.create({ service: "provider" }) export const Model = z .object({ id: z.string(), name: z.string().optional(), cost: z.object({ input: z.number(), inputCached: z.number(), output: z.number(), outputCached: z.number(), }), contextWindow: z.number(), maxOutputTokens: z.number().optional(), attachment: z.boolean(), reasoning: z.boolean().optional(), }) .openapi({ ref: "Provider.Model", }) export type Model = z.output export const Info = z .object({ id: z.string(), name: z.string(), options: z.record(z.string(), z.any()).optional(), models: Model.array(), }) .openapi({ ref: "Provider.Info", }) export type Info = z.output const AUTODETECT: Record = { anthropic: ["ANTHROPIC_API_KEY"], openai: ["OPENAI_API_KEY"], google: ["GOOGLE_GENERATIVE_AI_API_KEY", "GEMINI_API_KEY"], } const state = App.state("provider", async () => { const config = await Config.get() const providers = new Map() const models = new Map() const sdk = new Map() for (const item of PROVIDER_DATABASE) { if (!AUTODETECT[item.id].some((env) => process.env[env])) continue providers.set(item.id, item) } for (const item of config.provider ?? []) { providers.set(item.id, item) } return { models, providers, sdk, } }) export async function active() { return state().then((state) => state.providers) } async function getSDK(providerID: string) { const s = await state() if (s.sdk.has(providerID)) return s.sdk.get(providerID)! const dir = path.join(Global.cache(), `node_modules`, `@ai-sdk`, providerID) if (!(await Bun.file(path.join(dir, "package.json")).exists())) { log.info("installing", { providerID, }) BunProc.run(["add", `@ai-sdk/${providerID}@alpha`], { cwd: Global.cache(), }) } const mod = await import(path.join(dir)) const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!] const loaded = fn(s.providers.get(providerID)?.options) s.sdk.set(providerID, loaded) return loaded as SDK } export async function getModel(providerID: string, modelID: string) { const key = `${providerID}/${modelID}` const s = await state() if (s.models.has(key)) return s.models.get(key)! log.info("loading", { providerID, modelID, }) const provider = s.providers.get(providerID) if (!provider) throw new ModelNotFoundError(modelID) const info = provider.models.find((m) => m.id === modelID) if (!info) throw new ModelNotFoundError(modelID) const sdk = await getSDK(providerID) if (!sdk) throw new ModelNotFoundError(modelID) try { const language = sdk.languageModel(modelID) log.info("found", { providerID, modelID }) s.models.set(key, { info, language, }) return { info, language, } } catch (e) { if (e instanceof NoSuchModelError) throw new ModelNotFoundError(modelID) throw e } } export async function defaultModel() { const [provider] = await active().then((val) => val.values().toArray()) if (!provider) throw new Error("no providers found") const model = provider.models[0] if (!model) throw new Error("no models found") return { providerID: provider.id, modelID: model.id, } } const TOOLS = [ BashTool, EditTool, FetchTool, GlobTool, GrepTool, ListTool, LspDiagnosticTool, LspHoverTool, PatchTool, ViewTool, EditTool, ] const TOOL_MAPPING: Record = { anthropic: TOOLS, openai: TOOLS, google: TOOLS, } export async function tools(providerID: string) { const cfg = await Config.get() if (cfg.tool?.provider?.[providerID]) return cfg.tool.provider[providerID].map( (id) => TOOLS.find((t) => t.id === id)!, ) return TOOL_MAPPING[providerID] ?? TOOLS } class ModelNotFoundError extends Error { constructor(public readonly model: string) { super() } } }