model.ts 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. export type Format = z.infer<typeof FormatSchema>
  12. const ModelCostSchema = z.object({
  13. input: z.number(),
  14. output: z.number(),
  15. cacheRead: z.number().optional(),
  16. cacheWrite5m: z.number().optional(),
  17. cacheWrite1h: z.number().optional(),
  18. })
  19. const ModelSchema = z.object({
  20. name: z.string(),
  21. cost: ModelCostSchema,
  22. cost200K: ModelCostSchema.optional(),
  23. allowAnonymous: z.boolean().optional(),
  24. trial: z
  25. .object({
  26. limit: z.number(),
  27. provider: z.string(),
  28. })
  29. .optional(),
  30. rateLimit: z.number().optional(),
  31. fallbackProvider: z.string().optional(),
  32. providers: z.array(
  33. z.object({
  34. id: z.string(),
  35. model: z.string(),
  36. weight: z.number().optional(),
  37. disabled: z.boolean().optional(),
  38. storeModel: z.string().optional(),
  39. }),
  40. ),
  41. })
  42. const ProviderSchema = z.object({
  43. api: z.string(),
  44. apiKey: z.string(),
  45. format: FormatSchema,
  46. headerMappings: z.record(z.string(), z.string()).optional(),
  47. })
  48. const ModelsSchema = z.object({
  49. models: z.record(z.string(), ModelSchema),
  50. providers: z.record(z.string(), ProviderSchema),
  51. })
  52. export const validate = fn(ModelsSchema, (input) => {
  53. return input
  54. })
  55. export const list = fn(z.void(), () => {
  56. const json = JSON.parse(
  57. Resource.ZEN_MODELS1.value + Resource.ZEN_MODELS2.value + Resource.ZEN_MODELS3.value + Resource.ZEN_MODELS4.value,
  58. )
  59. return ModelsSchema.parse(json)
  60. })
  61. }
  62. export namespace Model {
  63. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  64. Actor.assertAdmin()
  65. return Database.use((db) =>
  66. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  67. )
  68. })
  69. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  70. Actor.assertAdmin()
  71. return Database.use((db) =>
  72. db
  73. .insert(ModelTable)
  74. .values({
  75. id: Identifier.create("model"),
  76. workspaceID: Actor.workspace(),
  77. model: model,
  78. })
  79. .onDuplicateKeyUpdate({
  80. set: {
  81. timeDeleted: null,
  82. },
  83. }),
  84. )
  85. })
  86. export const listDisabled = fn(z.void(), () => {
  87. return Database.use((db) =>
  88. db
  89. .select({ model: ModelTable.model })
  90. .from(ModelTable)
  91. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  92. .then((rows) => rows.map((row) => row.model)),
  93. )
  94. })
  95. export const isDisabled = fn(
  96. z.object({
  97. model: z.string(),
  98. }),
  99. ({ model }) => {
  100. return Database.use(async (db) => {
  101. const result = await db
  102. .select()
  103. .from(ModelTable)
  104. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  105. .limit(1)
  106. return result.length > 0
  107. })
  108. },
  109. )
  110. }