model.ts 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. export type Format = z.infer<typeof FormatSchema>
  12. const ModelCostSchema = z.object({
  13. input: z.number(),
  14. output: z.number(),
  15. cacheRead: z.number().optional(),
  16. cacheWrite5m: z.number().optional(),
  17. cacheWrite1h: z.number().optional(),
  18. })
  19. const ModelSchema = z.object({
  20. name: z.string(),
  21. cost: ModelCostSchema,
  22. cost200K: ModelCostSchema.optional(),
  23. allowAnonymous: z.boolean().optional(),
  24. byokProvider: z.enum(["openai", "anthropic", "google"]).optional(),
  25. stickyProvider: z.enum(["strict", "prefer"]).optional(),
  26. trialProvider: z.string().optional(),
  27. trialEnded: z.boolean().optional(),
  28. fallbackProvider: z.string().optional(),
  29. rateLimit: z.number().optional(),
  30. providers: z.array(
  31. z.object({
  32. id: z.string(),
  33. model: z.string(),
  34. weight: z.number().optional(),
  35. disabled: z.boolean().optional(),
  36. storeModel: z.string().optional(),
  37. payloadModifier: z.record(z.string(), z.any()).optional(),
  38. safetyIdentifier: z.boolean().optional(),
  39. }),
  40. ),
  41. })
  42. const ProviderSchema = z.object({
  43. displayName: z.string().optional(),
  44. api: z.string(),
  45. apiKey: z.union([z.string(), z.record(z.string(), z.string())]),
  46. format: FormatSchema.optional(),
  47. headerMappings: z.record(z.string(), z.string()).optional(),
  48. payloadModifier: z.record(z.string(), z.any()).optional(),
  49. payloadMappings: z.record(z.string(), z.string()).optional(),
  50. adjustCacheUsage: z.boolean().optional(),
  51. })
  52. const ModelsSchema = z.object({
  53. zenModels: z.record(
  54. z.string(),
  55. z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))]),
  56. ),
  57. liteModels: z.record(
  58. z.string(),
  59. z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))]),
  60. ),
  61. providers: z.record(z.string(), ProviderSchema),
  62. })
  63. export const validate = fn(ModelsSchema, (input) => {
  64. return input
  65. })
  66. export const list = fn(z.enum(["lite", "full"]), (modelList) => {
  67. const json = JSON.parse(
  68. Resource.ZEN_MODELS1.value +
  69. Resource.ZEN_MODELS2.value +
  70. Resource.ZEN_MODELS3.value +
  71. Resource.ZEN_MODELS4.value +
  72. Resource.ZEN_MODELS5.value +
  73. Resource.ZEN_MODELS6.value +
  74. Resource.ZEN_MODELS7.value +
  75. Resource.ZEN_MODELS8.value +
  76. Resource.ZEN_MODELS9.value +
  77. Resource.ZEN_MODELS10.value +
  78. Resource.ZEN_MODELS11.value +
  79. Resource.ZEN_MODELS12.value +
  80. Resource.ZEN_MODELS13.value +
  81. Resource.ZEN_MODELS14.value +
  82. Resource.ZEN_MODELS15.value +
  83. Resource.ZEN_MODELS16.value +
  84. Resource.ZEN_MODELS17.value +
  85. Resource.ZEN_MODELS18.value +
  86. Resource.ZEN_MODELS19.value +
  87. Resource.ZEN_MODELS20.value +
  88. Resource.ZEN_MODELS21.value +
  89. Resource.ZEN_MODELS22.value +
  90. Resource.ZEN_MODELS23.value +
  91. Resource.ZEN_MODELS24.value +
  92. Resource.ZEN_MODELS25.value +
  93. Resource.ZEN_MODELS26.value +
  94. Resource.ZEN_MODELS27.value +
  95. Resource.ZEN_MODELS28.value +
  96. Resource.ZEN_MODELS29.value +
  97. Resource.ZEN_MODELS30.value,
  98. )
  99. const { zenModels, liteModels, providers } = ModelsSchema.parse(json)
  100. const compositeProviders = Object.fromEntries(
  101. Object.entries(providers).map(([id, provider]) => [
  102. id,
  103. typeof provider.apiKey === "string"
  104. ? [{ id: id, key: provider.apiKey }]
  105. : Object.entries(provider.apiKey).map(([kid, key]) => ({
  106. id: `${id}.${kid}`,
  107. key,
  108. })),
  109. ]),
  110. )
  111. return {
  112. providers: Object.fromEntries(
  113. Object.entries(providers).flatMap(([providerId, provider]) =>
  114. compositeProviders[providerId].map((p) => [p.id, { ...provider, apiKey: p.key }]),
  115. ),
  116. ),
  117. models: (() => {
  118. const normalize = (model: z.infer<typeof ModelSchema>) => {
  119. const composite = model.providers.find((p) => compositeProviders[p.id].length > 1)
  120. if (!composite)
  121. return {
  122. trialProvider: model.trialProvider ? [model.trialProvider] : undefined,
  123. }
  124. const weightMulti = compositeProviders[composite.id].length
  125. return {
  126. trialProvider: (() => {
  127. if (!model.trialProvider) return undefined
  128. if (model.trialProvider === composite.id) return compositeProviders[composite.id].map((p) => p.id)
  129. return [model.trialProvider]
  130. })(),
  131. providers: model.providers.flatMap((p) =>
  132. p.id === composite.id
  133. ? compositeProviders[p.id].map((sub) => ({
  134. ...p,
  135. id: sub.id,
  136. weight: p.weight ?? 1,
  137. }))
  138. : [
  139. {
  140. ...p,
  141. weight: (p.weight ?? 1) * weightMulti,
  142. },
  143. ],
  144. ),
  145. }
  146. }
  147. return Object.fromEntries(
  148. Object.entries(modelList === "lite" ? liteModels : zenModels).map(([modelId, model]) => {
  149. const n = Array.isArray(model)
  150. ? model.map((m) => ({ ...m, ...normalize(m) }))
  151. : { ...model, ...normalize(model) }
  152. return [modelId, n]
  153. }),
  154. )
  155. })(),
  156. }
  157. })
  158. }
  159. export namespace Model {
  160. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  161. Actor.assertAdmin()
  162. return Database.use((db) =>
  163. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  164. )
  165. })
  166. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  167. Actor.assertAdmin()
  168. return Database.use((db) =>
  169. db
  170. .insert(ModelTable)
  171. .values({
  172. id: Identifier.create("model"),
  173. workspaceID: Actor.workspace(),
  174. model: model,
  175. })
  176. .onDuplicateKeyUpdate({
  177. set: {
  178. timeDeleted: null,
  179. },
  180. }),
  181. )
  182. })
  183. export const listDisabled = fn(z.void(), () => {
  184. return Database.use((db) =>
  185. db
  186. .select({ model: ModelTable.model })
  187. .from(ModelTable)
  188. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  189. .then((rows) => rows.map((row) => row.model)),
  190. )
  191. })
  192. export const isDisabled = fn(
  193. z.object({
  194. model: z.string(),
  195. }),
  196. ({ model }) => {
  197. return Database.use(async (db) => {
  198. const result = await db
  199. .select()
  200. .from(ModelTable)
  201. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  202. .limit(1)
  203. return result.length > 0
  204. })
  205. },
  206. )
  207. }