model.ts 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. export type Format = z.infer<typeof FormatSchema>
  12. const ModelCostSchema = z.object({
  13. input: z.number(),
  14. output: z.number(),
  15. cacheRead: z.number().optional(),
  16. cacheWrite5m: z.number().optional(),
  17. cacheWrite1h: z.number().optional(),
  18. })
  19. const ModelSchema = z.object({
  20. name: z.string(),
  21. cost: ModelCostSchema,
  22. cost200K: ModelCostSchema.optional(),
  23. allowAnonymous: z.boolean().optional(),
  24. rateLimit: z.number().optional(),
  25. fallbackProvider: z.string().optional(),
  26. providers: z.array(
  27. z.object({
  28. id: z.string(),
  29. model: z.string(),
  30. weight: z.number().optional(),
  31. disabled: z.boolean().optional(),
  32. }),
  33. ),
  34. })
  35. const ProviderSchema = z.object({
  36. api: z.string(),
  37. apiKey: z.string(),
  38. format: FormatSchema,
  39. headerMappings: z.record(z.string(), z.string()).optional(),
  40. })
  41. const ModelsSchema = z.object({
  42. models: z.record(z.string(), ModelSchema),
  43. providers: z.record(z.string(), ProviderSchema),
  44. })
  45. export const validate = fn(ModelsSchema, (input) => {
  46. return input
  47. })
  48. export const list = fn(z.void(), () => {
  49. const json = JSON.parse(Resource.ZEN_MODELS1.value + Resource.ZEN_MODELS2.value)
  50. return ModelsSchema.parse(json)
  51. })
  52. }
  53. export namespace Model {
  54. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  55. Actor.assertAdmin()
  56. return Database.use((db) =>
  57. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  58. )
  59. })
  60. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  61. Actor.assertAdmin()
  62. return Database.use((db) =>
  63. db
  64. .insert(ModelTable)
  65. .values({
  66. id: Identifier.create("model"),
  67. workspaceID: Actor.workspace(),
  68. model: model,
  69. })
  70. .onDuplicateKeyUpdate({
  71. set: {
  72. timeDeleted: null,
  73. },
  74. }),
  75. )
  76. })
  77. export const listDisabled = fn(z.void(), () => {
  78. return Database.use((db) =>
  79. db
  80. .select({ model: ModelTable.model })
  81. .from(ModelTable)
  82. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  83. .then((rows) => rows.map((row) => row.model)),
  84. )
  85. })
  86. export const isDisabled = fn(
  87. z.object({
  88. model: z.string(),
  89. }),
  90. ({ model }) => {
  91. return Database.use(async (db) => {
  92. const result = await db
  93. .select()
  94. .from(ModelTable)
  95. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  96. .limit(1)
  97. return result.length > 0
  98. })
  99. },
  100. )
  101. }