provider.ts 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. import z from "zod"
  2. import fuzzysort from "fuzzysort"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { Plugin } from "../plugin"
  9. import { ModelsDev } from "./models"
  10. import { NamedError } from "@opencode-ai/util/error"
  11. import { Auth } from "../auth"
  12. import { Instance } from "../project/instance"
  13. import { Flag } from "../flag/flag"
  14. import { iife } from "@/util/iife"
  15. // Direct imports for bundled providers
  16. import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock"
  17. import { createAnthropic } from "@ai-sdk/anthropic"
  18. import { createAzure } from "@ai-sdk/azure"
  19. import { createGoogleGenerativeAI } from "@ai-sdk/google"
  20. import { createVertex } from "@ai-sdk/google-vertex"
  21. import { createVertexAnthropic } from "@ai-sdk/google-vertex/anthropic"
  22. import { createOpenAI } from "@ai-sdk/openai"
  23. import { createOpenAICompatible } from "@ai-sdk/openai-compatible"
  24. import { createOpenRouter } from "@openrouter/ai-sdk-provider"
  25. export namespace Provider {
  26. const log = Log.create({ service: "provider" })
  27. const BUNDLED_PROVIDERS: Record<string, (options: any) => SDK> = {
  28. "@ai-sdk/amazon-bedrock": createAmazonBedrock,
  29. "@ai-sdk/anthropic": createAnthropic,
  30. "@ai-sdk/azure": createAzure,
  31. "@ai-sdk/google": createGoogleGenerativeAI,
  32. "@ai-sdk/google-vertex": createVertex,
  33. "@ai-sdk/google-vertex/anthropic": createVertexAnthropic,
  34. "@ai-sdk/openai": createOpenAI,
  35. "@ai-sdk/openai-compatible": createOpenAICompatible,
  36. "@openrouter/ai-sdk-provider": createOpenRouter,
  37. }
  38. type CustomLoader = (provider: ModelsDev.Provider) => Promise<{
  39. autoload: boolean
  40. getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
  41. options?: Record<string, any>
  42. }>
  43. type Source = "env" | "config" | "custom" | "api"
  44. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  45. async anthropic() {
  46. return {
  47. autoload: false,
  48. options: {
  49. headers: {
  50. "anthropic-beta":
  51. "claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14",
  52. },
  53. },
  54. }
  55. },
  56. async opencode(input) {
  57. const hasKey = await (async () => {
  58. if (input.env.some((item) => process.env[item])) return true
  59. if (await Auth.get(input.id)) return true
  60. return false
  61. })()
  62. if (!hasKey) {
  63. for (const [key, value] of Object.entries(input.models)) {
  64. if (value.cost.input === 0) continue
  65. delete input.models[key]
  66. }
  67. }
  68. return {
  69. autoload: Object.keys(input.models).length > 0,
  70. options: hasKey ? {} : { apiKey: "public" },
  71. }
  72. },
  73. openai: async () => {
  74. return {
  75. autoload: false,
  76. async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
  77. return sdk.responses(modelID)
  78. },
  79. options: {},
  80. }
  81. },
  82. azure: async () => {
  83. return {
  84. autoload: false,
  85. async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
  86. if (options?.["useCompletionUrls"]) {
  87. return sdk.chat(modelID)
  88. } else {
  89. return sdk.responses(modelID)
  90. }
  91. },
  92. options: {},
  93. }
  94. },
  95. "azure-cognitive-services": async () => {
  96. const resourceName = process.env["AZURE_COGNITIVE_SERVICES_RESOURCE_NAME"]
  97. return {
  98. autoload: false,
  99. async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
  100. if (options?.["useCompletionUrls"]) {
  101. return sdk.chat(modelID)
  102. } else {
  103. return sdk.responses(modelID)
  104. }
  105. },
  106. options: {
  107. baseURL: resourceName ? `https://${resourceName}.cognitiveservices.azure.com/openai` : undefined,
  108. },
  109. }
  110. },
  111. "amazon-bedrock": async () => {
  112. if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"] && !process.env["AWS_BEARER_TOKEN_BEDROCK"])
  113. return { autoload: false }
  114. const region = process.env["AWS_REGION"] ?? "us-east-1"
  115. const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
  116. return {
  117. autoload: true,
  118. options: {
  119. region,
  120. credentialProvider: fromNodeProviderChain(),
  121. },
  122. async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
  123. // Skip region prefixing if model already has global prefix
  124. if (modelID.startsWith("global.")) {
  125. return sdk.languageModel(modelID)
  126. }
  127. let regionPrefix = region.split("-")[0]
  128. switch (regionPrefix) {
  129. case "us": {
  130. const modelRequiresPrefix = [
  131. "nova-micro",
  132. "nova-lite",
  133. "nova-pro",
  134. "nova-premier",
  135. "claude",
  136. "deepseek",
  137. ].some((m) => modelID.includes(m))
  138. const isGovCloud = region.startsWith("us-gov")
  139. if (modelRequiresPrefix && !isGovCloud) {
  140. modelID = `${regionPrefix}.${modelID}`
  141. }
  142. break
  143. }
  144. case "eu": {
  145. const regionRequiresPrefix = [
  146. "eu-west-1",
  147. "eu-west-2",
  148. "eu-west-3",
  149. "eu-north-1",
  150. "eu-central-1",
  151. "eu-south-1",
  152. "eu-south-2",
  153. ].some((r) => region.includes(r))
  154. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
  155. modelID.includes(m),
  156. )
  157. if (regionRequiresPrefix && modelRequiresPrefix) {
  158. modelID = `${regionPrefix}.${modelID}`
  159. }
  160. break
  161. }
  162. case "ap": {
  163. const isAustraliaRegion = ["ap-southeast-2", "ap-southeast-4"].includes(region)
  164. if (
  165. isAustraliaRegion &&
  166. ["anthropic.claude-sonnet-4-5", "anthropic.claude-haiku"].some((m) => modelID.includes(m))
  167. ) {
  168. regionPrefix = "au"
  169. modelID = `${regionPrefix}.${modelID}`
  170. } else {
  171. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
  172. modelID.includes(m),
  173. )
  174. if (modelRequiresPrefix) {
  175. regionPrefix = "apac"
  176. modelID = `${regionPrefix}.${modelID}`
  177. }
  178. }
  179. break
  180. }
  181. }
  182. return sdk.languageModel(modelID)
  183. },
  184. }
  185. },
  186. openrouter: async () => {
  187. return {
  188. autoload: false,
  189. options: {
  190. headers: {
  191. "HTTP-Referer": "https://opencode.ai/",
  192. "X-Title": "opencode",
  193. },
  194. },
  195. }
  196. },
  197. vercel: async () => {
  198. return {
  199. autoload: false,
  200. options: {
  201. headers: {
  202. "http-referer": "https://opencode.ai/",
  203. "x-title": "opencode",
  204. },
  205. },
  206. }
  207. },
  208. "google-vertex": async () => {
  209. const project = process.env["GOOGLE_CLOUD_PROJECT"] ?? process.env["GCP_PROJECT"] ?? process.env["GCLOUD_PROJECT"]
  210. const location = process.env["GOOGLE_CLOUD_LOCATION"] ?? process.env["VERTEX_LOCATION"] ?? "us-east5"
  211. const autoload = Boolean(project)
  212. if (!autoload) return { autoload: false }
  213. return {
  214. autoload: true,
  215. options: {
  216. project,
  217. location,
  218. },
  219. async getModel(sdk: any, modelID: string) {
  220. const id = String(modelID).trim()
  221. return sdk.languageModel(id)
  222. },
  223. }
  224. },
  225. "google-vertex-anthropic": async () => {
  226. const project = process.env["GOOGLE_CLOUD_PROJECT"] ?? process.env["GCP_PROJECT"] ?? process.env["GCLOUD_PROJECT"]
  227. const location = process.env["GOOGLE_CLOUD_LOCATION"] ?? process.env["VERTEX_LOCATION"] ?? "global"
  228. const autoload = Boolean(project)
  229. if (!autoload) return { autoload: false }
  230. return {
  231. autoload: true,
  232. options: {
  233. project,
  234. location,
  235. },
  236. async getModel(sdk: any, modelID: string) {
  237. const id = String(modelID).trim()
  238. return sdk.languageModel(id)
  239. },
  240. }
  241. },
  242. zenmux: async () => {
  243. return {
  244. autoload: false,
  245. options: {
  246. headers: {
  247. "HTTP-Referer": "https://opencode.ai/",
  248. "X-Title": "opencode",
  249. },
  250. },
  251. }
  252. },
  253. }
  254. const state = Instance.state(async () => {
  255. using _ = log.time("state")
  256. const config = await Config.get()
  257. const database = await ModelsDev.get()
  258. const disabled = new Set(config.disabled_providers ?? [])
  259. const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
  260. function isProviderAllowed(providerID: string): boolean {
  261. if (enabled && !enabled.has(providerID)) return false
  262. if (disabled.has(providerID)) return false
  263. return true
  264. }
  265. const providers: {
  266. [providerID: string]: {
  267. source: Source
  268. info: ModelsDev.Provider
  269. getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
  270. options: Record<string, any>
  271. }
  272. } = {}
  273. const models = new Map<
  274. string,
  275. {
  276. providerID: string
  277. modelID: string
  278. info: ModelsDev.Model
  279. language: LanguageModel
  280. npm?: string
  281. }
  282. >()
  283. const sdk = new Map<number, SDK>()
  284. // Maps `${provider}/${key}` to the provider’s actual model ID for custom aliases.
  285. const realIdByKey = new Map<string, string>()
  286. log.info("init")
  287. function mergeProvider(
  288. id: string,
  289. options: Record<string, any>,
  290. source: Source,
  291. getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>,
  292. ) {
  293. const provider = providers[id]
  294. if (!provider) {
  295. const info = database[id]
  296. if (!info) return
  297. if (info.api && !options["baseURL"]) options["baseURL"] = info.api
  298. providers[id] = {
  299. source,
  300. info,
  301. options,
  302. getModel,
  303. }
  304. return
  305. }
  306. provider.options = mergeDeep(provider.options, options)
  307. provider.source = source
  308. provider.getModel = getModel ?? provider.getModel
  309. }
  310. const configProviders = Object.entries(config.provider ?? {})
  311. // Add GitHub Copilot Enterprise provider that inherits from GitHub Copilot
  312. if (database["github-copilot"]) {
  313. const githubCopilot = database["github-copilot"]
  314. database["github-copilot-enterprise"] = {
  315. ...githubCopilot,
  316. id: "github-copilot-enterprise",
  317. name: "GitHub Copilot Enterprise",
  318. // Enterprise uses a different API endpoint - will be set dynamically based on auth
  319. api: undefined,
  320. }
  321. }
  322. for (const [providerID, provider] of configProviders) {
  323. const existing = database[providerID]
  324. const parsed: ModelsDev.Provider = {
  325. id: providerID,
  326. npm: provider.npm ?? existing?.npm,
  327. name: provider.name ?? existing?.name ?? providerID,
  328. env: provider.env ?? existing?.env ?? [],
  329. api: provider.api ?? existing?.api,
  330. models: existing?.models ?? {},
  331. }
  332. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  333. const existing = parsed.models[model.id ?? modelID]
  334. const name = iife(() => {
  335. if (model.name) return model.name
  336. if (model.id && model.id !== modelID) return modelID
  337. return existing?.name ?? modelID
  338. })
  339. const parsedModel: ModelsDev.Model = {
  340. id: modelID,
  341. name,
  342. release_date: model.release_date ?? existing?.release_date,
  343. attachment: model.attachment ?? existing?.attachment ?? false,
  344. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  345. temperature: model.temperature ?? existing?.temperature ?? false,
  346. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  347. cost:
  348. !model.cost && !existing?.cost
  349. ? {
  350. input: 0,
  351. output: 0,
  352. cache_read: 0,
  353. cache_write: 0,
  354. }
  355. : {
  356. cache_read: 0,
  357. cache_write: 0,
  358. ...existing?.cost,
  359. ...model.cost,
  360. },
  361. options: {
  362. ...existing?.options,
  363. ...model.options,
  364. },
  365. limit: model.limit ??
  366. existing?.limit ?? {
  367. context: 0,
  368. output: 0,
  369. },
  370. modalities: model.modalities ??
  371. existing?.modalities ?? {
  372. input: ["text"],
  373. output: ["text"],
  374. },
  375. headers: model.headers,
  376. provider: model.provider ?? existing?.provider,
  377. }
  378. if (model.id && model.id !== modelID) {
  379. realIdByKey.set(`${providerID}/${modelID}`, model.id)
  380. }
  381. parsed.models[modelID] = parsedModel
  382. }
  383. database[providerID] = parsed
  384. }
  385. // load env
  386. for (const [providerID, provider] of Object.entries(database)) {
  387. if (disabled.has(providerID)) continue
  388. const apiKey = provider.env.map((item) => process.env[item]).at(0)
  389. if (!apiKey) continue
  390. mergeProvider(
  391. providerID,
  392. // only include apiKey if there's only one potential option
  393. provider.env.length === 1 ? { apiKey } : {},
  394. "env",
  395. )
  396. }
  397. // load apikeys
  398. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  399. if (disabled.has(providerID)) continue
  400. if (provider.type === "api") {
  401. mergeProvider(providerID, { apiKey: provider.key }, "api")
  402. }
  403. }
  404. // load custom
  405. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  406. if (disabled.has(providerID)) continue
  407. const result = await fn(database[providerID])
  408. if (result && (result.autoload || providers[providerID])) {
  409. mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
  410. }
  411. }
  412. for (const plugin of await Plugin.list()) {
  413. if (!plugin.auth) continue
  414. const providerID = plugin.auth.provider
  415. if (disabled.has(providerID)) continue
  416. // For github-copilot plugin, check if auth exists for either github-copilot or github-copilot-enterprise
  417. let hasAuth = false
  418. const auth = await Auth.get(providerID)
  419. if (auth) hasAuth = true
  420. // Special handling for github-copilot: also check for enterprise auth
  421. if (providerID === "github-copilot" && !hasAuth) {
  422. const enterpriseAuth = await Auth.get("github-copilot-enterprise")
  423. if (enterpriseAuth) hasAuth = true
  424. }
  425. if (!hasAuth) continue
  426. if (!plugin.auth.loader) continue
  427. // Load for the main provider if auth exists
  428. if (auth) {
  429. const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider])
  430. mergeProvider(plugin.auth.provider, options ?? {}, "custom")
  431. }
  432. // If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
  433. if (providerID === "github-copilot") {
  434. const enterpriseProviderID = "github-copilot-enterprise"
  435. if (!disabled.has(enterpriseProviderID)) {
  436. const enterpriseAuth = await Auth.get(enterpriseProviderID)
  437. if (enterpriseAuth) {
  438. const enterpriseOptions = await plugin.auth.loader(
  439. () => Auth.get(enterpriseProviderID) as any,
  440. database[enterpriseProviderID],
  441. )
  442. mergeProvider(enterpriseProviderID, enterpriseOptions ?? {}, "custom")
  443. }
  444. }
  445. }
  446. }
  447. // load config
  448. for (const [providerID, provider] of configProviders) {
  449. mergeProvider(providerID, provider.options ?? {}, "config")
  450. }
  451. for (const [providerID, provider] of Object.entries(providers)) {
  452. if (!isProviderAllowed(providerID)) {
  453. delete providers[providerID]
  454. continue
  455. }
  456. const configProvider = config.provider?.[providerID]
  457. const filteredModels = Object.fromEntries(
  458. Object.entries(provider.info.models)
  459. // Filter out blacklisted models
  460. .filter(
  461. ([modelID]) =>
  462. modelID !== "gpt-5-chat-latest" && !(providerID === "openrouter" && modelID === "openai/gpt-5-chat"),
  463. )
  464. // Filter out experimental models
  465. .filter(
  466. ([, model]) =>
  467. ((!model.experimental && model.status !== "alpha") || Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) &&
  468. model.status !== "deprecated",
  469. )
  470. // Filter by provider's whitelist/blacklist from config
  471. .filter(([modelID]) => {
  472. if (!configProvider) return true
  473. return (
  474. (!configProvider.blacklist || !configProvider.blacklist.includes(modelID)) &&
  475. (!configProvider.whitelist || configProvider.whitelist.includes(modelID))
  476. )
  477. }),
  478. )
  479. provider.info.models = filteredModels
  480. if (Object.keys(provider.info.models).length === 0) {
  481. delete providers[providerID]
  482. continue
  483. }
  484. log.info("found", { providerID, npm: provider.info.npm })
  485. }
  486. return {
  487. models,
  488. providers,
  489. sdk,
  490. realIdByKey,
  491. }
  492. })
  493. export async function list() {
  494. return state().then((state) => state.providers)
  495. }
  496. async function getSDK(provider: ModelsDev.Provider, model: ModelsDev.Model) {
  497. return (async () => {
  498. using _ = log.time("getSDK", {
  499. providerID: provider.id,
  500. })
  501. const s = await state()
  502. const pkg = model.provider?.npm ?? provider.npm ?? provider.id
  503. const options = { ...s.providers[provider.id]?.options }
  504. if (pkg.includes("@ai-sdk/openai-compatible") && options["includeUsage"] === undefined) {
  505. options["includeUsage"] = true
  506. }
  507. const key = Bun.hash.xxHash32(JSON.stringify({ pkg, options }))
  508. const existing = s.sdk.get(key)
  509. if (existing) return existing
  510. const customFetch = options["fetch"]
  511. options["fetch"] = async (input: any, init?: BunFetchRequestInit) => {
  512. // Preserve custom fetch if it exists, wrap it with timeout logic
  513. const fetchFn = customFetch ?? fetch
  514. const opts = init ?? {}
  515. if (options["timeout"] !== undefined && options["timeout"] !== null) {
  516. const signals: AbortSignal[] = []
  517. if (opts.signal) signals.push(opts.signal)
  518. if (options["timeout"] !== false) signals.push(AbortSignal.timeout(options["timeout"]))
  519. const combined = signals.length > 1 ? AbortSignal.any(signals) : signals[0]
  520. opts.signal = combined
  521. }
  522. return fetchFn(input, {
  523. ...opts,
  524. // @ts-ignore see here: https://github.com/oven-sh/bun/issues/16682
  525. timeout: false,
  526. })
  527. }
  528. // Special case: google-vertex-anthropic uses a subpath import
  529. const bundledKey = provider.id === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : pkg
  530. const bundledFn = BUNDLED_PROVIDERS[bundledKey]
  531. if (bundledFn) {
  532. log.info("using bundled provider", { providerID: provider.id, pkg: bundledKey })
  533. const loaded = bundledFn({
  534. name: provider.id,
  535. ...options,
  536. })
  537. s.sdk.set(key, loaded)
  538. return loaded as SDK
  539. }
  540. let installedPath: string
  541. if (!pkg.startsWith("file://")) {
  542. installedPath = await BunProc.install(pkg, "latest")
  543. } else {
  544. log.info("loading local provider", { pkg })
  545. installedPath = pkg
  546. }
  547. const mod = await import(installedPath)
  548. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  549. const loaded = fn({
  550. name: provider.id,
  551. ...options,
  552. })
  553. s.sdk.set(key, loaded)
  554. return loaded as SDK
  555. })().catch((e) => {
  556. throw new InitError({ providerID: provider.id }, { cause: e })
  557. })
  558. }
  559. export async function getProvider(providerID: string) {
  560. return state().then((s) => s.providers[providerID])
  561. }
  562. export async function getModel(providerID: string, modelID: string) {
  563. const key = `${providerID}/${modelID}`
  564. const s = await state()
  565. if (s.models.has(key)) return s.models.get(key)!
  566. log.info("getModel", {
  567. providerID,
  568. modelID,
  569. })
  570. const provider = s.providers[providerID]
  571. if (!provider) {
  572. const availableProviders = Object.keys(s.providers)
  573. const matches = fuzzysort.go(providerID, availableProviders, { limit: 3, threshold: -10000 })
  574. const suggestions = matches.map((m) => m.target)
  575. throw new ModelNotFoundError({ providerID, modelID, suggestions })
  576. }
  577. const info = provider.info.models[modelID]
  578. if (!info) {
  579. const availableModels = Object.keys(provider.info.models)
  580. const matches = fuzzysort.go(modelID, availableModels, { limit: 3, threshold: -10000 })
  581. const suggestions = matches.map((m) => m.target)
  582. throw new ModelNotFoundError({ providerID, modelID, suggestions })
  583. }
  584. const sdk = await getSDK(provider.info, info)
  585. try {
  586. const keyReal = `${providerID}/${modelID}`
  587. const realID = s.realIdByKey.get(keyReal) ?? info.id
  588. const language = provider.getModel
  589. ? await provider.getModel(sdk, realID, provider.options)
  590. : sdk.languageModel(realID)
  591. log.info("found", { providerID, modelID })
  592. s.models.set(key, {
  593. providerID,
  594. modelID,
  595. info,
  596. language,
  597. npm: info.provider?.npm ?? provider.info.npm,
  598. })
  599. return {
  600. modelID,
  601. providerID,
  602. info,
  603. language,
  604. npm: info.provider?.npm ?? provider.info.npm,
  605. }
  606. } catch (e) {
  607. if (e instanceof NoSuchModelError)
  608. throw new ModelNotFoundError(
  609. {
  610. modelID: modelID,
  611. providerID,
  612. },
  613. { cause: e },
  614. )
  615. throw e
  616. }
  617. }
  618. export async function getSmallModel(providerID: string) {
  619. const cfg = await Config.get()
  620. if (cfg.small_model) {
  621. const parsed = parseModel(cfg.small_model)
  622. return getModel(parsed.providerID, parsed.modelID)
  623. }
  624. const provider = await state().then((state) => state.providers[providerID])
  625. if (provider) {
  626. let priority = [
  627. "claude-haiku-4-5",
  628. "claude-haiku-4.5",
  629. "3-5-haiku",
  630. "3.5-haiku",
  631. "gemini-2.5-flash",
  632. "gpt-5-nano",
  633. ]
  634. // claude-haiku-4.5 is considered a premium model in github copilot, we shouldn't use premium requests for title gen
  635. if (providerID === "github-copilot") {
  636. priority = priority.filter((m) => m !== "claude-haiku-4.5")
  637. }
  638. if (providerID.startsWith("opencode")) {
  639. priority = ["gpt-5-nano"]
  640. }
  641. for (const item of priority) {
  642. for (const model of Object.keys(provider.info.models)) {
  643. if (model.includes(item)) return getModel(providerID, model)
  644. }
  645. }
  646. }
  647. // Check if opencode provider is available before using it
  648. const opencodeProvider = await state().then((state) => state.providers["opencode"])
  649. if (opencodeProvider && opencodeProvider.info.models["gpt-5-nano"]) {
  650. return getModel("opencode", "gpt-5-nano")
  651. }
  652. return undefined
  653. }
  654. const priority = ["gpt-5", "claude-sonnet-4", "big-pickle", "gemini-3-pro"]
  655. export function sort(models: ModelsDev.Model[]) {
  656. return sortBy(
  657. models,
  658. [(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
  659. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  660. [(model) => model.id, "desc"],
  661. )
  662. }
  663. export async function defaultModel() {
  664. const cfg = await Config.get()
  665. if (cfg.model) return parseModel(cfg.model)
  666. const provider = await list()
  667. .then((val) => Object.values(val))
  668. .then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
  669. if (!provider) throw new Error("no providers found")
  670. const [model] = sort(Object.values(provider.info.models))
  671. if (!model) throw new Error("no models found")
  672. return {
  673. providerID: provider.info.id,
  674. modelID: model.id,
  675. }
  676. }
  677. export function parseModel(model: string) {
  678. const [providerID, ...rest] = model.split("/")
  679. return {
  680. providerID: providerID,
  681. modelID: rest.join("/"),
  682. }
  683. }
  684. export const ModelNotFoundError = NamedError.create(
  685. "ProviderModelNotFoundError",
  686. z.object({
  687. providerID: z.string(),
  688. modelID: z.string(),
  689. suggestions: z.array(z.string()).optional(),
  690. }),
  691. )
  692. export const InitError = NamedError.create(
  693. "ProviderInitError",
  694. z.object({
  695. providerID: z.string(),
  696. }),
  697. )
  698. }