provider.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { BashTool } from "../tool/bash"
  9. import { EditTool } from "../tool/edit"
  10. import { WebFetchTool } from "../tool/webfetch"
  11. import { GlobTool } from "../tool/glob"
  12. import { GrepTool } from "../tool/grep"
  13. import { ListTool } from "../tool/ls"
  14. import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
  15. import { LspHoverTool } from "../tool/lsp-hover"
  16. import { PatchTool } from "../tool/patch"
  17. import { ReadTool } from "../tool/read"
  18. import type { Tool } from "../tool/tool"
  19. import { WriteTool } from "../tool/write"
  20. import { TodoReadTool, TodoWriteTool } from "../tool/todo"
  21. import { AuthAnthropic } from "../auth/anthropic"
  22. import { AuthCopilot } from "../auth/copilot"
  23. import { ModelsDev } from "./models"
  24. import { NamedError } from "../util/error"
  25. import { Auth } from "../auth"
  26. import { TaskTool } from "../tool/task"
  27. export namespace Provider {
  28. const log = Log.create({ service: "provider" })
  29. type CustomLoader = (provider: ModelsDev.Provider) => Promise<
  30. | {
  31. getModel?: (sdk: any, modelID: string) => Promise<any>
  32. options: Record<string, any>
  33. }
  34. | false
  35. >
  36. type Source = "env" | "config" | "custom" | "api"
  37. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  38. async anthropic(provider) {
  39. const access = await AuthAnthropic.access()
  40. if (!access) return false
  41. for (const model of Object.values(provider.models)) {
  42. model.cost = {
  43. input: 0,
  44. output: 0,
  45. }
  46. }
  47. return {
  48. options: {
  49. apiKey: "",
  50. async fetch(input: any, init: any) {
  51. const access = await AuthAnthropic.access()
  52. const headers = {
  53. ...init.headers,
  54. authorization: `Bearer ${access}`,
  55. "anthropic-beta": "oauth-2025-04-20",
  56. }
  57. delete headers["x-api-key"]
  58. return fetch(input, {
  59. ...init,
  60. headers,
  61. })
  62. },
  63. },
  64. }
  65. },
  66. "github-copilot": async (provider) => {
  67. const copilot = await AuthCopilot()
  68. if (!copilot) return false
  69. let info = await Auth.get("github-copilot")
  70. if (!info || info.type !== "oauth") return false
  71. if (provider && provider.models) {
  72. for (const model of Object.values(provider.models)) {
  73. model.cost = {
  74. input: 0,
  75. output: 0,
  76. }
  77. }
  78. }
  79. return {
  80. options: {
  81. apiKey: "",
  82. async fetch(input: any, init: any) {
  83. let info = await Auth.get("github-copilot")
  84. if (!info || info.type !== "oauth") return
  85. if (!info.access || info.expires < Date.now()) {
  86. const tokens = await copilot.access(info.refresh)
  87. if (!tokens)
  88. throw new Error("GitHub Copilot authentication expired")
  89. info = {
  90. type: "oauth",
  91. ...tokens,
  92. }
  93. await Auth.set("github-copilot", info)
  94. }
  95. const headers = {
  96. ...init.headers,
  97. ...copilot.HEADERS,
  98. Authorization: `Bearer ${info.access}`,
  99. "Openai-Intent": "conversation-edits",
  100. }
  101. delete headers["x-api-key"]
  102. return fetch(input, {
  103. ...init,
  104. headers,
  105. })
  106. },
  107. },
  108. }
  109. },
  110. openai: async () => {
  111. return {
  112. async getModel(sdk: any, modelID: string) {
  113. return sdk.responses(modelID)
  114. },
  115. options: {},
  116. }
  117. },
  118. "amazon-bedrock": async () => {
  119. if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"])
  120. return false
  121. const region = process.env["AWS_REGION"] ?? "us-east-1"
  122. const { fromNodeProviderChain } = await import(
  123. await BunProc.install("@aws-sdk/credential-providers")
  124. )
  125. return {
  126. options: {
  127. region,
  128. credentialProvider: fromNodeProviderChain(),
  129. },
  130. async getModel(sdk: any, modelID: string) {
  131. if (modelID.includes("claude")) {
  132. const prefix = region.split("-")[0]
  133. modelID = `${prefix}.${modelID}`
  134. }
  135. return sdk.languageModel(modelID)
  136. },
  137. }
  138. },
  139. }
  140. const state = App.state("provider", async () => {
  141. const config = await Config.get()
  142. const database = await ModelsDev.get()
  143. const providers: {
  144. [providerID: string]: {
  145. source: Source
  146. info: ModelsDev.Provider
  147. getModel?: (sdk: any, modelID: string) => Promise<any>
  148. options: Record<string, any>
  149. }
  150. } = {}
  151. const models = new Map<
  152. string,
  153. { info: ModelsDev.Model; language: LanguageModel }
  154. >()
  155. const sdk = new Map<string, SDK>()
  156. log.info("init")
  157. function mergeProvider(
  158. id: string,
  159. options: Record<string, any>,
  160. source: Source,
  161. getModel?: (sdk: any, modelID: string) => Promise<any>,
  162. ) {
  163. const provider = providers[id]
  164. if (!provider) {
  165. const info = database[id]
  166. if (!info) return
  167. if (info.api) options["baseURL"] = info.api
  168. providers[id] = {
  169. source,
  170. info,
  171. options,
  172. }
  173. return
  174. }
  175. provider.options = mergeDeep(provider.options, options)
  176. provider.source = source
  177. provider.getModel = getModel ?? provider.getModel
  178. }
  179. const configProviders = Object.entries(config.provider ?? {})
  180. for (const [providerID, provider] of configProviders) {
  181. const existing = database[providerID]
  182. const parsed: ModelsDev.Provider = {
  183. id: providerID,
  184. npm: provider.npm ?? existing?.npm,
  185. name: provider.name ?? existing?.name ?? providerID,
  186. env: provider.env ?? existing?.env ?? [],
  187. models: existing?.models ?? {},
  188. }
  189. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  190. const existing = parsed.models[modelID]
  191. const parsedModel: ModelsDev.Model = {
  192. id: modelID,
  193. name: model.name ?? existing?.name ?? modelID,
  194. attachment: model.attachment ?? existing?.attachment ?? false,
  195. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  196. temperature: model.temperature ?? existing?.temperature ?? false,
  197. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  198. cost: {
  199. ...existing?.cost,
  200. ...model.cost,
  201. input: 0,
  202. output: 0,
  203. cache_read: 0,
  204. cache_write: 0,
  205. },
  206. options: {
  207. ...existing?.options,
  208. ...model.options,
  209. },
  210. limit: model.limit ??
  211. existing?.limit ?? {
  212. context: 0,
  213. output: 0,
  214. },
  215. }
  216. parsed.models[modelID] = parsedModel
  217. }
  218. database[providerID] = parsed
  219. }
  220. const disabled = await Config.get().then(
  221. (cfg) => new Set(cfg.disabled_providers ?? []),
  222. )
  223. // load env
  224. for (const [providerID, provider] of Object.entries(database)) {
  225. if (disabled.has(providerID)) continue
  226. if (provider.env.some((item) => process.env[item])) {
  227. mergeProvider(providerID, {}, "env")
  228. }
  229. }
  230. // load apikeys
  231. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  232. if (disabled.has(providerID)) continue
  233. if (provider.type === "api") {
  234. mergeProvider(providerID, { apiKey: provider.key }, "api")
  235. }
  236. }
  237. // load custom
  238. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  239. if (disabled.has(providerID)) continue
  240. const result = await fn(database[providerID])
  241. if (result) {
  242. mergeProvider(providerID, result.options, "custom", result.getModel)
  243. }
  244. }
  245. // load config
  246. for (const [providerID, provider] of configProviders) {
  247. mergeProvider(providerID, provider.options ?? {}, "config")
  248. }
  249. for (const [providerID, provider] of Object.entries(providers)) {
  250. if (Object.keys(provider.info.models).length === 0) {
  251. delete providers[providerID]
  252. continue
  253. }
  254. log.info("found", { providerID })
  255. }
  256. return {
  257. models,
  258. providers,
  259. sdk,
  260. }
  261. })
  262. export async function list() {
  263. return state().then((state) => state.providers)
  264. }
  265. async function getSDK(provider: ModelsDev.Provider) {
  266. return (async () => {
  267. using _ = log.time("getSDK", {
  268. providerID: provider.id,
  269. })
  270. const s = await state()
  271. const existing = s.sdk.get(provider.id)
  272. if (existing) return existing
  273. const pkg = provider.npm ?? provider.id
  274. const mod = await import(await BunProc.install(pkg, "latest"))
  275. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  276. const loaded = fn(s.providers[provider.id]?.options)
  277. s.sdk.set(provider.id, loaded)
  278. return loaded as SDK
  279. })().catch((e) => {
  280. throw new InitError({ providerID: provider.id }, { cause: e })
  281. })
  282. }
  283. export async function getModel(providerID: string, modelID: string) {
  284. const key = `${providerID}/${modelID}`
  285. const s = await state()
  286. if (s.models.has(key)) return s.models.get(key)!
  287. log.info("getModel", {
  288. providerID,
  289. modelID,
  290. })
  291. const provider = s.providers[providerID]
  292. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  293. const info = provider.info.models[modelID]
  294. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  295. const sdk = await getSDK(provider.info)
  296. try {
  297. const language = provider.getModel
  298. ? await provider.getModel(sdk, modelID)
  299. : sdk.languageModel(modelID)
  300. log.info("found", { providerID, modelID })
  301. s.models.set(key, {
  302. info,
  303. language,
  304. })
  305. return {
  306. info,
  307. language,
  308. }
  309. } catch (e) {
  310. if (e instanceof NoSuchModelError)
  311. throw new ModelNotFoundError(
  312. {
  313. modelID: modelID,
  314. providerID,
  315. },
  316. { cause: e },
  317. )
  318. throw e
  319. }
  320. }
  321. const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
  322. export function sort(models: ModelsDev.Model[]) {
  323. return sortBy(
  324. models,
  325. [
  326. (model) => priority.findIndex((filter) => model.id.includes(filter)),
  327. "desc",
  328. ],
  329. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  330. [(model) => model.id, "desc"],
  331. )
  332. }
  333. export async function defaultModel() {
  334. const cfg = await Config.get()
  335. if (cfg.model) return parseModel(cfg.model)
  336. const provider = await list()
  337. .then((val) => Object.values(val))
  338. .then((x) =>
  339. x.find(
  340. (p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id),
  341. ),
  342. )
  343. if (!provider) throw new Error("no providers found")
  344. const [model] = sort(Object.values(provider.info.models))
  345. if (!model) throw new Error("no models found")
  346. return {
  347. providerID: provider.info.id,
  348. modelID: model.id,
  349. }
  350. }
  351. export function parseModel(model: string) {
  352. const [providerID, ...rest] = model.split("/")
  353. return {
  354. providerID: providerID,
  355. modelID: rest.join("/"),
  356. }
  357. }
  358. const TOOLS = [
  359. BashTool,
  360. EditTool,
  361. WebFetchTool,
  362. GlobTool,
  363. GrepTool,
  364. ListTool,
  365. LspDiagnosticTool,
  366. LspHoverTool,
  367. PatchTool,
  368. ReadTool,
  369. EditTool,
  370. // MultiEditTool,
  371. WriteTool,
  372. TodoWriteTool,
  373. TaskTool,
  374. TodoReadTool,
  375. ]
  376. const TOOL_MAPPING: Record<string, Tool.Info[]> = {
  377. anthropic: TOOLS.filter((t) => t.id !== "patch"),
  378. openai: TOOLS.map((t) => ({
  379. ...t,
  380. parameters: optionalToNullable(t.parameters),
  381. })),
  382. azure: TOOLS.map((t) => ({
  383. ...t,
  384. parameters: optionalToNullable(t.parameters),
  385. })),
  386. google: TOOLS,
  387. }
  388. export async function tools(providerID: string) {
  389. /*
  390. const cfg = await Config.get()
  391. if (cfg.tool?.provider?.[providerID])
  392. return cfg.tool.provider[providerID].map(
  393. (id) => TOOLS.find((t) => t.id === id)!,
  394. )
  395. */
  396. return TOOL_MAPPING[providerID] ?? TOOLS
  397. }
  398. function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
  399. if (schema instanceof z.ZodObject) {
  400. const shape = schema.shape
  401. const newShape: Record<string, z.ZodTypeAny> = {}
  402. for (const [key, value] of Object.entries(shape)) {
  403. const zodValue = value as z.ZodTypeAny
  404. if (zodValue instanceof z.ZodOptional) {
  405. newShape[key] = zodValue.unwrap().nullable()
  406. } else {
  407. newShape[key] = optionalToNullable(zodValue)
  408. }
  409. }
  410. return z.object(newShape)
  411. }
  412. if (schema instanceof z.ZodArray) {
  413. return z.array(optionalToNullable(schema.element))
  414. }
  415. if (schema instanceof z.ZodUnion) {
  416. return z.union(
  417. schema.options.map((option: z.ZodTypeAny) =>
  418. optionalToNullable(option),
  419. ) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]],
  420. )
  421. }
  422. return schema
  423. }
  424. export const ModelNotFoundError = NamedError.create(
  425. "ProviderModelNotFoundError",
  426. z.object({
  427. providerID: z.string(),
  428. modelID: z.string(),
  429. }),
  430. )
  431. export const InitError = NamedError.create(
  432. "ProviderInitError",
  433. z.object({
  434. providerID: z.string(),
  435. }),
  436. )
  437. export const AuthError = NamedError.create(
  438. "ProviderAuthError",
  439. z.object({
  440. providerID: z.string(),
  441. message: z.string(),
  442. }),
  443. )
  444. }