provider.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { AuthAnthropic } from "../auth/anthropic"
  9. import { AuthCopilot } from "../auth/copilot"
  10. import { ModelsDev } from "./models"
  11. import { NamedError } from "../util/error"
  12. import { Auth } from "../auth"
  13. export namespace Provider {
  14. const log = Log.create({ service: "provider" })
  15. type CustomLoader = (
  16. provider: ModelsDev.Provider,
  17. api?: string,
  18. ) => Promise<{
  19. autoload: boolean
  20. getModel?: (sdk: any, modelID: string) => Promise<any>
  21. options?: Record<string, any>
  22. }>
  23. type Source = "env" | "config" | "custom" | "api"
  24. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  25. async anthropic(provider) {
  26. const access = await AuthAnthropic.access()
  27. if (!access) return { autoload: false }
  28. for (const model of Object.values(provider.models)) {
  29. model.cost = {
  30. input: 0,
  31. output: 0,
  32. }
  33. }
  34. return {
  35. autoload: true,
  36. options: {
  37. apiKey: "",
  38. async fetch(input: any, init: any) {
  39. const access = await AuthAnthropic.access()
  40. const headers = {
  41. ...init.headers,
  42. authorization: `Bearer ${access}`,
  43. "anthropic-beta": "oauth-2025-04-20",
  44. }
  45. delete headers["x-api-key"]
  46. return fetch(input, {
  47. ...init,
  48. headers,
  49. })
  50. },
  51. },
  52. }
  53. },
  54. "github-copilot": async (provider) => {
  55. const copilot = await AuthCopilot()
  56. if (!copilot) return { autoload: false }
  57. let info = await Auth.get("github-copilot")
  58. if (!info || info.type !== "oauth") return { autoload: false }
  59. if (provider && provider.models) {
  60. for (const model of Object.values(provider.models)) {
  61. model.cost = {
  62. input: 0,
  63. output: 0,
  64. }
  65. }
  66. }
  67. return {
  68. autoload: true,
  69. options: {
  70. apiKey: "",
  71. async fetch(input: any, init: any) {
  72. const info = await Auth.get("github-copilot")
  73. if (!info || info.type !== "oauth") return
  74. if (!info.access || info.expires < Date.now()) {
  75. const tokens = await copilot.access(info.refresh)
  76. if (!tokens) throw new Error("GitHub Copilot authentication expired")
  77. await Auth.set("github-copilot", {
  78. type: "oauth",
  79. ...tokens,
  80. })
  81. info.access = tokens.access
  82. }
  83. let isAgentCall = false
  84. let isVisionRequest = false
  85. try {
  86. const body = typeof init.body === "string" ? JSON.parse(init.body) : init.body
  87. if (body?.messages) {
  88. isAgentCall = body.messages.some((msg: any) => msg.role && ["tool", "assistant"].includes(msg.role))
  89. isVisionRequest = body.messages.some(
  90. (msg: any) =>
  91. Array.isArray(msg.content) && msg.content.some((part: any) => part.type === "image_url"),
  92. )
  93. }
  94. } catch { }
  95. const headers: Record<string, string> = {
  96. ...init.headers,
  97. ...copilot.HEADERS,
  98. Authorization: `Bearer ${info.access}`,
  99. "Openai-Intent": "conversation-edits",
  100. "X-Initiator": isAgentCall ? "agent" : "user",
  101. }
  102. if (isVisionRequest) {
  103. headers["Copilot-Vision-Request"] = "true"
  104. }
  105. delete headers["x-api-key"]
  106. return fetch(input, {
  107. ...init,
  108. headers,
  109. })
  110. },
  111. },
  112. }
  113. },
  114. openai: async () => {
  115. return {
  116. autoload: false,
  117. async getModel(sdk: any, modelID: string) {
  118. return sdk.responses(modelID)
  119. },
  120. options: {},
  121. }
  122. },
  123. azure: async () => {
  124. return {
  125. autoload: false,
  126. async getModel(sdk: any, modelID: string) {
  127. return sdk.responses(modelID)
  128. },
  129. options: {},
  130. }
  131. },
  132. "amazon-bedrock": async () => {
  133. if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"] && !process.env["AWS_BEARER_TOKEN_BEDROCK"])
  134. return { autoload: false }
  135. const region = process.env["AWS_REGION"] ?? "us-east-1"
  136. const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
  137. return {
  138. autoload: true,
  139. options: {
  140. region,
  141. credentialProvider: fromNodeProviderChain(),
  142. },
  143. async getModel(sdk: any, modelID: string) {
  144. let regionPrefix = region.split("-")[0]
  145. switch (regionPrefix) {
  146. case "us": {
  147. const modelRequiresPrefix = ["claude", "deepseek"].some((m) => modelID.includes(m))
  148. if (modelRequiresPrefix) {
  149. modelID = `${regionPrefix}.${modelID}`
  150. }
  151. break
  152. }
  153. case "eu": {
  154. const regionRequiresPrefix = [
  155. "eu-west-1",
  156. "eu-west-3",
  157. "eu-north-1",
  158. "eu-central-1",
  159. "eu-south-1",
  160. "eu-south-2",
  161. ].some((r) => region.includes(r))
  162. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
  163. modelID.includes(m),
  164. )
  165. if (regionRequiresPrefix && modelRequiresPrefix) {
  166. modelID = `${regionPrefix}.${modelID}`
  167. }
  168. break
  169. }
  170. case "ap": {
  171. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
  172. modelID.includes(m),
  173. )
  174. if (modelRequiresPrefix) {
  175. regionPrefix = "apac"
  176. modelID = `${regionPrefix}.${modelID}`
  177. }
  178. break
  179. }
  180. }
  181. return sdk.languageModel(modelID)
  182. },
  183. }
  184. },
  185. openrouter: async () => {
  186. return {
  187. autoload: false,
  188. options: {
  189. headers: {
  190. "HTTP-Referer": "https://opencode.ai/",
  191. "X-Title": "opencode",
  192. },
  193. },
  194. }
  195. },
  196. vercel: async () => {
  197. return {
  198. autoload: false,
  199. options: {
  200. headers: {
  201. "http-referer": "https://opencode.ai/",
  202. "x-title": "opencode",
  203. },
  204. },
  205. }
  206. },
  207. }
  208. const state = App.state("provider", async () => {
  209. const config = await Config.get()
  210. const database = await ModelsDev.get()
  211. const providers: {
  212. [providerID: string]: {
  213. source: Source
  214. info: ModelsDev.Provider
  215. getModel?: (sdk: any, modelID: string) => Promise<any>
  216. options: Record<string, any>
  217. }
  218. } = {}
  219. const models = new Map<string, { info: ModelsDev.Model; language: LanguageModel }>()
  220. const sdk = new Map<string, SDK>()
  221. log.info("init")
  222. function mergeProvider(
  223. id: string,
  224. options: Record<string, any>,
  225. source: Source,
  226. getModel?: (sdk: any, modelID: string) => Promise<any>,
  227. ) {
  228. const provider = providers[id]
  229. if (!provider) {
  230. const info = database[id]
  231. if (!info) return
  232. if (info.api && !options["baseURL"]) options["baseURL"] = info.api
  233. providers[id] = {
  234. source,
  235. info,
  236. options,
  237. getModel,
  238. }
  239. return
  240. }
  241. provider.options = mergeDeep(provider.options, options)
  242. provider.source = source
  243. provider.getModel = getModel ?? provider.getModel
  244. }
  245. const configProviders = Object.entries(config.provider ?? {})
  246. for (const [providerID, provider] of configProviders) {
  247. const existing = database[providerID]
  248. const parsed: ModelsDev.Provider = {
  249. id: providerID,
  250. npm: provider.npm ?? existing?.npm,
  251. name: provider.name ?? existing?.name ?? providerID,
  252. env: provider.env ?? existing?.env ?? [],
  253. api: provider.api ?? existing?.api,
  254. models: existing?.models ?? {},
  255. }
  256. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  257. const existing = parsed.models[modelID]
  258. const parsedModel: ModelsDev.Model = {
  259. id: modelID,
  260. name: model.name ?? existing?.name ?? modelID,
  261. release_date: model.release_date ?? existing?.release_date,
  262. attachment: model.attachment ?? existing?.attachment ?? false,
  263. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  264. temperature: model.temperature ?? existing?.temperature ?? false,
  265. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  266. cost:
  267. !model.cost && !existing?.cost
  268. ? {
  269. input: 0,
  270. output: 0,
  271. cache_read: 0,
  272. cache_write: 0,
  273. }
  274. : {
  275. cache_read: 0,
  276. cache_write: 0,
  277. ...existing?.cost,
  278. ...model.cost,
  279. },
  280. options: {
  281. ...existing?.options,
  282. ...model.options,
  283. },
  284. limit: model.limit ??
  285. existing?.limit ?? {
  286. context: 0,
  287. output: 0,
  288. },
  289. }
  290. parsed.models[modelID] = parsedModel
  291. }
  292. database[providerID] = parsed
  293. }
  294. const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
  295. // load env
  296. for (const [providerID, provider] of Object.entries(database)) {
  297. if (disabled.has(providerID)) continue
  298. const apiKey = provider.env.map((item) => process.env[item]).at(0)
  299. if (!apiKey) continue
  300. mergeProvider(
  301. providerID,
  302. // only include apiKey if there's only one potential option
  303. provider.env.length === 1 ? { apiKey } : {},
  304. "env",
  305. )
  306. }
  307. // load apikeys
  308. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  309. if (disabled.has(providerID)) continue
  310. if (provider.type === "api") {
  311. mergeProvider(providerID, { apiKey: provider.key }, "api")
  312. }
  313. }
  314. // load custom
  315. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  316. if (disabled.has(providerID)) continue
  317. const result = await fn(database[providerID])
  318. if (result && (result.autoload || providers[providerID])) {
  319. mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
  320. }
  321. }
  322. // load config
  323. for (const [providerID, provider] of configProviders) {
  324. mergeProvider(providerID, provider.options ?? {}, "config")
  325. }
  326. for (const [providerID, provider] of Object.entries(providers)) {
  327. if (Object.keys(provider.info.models).length === 0) {
  328. delete providers[providerID]
  329. continue
  330. }
  331. log.info("found", { providerID })
  332. }
  333. return {
  334. models,
  335. providers,
  336. sdk,
  337. }
  338. })
  339. export async function list() {
  340. return state().then((state) => state.providers)
  341. }
  342. async function getSDK(provider: ModelsDev.Provider) {
  343. return (async () => {
  344. using _ = log.time("getSDK", {
  345. providerID: provider.id,
  346. })
  347. const s = await state()
  348. const existing = s.sdk.get(provider.id)
  349. if (existing) return existing
  350. const pkg = provider.npm ?? provider.id
  351. const mod = await import(await BunProc.install(pkg, "beta"))
  352. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  353. const loaded = fn({
  354. name: provider.id,
  355. ...s.providers[provider.id]?.options,
  356. })
  357. s.sdk.set(provider.id, loaded)
  358. return loaded as SDK
  359. })().catch((e) => {
  360. throw new InitError({ providerID: provider.id }, { cause: e })
  361. })
  362. }
  363. export async function getProvider(providerID: string) {
  364. return state().then((s) => s.providers[providerID])
  365. }
  366. export async function getModel(providerID: string, modelID: string) {
  367. const key = `${providerID}/${modelID}`
  368. const s = await state()
  369. if (s.models.has(key)) return s.models.get(key)!
  370. log.info("getModel", {
  371. providerID,
  372. modelID,
  373. })
  374. const provider = s.providers[providerID]
  375. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  376. const info = provider.info.models[modelID]
  377. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  378. const sdk = await getSDK(provider.info)
  379. try {
  380. const language = provider.getModel ? await provider.getModel(sdk, modelID) : sdk.languageModel(modelID)
  381. log.info("found", { providerID, modelID })
  382. s.models.set(key, {
  383. info,
  384. language,
  385. })
  386. return {
  387. info,
  388. language,
  389. }
  390. } catch (e) {
  391. if (e instanceof NoSuchModelError)
  392. throw new ModelNotFoundError(
  393. {
  394. modelID: modelID,
  395. providerID,
  396. },
  397. { cause: e },
  398. )
  399. throw e
  400. }
  401. }
  402. export async function getSmallModel(providerID: string) {
  403. const cfg = await Config.get()
  404. if (cfg.small_model) {
  405. const parsed = parseModel(cfg.small_model)
  406. return getModel(parsed.providerID, parsed.modelID)
  407. }
  408. const provider = await state().then((state) => state.providers[providerID])
  409. if (!provider) return
  410. const priority = ["3-5-haiku", "3.5-haiku", "gemini-2.5-flash"]
  411. for (const item of priority) {
  412. for (const model of Object.keys(provider.info.models)) {
  413. if (model.includes(item)) return getModel(providerID, model)
  414. }
  415. }
  416. }
  417. const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
  418. export function sort(models: ModelsDev.Model[]) {
  419. return sortBy(
  420. models,
  421. [(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
  422. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  423. [(model) => model.id, "desc"],
  424. )
  425. }
  426. export async function defaultModel() {
  427. const cfg = await Config.get()
  428. if (cfg.model) return parseModel(cfg.model)
  429. const provider = await list()
  430. .then((val) => Object.values(val))
  431. .then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
  432. if (!provider) throw new Error("no providers found")
  433. const [model] = sort(Object.values(provider.info.models))
  434. if (!model) throw new Error("no models found")
  435. return {
  436. providerID: provider.info.id,
  437. modelID: model.id,
  438. }
  439. }
  440. export function parseModel(model: string) {
  441. const [providerID, ...rest] = model.split("/")
  442. return {
  443. providerID: providerID,
  444. modelID: rest.join("/"),
  445. }
  446. }
  447. export const ModelNotFoundError = NamedError.create(
  448. "ProviderModelNotFoundError",
  449. z.object({
  450. providerID: z.string(),
  451. modelID: z.string(),
  452. }),
  453. )
  454. export const InitError = NamedError.create(
  455. "ProviderInitError",
  456. z.object({
  457. providerID: z.string(),
  458. }),
  459. )
  460. }