|
@@ -1,7 +1,7 @@
|
|
|
import z from "zod"
|
|
import z from "zod"
|
|
|
import { App } from "../app/app"
|
|
import { App } from "../app/app"
|
|
|
import { Config } from "../config/config"
|
|
import { Config } from "../config/config"
|
|
|
-import { PROVIDER_DATABASE } from "./database"
|
|
|
|
|
|
|
+import { mapValues, sortBy } from "remeda"
|
|
|
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
|
|
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
|
|
|
import { Log } from "../util/log"
|
|
import { Log } from "../util/log"
|
|
|
import path from "path"
|
|
import path from "path"
|
|
@@ -22,6 +22,7 @@ import type { Tool } from "../tool/tool"
|
|
|
import { WriteTool } from "../tool/write"
|
|
import { WriteTool } from "../tool/write"
|
|
|
import { TodoReadTool, TodoWriteTool } from "../tool/todo"
|
|
import { TodoReadTool, TodoWriteTool } from "../tool/todo"
|
|
|
import { AuthAnthropic } from "../auth/anthropic"
|
|
import { AuthAnthropic } from "../auth/anthropic"
|
|
|
|
|
+import { ModelsDev } from "./models"
|
|
|
|
|
|
|
|
export namespace Provider {
|
|
export namespace Provider {
|
|
|
const log = Log.create({ service: "provider" })
|
|
const log = Log.create({ service: "provider" })
|
|
@@ -30,16 +31,18 @@ export namespace Provider {
|
|
|
.object({
|
|
.object({
|
|
|
id: z.string(),
|
|
id: z.string(),
|
|
|
name: z.string().optional(),
|
|
name: z.string().optional(),
|
|
|
|
|
+ attachment: z.boolean(),
|
|
|
|
|
+ reasoning: z.boolean().optional(),
|
|
|
cost: z.object({
|
|
cost: z.object({
|
|
|
input: z.number(),
|
|
input: z.number(),
|
|
|
inputCached: z.number(),
|
|
inputCached: z.number(),
|
|
|
output: z.number(),
|
|
output: z.number(),
|
|
|
outputCached: z.number(),
|
|
outputCached: z.number(),
|
|
|
}),
|
|
}),
|
|
|
- contextWindow: z.number(),
|
|
|
|
|
- maxOutputTokens: z.number().optional(),
|
|
|
|
|
- attachment: z.boolean(),
|
|
|
|
|
- reasoning: z.boolean().optional(),
|
|
|
|
|
|
|
+ limit: z.object({
|
|
|
|
|
+ context: z.number(),
|
|
|
|
|
+ output: z.number(),
|
|
|
|
|
+ }),
|
|
|
})
|
|
})
|
|
|
.openapi({
|
|
.openapi({
|
|
|
ref: "Provider.Model",
|
|
ref: "Provider.Model",
|
|
@@ -50,23 +53,27 @@ export namespace Provider {
|
|
|
.object({
|
|
.object({
|
|
|
id: z.string(),
|
|
id: z.string(),
|
|
|
name: z.string(),
|
|
name: z.string(),
|
|
|
- options: z.record(z.string(), z.any()).optional(),
|
|
|
|
|
- models: Model.array(),
|
|
|
|
|
|
|
+ models: z.record(z.string(), Model),
|
|
|
})
|
|
})
|
|
|
.openapi({
|
|
.openapi({
|
|
|
ref: "Provider.Info",
|
|
ref: "Provider.Info",
|
|
|
})
|
|
})
|
|
|
export type Info = z.output<typeof Info>
|
|
export type Info = z.output<typeof Info>
|
|
|
|
|
|
|
|
- const AUTODETECT: Record<string, string[]> = {
|
|
|
|
|
- anthropic: ["ANTHROPIC_API_KEY"],
|
|
|
|
|
- openai: ["OPENAI_API_KEY"],
|
|
|
|
|
- google: ["GOOGLE_GENERATIVE_AI_API_KEY"], // TODO: support GEMINI_API_KEY?
|
|
|
|
|
|
|
+ type Autodetector = (provider: Info) => Promise<Record<string, any> | false>
|
|
|
|
|
+
|
|
|
|
|
+ function env(...keys: string[]): Autodetector {
|
|
|
|
|
+ return async () => {
|
|
|
|
|
+ for (const key of keys) {
|
|
|
|
|
+ if (process.env[key]) return {}
|
|
|
|
|
+ }
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const AUTODETECT2: Record<
|
|
|
|
|
|
|
+ const AUTODETECT: Record<
|
|
|
string,
|
|
string,
|
|
|
- () => Promise<Record<string, any> | false>
|
|
|
|
|
|
|
+ (provider: Info) => Promise<Record<string, any> | false>
|
|
|
> = {
|
|
> = {
|
|
|
anthropic: async () => {
|
|
anthropic: async () => {
|
|
|
const result = await AuthAnthropic.load()
|
|
const result = await AuthAnthropic.load()
|
|
@@ -78,44 +85,53 @@ export namespace Provider {
|
|
|
"anthropic-beta": "oauth-2025-04-20",
|
|
"anthropic-beta": "oauth-2025-04-20",
|
|
|
},
|
|
},
|
|
|
}
|
|
}
|
|
|
- if (process.env["ANTHROPIC_API_KEY"]) return {}
|
|
|
|
|
- return false
|
|
|
|
|
|
|
+ return env("ANTHROPIC_API_KEY")
|
|
|
},
|
|
},
|
|
|
|
|
+ google: env("GOOGLE_GENERATIVE_AI_API_KEY"),
|
|
|
|
|
+ openai: env("OPENAI_API_KEY"),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const state = App.state("provider", async () => {
|
|
const state = App.state("provider", async () => {
|
|
|
log.info("loading config")
|
|
log.info("loading config")
|
|
|
const config = await Config.get()
|
|
const config = await Config.get()
|
|
|
log.info("loading providers")
|
|
log.info("loading providers")
|
|
|
- const providers = new Map<string, Info>()
|
|
|
|
|
|
|
+ const database: Record<string, Provider.Info> = await ModelsDev.get()
|
|
|
|
|
+
|
|
|
|
|
+ const providers: {
|
|
|
|
|
+ [providerID: string]: {
|
|
|
|
|
+ info: Provider.Info
|
|
|
|
|
+ options: Record<string, any>
|
|
|
|
|
+ }
|
|
|
|
|
+ } = {}
|
|
|
const models = new Map<string, { info: Model; language: LanguageModel }>()
|
|
const models = new Map<string, { info: Model; language: LanguageModel }>()
|
|
|
const sdk = new Map<string, SDK>()
|
|
const sdk = new Map<string, SDK>()
|
|
|
|
|
|
|
|
log.info("loading")
|
|
log.info("loading")
|
|
|
|
|
|
|
|
- for (const [providerID, fn] of Object.entries(AUTODETECT2)) {
|
|
|
|
|
- const provider = PROVIDER_DATABASE.find((x) => x.id === providerID)
|
|
|
|
|
|
|
+ for (const [providerID, fn] of Object.entries(AUTODETECT)) {
|
|
|
|
|
+ const provider = database[providerID]
|
|
|
if (!provider) continue
|
|
if (!provider) continue
|
|
|
- const result = await fn()
|
|
|
|
|
- if (!result) continue
|
|
|
|
|
- providers.set(providerID, {
|
|
|
|
|
- ...provider,
|
|
|
|
|
- options: {
|
|
|
|
|
- ...provider.options,
|
|
|
|
|
- ...result,
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- for (const item of PROVIDER_DATABASE) {
|
|
|
|
|
- if (!AUTODETECT[item.id].some((env) => process.env[env])) continue
|
|
|
|
|
- log.info("found", { providerID: item.id })
|
|
|
|
|
- providers.set(item.id, item)
|
|
|
|
|
|
|
+ const options = await fn(provider)
|
|
|
|
|
+ if (!options) continue
|
|
|
|
|
+ providers[providerID] = {
|
|
|
|
|
+ info: provider,
|
|
|
|
|
+ options,
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- for (const item of config.provider ?? []) {
|
|
|
|
|
- log.info("found", { providerID: item.id })
|
|
|
|
|
- providers.set(item.id, item)
|
|
|
|
|
|
|
+ for (const [providerID, options] of Object.entries(config.provider ?? {})) {
|
|
|
|
|
+ const existing = providers[providerID]
|
|
|
|
|
+ if (existing) {
|
|
|
|
|
+ existing.options = {
|
|
|
|
|
+ ...existing.options,
|
|
|
|
|
+ ...options,
|
|
|
|
|
+ }
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ providers[providerID] = {
|
|
|
|
|
+ info: database[providerID],
|
|
|
|
|
+ options,
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
@@ -126,7 +142,9 @@ export namespace Provider {
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
export async function active() {
|
|
export async function active() {
|
|
|
- return state().then((state) => state.providers)
|
|
|
|
|
|
|
+ return state().then((state) =>
|
|
|
|
|
+ mapValues(state.providers, (item) => item.info),
|
|
|
|
|
+ )
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
async function getSDK(providerID: string) {
|
|
async function getSDK(providerID: string) {
|
|
@@ -149,7 +167,7 @@ export namespace Provider {
|
|
|
}
|
|
}
|
|
|
const mod = await import(path.join(dir))
|
|
const mod = await import(path.join(dir))
|
|
|
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
|
|
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
|
|
|
- const loaded = fn(s.providers.get(providerID)?.options)
|
|
|
|
|
|
|
+ const loaded = fn(s.providers[providerID]?.options)
|
|
|
s.sdk.set(providerID, loaded)
|
|
s.sdk.set(providerID, loaded)
|
|
|
return loaded as SDK
|
|
return loaded as SDK
|
|
|
}
|
|
}
|
|
@@ -164,9 +182,9 @@ export namespace Provider {
|
|
|
modelID,
|
|
modelID,
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
- const provider = s.providers.get(providerID)
|
|
|
|
|
|
|
+ const provider = s.providers[providerID]
|
|
|
if (!provider) throw new ModelNotFoundError(modelID)
|
|
if (!provider) throw new ModelNotFoundError(modelID)
|
|
|
- const info = provider.models.find((m) => m.id === modelID)
|
|
|
|
|
|
|
+ const info = provider.info.models[modelID]
|
|
|
if (!info) throw new ModelNotFoundError(modelID)
|
|
if (!info) throw new ModelNotFoundError(modelID)
|
|
|
|
|
|
|
|
const sdk = await getSDK(providerID)
|
|
const sdk = await getSDK(providerID)
|
|
@@ -189,10 +207,20 @@ export namespace Provider {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ const priority = ["claude-sonnet-4", "gemini-2.5-pro-preview", "codex-mini"]
|
|
|
|
|
+ export function sort(models: Model[]) {
|
|
|
|
|
+ return sortBy(
|
|
|
|
|
+ models,
|
|
|
|
|
+ [(model) => priority.indexOf(model.id), "desc"],
|
|
|
|
|
+ [(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
|
|
|
|
|
+ [(model) => model.id, "desc"],
|
|
|
|
|
+ )
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
export async function defaultModel() {
|
|
export async function defaultModel() {
|
|
|
- const [provider] = await active().then((val) => val.values().toArray())
|
|
|
|
|
|
|
+ const [provider] = await active().then((val) => Object.values(val))
|
|
|
if (!provider) throw new Error("no providers found")
|
|
if (!provider) throw new Error("no providers found")
|
|
|
- const model = provider.models[0]
|
|
|
|
|
|
|
+ const [model] = sort(Object.values(provider.models))
|
|
|
if (!model) throw new Error("no models found")
|
|
if (!model) throw new Error("no models found")
|
|
|
return {
|
|
return {
|
|
|
providerID: provider.id,
|
|
providerID: provider.id,
|