model.ts 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. export type Format = z.infer<typeof FormatSchema>
  12. const ModelCostSchema = z.object({
  13. input: z.number(),
  14. output: z.number(),
  15. cacheRead: z.number().optional(),
  16. cacheWrite5m: z.number().optional(),
  17. cacheWrite1h: z.number().optional(),
  18. })
  19. const ModelSchema = z.object({
  20. name: z.string(),
  21. cost: ModelCostSchema,
  22. cost200K: ModelCostSchema.optional(),
  23. allowAnonymous: z.boolean().optional(),
  24. rateLimit: z.number().optional(),
  25. fallbackProvider: z.string().optional(),
  26. providers: z.array(
  27. z.object({
  28. id: z.string(),
  29. model: z.string(),
  30. weight: z.number().optional(),
  31. disabled: z.boolean().optional(),
  32. storeModel: z.string().optional(),
  33. }),
  34. ),
  35. })
  36. const ProviderSchema = z.object({
  37. api: z.string(),
  38. apiKey: z.string(),
  39. format: FormatSchema,
  40. headerMappings: z.record(z.string(), z.string()).optional(),
  41. })
  42. const ModelsSchema = z.object({
  43. models: z.record(z.string(), ModelSchema),
  44. providers: z.record(z.string(), ProviderSchema),
  45. })
  46. export const validate = fn(ModelsSchema, (input) => {
  47. return input
  48. })
  49. export const list = fn(z.void(), () => {
  50. const json = JSON.parse(
  51. Resource.ZEN_MODELS1.value + Resource.ZEN_MODELS2.value + Resource.ZEN_MODELS3.value + Resource.ZEN_MODELS4.value,
  52. )
  53. return ModelsSchema.parse(json)
  54. })
  55. }
  56. export namespace Model {
  57. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  58. Actor.assertAdmin()
  59. return Database.use((db) =>
  60. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  61. )
  62. })
  63. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  64. Actor.assertAdmin()
  65. return Database.use((db) =>
  66. db
  67. .insert(ModelTable)
  68. .values({
  69. id: Identifier.create("model"),
  70. workspaceID: Actor.workspace(),
  71. model: model,
  72. })
  73. .onDuplicateKeyUpdate({
  74. set: {
  75. timeDeleted: null,
  76. },
  77. }),
  78. )
  79. })
  80. export const listDisabled = fn(z.void(), () => {
  81. return Database.use((db) =>
  82. db
  83. .select({ model: ModelTable.model })
  84. .from(ModelTable)
  85. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  86. .then((rows) => rows.map((row) => row.model)),
  87. )
  88. })
  89. export const isDisabled = fn(
  90. z.object({
  91. model: z.string(),
  92. }),
  93. ({ model }) => {
  94. return Database.use(async (db) => {
  95. const result = await db
  96. .select()
  97. .from(ModelTable)
  98. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  99. .limit(1)
  100. return result.length > 0
  101. })
  102. },
  103. )
  104. }