model.ts 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const ModelCostSchema = z.object({
  11. input: z.number(),
  12. output: z.number(),
  13. cacheRead: z.number().optional(),
  14. cacheWrite5m: z.number().optional(),
  15. cacheWrite1h: z.number().optional(),
  16. })
  17. const ModelSchema = z.object({
  18. name: z.string(),
  19. cost: ModelCostSchema,
  20. cost200K: ModelCostSchema.optional(),
  21. allowAnonymous: z.boolean().optional(),
  22. providers: z.array(
  23. z.object({
  24. id: z.string(),
  25. model: z.string(),
  26. weight: z.number().optional(),
  27. disabled: z.boolean().optional(),
  28. }),
  29. ),
  30. })
  31. const ProviderSchema = z.object({
  32. api: z.string(),
  33. apiKey: z.string(),
  34. headerMappings: z.record(z.string(), z.string()).optional(),
  35. })
  36. const ModelsSchema = z.object({
  37. models: z.record(z.string(), ModelSchema),
  38. providers: z.record(z.string(), ProviderSchema),
  39. })
  40. export const validate = fn(ModelsSchema, (input) => {
  41. return input
  42. })
  43. export const list = fn(z.void(), () => {
  44. const json = JSON.parse(Resource.ZEN_MODELS.value)
  45. return ModelsSchema.parse(json)
  46. })
  47. }
  48. export namespace Model {
  49. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  50. Actor.assertAdmin()
  51. return Database.use((db) =>
  52. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  53. )
  54. })
  55. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  56. Actor.assertAdmin()
  57. return Database.use((db) =>
  58. db
  59. .insert(ModelTable)
  60. .values({
  61. id: Identifier.create("model"),
  62. workspaceID: Actor.workspace(),
  63. model: model,
  64. })
  65. .onDuplicateKeyUpdate({
  66. set: {
  67. timeDeleted: null,
  68. },
  69. }),
  70. )
  71. })
  72. export const listDisabled = fn(z.void(), () => {
  73. return Database.use((db) =>
  74. db
  75. .select({ model: ModelTable.model })
  76. .from(ModelTable)
  77. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  78. .then((rows) => rows.map((row) => row.model)),
  79. )
  80. })
  81. export const isDisabled = fn(
  82. z.object({
  83. model: z.string(),
  84. }),
  85. ({ model }) => {
  86. return Database.use(async (db) => {
  87. const result = await db
  88. .select()
  89. .from(ModelTable)
  90. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  91. .limit(1)
  92. return result.length > 0
  93. })
  94. },
  95. )
  96. }