| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- import z from "zod"
- import { App } from "../app/app"
- import { Config } from "../config/config"
- import { PROVIDER_DATABASE } from "./database"
- import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
- import { Log } from "../util/log"
- import path from "path"
- import { Global } from "../global"
- import { BunProc } from "../bun"
- import { BashTool } from "../tool/bash"
- import { EditTool } from "../tool/edit"
- import { FetchTool } from "../tool/fetch"
- import { GlobTool } from "../tool/glob"
- import { GrepTool } from "../tool/grep"
- import { ListTool } from "../tool/ls"
- import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
- import { LspHoverTool } from "../tool/lsp-hover"
- import { PatchTool } from "../tool/patch"
- import { ViewTool } from "../tool/view"
- import type { Tool } from "../tool/tool"
- export namespace Provider {
- const log = Log.create({ service: "provider" })
- export const Model = z
- .object({
- id: z.string(),
- name: z.string().optional(),
- cost: z.object({
- input: z.number(),
- inputCached: z.number(),
- output: z.number(),
- outputCached: z.number(),
- }),
- contextWindow: z.number(),
- maxOutputTokens: z.number().optional(),
- attachment: z.boolean(),
- reasoning: z.boolean().optional(),
- })
- .openapi({
- ref: "Provider.Model",
- })
- export type Model = z.output<typeof Model>
- export const Info = z
- .object({
- id: z.string(),
- name: z.string(),
- options: z.record(z.string(), z.any()).optional(),
- models: Model.array(),
- })
- .openapi({
- ref: "Provider.Info",
- })
- export type Info = z.output<typeof Info>
- const AUTODETECT: Record<string, string[]> = {
- anthropic: ["ANTHROPIC_API_KEY"],
- openai: ["OPENAI_API_KEY"],
- google: ["GOOGLE_GENERATIVE_AI_API_KEY", "GEMINI_API_KEY"],
- }
- const state = App.state("provider", async () => {
- const config = await Config.get()
- const providers = new Map<string, Info>()
- const models = new Map<string, { info: Model; language: LanguageModel }>()
- const sdk = new Map<string, SDK>()
- for (const item of PROVIDER_DATABASE) {
- if (!AUTODETECT[item.id].some((env) => process.env[env])) continue
- providers.set(item.id, item)
- }
- for (const item of config.provider ?? []) {
- providers.set(item.id, item)
- }
- return {
- models,
- providers,
- sdk,
- }
- })
- export async function active() {
- return state().then((state) => state.providers)
- }
- async function getSDK(providerID: string) {
- const s = await state()
- if (s.sdk.has(providerID)) return s.sdk.get(providerID)!
- const dir = path.join(Global.cache(), `node_modules`, `@ai-sdk`, providerID)
- if (!(await Bun.file(path.join(dir, "package.json")).exists())) {
- log.info("installing", {
- providerID,
- })
- BunProc.run(["add", `@ai-sdk/${providerID}@alpha`], {
- cwd: Global.cache(),
- })
- }
- const mod = await import(path.join(dir))
- const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
- const loaded = fn(s.providers.get(providerID)?.options)
- s.sdk.set(providerID, loaded)
- return loaded as SDK
- }
- export async function getModel(providerID: string, modelID: string) {
- const key = `${providerID}/${modelID}`
- const s = await state()
- if (s.models.has(key)) return s.models.get(key)!
- log.info("loading", {
- providerID,
- modelID,
- })
- const provider = s.providers.get(providerID)
- if (!provider) throw new ModelNotFoundError(modelID)
- const info = provider.models.find((m) => m.id === modelID)
- if (!info) throw new ModelNotFoundError(modelID)
- const sdk = await getSDK(providerID)
- if (!sdk) throw new ModelNotFoundError(modelID)
- try {
- const language = sdk.languageModel(modelID)
- log.info("found", { providerID, modelID })
- s.models.set(key, {
- info,
- language,
- })
- return {
- info,
- language,
- }
- } catch (e) {
- if (e instanceof NoSuchModelError) throw new ModelNotFoundError(modelID)
- throw e
- }
- }
- export async function defaultModel() {
- const [provider] = await active().then((val) => val.values().toArray())
- if (!provider) throw new Error("no providers found")
- const model = provider.models[0]
- if (!model) throw new Error("no models found")
- return {
- providerID: provider.id,
- modelID: model.id,
- }
- }
- const TOOLS = [
- BashTool,
- EditTool,
- FetchTool,
- GlobTool,
- GrepTool,
- ListTool,
- LspDiagnosticTool,
- LspHoverTool,
- PatchTool,
- ViewTool,
- EditTool,
- ]
- const TOOL_MAPPING: Record<string, Tool.Info[]> = {
- anthropic: TOOLS,
- openai: TOOLS,
- google: TOOLS,
- }
- export async function tools(providerID: string) {
- const cfg = await Config.get()
- if (cfg.tool?.provider?.[providerID])
- return cfg.tool.provider[providerID].map(
- (id) => TOOLS.find((t) => t.id === id)!,
- )
- return TOOL_MAPPING[providerID] ?? TOOLS
- }
- class ModelNotFoundError extends Error {
- constructor(public readonly model: string) {
- super()
- }
- }
- }
|