provider.ts 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. import z from "zod/v4"
  2. import path from "path"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { Plugin } from "../plugin"
  9. import { ModelsDev } from "./models"
  10. import { NamedError } from "../util/error"
  11. import { Auth } from "../auth"
  12. import { Instance } from "../project/instance"
  13. import { Global } from "../global"
  14. import { Flag } from "../flag/flag"
  15. export namespace Provider {
  16. const log = Log.create({ service: "provider" })
  17. type CustomLoader = (provider: ModelsDev.Provider) => Promise<{
  18. autoload: boolean
  19. getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
  20. options?: Record<string, any>
  21. }>
  22. type Source = "env" | "config" | "custom" | "api"
  23. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  24. async anthropic() {
  25. return {
  26. autoload: false,
  27. options: {
  28. headers: {
  29. "anthropic-beta":
  30. "claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14",
  31. },
  32. },
  33. }
  34. },
  35. async opencode(input) {
  36. const hasKey = await (async () => {
  37. if (input.env.some((item) => process.env[item])) return true
  38. if (await Auth.get(input.id)) return true
  39. return false
  40. })()
  41. if (!hasKey) {
  42. for (const [key, value] of Object.entries(input.models)) {
  43. if (value.cost.input === 0) continue
  44. delete input.models[key]
  45. }
  46. }
  47. return {
  48. autoload: Object.keys(input.models).length > 0,
  49. options: {},
  50. }
  51. },
  52. openai: async () => {
  53. return {
  54. autoload: false,
  55. async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
  56. return sdk.responses(modelID)
  57. },
  58. options: {},
  59. }
  60. },
  61. azure: async () => {
  62. return {
  63. autoload: false,
  64. async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
  65. if (options?.["useCompletionUrls"]) {
  66. return sdk.chat(modelID)
  67. } else {
  68. return sdk.responses(modelID)
  69. }
  70. },
  71. options: {},
  72. }
  73. },
  74. "amazon-bedrock": async () => {
  75. if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"] && !process.env["AWS_BEARER_TOKEN_BEDROCK"])
  76. return { autoload: false }
  77. const region = process.env["AWS_REGION"] ?? "us-east-1"
  78. const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
  79. return {
  80. autoload: true,
  81. options: {
  82. region,
  83. credentialProvider: fromNodeProviderChain(),
  84. },
  85. async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
  86. let regionPrefix = region.split("-")[0]
  87. switch (regionPrefix) {
  88. case "us": {
  89. const modelRequiresPrefix = [
  90. "nova-micro",
  91. "nova-lite",
  92. "nova-pro",
  93. "nova-premier",
  94. "claude",
  95. "deepseek"
  96. ].some((m) => modelID.includes(m))
  97. const isGovCloud = region.startsWith("us-gov")
  98. if (modelRequiresPrefix && !isGovCloud) {
  99. modelID = `${regionPrefix}.${modelID}`
  100. }
  101. break
  102. }
  103. case "eu": {
  104. const regionRequiresPrefix = [
  105. "eu-west-1",
  106. "eu-west-3",
  107. "eu-north-1",
  108. "eu-central-1",
  109. "eu-south-1",
  110. "eu-south-2",
  111. ].some((r) => region.includes(r))
  112. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
  113. modelID.includes(m),
  114. )
  115. if (regionRequiresPrefix && modelRequiresPrefix) {
  116. modelID = `${regionPrefix}.${modelID}`
  117. }
  118. break
  119. }
  120. case "ap": {
  121. const isAustraliaRegion = ["ap-southeast-2", "ap-southeast-4"].includes(region)
  122. if (
  123. isAustraliaRegion &&
  124. ["anthropic.claude-sonnet-4-5", "anthropic.claude-haiku"].some((m) => modelID.includes(m))
  125. ) {
  126. regionPrefix = "au"
  127. modelID = `${regionPrefix}.${modelID}`
  128. } else {
  129. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
  130. modelID.includes(m),
  131. )
  132. if (modelRequiresPrefix) {
  133. regionPrefix = "apac"
  134. modelID = `${regionPrefix}.${modelID}`
  135. }
  136. }
  137. break
  138. }
  139. }
  140. return sdk.languageModel(modelID)
  141. },
  142. }
  143. },
  144. openrouter: async () => {
  145. return {
  146. autoload: false,
  147. options: {
  148. headers: {
  149. "HTTP-Referer": "https://opencode.ai/",
  150. "X-Title": "opencode",
  151. },
  152. },
  153. }
  154. },
  155. vercel: async () => {
  156. return {
  157. autoload: false,
  158. options: {
  159. headers: {
  160. "http-referer": "https://opencode.ai/",
  161. "x-title": "opencode",
  162. },
  163. },
  164. }
  165. },
  166. "google-vertex": async () => {
  167. const project = process.env["GOOGLE_CLOUD_PROJECT"] ?? process.env["GCP_PROJECT"] ?? process.env["GCLOUD_PROJECT"]
  168. const location = process.env["GOOGLE_CLOUD_LOCATION"] ?? process.env["VERTEX_LOCATION"] ?? "us-east5"
  169. const autoload = Boolean(project)
  170. if (!autoload) return { autoload: false }
  171. return {
  172. autoload: true,
  173. options: {
  174. project,
  175. location,
  176. },
  177. async getModel(sdk: any, modelID: string) {
  178. const id = String(modelID).trim()
  179. return sdk.languageModel(id)
  180. },
  181. }
  182. },
  183. "google-vertex-anthropic": async () => {
  184. const project = process.env["GOOGLE_CLOUD_PROJECT"] ?? process.env["GCP_PROJECT"] ?? process.env["GCLOUD_PROJECT"]
  185. const location = process.env["GOOGLE_CLOUD_LOCATION"] ?? process.env["VERTEX_LOCATION"] ?? "us-east5"
  186. const autoload = Boolean(project)
  187. if (!autoload) return { autoload: false }
  188. return {
  189. autoload: true,
  190. options: {
  191. project,
  192. location,
  193. },
  194. async getModel(sdk: any, modelID: string) {
  195. const id = String(modelID).trim()
  196. return sdk.languageModel(id)
  197. },
  198. }
  199. },
  200. }
  201. const state = Instance.state(async () => {
  202. const config = await Config.get()
  203. const database = await ModelsDev.get()
  204. const providers: {
  205. [providerID: string]: {
  206. source: Source
  207. info: ModelsDev.Provider
  208. getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
  209. options: Record<string, any>
  210. }
  211. } = {}
  212. const models = new Map<
  213. string,
  214. { providerID: string; modelID: string; info: ModelsDev.Model; language: LanguageModel; npm?: string }
  215. >()
  216. const sdk = new Map<number, SDK>()
  217. // Maps `${provider}/${key}` to the provider’s actual model ID for custom aliases.
  218. const realIdByKey = new Map<string, string>()
  219. log.info("init")
  220. function mergeProvider(
  221. id: string,
  222. options: Record<string, any>,
  223. source: Source,
  224. getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>,
  225. ) {
  226. const provider = providers[id]
  227. if (!provider) {
  228. const info = database[id]
  229. if (!info) return
  230. if (info.api && !options["baseURL"]) options["baseURL"] = info.api
  231. providers[id] = {
  232. source,
  233. info,
  234. options,
  235. getModel,
  236. }
  237. return
  238. }
  239. provider.options = mergeDeep(provider.options, options)
  240. provider.source = source
  241. provider.getModel = getModel ?? provider.getModel
  242. }
  243. const configProviders = Object.entries(config.provider ?? {})
  244. for (const [providerID, provider] of configProviders) {
  245. const existing = database[providerID]
  246. const parsed: ModelsDev.Provider = {
  247. id: providerID,
  248. npm: provider.npm ?? existing?.npm,
  249. name: provider.name ?? existing?.name ?? providerID,
  250. env: provider.env ?? existing?.env ?? [],
  251. api: provider.api ?? existing?.api,
  252. models: existing?.models ?? {},
  253. }
  254. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  255. const existing = parsed.models[modelID]
  256. const parsedModel: ModelsDev.Model = {
  257. id: modelID,
  258. name: model.name ?? existing?.name ?? modelID,
  259. release_date: model.release_date ?? existing?.release_date,
  260. attachment: model.attachment ?? existing?.attachment ?? false,
  261. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  262. temperature: model.temperature ?? existing?.temperature ?? false,
  263. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  264. cost:
  265. !model.cost && !existing?.cost
  266. ? {
  267. input: 0,
  268. output: 0,
  269. cache_read: 0,
  270. cache_write: 0,
  271. }
  272. : {
  273. cache_read: 0,
  274. cache_write: 0,
  275. ...existing?.cost,
  276. ...model.cost,
  277. },
  278. options: {
  279. ...existing?.options,
  280. ...model.options,
  281. },
  282. limit: model.limit ??
  283. existing?.limit ?? {
  284. context: 0,
  285. output: 0,
  286. },
  287. modalities: model.modalities ??
  288. existing?.modalities ?? {
  289. input: ["text"],
  290. output: ["text"],
  291. },
  292. provider: model.provider ?? existing?.provider,
  293. }
  294. if (model.id && model.id !== modelID) {
  295. realIdByKey.set(`${providerID}/${modelID}`, model.id)
  296. }
  297. parsed.models[modelID] = parsedModel
  298. }
  299. database[providerID] = parsed
  300. }
  301. const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
  302. // load env
  303. for (const [providerID, provider] of Object.entries(database)) {
  304. if (disabled.has(providerID)) continue
  305. const apiKey = provider.env.map((item) => process.env[item]).at(0)
  306. if (!apiKey) continue
  307. mergeProvider(
  308. providerID,
  309. // only include apiKey if there's only one potential option
  310. provider.env.length === 1 ? { apiKey } : {},
  311. "env",
  312. )
  313. }
  314. // load apikeys
  315. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  316. if (disabled.has(providerID)) continue
  317. if (provider.type === "api") {
  318. mergeProvider(providerID, { apiKey: provider.key }, "api")
  319. }
  320. }
  321. // load custom
  322. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  323. if (disabled.has(providerID)) continue
  324. const result = await fn(database[providerID])
  325. if (result && (result.autoload || providers[providerID])) {
  326. mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
  327. }
  328. }
  329. for (const plugin of await Plugin.list()) {
  330. if (!plugin.auth) continue
  331. const providerID = plugin.auth.provider
  332. if (disabled.has(providerID)) continue
  333. const auth = await Auth.get(providerID)
  334. if (!auth) continue
  335. if (!plugin.auth.loader) continue
  336. const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider])
  337. mergeProvider(plugin.auth.provider, options ?? {}, "custom")
  338. }
  339. // load config
  340. for (const [providerID, provider] of configProviders) {
  341. mergeProvider(providerID, provider.options ?? {}, "config")
  342. }
  343. for (const [providerID, provider] of Object.entries(providers)) {
  344. const filteredModels = Object.fromEntries(
  345. Object.entries(provider.info.models)
  346. // Filter out blacklisted models
  347. .filter(
  348. ([modelID]) =>
  349. modelID !== "gpt-5-chat-latest" && !(providerID === "openrouter" && modelID === "openai/gpt-5-chat"),
  350. )
  351. // Filter out experimental models
  352. .filter(
  353. ([, model]) =>
  354. (!model.experimental && model.status !== "alpha") || Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS,
  355. ),
  356. )
  357. provider.info.models = filteredModels
  358. if (Object.keys(provider.info.models).length === 0) {
  359. delete providers[providerID]
  360. continue
  361. }
  362. log.info("found", { providerID })
  363. }
  364. return {
  365. models,
  366. providers,
  367. sdk,
  368. realIdByKey,
  369. }
  370. })
  371. export async function list() {
  372. return state().then((state) => state.providers)
  373. }
  374. async function getSDK(provider: ModelsDev.Provider, model: ModelsDev.Model) {
  375. return (async () => {
  376. using _ = log.time("getSDK", {
  377. providerID: provider.id,
  378. })
  379. const s = await state()
  380. const pkg = model.provider?.npm ?? provider.npm ?? provider.id
  381. const options = { ...s.providers[provider.id]?.options }
  382. if (pkg.includes("@ai-sdk/openai-compatible") && options["includeUsage"] === undefined) {
  383. options["includeUsage"] = true
  384. }
  385. const key = Bun.hash.xxHash32(JSON.stringify({ pkg, options }))
  386. const existing = s.sdk.get(key)
  387. if (existing) return existing
  388. const installedPath = await BunProc.install(pkg, "latest")
  389. // The `google-vertex-anthropic` provider points to the `@ai-sdk/google-vertex` package.
  390. // Ref: https://github.com/sst/models.dev/blob/0a87de42ab177bebad0620a889e2eb2b4a5dd4ab/providers/google-vertex-anthropic/provider.toml
  391. // However, the actual export is at the subpath `@ai-sdk/google-vertex/anthropic`.
  392. // Ref: https://ai-sdk.dev/providers/ai-sdk-providers/google-vertex#google-vertex-anthropic-provider-usage
  393. // In addition, Bun's dynamic import logic does not support subpath imports,
  394. // so we patch the import path to load directly from `dist`.
  395. const modPath =
  396. provider.id === "google-vertex-anthropic" ? `${installedPath}/dist/anthropic/index.mjs` : installedPath
  397. const mod = await import(modPath)
  398. if (options["timeout"] !== undefined) {
  399. // Only override fetch if user explicitly sets timeout
  400. options["fetch"] = async (input: any, init?: BunFetchRequestInit) => {
  401. const { signal, ...rest } = init ?? {}
  402. const signals: AbortSignal[] = []
  403. if (signal) signals.push(signal)
  404. signals.push(AbortSignal.timeout(options["timeout"]))
  405. const combined = signals.length > 1 ? AbortSignal.any(signals) : signals[0]
  406. return fetch(input, {
  407. ...rest,
  408. signal: combined,
  409. // @ts-ignore see here: https://github.com/oven-sh/bun/issues/16682
  410. timeout: false,
  411. })
  412. }
  413. }
  414. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  415. const loaded = fn({
  416. name: provider.id,
  417. ...options,
  418. })
  419. s.sdk.set(key, loaded)
  420. return loaded as SDK
  421. })().catch((e) => {
  422. throw new InitError({ providerID: provider.id }, { cause: e })
  423. })
  424. }
  425. export async function getProvider(providerID: string) {
  426. return state().then((s) => s.providers[providerID])
  427. }
  428. export async function getModel(providerID: string, modelID: string) {
  429. const key = `${providerID}/${modelID}`
  430. const s = await state()
  431. if (s.models.has(key)) return s.models.get(key)!
  432. log.info("getModel", {
  433. providerID,
  434. modelID,
  435. })
  436. const provider = s.providers[providerID]
  437. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  438. const info = provider.info.models[modelID]
  439. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  440. const sdk = await getSDK(provider.info, info)
  441. try {
  442. const keyReal = `${providerID}/${modelID}`
  443. const realID = s.realIdByKey.get(keyReal) ?? info.id
  444. const language = provider.getModel
  445. ? await provider.getModel(sdk, realID, provider.options)
  446. : sdk.languageModel(realID)
  447. log.info("found", { providerID, modelID })
  448. s.models.set(key, {
  449. providerID,
  450. modelID,
  451. info,
  452. language,
  453. npm: info.provider?.npm ?? provider.info.npm,
  454. })
  455. return {
  456. modelID,
  457. providerID,
  458. info,
  459. language,
  460. npm: info.provider?.npm ?? provider.info.npm,
  461. }
  462. } catch (e) {
  463. if (e instanceof NoSuchModelError)
  464. throw new ModelNotFoundError(
  465. {
  466. modelID: modelID,
  467. providerID,
  468. },
  469. { cause: e },
  470. )
  471. throw e
  472. }
  473. }
  474. export async function getSmallModel(providerID: string) {
  475. const cfg = await Config.get()
  476. if (cfg.small_model) {
  477. const parsed = parseModel(cfg.small_model)
  478. return getModel(parsed.providerID, parsed.modelID)
  479. }
  480. const provider = await state().then((state) => state.providers[providerID])
  481. if (!provider) return
  482. const priority = [
  483. "claude-haiku-4-5",
  484. "claude-haiku-4.5",
  485. "3-5-haiku",
  486. "3.5-haiku",
  487. "gemini-2.5-flash",
  488. "gpt-5-nano",
  489. ]
  490. for (const item of priority) {
  491. for (const model of Object.keys(provider.info.models)) {
  492. if (model.includes(item)) return getModel(providerID, model)
  493. }
  494. }
  495. }
  496. const priority = ["gemini-2.5-pro-preview", "gpt-5", "claude-sonnet-4"]
  497. export function sort(models: ModelsDev.Model[]) {
  498. return sortBy(
  499. models,
  500. [(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
  501. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  502. [(model) => model.id, "desc"],
  503. )
  504. }
  505. export async function defaultModel() {
  506. const cfg = await Config.get()
  507. if (cfg.model) return parseModel(cfg.model)
  508. // this will be adjusted when migration to opentui is complete,
  509. // for now we just read the tui state toml file directly
  510. //
  511. // NOTE: cannot just import file as toml without cleaning due to lack of
  512. // support for date/time references in Bun toml parser: https://github.com/oven-sh/bun/issues/22426
  513. const lastused = await Bun.file(path.join(Global.Path.state, "tui"))
  514. .text()
  515. .then((text) => {
  516. // remove the date/time references since Bun toml parser doesn't support yet
  517. const cleaned = text
  518. .split("\n")
  519. .filter((line) => !line.trim().startsWith("last_used ="))
  520. .join("\n")
  521. const state = Bun.TOML.parse(cleaned) as {
  522. recently_used_models?: {
  523. provider_id: string
  524. model_id: string
  525. }[]
  526. }
  527. const [model] = state?.recently_used_models ?? []
  528. if (model) {
  529. return {
  530. providerID: model.provider_id,
  531. modelID: model.model_id,
  532. }
  533. }
  534. })
  535. .catch((error) => {
  536. log.error("failed to find last used model", {
  537. error,
  538. })
  539. return undefined
  540. })
  541. if (lastused) return lastused
  542. const provider = await list()
  543. .then((val) => Object.values(val))
  544. .then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
  545. if (!provider) throw new Error("no providers found")
  546. const [model] = sort(Object.values(provider.info.models))
  547. if (!model) throw new Error("no models found")
  548. return {
  549. providerID: provider.info.id,
  550. modelID: model.id,
  551. }
  552. }
  553. export function parseModel(model: string) {
  554. const [providerID, ...rest] = model.split("/")
  555. return {
  556. providerID: providerID,
  557. modelID: rest.join("/"),
  558. }
  559. }
  560. export const ModelNotFoundError = NamedError.create(
  561. "ProviderModelNotFoundError",
  562. z.object({
  563. providerID: z.string(),
  564. modelID: z.string(),
  565. }),
  566. )
  567. export const InitError = NamedError.create(
  568. "ProviderInitError",
  569. z.object({
  570. providerID: z.string(),
  571. }),
  572. )
  573. }