llm.ts 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import { App } from "../app";
  2. import { Log } from "../util/log";
  3. import { mergeDeep } from "remeda";
  4. import path from "node:path";
  5. import type { LanguageModel, Provider } from "ai";
  6. import { NoSuchModelError } from "ai";
  7. import type { Config } from "../app/config";
  8. import { BunProc } from "../bun";
  9. import { Global } from "../global";
  10. export namespace LLM {
  11. const log = Log.create({ service: "llm" });
  12. export class ModelNotFoundError extends Error {
  13. constructor(public readonly model: string) {
  14. super();
  15. }
  16. }
  17. const NATIVE_PROVIDERS: Record<string, Config.Provider> = {
  18. anthropic: {
  19. models: {
  20. "claude-sonnet-4-20250514": {
  21. name: "Claude 4 Sonnet",
  22. cost: {
  23. input: 3.0,
  24. inputCached: 3.75,
  25. output: 15.0,
  26. outputCached: 0.3,
  27. },
  28. contextWindow: 200000,
  29. maxTokens: 50000,
  30. attachment: true,
  31. },
  32. },
  33. },
  34. };
  35. const AUTODETECT: Record<string, string[]> = {
  36. anthropic: ["ANTHROPIC_API_KEY"],
  37. };
  38. const state = App.state("llm", async (app) => {
  39. const providers: Record<
  40. string,
  41. {
  42. info: Config.Provider;
  43. instance: Provider;
  44. }
  45. > = {};
  46. const list = mergeDeep(NATIVE_PROVIDERS, app.config.providers ?? {});
  47. for (const [providerID, providerInfo] of Object.entries(list)) {
  48. if (
  49. !app.config.providers?.[providerID] &&
  50. !AUTODETECT[providerID]?.some((env) => process.env[env])
  51. )
  52. continue;
  53. const dir = path.join(
  54. Global.cache(),
  55. `node_modules`,
  56. `@ai-sdk`,
  57. providerID,
  58. );
  59. if (!(await Bun.file(path.join(dir, "package.json")).exists())) {
  60. BunProc.run(["add", "--exact", `@ai-sdk/${providerID}@alpha`], {
  61. cwd: Global.cache(),
  62. });
  63. }
  64. const mod = await import(
  65. path.join(Global.cache(), `node_modules`, `@ai-sdk`, providerID)
  66. );
  67. const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!];
  68. const loaded = fn(providerInfo.options);
  69. log.info("loaded", { provider: providerID });
  70. providers[providerID] = {
  71. info: providerInfo,
  72. instance: loaded,
  73. };
  74. }
  75. return {
  76. models: new Map<string, LanguageModel>(),
  77. providers,
  78. };
  79. });
  80. export async function providers() {
  81. return state().then((state) => state.providers);
  82. }
  83. export async function findModel(providerID: string, modelID: string) {
  84. const key = `${providerID}/${modelID}`;
  85. const s = await state();
  86. if (s.models.has(key)) return s.models.get(key)!;
  87. const provider = s.providers[providerID];
  88. if (!provider) throw new ModelNotFoundError(modelID);
  89. log.info("loading", {
  90. providerID,
  91. modelID,
  92. });
  93. try {
  94. const match = provider.instance.languageModel(modelID);
  95. log.info("found", { providerID, modelID });
  96. s.models.set(key, match);
  97. return match;
  98. } catch (e) {
  99. if (e instanceof NoSuchModelError) throw new ModelNotFoundError(modelID);
  100. throw e;
  101. }
  102. }
  103. }