provider.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { BashTool } from "../tool/bash"
  9. import { EditTool } from "../tool/edit"
  10. import { WebFetchTool } from "../tool/webfetch"
  11. import { GlobTool } from "../tool/glob"
  12. import { GrepTool } from "../tool/grep"
  13. import { ListTool } from "../tool/ls"
  14. import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
  15. import { LspHoverTool } from "../tool/lsp-hover"
  16. import { PatchTool } from "../tool/patch"
  17. import { ReadTool } from "../tool/read"
  18. import type { Tool } from "../tool/tool"
  19. import { WriteTool } from "../tool/write"
  20. import { TodoReadTool, TodoWriteTool } from "../tool/todo"
  21. import { AuthAnthropic } from "../auth/anthropic"
  22. import { AuthGithubCopilot } from "../auth/github-copilot"
  23. import { ModelsDev } from "./models"
  24. import { NamedError } from "../util/error"
  25. import { Auth } from "../auth"
  26. import { TaskTool } from "../tool/task"
  27. export namespace Provider {
  28. const log = Log.create({ service: "provider" })
  29. type CustomLoader = (provider: ModelsDev.Provider) => Promise<
  30. | {
  31. getModel?: (sdk: any, modelID: string) => Promise<any>
  32. options: Record<string, any>
  33. }
  34. | false
  35. >
  36. type Source = "env" | "config" | "custom" | "api"
  37. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  38. async anthropic(provider) {
  39. const access = await AuthAnthropic.access()
  40. if (!access) return false
  41. for (const model of Object.values(provider.models)) {
  42. model.cost = {
  43. input: 0,
  44. output: 0,
  45. }
  46. }
  47. return {
  48. options: {
  49. apiKey: "",
  50. async fetch(input: any, init: any) {
  51. const access = await AuthAnthropic.access()
  52. const headers = {
  53. ...init.headers,
  54. authorization: `Bearer ${access}`,
  55. "anthropic-beta": "oauth-2025-04-20",
  56. }
  57. delete headers["x-api-key"]
  58. return fetch(input, {
  59. ...init,
  60. headers,
  61. })
  62. },
  63. },
  64. }
  65. },
  66. "github-copilot": async (provider) => {
  67. const info = await AuthGithubCopilot.access()
  68. if (!info) return false
  69. if (provider && provider.models) {
  70. for (const model of Object.values(provider.models)) {
  71. model.cost = {
  72. input: 0,
  73. output: 0,
  74. }
  75. }
  76. }
  77. return {
  78. options: {
  79. apiKey: "",
  80. async fetch(input: any, init: any) {
  81. const token = await AuthGithubCopilot.access()
  82. if (!token) throw new Error("GitHub Copilot authentication expired")
  83. const headers = {
  84. ...init.headers,
  85. Authorization: `Bearer ${token}`,
  86. "User-Agent": "GithubCopilot/1.155.0",
  87. "Editor-Version": "vscode/1.85.1",
  88. "Editor-Plugin-Version": "copilot/1.155.0",
  89. }
  90. delete headers["x-api-key"]
  91. return fetch(input, {
  92. ...init,
  93. headers,
  94. })
  95. },
  96. },
  97. }
  98. },
  99. openai: async () => {
  100. return {
  101. async getModel(sdk: any, modelID: string) {
  102. return sdk.responses(modelID)
  103. },
  104. options: {},
  105. }
  106. },
  107. "amazon-bedrock": async () => {
  108. if (!process.env["AWS_PROFILE"]) return false
  109. const region = process.env["AWS_REGION"] ?? "us-east-1"
  110. const { fromNodeProviderChain } = await import(
  111. await BunProc.install("@aws-sdk/credential-providers")
  112. )
  113. return {
  114. options: {
  115. region,
  116. credentialProvider: fromNodeProviderChain(),
  117. },
  118. async getModel(sdk: any, modelID: string) {
  119. if (modelID.includes("claude")) {
  120. const prefix = region.split("-")[0]
  121. modelID = `${prefix}.${modelID}`
  122. }
  123. return sdk.languageModel(modelID)
  124. },
  125. }
  126. },
  127. }
  128. const state = App.state("provider", async () => {
  129. const config = await Config.get()
  130. const database = await ModelsDev.get()
  131. const providers: {
  132. [providerID: string]: {
  133. source: Source
  134. info: ModelsDev.Provider
  135. getModel?: (sdk: any, modelID: string) => Promise<any>
  136. options: Record<string, any>
  137. }
  138. } = {}
  139. const models = new Map<
  140. string,
  141. { info: ModelsDev.Model; language: LanguageModel }
  142. >()
  143. const sdk = new Map<string, SDK>()
  144. log.info("init")
  145. function mergeProvider(
  146. id: string,
  147. options: Record<string, any>,
  148. source: Source,
  149. getModel?: (sdk: any, modelID: string) => Promise<any>,
  150. ) {
  151. const provider = providers[id]
  152. if (!provider) {
  153. const info = database[id]
  154. if (!info) return
  155. if (info.api) options["baseURL"] = info.api
  156. providers[id] = {
  157. source,
  158. info,
  159. options,
  160. }
  161. return
  162. }
  163. provider.options = mergeDeep(provider.options, options)
  164. provider.source = source
  165. provider.getModel = getModel ?? provider.getModel
  166. }
  167. const configProviders = Object.entries(config.provider ?? {})
  168. for (const [providerID, provider] of configProviders) {
  169. const existing = database[providerID]
  170. const parsed: ModelsDev.Provider = {
  171. id: providerID,
  172. npm: provider.npm ?? existing?.npm,
  173. name: provider.name ?? existing?.name ?? providerID,
  174. env: provider.env ?? existing?.env ?? [],
  175. models: existing?.models ?? {},
  176. }
  177. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  178. const existing = parsed.models[modelID]
  179. const parsedModel: ModelsDev.Model = {
  180. id: modelID,
  181. name: model.name ?? existing?.name ?? modelID,
  182. attachment: model.attachment ?? existing?.attachment ?? false,
  183. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  184. temperature: model.temperature ?? existing?.temperature ?? false,
  185. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  186. cost: {
  187. ...existing?.cost,
  188. ...model.cost,
  189. input: 0,
  190. output: 0,
  191. cache_read: 0,
  192. cache_write: 0,
  193. },
  194. options: {
  195. ...existing?.options,
  196. ...model.options,
  197. },
  198. limit: model.limit ??
  199. existing?.limit ?? {
  200. context: 0,
  201. output: 0,
  202. },
  203. }
  204. parsed.models[modelID] = parsedModel
  205. }
  206. database[providerID] = parsed
  207. }
  208. const disabled = await Config.get().then(
  209. (cfg) => new Set(cfg.disabled_providers ?? []),
  210. )
  211. // load env
  212. for (const [providerID, provider] of Object.entries(database)) {
  213. if (disabled.has(providerID)) continue
  214. if (provider.env.some((item) => process.env[item])) {
  215. mergeProvider(providerID, {}, "env")
  216. }
  217. }
  218. // load apikeys
  219. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  220. if (disabled.has(providerID)) continue
  221. if (provider.type === "api") {
  222. mergeProvider(providerID, { apiKey: provider.key }, "api")
  223. }
  224. }
  225. // load custom
  226. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  227. if (disabled.has(providerID)) continue
  228. const result = await fn(database[providerID])
  229. if (result) {
  230. mergeProvider(providerID, result.options, "custom", result.getModel)
  231. }
  232. }
  233. // load config
  234. for (const [providerID, provider] of configProviders) {
  235. mergeProvider(providerID, provider.options ?? {}, "config")
  236. }
  237. for (const [providerID, provider] of Object.entries(providers)) {
  238. if (Object.keys(provider.info.models).length === 0) {
  239. delete providers[providerID]
  240. continue
  241. }
  242. log.info("found", { providerID })
  243. }
  244. return {
  245. models,
  246. providers,
  247. sdk,
  248. }
  249. })
  250. export async function list() {
  251. return state().then((state) => state.providers)
  252. }
  253. async function getSDK(provider: ModelsDev.Provider) {
  254. return (async () => {
  255. using _ = log.time("getSDK", {
  256. providerID: provider.id,
  257. })
  258. const s = await state()
  259. const existing = s.sdk.get(provider.id)
  260. if (existing) return existing
  261. const pkg = provider.npm ?? provider.id
  262. const mod = await import(await BunProc.install(pkg, "latest"))
  263. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  264. const loaded = fn(s.providers[provider.id]?.options)
  265. s.sdk.set(provider.id, loaded)
  266. return loaded as SDK
  267. })().catch((e) => {
  268. throw new InitError({ providerID: provider.id }, { cause: e })
  269. })
  270. }
  271. export async function getModel(providerID: string, modelID: string) {
  272. const key = `${providerID}/${modelID}`
  273. const s = await state()
  274. if (s.models.has(key)) return s.models.get(key)!
  275. log.info("getModel", {
  276. providerID,
  277. modelID,
  278. })
  279. const provider = s.providers[providerID]
  280. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  281. const info = provider.info.models[modelID]
  282. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  283. const sdk = await getSDK(provider.info)
  284. try {
  285. const language = provider.getModel
  286. ? await provider.getModel(sdk, modelID)
  287. : sdk.languageModel(modelID)
  288. log.info("found", { providerID, modelID })
  289. s.models.set(key, {
  290. info,
  291. language,
  292. })
  293. return {
  294. info,
  295. language,
  296. }
  297. } catch (e) {
  298. if (e instanceof NoSuchModelError)
  299. throw new ModelNotFoundError(
  300. {
  301. modelID: modelID,
  302. providerID,
  303. },
  304. { cause: e },
  305. )
  306. throw e
  307. }
  308. }
  309. const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
  310. export function sort(models: ModelsDev.Model[]) {
  311. return sortBy(
  312. models,
  313. [
  314. (model) => priority.findIndex((filter) => model.id.includes(filter)),
  315. "desc",
  316. ],
  317. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  318. [(model) => model.id, "desc"],
  319. )
  320. }
  321. export async function defaultModel() {
  322. const cfg = await Config.get()
  323. if (cfg.model) return parseModel(cfg.model)
  324. const provider = await list()
  325. .then((val) => Object.values(val))
  326. .then((x) =>
  327. x.find(
  328. (p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id),
  329. ),
  330. )
  331. if (!provider) throw new Error("no providers found")
  332. const [model] = sort(Object.values(provider.info.models))
  333. if (!model) throw new Error("no models found")
  334. return {
  335. providerID: provider.info.id,
  336. modelID: model.id,
  337. }
  338. }
  339. export function parseModel(model: string) {
  340. const [providerID, ...rest] = model.split("/")
  341. return {
  342. providerID: providerID,
  343. modelID: rest.join("/"),
  344. }
  345. }
  346. const TOOLS = [
  347. BashTool,
  348. EditTool,
  349. WebFetchTool,
  350. GlobTool,
  351. GrepTool,
  352. ListTool,
  353. LspDiagnosticTool,
  354. LspHoverTool,
  355. PatchTool,
  356. ReadTool,
  357. EditTool,
  358. // MultiEditTool,
  359. WriteTool,
  360. TodoWriteTool,
  361. TaskTool,
  362. TodoReadTool,
  363. ]
  364. const TOOL_MAPPING: Record<string, Tool.Info[]> = {
  365. anthropic: TOOLS.filter((t) => t.id !== "patch"),
  366. openai: TOOLS.map((t) => ({
  367. ...t,
  368. parameters: optionalToNullable(t.parameters),
  369. })),
  370. azure: TOOLS.map((t) => ({
  371. ...t,
  372. parameters: optionalToNullable(t.parameters),
  373. })),
  374. google: TOOLS,
  375. }
  376. export async function tools(providerID: string) {
  377. /*
  378. const cfg = await Config.get()
  379. if (cfg.tool?.provider?.[providerID])
  380. return cfg.tool.provider[providerID].map(
  381. (id) => TOOLS.find((t) => t.id === id)!,
  382. )
  383. */
  384. return TOOL_MAPPING[providerID] ?? TOOLS
  385. }
  386. function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
  387. if (schema instanceof z.ZodObject) {
  388. const shape = schema.shape
  389. const newShape: Record<string, z.ZodTypeAny> = {}
  390. for (const [key, value] of Object.entries(shape)) {
  391. const zodValue = value as z.ZodTypeAny
  392. if (zodValue instanceof z.ZodOptional) {
  393. newShape[key] = zodValue.unwrap().nullable()
  394. } else {
  395. newShape[key] = optionalToNullable(zodValue)
  396. }
  397. }
  398. return z.object(newShape)
  399. }
  400. if (schema instanceof z.ZodArray) {
  401. return z.array(optionalToNullable(schema.element))
  402. }
  403. if (schema instanceof z.ZodUnion) {
  404. return z.union(
  405. schema.options.map((option: z.ZodTypeAny) =>
  406. optionalToNullable(option),
  407. ) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]],
  408. )
  409. }
  410. return schema
  411. }
  412. export const ModelNotFoundError = NamedError.create(
  413. "ProviderModelNotFoundError",
  414. z.object({
  415. providerID: z.string(),
  416. modelID: z.string(),
  417. }),
  418. )
  419. export const InitError = NamedError.create(
  420. "ProviderInitError",
  421. z.object({
  422. providerID: z.string(),
  423. }),
  424. )
  425. export const AuthError = NamedError.create(
  426. "ProviderAuthError",
  427. z.object({
  428. providerID: z.string(),
  429. message: z.string(),
  430. }),
  431. )
  432. }