provider.ts 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import z from "zod"
  2. import { App } from "../app/app"
  3. import { Config } from "../config/config"
  4. import { PROVIDER_DATABASE } from "./database"
  5. import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
  6. import { Log } from "../util/log"
  7. import path from "path"
  8. import { Global } from "../global"
  9. import { BunProc } from "../bun"
  10. import { BashTool } from "../tool/bash"
  11. import { EditTool } from "../tool/edit"
  12. import { FetchTool } from "../tool/fetch"
  13. import { GlobTool } from "../tool/glob"
  14. import { GrepTool } from "../tool/grep"
  15. import { ListTool } from "../tool/ls"
  16. import { LspDiagnosticTool } from "../tool/lsp-diagnostics"
  17. import { LspHoverTool } from "../tool/lsp-hover"
  18. import { PatchTool } from "../tool/patch"
  19. import { ViewTool } from "../tool/view"
  20. import type { Tool } from "../tool/tool"
  21. export namespace Provider {
  22. const log = Log.create({ service: "provider" })
  23. export const Model = z
  24. .object({
  25. id: z.string(),
  26. name: z.string().optional(),
  27. cost: z.object({
  28. input: z.number(),
  29. inputCached: z.number(),
  30. output: z.number(),
  31. outputCached: z.number(),
  32. }),
  33. contextWindow: z.number(),
  34. maxOutputTokens: z.number().optional(),
  35. attachment: z.boolean(),
  36. reasoning: z.boolean().optional(),
  37. })
  38. .openapi({
  39. ref: "Provider.Model",
  40. })
  41. export type Model = z.output<typeof Model>
  42. export const Info = z
  43. .object({
  44. id: z.string(),
  45. name: z.string(),
  46. options: z.record(z.string(), z.any()).optional(),
  47. models: Model.array(),
  48. })
  49. .openapi({
  50. ref: "Provider.Info",
  51. })
  52. export type Info = z.output<typeof Info>
  53. const AUTODETECT: Record<string, string[]> = {
  54. anthropic: ["ANTHROPIC_API_KEY"],
  55. openai: ["OPENAI_API_KEY"],
  56. google: ["GOOGLE_GENERATIVE_AI_API_KEY", "GEMINI_API_KEY"],
  57. }
  58. const state = App.state("provider", async () => {
  59. const config = await Config.get()
  60. const providers = new Map<string, Info>()
  61. const models = new Map<string, { info: Model; language: LanguageModel }>()
  62. const sdk = new Map<string, SDK>()
  63. for (const item of PROVIDER_DATABASE) {
  64. if (!AUTODETECT[item.id].some((env) => process.env[env])) continue
  65. providers.set(item.id, item)
  66. }
  67. for (const item of config.provider ?? []) {
  68. providers.set(item.id, item)
  69. }
  70. return {
  71. models,
  72. providers,
  73. sdk,
  74. }
  75. })
  76. export async function active() {
  77. return state().then((state) => state.providers)
  78. }
  79. async function getSDK(providerID: string) {
  80. const s = await state()
  81. if (s.sdk.has(providerID)) return s.sdk.get(providerID)!
  82. const dir = path.join(Global.cache(), `node_modules`, `@ai-sdk`, providerID)
  83. if (!(await Bun.file(path.join(dir, "package.json")).exists())) {
  84. log.info("installing", {
  85. providerID,
  86. })
  87. BunProc.run(["add", `@ai-sdk/${providerID}@alpha`], {
  88. cwd: Global.cache(),
  89. })
  90. }
  91. const mod = await import(path.join(dir))
  92. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
  93. const loaded = fn(s.providers.get(providerID)?.options)
  94. s.sdk.set(providerID, loaded)
  95. return loaded as SDK
  96. }
  97. export async function getModel(providerID: string, modelID: string) {
  98. const key = `${providerID}/${modelID}`
  99. const s = await state()
  100. if (s.models.has(key)) return s.models.get(key)!
  101. log.info("loading", {
  102. providerID,
  103. modelID,
  104. })
  105. const provider = s.providers.get(providerID)
  106. if (!provider) throw new ModelNotFoundError(modelID)
  107. const info = provider.models.find((m) => m.id === modelID)
  108. if (!info) throw new ModelNotFoundError(modelID)
  109. const sdk = await getSDK(providerID)
  110. if (!sdk) throw new ModelNotFoundError(modelID)
  111. try {
  112. const language = sdk.languageModel(modelID)
  113. log.info("found", { providerID, modelID })
  114. s.models.set(key, {
  115. info,
  116. language,
  117. })
  118. return {
  119. info,
  120. language,
  121. }
  122. } catch (e) {
  123. if (e instanceof NoSuchModelError) throw new ModelNotFoundError(modelID)
  124. throw e
  125. }
  126. }
  127. export async function defaultModel() {
  128. const [provider] = await active().then((val) => val.values().toArray())
  129. if (!provider) throw new Error("no providers found")
  130. const model = provider.models[0]
  131. if (!model) throw new Error("no models found")
  132. return {
  133. providerID: provider.id,
  134. modelID: model.id,
  135. }
  136. }
  137. const TOOLS = [
  138. BashTool,
  139. EditTool,
  140. FetchTool,
  141. GlobTool,
  142. GrepTool,
  143. ListTool,
  144. LspDiagnosticTool,
  145. LspHoverTool,
  146. PatchTool,
  147. ViewTool,
  148. EditTool,
  149. ]
  150. const TOOL_MAPPING: Record<string, Tool.Info[]> = {
  151. anthropic: TOOLS,
  152. openai: TOOLS,
  153. google: TOOLS,
  154. }
  155. export async function tools(providerID: string) {
  156. const cfg = await Config.get()
  157. if (cfg.tool?.provider?.[providerID])
  158. return cfg.tool.provider[providerID].map(
  159. (id) => TOOLS.find((t) => t.id === id)!,
  160. )
  161. return TOOL_MAPPING[providerID] ?? TOOLS
  162. }
  163. class ModelNotFoundError extends Error {
  164. constructor(public readonly model: string) {
  165. super()
  166. }
  167. }
  168. }