model.ts 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. const TrialSchema = z.object({
  12. provider: z.string(),
  13. limits: z.array(
  14. z.object({
  15. limit: z.number(),
  16. client: z.enum(["cli", "desktop"]).optional(),
  17. }),
  18. ),
  19. })
  20. const RateLimitSchema = z.object({
  21. period: z.enum(["day", "rolling"]),
  22. value: z.number().int(),
  23. checkHeader: z.string().optional(),
  24. fallbackValue: z.number().int().optional(),
  25. })
  26. export type Format = z.infer<typeof FormatSchema>
  27. export type Trial = z.infer<typeof TrialSchema>
  28. export type RateLimit = z.infer<typeof RateLimitSchema>
  29. const ModelCostSchema = z.object({
  30. input: z.number(),
  31. output: z.number(),
  32. cacheRead: z.number().optional(),
  33. cacheWrite5m: z.number().optional(),
  34. cacheWrite1h: z.number().optional(),
  35. })
  36. const ModelSchema = z.object({
  37. name: z.string(),
  38. cost: ModelCostSchema,
  39. cost200K: ModelCostSchema.optional(),
  40. allowAnonymous: z.boolean().optional(),
  41. byokProvider: z.enum(["openai", "anthropic", "google"]).optional(),
  42. stickyProvider: z.enum(["strict", "prefer"]).optional(),
  43. trial: TrialSchema.optional(),
  44. rateLimit: RateLimitSchema.optional(),
  45. fallbackProvider: z.string().optional(),
  46. providers: z.array(
  47. z.object({
  48. id: z.string(),
  49. model: z.string(),
  50. weight: z.number().optional(),
  51. disabled: z.boolean().optional(),
  52. storeModel: z.string().optional(),
  53. }),
  54. ),
  55. })
  56. const ProviderSchema = z.object({
  57. api: z.string(),
  58. apiKey: z.string(),
  59. format: FormatSchema.optional(),
  60. headerMappings: z.record(z.string(), z.string()).optional(),
  61. payloadModifier: z.record(z.string(), z.any()).optional(),
  62. family: z.string().optional(),
  63. })
  64. const ProviderFamilySchema = z.object({
  65. headers: z.record(z.string(), z.string()).optional(),
  66. responseModifier: z.record(z.string(), z.string()).optional(),
  67. })
  68. const ModelsSchema = z.object({
  69. models: z.record(z.string(), z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))])),
  70. providers: z.record(z.string(), ProviderSchema),
  71. providerFamilies: z.record(z.string(), ProviderFamilySchema),
  72. })
  73. export const validate = fn(ModelsSchema, (input) => {
  74. return input
  75. })
  76. export const list = fn(z.void(), () => {
  77. const json = JSON.parse(
  78. Resource.ZEN_MODELS1.value +
  79. Resource.ZEN_MODELS2.value +
  80. Resource.ZEN_MODELS3.value +
  81. Resource.ZEN_MODELS4.value +
  82. Resource.ZEN_MODELS5.value +
  83. Resource.ZEN_MODELS6.value +
  84. Resource.ZEN_MODELS7.value +
  85. Resource.ZEN_MODELS8.value +
  86. Resource.ZEN_MODELS9.value +
  87. Resource.ZEN_MODELS10.value +
  88. Resource.ZEN_MODELS11.value +
  89. Resource.ZEN_MODELS12.value +
  90. Resource.ZEN_MODELS13.value +
  91. Resource.ZEN_MODELS14.value +
  92. Resource.ZEN_MODELS15.value +
  93. Resource.ZEN_MODELS16.value +
  94. Resource.ZEN_MODELS17.value +
  95. Resource.ZEN_MODELS18.value +
  96. Resource.ZEN_MODELS19.value +
  97. Resource.ZEN_MODELS20.value +
  98. Resource.ZEN_MODELS21.value +
  99. Resource.ZEN_MODELS22.value +
  100. Resource.ZEN_MODELS23.value +
  101. Resource.ZEN_MODELS24.value +
  102. Resource.ZEN_MODELS25.value +
  103. Resource.ZEN_MODELS26.value +
  104. Resource.ZEN_MODELS27.value +
  105. Resource.ZEN_MODELS28.value +
  106. Resource.ZEN_MODELS29.value +
  107. Resource.ZEN_MODELS30.value,
  108. )
  109. const { models, providers, providerFamilies } = ModelsSchema.parse(json)
  110. return {
  111. models,
  112. providers: Object.fromEntries(
  113. Object.entries(providers).map(([id, provider]) => [
  114. id,
  115. { ...provider, ...(provider.family ? providerFamilies[provider.family] : {}) },
  116. ]),
  117. ),
  118. }
  119. })
  120. }
  121. export namespace Model {
  122. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  123. Actor.assertAdmin()
  124. return Database.use((db) =>
  125. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  126. )
  127. })
  128. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  129. Actor.assertAdmin()
  130. return Database.use((db) =>
  131. db
  132. .insert(ModelTable)
  133. .values({
  134. id: Identifier.create("model"),
  135. workspaceID: Actor.workspace(),
  136. model: model,
  137. })
  138. .onDuplicateKeyUpdate({
  139. set: {
  140. timeDeleted: null,
  141. },
  142. }),
  143. )
  144. })
  145. export const listDisabled = fn(z.void(), () => {
  146. return Database.use((db) =>
  147. db
  148. .select({ model: ModelTable.model })
  149. .from(ModelTable)
  150. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  151. .then((rows) => rows.map((row) => row.model)),
  152. )
  153. })
  154. export const isDisabled = fn(
  155. z.object({
  156. model: z.string(),
  157. }),
  158. ({ model }) => {
  159. return Database.use(async (db) => {
  160. const result = await db
  161. .select()
  162. .from(ModelTable)
  163. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  164. .limit(1)
  165. return result.length > 0
  166. })
  167. },
  168. )
  169. }