model.ts 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. export type Format = z.infer<typeof FormatSchema>
  12. const ModelCostSchema = z.object({
  13. input: z.number(),
  14. output: z.number(),
  15. cacheRead: z.number().optional(),
  16. cacheWrite5m: z.number().optional(),
  17. cacheWrite1h: z.number().optional(),
  18. })
  19. const ModelSchema = z.object({
  20. name: z.string(),
  21. cost: ModelCostSchema,
  22. cost200K: ModelCostSchema.optional(),
  23. allowAnonymous: z.boolean().optional(),
  24. stickyProvider: z.boolean().optional(),
  25. trial: z
  26. .object({
  27. limit: z.number(),
  28. provider: z.string(),
  29. })
  30. .optional(),
  31. rateLimit: z.number().optional(),
  32. fallbackProvider: z.string().optional(),
  33. providers: z.array(
  34. z.object({
  35. id: z.string(),
  36. model: z.string(),
  37. weight: z.number().optional(),
  38. disabled: z.boolean().optional(),
  39. storeModel: z.string().optional(),
  40. }),
  41. ),
  42. })
  43. const ProviderSchema = z.object({
  44. api: z.string(),
  45. apiKey: z.string(),
  46. format: FormatSchema,
  47. headerMappings: z.record(z.string(), z.string()).optional(),
  48. })
  49. const ModelsSchema = z.object({
  50. models: z.record(z.string(), ModelSchema),
  51. providers: z.record(z.string(), ProviderSchema),
  52. })
  53. export const validate = fn(ModelsSchema, (input) => {
  54. return input
  55. })
  56. export const list = fn(z.void(), () => {
  57. const json = JSON.parse(
  58. Resource.ZEN_MODELS1.value + Resource.ZEN_MODELS2.value + Resource.ZEN_MODELS3.value + Resource.ZEN_MODELS4.value,
  59. )
  60. return ModelsSchema.parse(json)
  61. })
  62. }
  63. export namespace Model {
  64. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  65. Actor.assertAdmin()
  66. return Database.use((db) =>
  67. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  68. )
  69. })
  70. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  71. Actor.assertAdmin()
  72. return Database.use((db) =>
  73. db
  74. .insert(ModelTable)
  75. .values({
  76. id: Identifier.create("model"),
  77. workspaceID: Actor.workspace(),
  78. model: model,
  79. })
  80. .onDuplicateKeyUpdate({
  81. set: {
  82. timeDeleted: null,
  83. },
  84. }),
  85. )
  86. })
  87. export const listDisabled = fn(z.void(), () => {
  88. return Database.use((db) =>
  89. db
  90. .select({ model: ModelTable.model })
  91. .from(ModelTable)
  92. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  93. .then((rows) => rows.map((row) => row.model)),
  94. )
  95. })
  96. export const isDisabled = fn(
  97. z.object({
  98. model: z.string(),
  99. }),
  100. ({ model }) => {
  101. return Database.use(async (db) => {
  102. const result = await db
  103. .select()
  104. .from(ModelTable)
  105. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  106. .limit(1)
  107. return result.length > 0
  108. })
  109. },
  110. )
  111. }