provider.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { BashTool } from "../tool/bash"
  9. import { EditTool } from "../tool/edit"
  10. import { WebFetchTool } from "../tool/webfetch"
  11. import { GlobTool } from "../tool/glob"
  12. import { GrepTool } from "../tool/grep"
  13. import { ListTool } from "../tool/ls"
  14. import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
  15. import { LspHoverTool } from "../tool/lsp-hover"
  16. import { PatchTool } from "../tool/patch"
  17. import { ReadTool } from "../tool/read"
  18. import type { Tool } from "../tool/tool"
  19. import { WriteTool } from "../tool/write"
  20. import { TodoReadTool, TodoWriteTool } from "../tool/todo"
  21. import { AuthAnthropic } from "../auth/anthropic"
  22. import { AuthCopilot } from "../auth/copilot"
  23. import { ModelsDev } from "./models"
  24. import { NamedError } from "../util/error"
  25. import { Auth } from "../auth"
  26. import { TaskTool } from "../tool/task"
  27. export namespace Provider {
  28. const log = Log.create({ service: "provider" })
  29. type CustomLoader = (
  30. provider: ModelsDev.Provider,
  31. api?: string,
  32. ) => Promise<{
  33. autoload: boolean
  34. getModel?: (sdk: any, modelID: string) => Promise<any>
  35. options?: Record<string, any>
  36. }>
  37. type Source = "env" | "config" | "custom" | "api"
  38. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  39. async anthropic(provider) {
  40. const access = await AuthAnthropic.access()
  41. if (!access) return { autoload: false }
  42. for (const model of Object.values(provider.models)) {
  43. model.cost = {
  44. input: 0,
  45. output: 0,
  46. }
  47. }
  48. return {
  49. autoload: true,
  50. options: {
  51. apiKey: "",
  52. async fetch(input: any, init: any) {
  53. const access = await AuthAnthropic.access()
  54. const headers = {
  55. ...init.headers,
  56. authorization: `Bearer ${access}`,
  57. "anthropic-beta": "oauth-2025-04-20",
  58. }
  59. delete headers["x-api-key"]
  60. return fetch(input, {
  61. ...init,
  62. headers,
  63. })
  64. },
  65. },
  66. }
  67. },
  68. "github-copilot": async (provider) => {
  69. const copilot = await AuthCopilot()
  70. if (!copilot) return { autoload: false }
  71. let info = await Auth.get("github-copilot")
  72. if (!info || info.type !== "oauth") return { autoload: false }
  73. if (provider && provider.models) {
  74. for (const model of Object.values(provider.models)) {
  75. model.cost = {
  76. input: 0,
  77. output: 0,
  78. }
  79. }
  80. }
  81. return {
  82. autoload: true,
  83. options: {
  84. apiKey: "",
  85. async fetch(input: any, init: any) {
  86. const info = await Auth.get("github-copilot")
  87. if (!info || info.type !== "oauth") return
  88. if (!info.access || info.expires < Date.now()) {
  89. const tokens = await copilot.access(info.refresh)
  90. if (!tokens)
  91. throw new Error("GitHub Copilot authentication expired")
  92. await Auth.set("github-copilot", {
  93. type: "oauth",
  94. ...tokens,
  95. })
  96. info.access = tokens.access
  97. }
  98. const headers = {
  99. ...init.headers,
  100. ...copilot.HEADERS,
  101. Authorization: `Bearer ${info.access}`,
  102. "Openai-Intent": "conversation-edits",
  103. }
  104. delete headers["x-api-key"]
  105. return fetch(input, {
  106. ...init,
  107. headers,
  108. })
  109. },
  110. },
  111. }
  112. },
  113. openai: async () => {
  114. return {
  115. autoload: false,
  116. async getModel(sdk: any, modelID: string) {
  117. return sdk.responses(modelID)
  118. },
  119. options: {},
  120. }
  121. },
  122. "amazon-bedrock": async () => {
  123. if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"])
  124. return { autoload: false }
  125. const region = process.env["AWS_REGION"] ?? "us-east-1"
  126. const { fromNodeProviderChain } = await import(
  127. await BunProc.install("@aws-sdk/credential-providers")
  128. )
  129. return {
  130. autoload: true,
  131. options: {
  132. region,
  133. credentialProvider: fromNodeProviderChain(),
  134. },
  135. async getModel(sdk: any, modelID: string) {
  136. if (modelID.includes("claude")) {
  137. const prefix = region.split("-")[0]
  138. modelID = `${prefix}.${modelID}`
  139. }
  140. return sdk.languageModel(modelID)
  141. },
  142. }
  143. },
  144. }
  145. const state = App.state("provider", async () => {
  146. const config = await Config.get()
  147. const database = await ModelsDev.get()
  148. const providers: {
  149. [providerID: string]: {
  150. source: Source
  151. info: ModelsDev.Provider
  152. getModel?: (sdk: any, modelID: string) => Promise<any>
  153. options: Record<string, any>
  154. }
  155. } = {}
  156. const models = new Map<
  157. string,
  158. { info: ModelsDev.Model; language: LanguageModel }
  159. >()
  160. const sdk = new Map<string, SDK>()
  161. log.info("init")
  162. function mergeProvider(
  163. id: string,
  164. options: Record<string, any>,
  165. source: Source,
  166. getModel?: (sdk: any, modelID: string) => Promise<any>,
  167. ) {
  168. const provider = providers[id]
  169. if (!provider) {
  170. const info = database[id]
  171. if (!info) return
  172. if (info.api) options["baseURL"] = info.api
  173. providers[id] = {
  174. source,
  175. info,
  176. options,
  177. }
  178. return
  179. }
  180. provider.options = mergeDeep(provider.options, options)
  181. provider.source = source
  182. provider.getModel = getModel ?? provider.getModel
  183. }
  184. const configProviders = Object.entries(config.provider ?? {})
  185. for (const [providerID, provider] of configProviders) {
  186. const existing = database[providerID]
  187. const parsed: ModelsDev.Provider = {
  188. id: providerID,
  189. npm: provider.npm ?? existing?.npm,
  190. name: provider.name ?? existing?.name ?? providerID,
  191. env: provider.env ?? existing?.env ?? [],
  192. models: existing?.models ?? {},
  193. }
  194. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  195. const existing = parsed.models[modelID]
  196. const parsedModel: ModelsDev.Model = {
  197. id: modelID,
  198. name: model.name ?? existing?.name ?? modelID,
  199. attachment: model.attachment ?? existing?.attachment ?? false,
  200. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  201. temperature: model.temperature ?? existing?.temperature ?? false,
  202. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  203. cost: {
  204. ...existing?.cost,
  205. ...model.cost,
  206. input: 0,
  207. output: 0,
  208. cache_read: 0,
  209. cache_write: 0,
  210. },
  211. options: {
  212. ...existing?.options,
  213. ...model.options,
  214. },
  215. limit: model.limit ??
  216. existing?.limit ?? {
  217. context: 0,
  218. output: 0,
  219. },
  220. }
  221. parsed.models[modelID] = parsedModel
  222. }
  223. database[providerID] = parsed
  224. }
  225. const disabled = await Config.get().then(
  226. (cfg) => new Set(cfg.disabled_providers ?? []),
  227. )
  228. // load env
  229. for (const [providerID, provider] of Object.entries(database)) {
  230. if (disabled.has(providerID)) continue
  231. if (provider.env.some((item) => process.env[item])) {
  232. mergeProvider(providerID, {}, "env")
  233. }
  234. }
  235. // load apikeys
  236. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  237. if (disabled.has(providerID)) continue
  238. if (provider.type === "api") {
  239. mergeProvider(providerID, { apiKey: provider.key }, "api")
  240. }
  241. }
  242. // load custom
  243. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  244. if (disabled.has(providerID)) continue
  245. const result = await fn(database[providerID])
  246. if (result && (result.autoload || providers[providerID])) {
  247. mergeProvider(
  248. providerID,
  249. result.options ?? {},
  250. "custom",
  251. result.getModel,
  252. )
  253. }
  254. }
  255. // load config
  256. for (const [providerID, provider] of configProviders) {
  257. mergeProvider(providerID, provider.options ?? {}, "config")
  258. }
  259. for (const [providerID, provider] of Object.entries(providers)) {
  260. if (Object.keys(provider.info.models).length === 0) {
  261. delete providers[providerID]
  262. continue
  263. }
  264. log.info("found", { providerID })
  265. }
  266. return {
  267. models,
  268. providers,
  269. sdk,
  270. }
  271. })
  272. export async function list() {
  273. return state().then((state) => state.providers)
  274. }
  275. async function getSDK(provider: ModelsDev.Provider) {
  276. return (async () => {
  277. using _ = log.time("getSDK", {
  278. providerID: provider.id,
  279. })
  280. const s = await state()
  281. const existing = s.sdk.get(provider.id)
  282. if (existing) return existing
  283. const pkg = provider.npm ?? provider.id
  284. const mod = await import(await BunProc.install(pkg, "latest"))
  285. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  286. const loaded = fn(s.providers[provider.id]?.options)
  287. s.sdk.set(provider.id, loaded)
  288. return loaded as SDK
  289. })().catch((e) => {
  290. throw new InitError({ providerID: provider.id }, { cause: e })
  291. })
  292. }
  293. export async function getModel(providerID: string, modelID: string) {
  294. const key = `${providerID}/${modelID}`
  295. const s = await state()
  296. if (s.models.has(key)) return s.models.get(key)!
  297. log.info("getModel", {
  298. providerID,
  299. modelID,
  300. })
  301. const provider = s.providers[providerID]
  302. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  303. const info = provider.info.models[modelID]
  304. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  305. const sdk = await getSDK(provider.info)
  306. try {
  307. const language = provider.getModel
  308. ? await provider.getModel(sdk, modelID)
  309. : sdk.languageModel(modelID)
  310. log.info("found", { providerID, modelID })
  311. s.models.set(key, {
  312. info,
  313. language,
  314. })
  315. return {
  316. info,
  317. language,
  318. }
  319. } catch (e) {
  320. if (e instanceof NoSuchModelError)
  321. throw new ModelNotFoundError(
  322. {
  323. modelID: modelID,
  324. providerID,
  325. },
  326. { cause: e },
  327. )
  328. throw e
  329. }
  330. }
  331. const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
  332. export function sort(models: ModelsDev.Model[]) {
  333. return sortBy(
  334. models,
  335. [
  336. (model) => priority.findIndex((filter) => model.id.includes(filter)),
  337. "desc",
  338. ],
  339. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  340. [(model) => model.id, "desc"],
  341. )
  342. }
  343. export async function defaultModel() {
  344. const cfg = await Config.get()
  345. if (cfg.model) return parseModel(cfg.model)
  346. const provider = await list()
  347. .then((val) => Object.values(val))
  348. .then((x) =>
  349. x.find(
  350. (p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id),
  351. ),
  352. )
  353. if (!provider) throw new Error("no providers found")
  354. const [model] = sort(Object.values(provider.info.models))
  355. if (!model) throw new Error("no models found")
  356. return {
  357. providerID: provider.info.id,
  358. modelID: model.id,
  359. }
  360. }
  361. export function parseModel(model: string) {
  362. const [providerID, ...rest] = model.split("/")
  363. return {
  364. providerID: providerID,
  365. modelID: rest.join("/"),
  366. }
  367. }
  368. const TOOLS = [
  369. BashTool,
  370. EditTool,
  371. WebFetchTool,
  372. GlobTool,
  373. GrepTool,
  374. ListTool,
  375. LspDiagnosticTool,
  376. LspHoverTool,
  377. PatchTool,
  378. ReadTool,
  379. EditTool,
  380. // MultiEditTool,
  381. WriteTool,
  382. TodoWriteTool,
  383. TaskTool,
  384. TodoReadTool,
  385. ]
  386. const TOOL_MAPPING: Record<string, Tool.Info[]> = {
  387. anthropic: TOOLS.filter((t) => t.id !== "patch"),
  388. openai: TOOLS.map((t) => ({
  389. ...t,
  390. parameters: optionalToNullable(t.parameters),
  391. })),
  392. azure: TOOLS.map((t) => ({
  393. ...t,
  394. parameters: optionalToNullable(t.parameters),
  395. })),
  396. google: TOOLS,
  397. }
  398. export async function tools(providerID: string) {
  399. /*
  400. const cfg = await Config.get()
  401. if (cfg.tool?.provider?.[providerID])
  402. return cfg.tool.provider[providerID].map(
  403. (id) => TOOLS.find((t) => t.id === id)!,
  404. )
  405. */
  406. return TOOL_MAPPING[providerID] ?? TOOLS
  407. }
  408. function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
  409. if (schema instanceof z.ZodObject) {
  410. const shape = schema.shape
  411. const newShape: Record<string, z.ZodTypeAny> = {}
  412. for (const [key, value] of Object.entries(shape)) {
  413. const zodValue = value as z.ZodTypeAny
  414. if (zodValue instanceof z.ZodOptional) {
  415. newShape[key] = zodValue.unwrap().nullable()
  416. } else {
  417. newShape[key] = optionalToNullable(zodValue)
  418. }
  419. }
  420. return z.object(newShape)
  421. }
  422. if (schema instanceof z.ZodArray) {
  423. return z.array(optionalToNullable(schema.element))
  424. }
  425. if (schema instanceof z.ZodUnion) {
  426. return z.union(
  427. schema.options.map((option: z.ZodTypeAny) =>
  428. optionalToNullable(option),
  429. ) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]],
  430. )
  431. }
  432. return schema
  433. }
  434. export const ModelNotFoundError = NamedError.create(
  435. "ProviderModelNotFoundError",
  436. z.object({
  437. providerID: z.string(),
  438. modelID: z.string(),
  439. }),
  440. )
  441. export const InitError = NamedError.create(
  442. "ProviderInitError",
  443. z.object({
  444. providerID: z.string(),
  445. }),
  446. )
  447. export const AuthError = NamedError.create(
  448. "ProviderAuthError",
  449. z.object({
  450. providerID: z.string(),
  451. message: z.string(),
  452. }),
  453. )
  454. }