model.ts 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. const TrialSchema = z.object({
  12. provider: z.string(),
  13. limits: z.array(
  14. z.object({
  15. limit: z.number(),
  16. client: z.enum(["cli", "desktop"]).optional(),
  17. }),
  18. ),
  19. })
  20. export type Format = z.infer<typeof FormatSchema>
  21. export type Trial = z.infer<typeof TrialSchema>
  22. const ModelCostSchema = z.object({
  23. input: z.number(),
  24. output: z.number(),
  25. cacheRead: z.number().optional(),
  26. cacheWrite5m: z.number().optional(),
  27. cacheWrite1h: z.number().optional(),
  28. })
  29. const ModelSchema = z.object({
  30. name: z.string(),
  31. cost: ModelCostSchema,
  32. cost200K: ModelCostSchema.optional(),
  33. allowAnonymous: z.boolean().optional(),
  34. byokProvider: z.enum(["openai", "anthropic", "google"]).optional(),
  35. stickyProvider: z.boolean().optional(),
  36. trial: TrialSchema.optional(),
  37. rateLimit: z.number().optional(),
  38. fallbackProvider: z.string().optional(),
  39. providers: z.array(
  40. z.object({
  41. id: z.string(),
  42. model: z.string(),
  43. weight: z.number().optional(),
  44. disabled: z.boolean().optional(),
  45. storeModel: z.string().optional(),
  46. }),
  47. ),
  48. })
  49. const ProviderSchema = z.object({
  50. api: z.string(),
  51. apiKey: z.string(),
  52. format: FormatSchema,
  53. headerMappings: z.record(z.string(), z.string()).optional(),
  54. })
  55. const ModelsSchema = z.object({
  56. models: z.record(z.string(), z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))])),
  57. providers: z.record(z.string(), ProviderSchema),
  58. })
  59. export const validate = fn(ModelsSchema, (input) => {
  60. return input
  61. })
  62. export const list = fn(z.void(), () => {
  63. const json = JSON.parse(
  64. Resource.ZEN_MODELS1.value +
  65. Resource.ZEN_MODELS2.value +
  66. Resource.ZEN_MODELS3.value +
  67. Resource.ZEN_MODELS4.value +
  68. Resource.ZEN_MODELS5.value +
  69. Resource.ZEN_MODELS6.value,
  70. )
  71. return ModelsSchema.parse(json)
  72. })
  73. }
  74. export namespace Model {
  75. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  76. Actor.assertAdmin()
  77. return Database.use((db) =>
  78. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  79. )
  80. })
  81. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  82. Actor.assertAdmin()
  83. return Database.use((db) =>
  84. db
  85. .insert(ModelTable)
  86. .values({
  87. id: Identifier.create("model"),
  88. workspaceID: Actor.workspace(),
  89. model: model,
  90. })
  91. .onDuplicateKeyUpdate({
  92. set: {
  93. timeDeleted: null,
  94. },
  95. }),
  96. )
  97. })
  98. export const listDisabled = fn(z.void(), () => {
  99. return Database.use((db) =>
  100. db
  101. .select({ model: ModelTable.model })
  102. .from(ModelTable)
  103. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  104. .then((rows) => rows.map((row) => row.model)),
  105. )
  106. })
  107. export const isDisabled = fn(
  108. z.object({
  109. model: z.string(),
  110. }),
  111. ({ model }) => {
  112. return Database.use(async (db) => {
  113. const result = await db
  114. .select()
  115. .from(ModelTable)
  116. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  117. .limit(1)
  118. return result.length > 0
  119. })
  120. },
  121. )
  122. }