provider.ts 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { PROVIDER_DATABASE } from "./database"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import path from "path"
  8. import { Global } from "../global"
  9. import { BunProc } from "../bun"
  10. import { BashTool } from "../tool/bash"
  11. import { EditTool } from "../tool/edit"
  12. import { WebFetchTool } from "../tool/webfetch"
  13. import { GlobTool } from "../tool/glob"
  14. import { GrepTool } from "../tool/grep"
  15. import { ListTool } from "../tool/ls"
  16. import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
  17. import { LspHoverTool } from "../tool/lsp-hover"
  18. import { PatchTool } from "../tool/patch"
  19. import { ReadTool } from "../tool/read"
  20. import type { Tool } from "../tool/tool"
  21. import { MultiEditTool } from "../tool/multiedit"
  22. import { WriteTool } from "../tool/write"
  23. export namespace Provider {
  24. const log = Log.create({ service: "provider" })
  25. export const Model = z
  26. .object({
  27. id: z.string(),
  28. name: z.string().optional(),
  29. cost: z.object({
  30. input: z.number(),
  31. inputCached: z.number(),
  32. output: z.number(),
  33. outputCached: z.number(),
  34. }),
  35. contextWindow: z.number(),
  36. maxOutputTokens: z.number().optional(),
  37. attachment: z.boolean(),
  38. reasoning: z.boolean().optional(),
  39. })
  40. .openapi({
  41. ref: "Provider.Model",
  42. })
  43. export type Model = z.output<typeof Model>
  44. export const Info = z
  45. .object({
  46. id: z.string(),
  47. name: z.string(),
  48. options: z.record(z.string(), z.any()).optional(),
  49. models: Model.array(),
  50. })
  51. .openapi({
  52. ref: "Provider.Info",
  53. })
  54. export type Info = z.output<typeof Info>
  55. const AUTODETECT: Record<string, string[]> = {
  56. anthropic: ["ANTHROPIC_API_KEY"],
  57. openai: ["OPENAI_API_KEY"],
  58. google: ["GOOGLE_GENERATIVE_AI_API_KEY"], // TODO: support GEMINI_API_KEY?
  59. }
  60. const state = App.state("provider", async () => {
  61. log.info("loading config")
  62. const config = await Config.get()
  63. log.info("loading providers")
  64. const providers = new Map<string, Info>()
  65. const models = new Map<string, { info: Model; language: LanguageModel }>()
  66. const sdk = new Map<string, SDK>()
  67. log.info("loading")
  68. for (const item of PROVIDER_DATABASE) {
  69. if (!AUTODETECT[item.id].some((env) => process.env[env])) continue
  70. log.info("found", { providerID: item.id })
  71. providers.set(item.id, item)
  72. }
  73. for (const item of config.provider ?? []) {
  74. log.info("found", { providerID: item.id })
  75. providers.set(item.id, item)
  76. }
  77. return {
  78. models,
  79. providers,
  80. sdk,
  81. }
  82. })
  83. export async function active() {
  84. return state().then((state) => state.providers)
  85. }
  86. async function getSDK(providerID: string) {
  87. const s = await state()
  88. if (s.sdk.has(providerID)) return s.sdk.get(providerID)!
  89. const dir = path.join(
  90. Global.Path.cache,
  91. `node_modules`,
  92. `@ai-sdk`,
  93. providerID,
  94. )
  95. if (!(await Bun.file(path.join(dir, "package.json")).exists())) {
  96. log.info("installing", {
  97. providerID,
  98. })
  99. BunProc.run(["add", `@ai-sdk/${providerID}@alpha`], {
  100. cwd: Global.Path.cache,
  101. })
  102. }
  103. const mod = await import(path.join(dir))
  104. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  105. const loaded = fn(s.providers.get(providerID)?.options)
  106. s.sdk.set(providerID, loaded)
  107. return loaded as SDK
  108. }
  109. export async function getModel(providerID: string, modelID: string) {
  110. const key = `${providerID}/${modelID}`
  111. const s = await state()
  112. if (s.models.has(key)) return s.models.get(key)!
  113. log.info("loading", {
  114. providerID,
  115. modelID,
  116. })
  117. const provider = s.providers.get(providerID)
  118. if (!provider) throw new ModelNotFoundError(modelID)
  119. const info = provider.models.find((m) => m.id === modelID)
  120. if (!info) throw new ModelNotFoundError(modelID)
  121. const sdk = await getSDK(providerID)
  122. if (!sdk) throw new ModelNotFoundError(modelID)
  123. try {
  124. const language = sdk.languageModel(modelID)
  125. log.info("found", { providerID, modelID })
  126. s.models.set(key, {
  127. info,
  128. language,
  129. })
  130. return {
  131. info,
  132. language,
  133. }
  134. } catch (e) {
  135. if (e instanceof NoSuchModelError) throw new ModelNotFoundError(modelID)
  136. throw e
  137. }
  138. }
  139. export async function defaultModel() {
  140. const [provider] = await active().then((val) => val.values().toArray())
  141. if (!provider) throw new Error("no providers found")
  142. const model = provider.models[0]
  143. if (!model) throw new Error("no models found")
  144. return {
  145. providerID: provider.id,
  146. modelID: model.id,
  147. }
  148. }
  149. const TOOLS = [
  150. BashTool,
  151. EditTool,
  152. WebFetchTool,
  153. GlobTool,
  154. GrepTool,
  155. ListTool,
  156. LspDiagnosticTool,
  157. LspHoverTool,
  158. PatchTool,
  159. ReadTool,
  160. EditTool,
  161. MultiEditTool,
  162. WriteTool,
  163. ]
  164. const TOOL_MAPPING: Record<string, Tool.Info[]> = {
  165. anthropic: TOOLS.filter((t) => t.id !== "opencode.patch"),
  166. openai: TOOLS,
  167. google: TOOLS,
  168. }
  169. export async function tools(providerID: string) {
  170. const cfg = await Config.get()
  171. if (cfg.tool?.provider?.[providerID])
  172. return cfg.tool.provider[providerID].map(
  173. (id) => TOOLS.find((t) => t.id === id)!,
  174. )
  175. return TOOL_MAPPING[providerID] ?? TOOLS
  176. }
  177. class ModelNotFoundError extends Error {
  178. constructor(public readonly model: string) {
  179. super()
  180. }
  181. }
  182. }