provider.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { Plugin } from "../plugin"
  9. import { ModelsDev } from "./models"
  10. import { NamedError } from "../util/error"
  11. import { Auth } from "../auth"
  12. export namespace Provider {
  13. const log = Log.create({ service: "provider" })
  14. type CustomLoader = (
  15. provider: ModelsDev.Provider,
  16. api?: string,
  17. ) => Promise<{
  18. autoload: boolean
  19. getModel?: (sdk: any, modelID: string) => Promise<any>
  20. options?: Record<string, any>
  21. }>
  22. type Source = "env" | "config" | "custom" | "api"
  23. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  24. async anthropic() {
  25. return {
  26. autoload: false,
  27. options: {
  28. headers: {
  29. "anthropic-beta":
  30. "claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14",
  31. },
  32. },
  33. }
  34. },
  35. async opencode(input) {
  36. return {
  37. autoload: Object.keys(input.models).length > 0,
  38. options: {},
  39. }
  40. },
  41. openai: async () => {
  42. return {
  43. autoload: false,
  44. async getModel(sdk: any, modelID: string) {
  45. return sdk.responses(modelID)
  46. },
  47. options: {},
  48. }
  49. },
  50. azure: async () => {
  51. return {
  52. autoload: false,
  53. async getModel(sdk: any, modelID: string) {
  54. return sdk.responses(modelID)
  55. },
  56. options: {},
  57. }
  58. },
  59. "amazon-bedrock": async () => {
  60. if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"] && !process.env["AWS_BEARER_TOKEN_BEDROCK"])
  61. return { autoload: false }
  62. const region = process.env["AWS_REGION"] ?? "us-east-1"
  63. const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
  64. return {
  65. autoload: true,
  66. options: {
  67. region,
  68. credentialProvider: fromNodeProviderChain(),
  69. },
  70. async getModel(sdk: any, modelID: string) {
  71. let regionPrefix = region.split("-")[0]
  72. switch (regionPrefix) {
  73. case "us": {
  74. const modelRequiresPrefix = ["claude", "deepseek"].some((m) => modelID.includes(m))
  75. if (modelRequiresPrefix) {
  76. modelID = `${regionPrefix}.${modelID}`
  77. }
  78. break
  79. }
  80. case "eu": {
  81. const regionRequiresPrefix = [
  82. "eu-west-1",
  83. "eu-west-3",
  84. "eu-north-1",
  85. "eu-central-1",
  86. "eu-south-1",
  87. "eu-south-2",
  88. ].some((r) => region.includes(r))
  89. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
  90. modelID.includes(m),
  91. )
  92. if (regionRequiresPrefix && modelRequiresPrefix) {
  93. modelID = `${regionPrefix}.${modelID}`
  94. }
  95. break
  96. }
  97. case "ap": {
  98. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
  99. modelID.includes(m),
  100. )
  101. if (modelRequiresPrefix) {
  102. regionPrefix = "apac"
  103. modelID = `${regionPrefix}.${modelID}`
  104. }
  105. break
  106. }
  107. }
  108. return sdk.languageModel(modelID)
  109. },
  110. }
  111. },
  112. openrouter: async () => {
  113. return {
  114. autoload: false,
  115. options: {
  116. headers: {
  117. "HTTP-Referer": "https://opencode.ai/",
  118. "X-Title": "opencode",
  119. },
  120. },
  121. }
  122. },
  123. vercel: async () => {
  124. return {
  125. autoload: false,
  126. options: {
  127. headers: {
  128. "http-referer": "https://opencode.ai/",
  129. "x-title": "opencode",
  130. },
  131. },
  132. }
  133. },
  134. }
  135. const state = App.state("provider", async () => {
  136. const config = await Config.get()
  137. const database = await ModelsDev.get()
  138. const providers: {
  139. [providerID: string]: {
  140. source: Source
  141. info: ModelsDev.Provider
  142. getModel?: (sdk: any, modelID: string) => Promise<any>
  143. options: Record<string, any>
  144. }
  145. } = {}
  146. const models = new Map<string, { info: ModelsDev.Model; language: LanguageModel }>()
  147. const sdk = new Map<string, SDK>()
  148. log.info("init")
  149. function mergeProvider(
  150. id: string,
  151. options: Record<string, any>,
  152. source: Source,
  153. getModel?: (sdk: any, modelID: string) => Promise<any>,
  154. ) {
  155. const provider = providers[id]
  156. if (!provider) {
  157. const info = database[id]
  158. if (!info) return
  159. if (info.api && !options["baseURL"]) options["baseURL"] = info.api
  160. providers[id] = {
  161. source,
  162. info,
  163. options,
  164. getModel,
  165. }
  166. return
  167. }
  168. provider.options = mergeDeep(provider.options, options)
  169. provider.source = source
  170. provider.getModel = getModel ?? provider.getModel
  171. }
  172. const configProviders = Object.entries(config.provider ?? {})
  173. for (const [providerID, provider] of configProviders) {
  174. const existing = database[providerID]
  175. const parsed: ModelsDev.Provider = {
  176. id: providerID,
  177. npm: provider.npm ?? existing?.npm,
  178. name: provider.name ?? existing?.name ?? providerID,
  179. env: provider.env ?? existing?.env ?? [],
  180. api: provider.api ?? existing?.api,
  181. models: existing?.models ?? {},
  182. }
  183. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  184. const existing = parsed.models[modelID]
  185. const parsedModel: ModelsDev.Model = {
  186. id: modelID,
  187. name: model.name ?? existing?.name ?? modelID,
  188. release_date: model.release_date ?? existing?.release_date,
  189. attachment: model.attachment ?? existing?.attachment ?? false,
  190. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  191. temperature: model.temperature ?? existing?.temperature ?? false,
  192. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  193. cost:
  194. !model.cost && !existing?.cost
  195. ? {
  196. input: 0,
  197. output: 0,
  198. cache_read: 0,
  199. cache_write: 0,
  200. }
  201. : {
  202. cache_read: 0,
  203. cache_write: 0,
  204. ...existing?.cost,
  205. ...model.cost,
  206. },
  207. options: {
  208. ...existing?.options,
  209. ...model.options,
  210. },
  211. limit: model.limit ??
  212. existing?.limit ?? {
  213. context: 0,
  214. output: 0,
  215. },
  216. }
  217. parsed.models[modelID] = parsedModel
  218. }
  219. database[providerID] = parsed
  220. }
  221. const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
  222. // load env
  223. for (const [providerID, provider] of Object.entries(database)) {
  224. if (disabled.has(providerID)) continue
  225. const apiKey = provider.env.map((item) => process.env[item]).at(0)
  226. if (!apiKey) continue
  227. mergeProvider(
  228. providerID,
  229. // only include apiKey if there's only one potential option
  230. provider.env.length === 1 ? { apiKey } : {},
  231. "env",
  232. )
  233. }
  234. // load apikeys
  235. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  236. if (disabled.has(providerID)) continue
  237. if (provider.type === "api") {
  238. mergeProvider(providerID, { apiKey: provider.key }, "api")
  239. }
  240. }
  241. // load custom
  242. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  243. if (disabled.has(providerID)) continue
  244. const result = await fn(database[providerID])
  245. if (result && (result.autoload || providers[providerID])) {
  246. mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
  247. }
  248. }
  249. for (const plugin of await Plugin.list()) {
  250. if (!plugin.auth) continue
  251. const providerID = plugin.auth.provider
  252. if (disabled.has(providerID)) continue
  253. const auth = await Auth.get(providerID)
  254. if (!auth) continue
  255. if (!plugin.auth.loader) continue
  256. const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider])
  257. mergeProvider(plugin.auth.provider, options ?? {}, "custom")
  258. }
  259. // load config
  260. for (const [providerID, provider] of configProviders) {
  261. mergeProvider(providerID, provider.options ?? {}, "config")
  262. }
  263. for (const [providerID, provider] of Object.entries(providers)) {
  264. // Filter out blacklisted models
  265. const filteredModels = Object.fromEntries(
  266. Object.entries(provider.info.models).filter(
  267. ([modelID]) =>
  268. modelID !== "gpt-5-chat-latest" && !(providerID === "openrouter" && modelID === "openai/gpt-5-chat"),
  269. ),
  270. )
  271. provider.info.models = filteredModels
  272. if (Object.keys(provider.info.models).length === 0) {
  273. delete providers[providerID]
  274. continue
  275. }
  276. log.info("found", { providerID })
  277. }
  278. return {
  279. models,
  280. providers,
  281. sdk,
  282. }
  283. })
  284. export async function list() {
  285. return state().then((state) => state.providers)
  286. }
  287. async function getSDK(provider: ModelsDev.Provider) {
  288. return (async () => {
  289. using _ = log.time("getSDK", {
  290. providerID: provider.id,
  291. })
  292. const s = await state()
  293. const existing = s.sdk.get(provider.id)
  294. if (existing) return existing
  295. const pkg = provider.npm ?? provider.id
  296. const mod = await import(await BunProc.install(pkg, "latest"))
  297. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  298. let options = { ...s.providers[provider.id]?.options }
  299. if (options["timeout"] !== undefined) {
  300. // Only override fetch if user explicitly sets timeout
  301. options["fetch"] = async (input: any, init?: any) => {
  302. return await fetch(input, { ...init, timeout: options["timeout"] })
  303. }
  304. }
  305. const loaded = fn({
  306. name: provider.id,
  307. ...options,
  308. })
  309. s.sdk.set(provider.id, loaded)
  310. return loaded as SDK
  311. })().catch((e) => {
  312. throw new InitError({ providerID: provider.id }, { cause: e })
  313. })
  314. }
  315. export async function getProvider(providerID: string) {
  316. return state().then((s) => s.providers[providerID])
  317. }
  318. export async function getModel(providerID: string, modelID: string) {
  319. const key = `${providerID}/${modelID}`
  320. const s = await state()
  321. if (s.models.has(key)) return s.models.get(key)!
  322. log.info("getModel", {
  323. providerID,
  324. modelID,
  325. })
  326. const provider = s.providers[providerID]
  327. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  328. const info = provider.info.models[modelID]
  329. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  330. const sdk = await getSDK(provider.info)
  331. try {
  332. const language = provider.getModel ? await provider.getModel(sdk, modelID) : sdk.languageModel(modelID)
  333. log.info("found", { providerID, modelID })
  334. s.models.set(key, {
  335. info,
  336. language,
  337. })
  338. return {
  339. info,
  340. language,
  341. }
  342. } catch (e) {
  343. if (e instanceof NoSuchModelError)
  344. throw new ModelNotFoundError(
  345. {
  346. modelID: modelID,
  347. providerID,
  348. },
  349. { cause: e },
  350. )
  351. throw e
  352. }
  353. }
  354. export async function getSmallModel(providerID: string) {
  355. const cfg = await Config.get()
  356. if (cfg.small_model) {
  357. const parsed = parseModel(cfg.small_model)
  358. return getModel(parsed.providerID, parsed.modelID)
  359. }
  360. const provider = await state().then((state) => state.providers[providerID])
  361. if (!provider) return
  362. const priority = ["3-5-haiku", "3.5-haiku", "gemini-2.5-flash", "gpt-5-nano"]
  363. for (const item of priority) {
  364. for (const model of Object.keys(provider.info.models)) {
  365. if (model.includes(item)) return getModel(providerID, model)
  366. }
  367. }
  368. }
  369. const priority = ["gemini-2.5-pro-preview", "gpt-5", "claude-sonnet-4"]
  370. export function sort(models: ModelsDev.Model[]) {
  371. return sortBy(
  372. models,
  373. [(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
  374. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  375. [(model) => model.id, "desc"],
  376. )
  377. }
  378. export async function defaultModel() {
  379. const cfg = await Config.get()
  380. if (cfg.model) return parseModel(cfg.model)
  381. const provider = await list()
  382. .then((val) => Object.values(val))
  383. .then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
  384. if (!provider) throw new Error("no providers found")
  385. const [model] = sort(Object.values(provider.info.models))
  386. if (!model) throw new Error("no models found")
  387. return {
  388. providerID: provider.info.id,
  389. modelID: model.id,
  390. }
  391. }
  392. export function parseModel(model: string) {
  393. const [providerID, ...rest] = model.split("/")
  394. return {
  395. providerID: providerID,
  396. modelID: rest.join("/"),
  397. }
  398. }
  399. export const ModelNotFoundError = NamedError.create(
  400. "ProviderModelNotFoundError",
  401. z.object({
  402. providerID: z.string(),
  403. modelID: z.string(),
  404. }),
  405. )
  406. export const InitError = NamedError.create(
  407. "ProviderInitError",
  408. z.object({
  409. providerID: z.string(),
  410. }),
  411. )
  412. }