provider.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { mergeDeep, sortBy } from "remeda"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import { BunProc } from "../bun"
  8. import { BashTool } from "../tool/bash"
  9. import { EditTool } from "../tool/edit"
  10. import { WebFetchTool } from "../tool/webfetch"
  11. import { GlobTool } from "../tool/glob"
  12. import { GrepTool } from "../tool/grep"
  13. import { ListTool } from "../tool/ls"
  14. import { PatchTool } from "../tool/patch"
  15. import { ReadTool } from "../tool/read"
  16. import type { Tool } from "../tool/tool"
  17. import { WriteTool } from "../tool/write"
  18. import { TodoReadTool, TodoWriteTool } from "../tool/todo"
  19. import { AuthAnthropic } from "../auth/anthropic"
  20. import { AuthCopilot } from "../auth/copilot"
  21. import { ModelsDev } from "./models"
  22. import { NamedError } from "../util/error"
  23. import { Auth } from "../auth"
  24. import { TaskTool } from "../tool/task"
  25. export namespace Provider {
  26. const log = Log.create({ service: "provider" })
  27. type CustomLoader = (
  28. provider: ModelsDev.Provider,
  29. api?: string,
  30. ) => Promise<{
  31. autoload: boolean
  32. getModel?: (sdk: any, modelID: string) => Promise<any>
  33. options?: Record<string, any>
  34. }>
  35. type Source = "env" | "config" | "custom" | "api"
  36. const CUSTOM_LOADERS: Record<string, CustomLoader> = {
  37. async anthropic(provider) {
  38. const access = await AuthAnthropic.access()
  39. if (!access) return { autoload: false }
  40. for (const model of Object.values(provider.models)) {
  41. model.cost = {
  42. input: 0,
  43. output: 0,
  44. }
  45. }
  46. return {
  47. autoload: true,
  48. options: {
  49. apiKey: "",
  50. async fetch(input: any, init: any) {
  51. const access = await AuthAnthropic.access()
  52. const headers = {
  53. ...init.headers,
  54. authorization: `Bearer ${access}`,
  55. "anthropic-beta": "oauth-2025-04-20",
  56. }
  57. delete headers["x-api-key"]
  58. return fetch(input, {
  59. ...init,
  60. headers,
  61. })
  62. },
  63. },
  64. }
  65. },
  66. "github-copilot": async (provider) => {
  67. const copilot = await AuthCopilot()
  68. if (!copilot) return { autoload: false }
  69. let info = await Auth.get("github-copilot")
  70. if (!info || info.type !== "oauth") return { autoload: false }
  71. if (provider && provider.models) {
  72. for (const model of Object.values(provider.models)) {
  73. model.cost = {
  74. input: 0,
  75. output: 0,
  76. }
  77. }
  78. }
  79. return {
  80. autoload: true,
  81. options: {
  82. apiKey: "",
  83. async fetch(input: any, init: any) {
  84. const info = await Auth.get("github-copilot")
  85. if (!info || info.type !== "oauth") return
  86. if (!info.access || info.expires < Date.now()) {
  87. const tokens = await copilot.access(info.refresh)
  88. if (!tokens) throw new Error("GitHub Copilot authentication expired")
  89. await Auth.set("github-copilot", {
  90. type: "oauth",
  91. ...tokens,
  92. })
  93. info.access = tokens.access
  94. }
  95. let isAgentCall = false
  96. let isVisionRequest = false
  97. try {
  98. const body = typeof init.body === "string" ? JSON.parse(init.body) : init.body
  99. if (body?.messages) {
  100. isAgentCall = body.messages.some((msg: any) => msg.role && ["tool", "assistant"].includes(msg.role))
  101. isVisionRequest = body.messages.some(
  102. (msg: any) =>
  103. Array.isArray(msg.content) && msg.content.some((part: any) => part.type === "image_url"),
  104. )
  105. }
  106. } catch {}
  107. const headers: Record<string, string> = {
  108. ...init.headers,
  109. ...copilot.HEADERS,
  110. Authorization: `Bearer ${info.access}`,
  111. "Openai-Intent": "conversation-edits",
  112. "X-Initiator": isAgentCall ? "agent" : "user",
  113. }
  114. if (isVisionRequest) {
  115. headers["Copilot-Vision-Request"] = "true"
  116. }
  117. delete headers["x-api-key"]
  118. return fetch(input, {
  119. ...init,
  120. headers,
  121. })
  122. },
  123. },
  124. }
  125. },
  126. openai: async () => {
  127. return {
  128. autoload: false,
  129. async getModel(sdk: any, modelID: string) {
  130. return sdk.responses(modelID)
  131. },
  132. options: {},
  133. }
  134. },
  135. "amazon-bedrock": async () => {
  136. if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"]) return { autoload: false }
  137. const region = process.env["AWS_REGION"] ?? "us-east-1"
  138. const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
  139. return {
  140. autoload: true,
  141. options: {
  142. region,
  143. credentialProvider: fromNodeProviderChain(),
  144. },
  145. async getModel(sdk: any, modelID: string) {
  146. let regionPrefix = region.split("-")[0]
  147. switch (regionPrefix) {
  148. case "us": {
  149. const modelRequiresPrefix = ["claude", "deepseek"].some((m) => modelID.includes(m))
  150. if (modelRequiresPrefix) {
  151. modelID = `${regionPrefix}.${modelID}`
  152. }
  153. break
  154. }
  155. case "eu": {
  156. const regionRequiresPrefix = [
  157. "eu-west-1",
  158. "eu-west-3",
  159. "eu-north-1",
  160. "eu-central-1",
  161. "eu-south-1",
  162. "eu-south-2",
  163. ].some((r) => region.includes(r))
  164. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
  165. modelID.includes(m),
  166. )
  167. if (regionRequiresPrefix && modelRequiresPrefix) {
  168. modelID = `${regionPrefix}.${modelID}`
  169. }
  170. break
  171. }
  172. case "ap": {
  173. const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
  174. modelID.includes(m),
  175. )
  176. if (modelRequiresPrefix) {
  177. regionPrefix = "apac"
  178. modelID = `${regionPrefix}.${modelID}`
  179. }
  180. break
  181. }
  182. }
  183. return sdk.languageModel(modelID)
  184. },
  185. }
  186. },
  187. openrouter: async () => {
  188. return {
  189. autoload: false,
  190. options: {
  191. headers: {
  192. "HTTP-Referer": "https://opencode.ai/",
  193. "X-Title": "opencode",
  194. },
  195. },
  196. }
  197. },
  198. }
  199. const state = App.state("provider", async () => {
  200. const config = await Config.get()
  201. const database = await ModelsDev.get()
  202. const providers: {
  203. [providerID: string]: {
  204. source: Source
  205. info: ModelsDev.Provider
  206. getModel?: (sdk: any, modelID: string) => Promise<any>
  207. options: Record<string, any>
  208. }
  209. } = {}
  210. const models = new Map<string, { info: ModelsDev.Model; language: LanguageModel }>()
  211. const sdk = new Map<string, SDK>()
  212. log.info("init")
  213. function mergeProvider(
  214. id: string,
  215. options: Record<string, any>,
  216. source: Source,
  217. getModel?: (sdk: any, modelID: string) => Promise<any>,
  218. ) {
  219. const provider = providers[id]
  220. if (!provider) {
  221. const info = database[id]
  222. if (!info) return
  223. if (info.api && !options["baseURL"]) options["baseURL"] = info.api
  224. providers[id] = {
  225. source,
  226. info,
  227. options,
  228. getModel,
  229. }
  230. return
  231. }
  232. provider.options = mergeDeep(provider.options, options)
  233. provider.source = source
  234. provider.getModel = getModel ?? provider.getModel
  235. }
  236. const configProviders = Object.entries(config.provider ?? {})
  237. for (const [providerID, provider] of configProviders) {
  238. const existing = database[providerID]
  239. const parsed: ModelsDev.Provider = {
  240. id: providerID,
  241. npm: provider.npm ?? existing?.npm,
  242. name: provider.name ?? existing?.name ?? providerID,
  243. env: provider.env ?? existing?.env ?? [],
  244. api: provider.api ?? existing?.api,
  245. models: existing?.models ?? {},
  246. }
  247. for (const [modelID, model] of Object.entries(provider.models ?? {})) {
  248. const existing = parsed.models[modelID]
  249. const parsedModel: ModelsDev.Model = {
  250. id: modelID,
  251. name: model.name ?? existing?.name ?? modelID,
  252. release_date: model.release_date ?? existing?.release_date,
  253. attachment: model.attachment ?? existing?.attachment ?? false,
  254. reasoning: model.reasoning ?? existing?.reasoning ?? false,
  255. temperature: model.temperature ?? existing?.temperature ?? false,
  256. tool_call: model.tool_call ?? existing?.tool_call ?? true,
  257. cost: {
  258. ...existing?.cost,
  259. ...model.cost,
  260. input: 0,
  261. output: 0,
  262. cache_read: 0,
  263. cache_write: 0,
  264. },
  265. options: {
  266. ...existing?.options,
  267. ...model.options,
  268. },
  269. limit: model.limit ??
  270. existing?.limit ?? {
  271. context: 0,
  272. output: 0,
  273. },
  274. }
  275. parsed.models[modelID] = parsedModel
  276. }
  277. database[providerID] = parsed
  278. }
  279. const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
  280. // load env
  281. for (const [providerID, provider] of Object.entries(database)) {
  282. if (disabled.has(providerID)) continue
  283. const apiKey = provider.env.map((item) => process.env[item]).at(0)
  284. if (!apiKey) continue
  285. mergeProvider(
  286. providerID,
  287. // only include apiKey if there's only one potential option
  288. provider.env.length === 1 ? { apiKey } : {},
  289. "env",
  290. )
  291. }
  292. // load apikeys
  293. for (const [providerID, provider] of Object.entries(await Auth.all())) {
  294. if (disabled.has(providerID)) continue
  295. if (provider.type === "api") {
  296. mergeProvider(providerID, { apiKey: provider.key }, "api")
  297. }
  298. }
  299. // load custom
  300. for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
  301. if (disabled.has(providerID)) continue
  302. const result = await fn(database[providerID])
  303. if (result && (result.autoload || providers[providerID])) {
  304. mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
  305. }
  306. }
  307. // load config
  308. for (const [providerID, provider] of configProviders) {
  309. mergeProvider(providerID, provider.options ?? {}, "config")
  310. }
  311. for (const [providerID, provider] of Object.entries(providers)) {
  312. if (Object.keys(provider.info.models).length === 0) {
  313. delete providers[providerID]
  314. continue
  315. }
  316. log.info("found", { providerID })
  317. }
  318. return {
  319. models,
  320. providers,
  321. sdk,
  322. }
  323. })
  324. export async function list() {
  325. return state().then((state) => state.providers)
  326. }
  327. async function getSDK(provider: ModelsDev.Provider) {
  328. return (async () => {
  329. using _ = log.time("getSDK", {
  330. providerID: provider.id,
  331. })
  332. const s = await state()
  333. const existing = s.sdk.get(provider.id)
  334. if (existing) return existing
  335. const pkg = provider.npm ?? provider.id
  336. const mod = await import(await BunProc.install(pkg, "beta"))
  337. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  338. const loaded = fn(s.providers[provider.id]?.options)
  339. s.sdk.set(provider.id, loaded)
  340. return loaded as SDK
  341. })().catch((e) => {
  342. throw new InitError({ providerID: provider.id }, { cause: e })
  343. })
  344. }
  345. export async function getModel(providerID: string, modelID: string) {
  346. const key = `${providerID}/${modelID}`
  347. const s = await state()
  348. if (s.models.has(key)) return s.models.get(key)!
  349. log.info("getModel", {
  350. providerID,
  351. modelID,
  352. })
  353. const provider = s.providers[providerID]
  354. if (!provider) throw new ModelNotFoundError({ providerID, modelID })
  355. const info = provider.info.models[modelID]
  356. if (!info) throw new ModelNotFoundError({ providerID, modelID })
  357. const sdk = await getSDK(provider.info)
  358. try {
  359. const language = provider.getModel ? await provider.getModel(sdk, modelID) : sdk.languageModel(modelID)
  360. log.info("found", { providerID, modelID })
  361. s.models.set(key, {
  362. info,
  363. language,
  364. })
  365. return {
  366. info,
  367. language,
  368. }
  369. } catch (e) {
  370. if (e instanceof NoSuchModelError)
  371. throw new ModelNotFoundError(
  372. {
  373. modelID: modelID,
  374. providerID,
  375. },
  376. { cause: e },
  377. )
  378. throw e
  379. }
  380. }
  381. export async function getSmallModel(providerID: string) {
  382. const provider = await state().then((state) => state.providers[providerID])
  383. if (!provider) return
  384. const priority = ["3-5-haiku", "3.5-haiku", "gemini-2.5-flash"]
  385. for (const item of priority) {
  386. for (const model of Object.keys(provider.info.models)) {
  387. if (model.includes(item)) return getModel(providerID, model)
  388. }
  389. }
  390. }
  391. const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
  392. export function sort(models: ModelsDev.Model[]) {
  393. return sortBy(
  394. models,
  395. [(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
  396. [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
  397. [(model) => model.id, "desc"],
  398. )
  399. }
  400. export async function defaultModel() {
  401. const cfg = await Config.get()
  402. if (cfg.model) return parseModel(cfg.model)
  403. const provider = await list()
  404. .then((val) => Object.values(val))
  405. .then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
  406. if (!provider) throw new Error("no providers found")
  407. const [model] = sort(Object.values(provider.info.models))
  408. if (!model) throw new Error("no models found")
  409. return {
  410. providerID: provider.info.id,
  411. modelID: model.id,
  412. }
  413. }
  414. export function parseModel(model: string) {
  415. const [providerID, ...rest] = model.split("/")
  416. return {
  417. providerID: providerID,
  418. modelID: rest.join("/"),
  419. }
  420. }
  421. const TOOLS = [
  422. BashTool,
  423. EditTool,
  424. WebFetchTool,
  425. GlobTool,
  426. GrepTool,
  427. ListTool,
  428. // LspDiagnosticTool,
  429. // LspHoverTool,
  430. PatchTool,
  431. ReadTool,
  432. // MultiEditTool,
  433. WriteTool,
  434. TodoWriteTool,
  435. TodoReadTool,
  436. TaskTool,
  437. ]
  438. const TOOL_MAPPING: Record<string, Tool.Info[]> = {
  439. anthropic: TOOLS.filter((t) => t.id !== "patch"),
  440. openai: TOOLS.map((t) => ({
  441. ...t,
  442. parameters: optionalToNullable(t.parameters),
  443. })),
  444. azure: TOOLS.map((t) => ({
  445. ...t,
  446. parameters: optionalToNullable(t.parameters),
  447. })),
  448. google: TOOLS,
  449. }
  450. export async function tools(providerID: string) {
  451. /*
  452. const cfg = await Config.get()
  453. if (cfg.tool?.provider?.[providerID])
  454. return cfg.tool.provider[providerID].map(
  455. (id) => TOOLS.find((t) => t.id === id)!,
  456. )
  457. */
  458. return TOOL_MAPPING[providerID] ?? TOOLS
  459. }
  460. function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
  461. if (schema instanceof z.ZodObject) {
  462. const shape = schema.shape
  463. const newShape: Record<string, z.ZodTypeAny> = {}
  464. for (const [key, value] of Object.entries(shape)) {
  465. const zodValue = value as z.ZodTypeAny
  466. if (zodValue instanceof z.ZodOptional) {
  467. newShape[key] = zodValue.unwrap().nullable()
  468. } else {
  469. newShape[key] = optionalToNullable(zodValue)
  470. }
  471. }
  472. return z.object(newShape)
  473. }
  474. if (schema instanceof z.ZodArray) {
  475. return z.array(optionalToNullable(schema.element))
  476. }
  477. if (schema instanceof z.ZodUnion) {
  478. return z.union(
  479. schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
  480. z.ZodTypeAny,
  481. z.ZodTypeAny,
  482. ...z.ZodTypeAny[],
  483. ],
  484. )
  485. }
  486. return schema
  487. }
  488. export const ModelNotFoundError = NamedError.create(
  489. "ProviderModelNotFoundError",
  490. z.object({
  491. providerID: z.string(),
  492. modelID: z.string(),
  493. }),
  494. )
  495. export const InitError = NamedError.create(
  496. "ProviderInitError",
  497. z.object({
  498. providerID: z.string(),
  499. }),
  500. )
  501. }