Frank 4 месяцев назад
Родитель
Сommit
88fef05923

+ 2 - 2
packages/console/app/src/routes/workspace/[id]/model-section.tsx

@@ -2,7 +2,7 @@ import { Model } from "@opencode-ai/console-core/model.js"
 import { query, action, useParams, createAsync, json } from "@solidjs/router"
 import { query, action, useParams, createAsync, json } from "@solidjs/router"
 import { createMemo, For, Show } from "solid-js"
 import { createMemo, For, Show } from "solid-js"
 import { withActor } from "~/context/auth.withActor"
 import { withActor } from "~/context/auth.withActor"
-import { ZenModel } from "@opencode-ai/console-core/model.js"
+import { ZenData } from "@opencode-ai/console-core/model.js"
 import styles from "./model-section.module.css"
 import styles from "./model-section.module.css"
 import { querySessionInfo } from "../common"
 import { querySessionInfo } from "../common"
 import { IconAlibaba, IconAnthropic, IconMoonshotAI, IconOpenAI, IconStealth, IconXai, IconZai } from "~/component/icon"
 import { IconAlibaba, IconAnthropic, IconMoonshotAI, IconOpenAI, IconStealth, IconXai, IconZai } from "~/component/icon"
@@ -21,7 +21,7 @@ const getModelsInfo = query(async (workspaceID: string) => {
   "use server"
   "use server"
   return withActor(async () => {
   return withActor(async () => {
     return {
     return {
-      all: Object.entries(ZenModel.list())
+      all: Object.entries(ZenData.list().models)
         .filter(([id, _model]) => !["claude-3-5-haiku"].includes(id))
         .filter(([id, _model]) => !["claude-3-5-haiku"].includes(id))
         .filter(([id, _model]) => !id.startsWith("an-"))
         .filter(([id, _model]) => !id.startsWith("an-"))
         .sort(([_idA, modelA], [_idB, modelB]) => modelA.name.localeCompare(modelB.name))
         .sort(([_idA, modelA], [_idB, modelB]) => modelA.name.localeCompare(modelB.name))

+ 18 - 14
packages/console/app/src/routes/zen/handler.ts

@@ -10,7 +10,7 @@ import { Resource } from "@opencode-ai/console-resource"
 import { Billing } from "../../../../core/src/billing"
 import { Billing } from "../../../../core/src/billing"
 import { Actor } from "@opencode-ai/console-core/actor.js"
 import { Actor } from "@opencode-ai/console-core/actor.js"
 import { WorkspaceTable } from "@opencode-ai/console-core/schema/workspace.sql.js"
 import { WorkspaceTable } from "@opencode-ai/console-core/schema/workspace.sql.js"
-import { ZenModel } from "@opencode-ai/console-core/model.js"
+import { ZenData } from "@opencode-ai/console-core/model.js"
 import { UserTable } from "@opencode-ai/console-core/schema/user.sql.js"
 import { UserTable } from "@opencode-ai/console-core/schema/user.sql.js"
 import { ModelTable } from "@opencode-ai/console-core/schema/model.sql.js"
 import { ModelTable } from "@opencode-ai/console-core/schema/model.sql.js"
 import { ProviderTable } from "@opencode-ai/console-core/schema/provider.sql.js"
 import { ProviderTable } from "@opencode-ai/console-core/schema/provider.sql.js"
@@ -39,7 +39,8 @@ export async function handler(
   class UserLimitError extends Error {}
   class UserLimitError extends Error {}
   class ModelError extends Error {}
   class ModelError extends Error {}
 
 
-  type Model = z.infer<typeof ZenModel.ModelSchema>
+  type ZenData = Awaited<ReturnType<typeof ZenData.list>>
+  type Model = ZenData["models"][string]
 
 
   const FREE_WORKSPACES = [
   const FREE_WORKSPACES = [
     "wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank
     "wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank
@@ -66,8 +67,9 @@ export async function handler(
       session: input.request.headers.get("x-opencode-session"),
       session: input.request.headers.get("x-opencode-session"),
       request: input.request.headers.get("x-opencode-request"),
       request: input.request.headers.get("x-opencode-request"),
     })
     })
-    const modelInfo = validateModel(body.model)
-    const providerInfo = selectProvider(modelInfo)
+    const zenData = ZenData.list()
+    const modelInfo = validateModel(zenData, body.model)
+    const providerInfo = selectProvider(zenData, modelInfo)
     const authInfo = await authenticate(modelInfo, providerInfo)
     const authInfo = await authenticate(modelInfo, providerInfo)
     validateBilling(modelInfo, authInfo)
     validateBilling(modelInfo, authInfo)
     validateModelSettings(authInfo)
     validateModelSettings(authInfo)
@@ -211,27 +213,29 @@ export async function handler(
     )
     )
   }
   }
 
 
-  function validateModel(reqModel: string) {
-    const json = JSON.parse(Resource.ZEN_MODELS.value)
-
-    const allModels = ZenModel.ModelsSchema.parse(json)
-
-    if (!(reqModel in allModels)) {
+  function validateModel(zenData: ZenData, reqModel: string) {
+    if (!(reqModel in zenData.models)) {
       throw new ModelError(`Model ${reqModel} not supported`)
       throw new ModelError(`Model ${reqModel} not supported`)
     }
     }
-    const modelId = reqModel as keyof typeof allModels
-    const modelData = allModels[modelId]
+    const modelId = reqModel as keyof typeof zenData.models
+    const modelData = zenData.models[modelId]
 
 
     logger.metric({ model: modelId })
     logger.metric({ model: modelId })
 
 
     return { id: modelId, ...modelData }
     return { id: modelId, ...modelData }
   }
   }
 
 
-  function selectProvider(model: Awaited<ReturnType<typeof validateModel>>) {
+  function selectProvider(zenData: ZenData, model: Awaited<ReturnType<typeof validateModel>>) {
     const providers = model.providers
     const providers = model.providers
       .filter((provider) => !provider.disabled)
       .filter((provider) => !provider.disabled)
       .flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
       .flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
-    return providers[Math.floor(Math.random() * providers.length)]
+    const provider = providers[Math.floor(Math.random() * providers.length)]
+
+    if (!(provider.id in zenData.providers)) {
+      throw new ModelError(`Provider ${provider.id} not supported`)
+    }
+
+    return { ...provider, ...zenData.providers[provider.id] }
   }
   }
 
 
   async function authenticate(
   async function authenticate(

+ 2 - 2
packages/console/core/script/promote-models.ts

@@ -2,7 +2,7 @@
 
 
 import { $ } from "bun"
 import { $ } from "bun"
 import path from "path"
 import path from "path"
-import { ZenModel } from "../src/model"
+import { ZenData } from "../src/model"
 
 
 const stage = process.argv[2]
 const stage = process.argv[2]
 if (!stage) throw new Error("Stage is required")
 if (!stage) throw new Error("Stage is required")
@@ -18,7 +18,7 @@ const value = ret
 if (!value) throw new Error("ZEN_MODELS not found")
 if (!value) throw new Error("ZEN_MODELS not found")
 
 
 // validate value
 // validate value
-ZenModel.ModelsSchema.parse(JSON.parse(value))
+ZenData.validate(JSON.parse(value))
 
 
 // update the secret
 // update the secret
 await $`bun sst secret set ZEN_MODELS ${value} --stage ${stage}`
 await $`bun sst secret set ZEN_MODELS ${value} --stage ${stage}`

+ 2 - 2
packages/console/core/script/update-models.ts

@@ -3,7 +3,7 @@
 import { $ } from "bun"
 import { $ } from "bun"
 import path from "path"
 import path from "path"
 import os from "os"
 import os from "os"
-import { ZenModel } from "../src/model"
+import { ZenData } from "../src/model"
 
 
 const root = path.resolve(process.cwd(), "..", "..", "..")
 const root = path.resolve(process.cwd(), "..", "..", "..")
 const models = await $`bun sst secret list`.cwd(root).text()
 const models = await $`bun sst secret list`.cwd(root).text()
@@ -26,7 +26,7 @@ console.log("tempFile", tempFile.name)
 // open temp file in vim and read the file on close
 // open temp file in vim and read the file on close
 await $`vim ${tempFile.name}`
 await $`vim ${tempFile.name}`
 const newValue = JSON.parse(await tempFile.text())
 const newValue = JSON.parse(await tempFile.text())
-ZenModel.ModelsSchema.parse(newValue)
+ZenData.validate(newValue)
 
 
 // update the secret
 // update the secret
 await $`bun sst secret set ZEN_MODELS ${JSON.stringify(newValue)}`
 await $`bun sst secret set ZEN_MODELS ${JSON.stringify(newValue)}`

+ 20 - 7
packages/console/core/src/model.ts

@@ -7,7 +7,7 @@ import { fn } from "./util/fn"
 import { Actor } from "./actor"
 import { Actor } from "./actor"
 import { Resource } from "@opencode-ai/console-resource"
 import { Resource } from "@opencode-ai/console-resource"
 
 
-export namespace ZenModel {
+export namespace ZenData {
   const ModelCostSchema = z.object({
   const ModelCostSchema = z.object({
     input: z.number(),
     input: z.number(),
     output: z.number(),
     output: z.number(),
@@ -16,7 +16,7 @@ export namespace ZenModel {
     cacheWrite1h: z.number().optional(),
     cacheWrite1h: z.number().optional(),
   })
   })
 
 
-  export const ModelSchema = z.object({
+  const ModelSchema = z.object({
     name: z.string(),
     name: z.string(),
     cost: ModelCostSchema,
     cost: ModelCostSchema,
     cost200K: ModelCostSchema.optional(),
     cost200K: ModelCostSchema.optional(),
@@ -24,19 +24,32 @@ export namespace ZenModel {
     providers: z.array(
     providers: z.array(
       z.object({
       z.object({
         id: z.string(),
         id: z.string(),
-        api: z.string(),
-        apiKey: z.string(),
         model: z.string(),
         model: z.string(),
         weight: z.number().optional(),
         weight: z.number().optional(),
-        headerMappings: z.record(z.string(), z.string()).optional(),
         disabled: z.boolean().optional(),
         disabled: z.boolean().optional(),
       }),
       }),
     ),
     ),
   })
   })
 
 
-  export const ModelsSchema = z.record(z.string(), ModelSchema)
+  const ProviderSchema = z.object({
+    api: z.string(),
+    apiKey: z.string(),
+    headerMappings: z.record(z.string(), z.string()).optional(),
+  })
+
+  const ModelsSchema = z.object({
+    models: z.record(z.string(), ModelSchema),
+    providers: z.record(z.string(), ProviderSchema),
+  })
 
 
-  export const list = fn(z.void(), () => ModelsSchema.parse(JSON.parse(Resource.ZEN_MODELS.value)))
+  export const validate = fn(ModelsSchema, (input) => {
+    return input
+  })
+
+  export const list = fn(z.void(), () => {
+    const json = JSON.parse(Resource.ZEN_MODELS.value)
+    return ModelsSchema.parse(json)
+  })
 }
 }
 
 
 export namespace Model {
 export namespace Model {