| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474 |
- import z from "zod"
- import { App } from "../app/app"
- import { Config } from "../config/config"
- import { mergeDeep, sortBy } from "remeda"
- import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
- import { Log } from "../util/log"
- import { BunProc } from "../bun"
- import { BashTool } from "../tool/bash"
- import { EditTool } from "../tool/edit"
- import { WebFetchTool } from "../tool/webfetch"
- import { GlobTool } from "../tool/glob"
- import { GrepTool } from "../tool/grep"
- import { ListTool } from "../tool/ls"
- import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
- import { LspHoverTool } from "../tool/lsp-hover"
- import { PatchTool } from "../tool/patch"
- import { ReadTool } from "../tool/read"
- import type { Tool } from "../tool/tool"
- import { WriteTool } from "../tool/write"
- import { TodoReadTool, TodoWriteTool } from "../tool/todo"
- import { AuthAnthropic } from "../auth/anthropic"
- import { AuthGithubCopilot } from "../auth/github-copilot"
- import { ModelsDev } from "./models"
- import { NamedError } from "../util/error"
- import { Auth } from "../auth"
- import { TaskTool } from "../tool/task"
- export namespace Provider {
- const log = Log.create({ service: "provider" })
- type CustomLoader = (provider: ModelsDev.Provider) => Promise<
- | {
- getModel?: (sdk: any, modelID: string) => Promise<any>
- options: Record<string, any>
- }
- | false
- >
- type Source = "env" | "config" | "custom" | "api"
- const CUSTOM_LOADERS: Record<string, CustomLoader> = {
- async anthropic(provider) {
- const access = await AuthAnthropic.access()
- if (!access) return false
- for (const model of Object.values(provider.models)) {
- model.cost = {
- input: 0,
- output: 0,
- }
- }
- return {
- options: {
- apiKey: "",
- async fetch(input: any, init: any) {
- const access = await AuthAnthropic.access()
- const headers = {
- ...init.headers,
- authorization: `Bearer ${access}`,
- "anthropic-beta": "oauth-2025-04-20",
- }
- delete headers["x-api-key"]
- return fetch(input, {
- ...init,
- headers,
- })
- },
- },
- }
- },
- "github-copilot": async (provider) => {
- const info = await AuthGithubCopilot.access()
- if (!info) return false
- if (provider && provider.models) {
- for (const model of Object.values(provider.models)) {
- model.cost = {
- input: 0,
- output: 0,
- }
- }
- }
- return {
- options: {
- apiKey: "",
- async fetch(input: any, init: any) {
- const token = await AuthGithubCopilot.access()
- if (!token) throw new Error("GitHub Copilot authentication expired")
- const headers = {
- ...init.headers,
- Authorization: `Bearer ${token}`,
- "User-Agent": "GithubCopilot/1.155.0",
- "Editor-Version": "vscode/1.85.1",
- "Editor-Plugin-Version": "copilot/1.155.0",
- }
- delete headers["x-api-key"]
- return fetch(input, {
- ...init,
- headers,
- })
- },
- },
- }
- },
- openai: async () => {
- return {
- async getModel(sdk: any, modelID: string) {
- return sdk.responses(modelID)
- },
- options: {},
- }
- },
- "amazon-bedrock": async () => {
- if (!process.env["AWS_PROFILE"]) return false
- const region = process.env["AWS_REGION"] ?? "us-east-1"
- const { fromNodeProviderChain } = await import(
- await BunProc.install("@aws-sdk/credential-providers")
- )
- return {
- options: {
- region,
- credentialProvider: fromNodeProviderChain(),
- },
- async getModel(sdk: any, modelID: string) {
- if (modelID.includes("claude")) {
- const prefix = region.split("-")[0]
- modelID = `${prefix}.${modelID}`
- }
- return sdk.languageModel(modelID)
- },
- }
- },
- }
- const state = App.state("provider", async () => {
- const config = await Config.get()
- const database = await ModelsDev.get()
- const providers: {
- [providerID: string]: {
- source: Source
- info: ModelsDev.Provider
- getModel?: (sdk: any, modelID: string) => Promise<any>
- options: Record<string, any>
- }
- } = {}
- const models = new Map<
- string,
- { info: ModelsDev.Model; language: LanguageModel }
- >()
- const sdk = new Map<string, SDK>()
- log.info("init")
- function mergeProvider(
- id: string,
- options: Record<string, any>,
- source: Source,
- getModel?: (sdk: any, modelID: string) => Promise<any>,
- ) {
- const provider = providers[id]
- if (!provider) {
- const info = database[id]
- if (!info) return
- if (info.api) options["baseURL"] = info.api
- providers[id] = {
- source,
- info,
- options,
- }
- return
- }
- provider.options = mergeDeep(provider.options, options)
- provider.source = source
- provider.getModel = getModel ?? provider.getModel
- }
- const configProviders = Object.entries(config.provider ?? {})
- for (const [providerID, provider] of configProviders) {
- const existing = database[providerID]
- const parsed: ModelsDev.Provider = {
- id: providerID,
- npm: provider.npm ?? existing?.npm,
- name: provider.name ?? existing?.name ?? providerID,
- env: provider.env ?? existing?.env ?? [],
- models: existing?.models ?? {},
- }
- for (const [modelID, model] of Object.entries(provider.models ?? {})) {
- const existing = parsed.models[modelID]
- const parsedModel: ModelsDev.Model = {
- id: modelID,
- name: model.name ?? existing?.name ?? modelID,
- attachment: model.attachment ?? existing?.attachment ?? false,
- reasoning: model.reasoning ?? existing?.reasoning ?? false,
- temperature: model.temperature ?? existing?.temperature ?? false,
- tool_call: model.tool_call ?? existing?.tool_call ?? true,
- cost: {
- ...existing?.cost,
- ...model.cost,
- input: 0,
- output: 0,
- cache_read: 0,
- cache_write: 0,
- },
- options: {
- ...existing?.options,
- ...model.options,
- },
- limit: model.limit ??
- existing?.limit ?? {
- context: 0,
- output: 0,
- },
- }
- parsed.models[modelID] = parsedModel
- }
- database[providerID] = parsed
- }
- const disabled = await Config.get().then(
- (cfg) => new Set(cfg.disabled_providers ?? []),
- )
- // load env
- for (const [providerID, provider] of Object.entries(database)) {
- if (disabled.has(providerID)) continue
- if (provider.env.some((item) => process.env[item])) {
- mergeProvider(providerID, {}, "env")
- }
- }
- // load apikeys
- for (const [providerID, provider] of Object.entries(await Auth.all())) {
- if (disabled.has(providerID)) continue
- if (provider.type === "api") {
- mergeProvider(providerID, { apiKey: provider.key }, "api")
- }
- }
- // load custom
- for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
- if (disabled.has(providerID)) continue
- const result = await fn(database[providerID])
- if (result) {
- mergeProvider(providerID, result.options, "custom", result.getModel)
- }
- }
- // load config
- for (const [providerID, provider] of configProviders) {
- mergeProvider(providerID, provider.options ?? {}, "config")
- }
- for (const [providerID, provider] of Object.entries(providers)) {
- if (Object.keys(provider.info.models).length === 0) {
- delete providers[providerID]
- continue
- }
- log.info("found", { providerID })
- }
- return {
- models,
- providers,
- sdk,
- }
- })
- export async function list() {
- return state().then((state) => state.providers)
- }
- async function getSDK(provider: ModelsDev.Provider) {
- return (async () => {
- using _ = log.time("getSDK", {
- providerID: provider.id,
- })
- const s = await state()
- const existing = s.sdk.get(provider.id)
- if (existing) return existing
- const pkg = provider.npm ?? provider.id
- const mod = await import(await BunProc.install(pkg, "latest"))
- const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
- const loaded = fn(s.providers[provider.id]?.options)
- s.sdk.set(provider.id, loaded)
- return loaded as SDK
- })().catch((e) => {
- throw new InitError({ providerID: provider.id }, { cause: e })
- })
- }
- export async function getModel(providerID: string, modelID: string) {
- const key = `${providerID}/${modelID}`
- const s = await state()
- if (s.models.has(key)) return s.models.get(key)!
- log.info("getModel", {
- providerID,
- modelID,
- })
- const provider = s.providers[providerID]
- if (!provider) throw new ModelNotFoundError({ providerID, modelID })
- const info = provider.info.models[modelID]
- if (!info) throw new ModelNotFoundError({ providerID, modelID })
- const sdk = await getSDK(provider.info)
- try {
- const language = provider.getModel
- ? await provider.getModel(sdk, modelID)
- : sdk.languageModel(modelID)
- log.info("found", { providerID, modelID })
- s.models.set(key, {
- info,
- language,
- })
- return {
- info,
- language,
- }
- } catch (e) {
- if (e instanceof NoSuchModelError)
- throw new ModelNotFoundError(
- {
- modelID: modelID,
- providerID,
- },
- { cause: e },
- )
- throw e
- }
- }
- const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
- export function sort(models: ModelsDev.Model[]) {
- return sortBy(
- models,
- [
- (model) => priority.findIndex((filter) => model.id.includes(filter)),
- "desc",
- ],
- [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
- [(model) => model.id, "desc"],
- )
- }
- export async function defaultModel() {
- const cfg = await Config.get()
- if (cfg.model) return parseModel(cfg.model)
- const provider = await list()
- .then((val) => Object.values(val))
- .then((x) =>
- x.find(
- (p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id),
- ),
- )
- if (!provider) throw new Error("no providers found")
- const [model] = sort(Object.values(provider.info.models))
- if (!model) throw new Error("no models found")
- return {
- providerID: provider.info.id,
- modelID: model.id,
- }
- }
- export function parseModel(model: string) {
- const [providerID, ...rest] = model.split("/")
- return {
- providerID: providerID,
- modelID: rest.join("/"),
- }
- }
- const TOOLS = [
- BashTool,
- EditTool,
- WebFetchTool,
- GlobTool,
- GrepTool,
- ListTool,
- LspDiagnosticTool,
- LspHoverTool,
- PatchTool,
- ReadTool,
- EditTool,
- // MultiEditTool,
- WriteTool,
- TodoWriteTool,
- TaskTool,
- TodoReadTool,
- ]
- const TOOL_MAPPING: Record<string, Tool.Info[]> = {
- anthropic: TOOLS.filter((t) => t.id !== "patch"),
- openai: TOOLS.map((t) => ({
- ...t,
- parameters: optionalToNullable(t.parameters),
- })),
- azure: TOOLS.map((t) => ({
- ...t,
- parameters: optionalToNullable(t.parameters),
- })),
- google: TOOLS,
- }
- export async function tools(providerID: string) {
- /*
- const cfg = await Config.get()
- if (cfg.tool?.provider?.[providerID])
- return cfg.tool.provider[providerID].map(
- (id) => TOOLS.find((t) => t.id === id)!,
- )
- */
- return TOOL_MAPPING[providerID] ?? TOOLS
- }
- function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
- if (schema instanceof z.ZodObject) {
- const shape = schema.shape
- const newShape: Record<string, z.ZodTypeAny> = {}
- for (const [key, value] of Object.entries(shape)) {
- const zodValue = value as z.ZodTypeAny
- if (zodValue instanceof z.ZodOptional) {
- newShape[key] = zodValue.unwrap().nullable()
- } else {
- newShape[key] = optionalToNullable(zodValue)
- }
- }
- return z.object(newShape)
- }
- if (schema instanceof z.ZodArray) {
- return z.array(optionalToNullable(schema.element))
- }
- if (schema instanceof z.ZodUnion) {
- return z.union(
- schema.options.map((option: z.ZodTypeAny) =>
- optionalToNullable(option),
- ) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]],
- )
- }
- return schema
- }
- export const ModelNotFoundError = NamedError.create(
- "ProviderModelNotFoundError",
- z.object({
- providerID: z.string(),
- modelID: z.string(),
- }),
- )
- export const InitError = NamedError.create(
- "ProviderInitError",
- z.object({
- providerID: z.string(),
- }),
- )
- export const AuthError = NamedError.create(
- "ProviderAuthError",
- z.object({
- providerID: z.string(),
- message: z.string(),
- }),
- )
- }
|