provider.ts 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { BashTool } from "../tool/bash"
  9. import { EditTool } from "../tool/edit"
  10. import { WebFetchTool } from "../tool/webfetch"
  11. import { GlobTool } from "../tool/glob"
  12. import { GrepTool } from "../tool/grep"
  13. import { ListTool } from "../tool/ls"
  14. import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
  15. import { LspHoverTool } from "../tool/lsp-hover"
  16. import { PatchTool } from "../tool/patch"
  17. import { ReadTool } from "../tool/read"
  18. import type { Tool } from "../tool/tool"
  19. import { WriteTool } from "../tool/write"
  20. import { TodoReadTool, TodoWriteTool } from "../tool/todo"
  21. import { AuthAnthropic } from "../auth/anthropic"
  22. import { ModelsDev } from "./models"
  23. import { NamedError } from "../util/error"
  24. import { Auth } from "../auth"
  25. import { TaskTool } from "../tool/task"
  26. export namespace Provider {
  27. const log = Log.create({ service: "provider" })
  28. type CustomLoader = (provider: ModelsDev.Provider) => Promise<
  29. | {
  30. getModel?: (sdk: any, modelID: string) => Promise<any>
  31. options: Record<string, any>
  32. }
  33. | false
  34. >
  35. type Source = "env" | "config" | "custom" | "api"
  36. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  37. async anthropic(provider) {
  38. const access = await AuthAnthropic.access()
  39. if (!access) return false
  40. for (const model of Object.values(provider.models)) {
  41. model.cost = {
  42. input: 0,
  43. output: 0,
  44. }
  45. }
  46. return {
  47. options: {
  48. apiKey: "",
  49. async fetch(input: any, init: any) {
  50. const access = await AuthAnthropic.access()
  51. const headers = {
  52. ...init.headers,
  53. authorization: `Bearer ${access}`,
  54. "anthropic-beta": "oauth-2025-04-20",
  55. }
  56. delete headers["x-api-key"]
  57. return fetch(input, {
  58. ...init,
  59. headers,
  60. })
  61. },
  62. },
  63. }
  64. },
  65. openai: async () => {
  66. return {
  67. async getModel(sdk: any, modelID: string) {
  68. return sdk.responses(modelID)
  69. },
  70. options: {},
  71. }
  72. },
  73. "amazon-bedrock": async () => {
  74. if (!process.env["AWS_PROFILE"]) false
  75. const region = process.env["AWS_REGION"] ?? "us-east-1"
  76. const { fromNodeProviderChain } = await import(
  77. await BunProc.install("@aws-sdk/credential-providers")
  78. )
  79. return {
  80. options: {
  81. region,
  82. credentialProvider: fromNodeProviderChain(),
  83. },
  84. async getModel(sdk: any, modelID: string) {
  85. if (modelID.includes("claude")) {
  86. const prefix = region.split("-")[0]
  87. modelID = `${prefix}.${modelID}`
  88. }
  89. return sdk.languageModel(modelID)
  90. },
  91. }
  92. },
  93. }
  94. const state = App.state("provider", async () => {
  95. const config = await Config.get()
  96. const database = await ModelsDev.get()
  97. const providers: {
  98. [providerID: string]: {
  99. source: Source
  100. info: ModelsDev.Provider
  101. getModel?: (sdk: any, modelID: string) => Promise<any>
  102. options: Record<string, any>
  103. }
  104. } = {}
  105. const models = new Map<
  106. string,
  107. { info: ModelsDev.Model; language: LanguageModel }
  108. >()
  109. const sdk = new Map<string, SDK>()
  110. log.info("init")
  111. function mergeProvider(
  112. id: string,
  113. options: Record<string, any>,
  114. source: Source,
  115. getModel?: (sdk: any, modelID: string) => Promise<any>,
  116. ) {
  117. const provider = providers[id]
  118. if (!provider) {
  119. const info = database[id]
  120. if (!info) return
  121. if (info.api) options["baseURL"] = info.api
  122. providers[id] = {
  123. source,
  124. info,
  125. options,
  126. }
  127. return
  128. }
  129. provider.options = mergeDeep(provider.options, options)
  130. provider.source = source
  131. provider.getModel = getModel ?? provider.getModel
  132. }
  133. const configProviders = Object.entries(config.provider ?? {})
  134. for (const [providerID, provider] of configProviders) {
  135. const existing = database[providerID]
  136. const parsed: ModelsDev.Provider = {
  137. id: providerID,
  138. npm: provider.npm ?? existing?.npm,
  139. name: provider.name ?? existing?.name ?? providerID,
  140. env: provider.env ?? existing?.env ?? [],
  141. models: existing?.models ?? {},
  142. }
  143. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  144. const existing = parsed.models[modelID]
  145. const parsedModel: ModelsDev.Model = {
  146. id: modelID,
  147. name: model.name ?? existing?.name ?? modelID,
  148. attachment: model.attachment ?? existing?.attachment ?? false,
  149. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  150. temperature: model.temperature ?? existing?.temperature ?? false,
  151. cost: model.cost ??
  152. existing?.cost ?? {
  153. input: 0,
  154. output: 0,
  155. inputCached: 0,
  156. outputCached: 0,
  157. },
  158. limit: model.limit ??
  159. existing?.limit ?? {
  160. context: 0,
  161. output: 0,
  162. },
  163. }
  164. parsed.models[modelID] = parsedModel
  165. }
  166. database[providerID] = parsed
  167. }
  168. const disabled = await Config.get().then(
  169. (cfg) => new Set(cfg.disabled_providers ?? []),
  170. )
  171. // load env
  172. for (const [providerID, provider] of Object.entries(database)) {
  173. if (disabled.has(providerID)) continue
  174. if (provider.env.some((item) => process.env[item])) {
  175. mergeProvider(providerID, {}, "env")
  176. }
  177. }
  178. // load apikeys
  179. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  180. if (disabled.has(providerID)) continue
  181. if (provider.type === "api") {
  182. mergeProvider(providerID, { apiKey: provider.key }, "api")
  183. }
  184. }
  185. // load custom
  186. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  187. if (disabled.has(providerID)) continue
  188. const result = await fn(database[providerID])
  189. if (result)
  190. mergeProvider(providerID, result.options, "custom", result.getModel)
  191. }
  192. // load config
  193. for (const [providerID, provider] of configProviders) {
  194. mergeProvider(providerID, provider.options ?? {}, "config")
  195. }
  196. for (const [providerID, provider] of Object.entries(providers)) {
  197. if (Object.keys(provider.info.models).length === 0) {
  198. delete providers[providerID]
  199. continue
  200. }
  201. log.info("found", { providerID })
  202. }
  203. return {
  204. models,
  205. providers,
  206. sdk,
  207. }
  208. })
  209. export async function list() {
  210. return state().then((state) => state.providers)
  211. }
  212. async function getSDK(provider: ModelsDev.Provider) {
  213. return (async () => {
  214. using _ = log.time("getSDK", {
  215. providerID: provider.id,
  216. })
  217. const s = await state()
  218. const existing = s.sdk.get(provider.id)
  219. if (existing) return existing
  220. const pkg = provider.npm ?? provider.id
  221. const mod = await import(await BunProc.install(pkg, "latest"))
  222. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  223. const loaded = fn(s.providers[provider.id]?.options)
  224. s.sdk.set(provider.id, loaded)
  225. return loaded as SDK
  226. })().catch((e) => {
  227. throw new InitError({ providerID: provider.id }, { cause: e })
  228. })
  229. }
  230. export async function getModel(providerID: string, modelID: string) {
  231. const key = `${providerID}/${modelID}`
  232. const s = await state()
  233. if (s.models.has(key)) return s.models.get(key)!
  234. log.info("getModel", {
  235. providerID,
  236. modelID,
  237. })
  238. const provider = s.providers[providerID]
  239. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  240. const info = provider.info.models[modelID]
  241. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  242. const sdk = await getSDK(provider.info)
  243. try {
  244. const language = provider.getModel
  245. ? await provider.getModel(sdk, modelID)
  246. : sdk.languageModel(modelID)
  247. log.info("found", { providerID, modelID })
  248. s.models.set(key, {
  249. info,
  250. language,
  251. })
  252. return {
  253. info,
  254. language,
  255. }
  256. } catch (e) {
  257. if (e instanceof NoSuchModelError)
  258. throw new ModelNotFoundError(
  259. {
  260. modelID: modelID,
  261. providerID,
  262. },
  263. { cause: e },
  264. )
  265. throw e
  266. }
  267. }
  268. const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
  269. export function sort(models: ModelsDev.Model[]) {
  270. return sortBy(
  271. models,
  272. [
  273. (model) => priority.findIndex((filter) => model.id.includes(filter)),
  274. "desc",
  275. ],
  276. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  277. [(model) => model.id, "desc"],
  278. )
  279. }
  280. export async function defaultModel() {
  281. const cfg = await Config.get()
  282. if (cfg.model) return parseModel(cfg.model)
  283. const provider = await list()
  284. .then((val) => Object.values(val))
  285. .then((x) =>
  286. x.find(
  287. (p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id),
  288. ),
  289. )
  290. if (!provider) throw new Error("no providers found")
  291. const [model] = sort(Object.values(provider.info.models))
  292. if (!model) throw new Error("no models found")
  293. return {
  294. providerID: provider.info.id,
  295. modelID: model.id,
  296. }
  297. }
  298. export function parseModel(model: string) {
  299. const [providerID, ...rest] = model.split("/")
  300. return {
  301. providerID: providerID,
  302. modelID: rest.join("/"),
  303. }
  304. }
  305. const TOOLS = [
  306. BashTool,
  307. EditTool,
  308. WebFetchTool,
  309. GlobTool,
  310. GrepTool,
  311. ListTool,
  312. LspDiagnosticTool,
  313. LspHoverTool,
  314. PatchTool,
  315. ReadTool,
  316. EditTool,
  317. // MultiEditTool,
  318. WriteTool,
  319. TodoWriteTool,
  320. TaskTool,
  321. TodoReadTool,
  322. ]
  323. const TOOL_MAPPING: Record<string, Tool.Info[]> = {
  324. anthropic: TOOLS.filter((t) => t.id !== "patch"),
  325. openai: TOOLS.map((t) => ({
  326. ...t,
  327. parameters: optionalToNullable(t.parameters),
  328. })),
  329. azure: TOOLS.map((t) => ({
  330. ...t,
  331. parameters: optionalToNullable(t.parameters),
  332. })),
  333. google: TOOLS,
  334. }
  335. export async function tools(providerID: string) {
  336. /*
  337. const cfg = await Config.get()
  338. if (cfg.tool?.provider?.[providerID])
  339. return cfg.tool.provider[providerID].map(
  340. (id) => TOOLS.find((t) => t.id === id)!,
  341. )
  342. */
  343. return TOOL_MAPPING[providerID] ?? TOOLS
  344. }
  345. function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
  346. if (schema instanceof z.ZodObject) {
  347. const shape = schema.shape
  348. const newShape: Record<string, z.ZodTypeAny> = {}
  349. for (const [key, value] of Object.entries(shape)) {
  350. const zodValue = value as z.ZodTypeAny
  351. if (zodValue instanceof z.ZodOptional) {
  352. newShape[key] = zodValue.unwrap().nullable()
  353. } else {
  354. newShape[key] = optionalToNullable(zodValue)
  355. }
  356. }
  357. return z.object(newShape)
  358. }
  359. if (schema instanceof z.ZodArray) {
  360. return z.array(optionalToNullable(schema.element))
  361. }
  362. if (schema instanceof z.ZodUnion) {
  363. return z.union(
  364. schema.options.map((option: z.ZodTypeAny) =>
  365. optionalToNullable(option),
  366. ) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]],
  367. )
  368. }
  369. return schema
  370. }
  371. export const ModelNotFoundError = NamedError.create(
  372. "ProviderModelNotFoundError",
  373. z.object({
  374. providerID: z.string(),
  375. modelID: z.string(),
  376. }),
  377. )
  378. export const InitError = NamedError.create(
  379. "ProviderInitError",
  380. z.object({
  381. providerID: z.string(),
  382. }),
  383. )
  384. export const AuthError = NamedError.create(
  385. "ProviderAuthError",
  386. z.object({
  387. providerID: z.string(),
  388. message: z.string(),
  389. }),
  390. )
  391. }