model.ts 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import { z } from "zod"
  2. import { eq, and } from "drizzle-orm"
  3. import { Database } from "./drizzle"
  4. import { ModelTable } from "./schema/model.sql"
  5. import { Identifier } from "./identifier"
  6. import { fn } from "./util/fn"
  7. import { Actor } from "./actor"
  8. import { Resource } from "@opencode-ai/console-resource"
  9. export namespace ZenData {
  10. const FormatSchema = z.enum(["anthropic", "google", "openai", "oa-compat"])
  11. export type Format = z.infer<typeof FormatSchema>
  12. const ModelCostSchema = z.object({
  13. input: z.number(),
  14. output: z.number(),
  15. cacheRead: z.number().optional(),
  16. cacheWrite5m: z.number().optional(),
  17. cacheWrite1h: z.number().optional(),
  18. })
  19. const ModelSchema = z.object({
  20. name: z.string(),
  21. cost: ModelCostSchema,
  22. cost200K: ModelCostSchema.optional(),
  23. allowAnonymous: z.boolean().optional(),
  24. byokProvider: z.enum(["openai", "anthropic", "google"]).optional(),
  25. stickyProvider: z.boolean().optional(),
  26. trial: z
  27. .object({
  28. limit: z.number(),
  29. provider: z.string(),
  30. })
  31. .optional(),
  32. rateLimit: z.number().optional(),
  33. fallbackProvider: z.string().optional(),
  34. providers: z.array(
  35. z.object({
  36. id: z.string(),
  37. model: z.string(),
  38. weight: z.number().optional(),
  39. disabled: z.boolean().optional(),
  40. storeModel: z.string().optional(),
  41. }),
  42. ),
  43. })
  44. const ProviderSchema = z.object({
  45. api: z.string(),
  46. apiKey: z.string(),
  47. format: FormatSchema,
  48. headerMappings: z.record(z.string(), z.string()).optional(),
  49. })
  50. const ModelsSchema = z.object({
  51. models: z.record(z.string(), ModelSchema),
  52. providers: z.record(z.string(), ProviderSchema),
  53. })
  54. export const validate = fn(ModelsSchema, (input) => {
  55. return input
  56. })
  57. export const list = fn(z.void(), () => {
  58. const json = JSON.parse(
  59. Resource.ZEN_MODELS1.value + Resource.ZEN_MODELS2.value + Resource.ZEN_MODELS3.value + Resource.ZEN_MODELS4.value,
  60. )
  61. return ModelsSchema.parse(json)
  62. })
  63. }
  64. export namespace Model {
  65. export const enable = fn(z.object({ model: z.string() }), ({ model }) => {
  66. Actor.assertAdmin()
  67. return Database.use((db) =>
  68. db.delete(ModelTable).where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model))),
  69. )
  70. })
  71. export const disable = fn(z.object({ model: z.string() }), ({ model }) => {
  72. Actor.assertAdmin()
  73. return Database.use((db) =>
  74. db
  75. .insert(ModelTable)
  76. .values({
  77. id: Identifier.create("model"),
  78. workspaceID: Actor.workspace(),
  79. model: model,
  80. })
  81. .onDuplicateKeyUpdate({
  82. set: {
  83. timeDeleted: null,
  84. },
  85. }),
  86. )
  87. })
  88. export const listDisabled = fn(z.void(), () => {
  89. return Database.use((db) =>
  90. db
  91. .select({ model: ModelTable.model })
  92. .from(ModelTable)
  93. .where(eq(ModelTable.workspaceID, Actor.workspace()))
  94. .then((rows) => rows.map((row) => row.model)),
  95. )
  96. })
  97. export const isDisabled = fn(
  98. z.object({
  99. model: z.string(),
  100. }),
  101. ({ model }) => {
  102. return Database.use(async (db) => {
  103. const result = await db
  104. .select()
  105. .from(ModelTable)
  106. .where(and(eq(ModelTable.workspaceID, Actor.workspace()), eq(ModelTable.model, model)))
  107. .limit(1)
  108. return result.length > 0
  109. })
  110. },
  111. )
  112. }