Frank 5 months ago
parent
commit
a55943e469
1 changed files with 36 additions and 46 deletions
  1. 36 46
      packages/console/app/src/routes/zen/handler.ts

+ 36 - 46
packages/console/app/src/routes/zen/handler.ts

@@ -45,6 +45,7 @@ export async function handler(
   const ModelSchema = z.object({
     cost: ModelCostSchema,
     cost200K: ModelCostSchema.optional(),
+    allowAnonymous: z.boolean().optional(),
     providers: z.array(
       z.object({
         id: z.string(),
@@ -52,7 +53,6 @@ export async function handler(
         apiKey: z.string(),
         model: z.string(),
         weight: z.number().optional(),
-        allowAnonymous: z.boolean().optional(),
         headerMappings: z.record(z.string(), z.string()).optional(),
         disabled: z.boolean().optional(),
       }),
@@ -85,10 +85,10 @@ export async function handler(
       session: input.request.headers.get("x-opencode-session"),
       request: input.request.headers.get("x-opencode-request"),
     })
-    const authInfo = await authenticate()
-    const modelInfo = validateModel(body.model, authInfo)
-    const providerInfo = selectProvider(modelInfo, authInfo)
-    if (authInfo && !providerInfo.allowAnonymous) validateBilling(authInfo)
+    const modelInfo = validateModel(body.model)
+    const providerInfo = selectProvider(modelInfo)
+    const authInfo = await authenticate(modelInfo)
+    validateBilling(modelInfo, authInfo)
     logger.metric({ provider: providerInfo.id })
 
     // Request to model provider
@@ -221,16 +221,41 @@ export async function handler(
     )
   }
 
-  async function authenticate() {
+  function validateModel(reqModel: string) {
+    const json = JSON.parse(Resource.ZEN_MODELS.value)
+
+    const allModels = z.record(z.string(), ModelSchema).parse(json)
+
+    if (!(reqModel in allModels)) {
+      throw new ModelError(`Model ${reqModel} not supported`)
+    }
+    const modelId = reqModel as keyof typeof allModels
+    const modelData = allModels[modelId]
+
+    logger.metric({ model: modelId })
+
+    return { id: modelId, ...modelData }
+  }
+
+  function selectProvider(model: Model) {
+    const providers = model.providers
+      .filter((provider) => !provider.disabled)
+      .flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
+    return providers[Math.floor(Math.random() * providers.length)]
+  }
+
+  async function authenticate(model: Model) {
     const apiKey = opts.parseApiKey(input.request.headers)
-    if (!apiKey) return
+    if (!apiKey) {
+      if (model.allowAnonymous) return
+      throw new AuthError("Missing API key.")
+    }
 
     const data = await Database.use((tx) =>
       tx
         .select({
           apiKey: KeyTable.id,
           workspaceID: KeyTable.workspaceID,
-          dataShare: WorkspaceTable.dataShare,
           balance: BillingTable.balance,
           paymentMethodID: BillingTable.paymentMethodID,
           monthlyLimit: BillingTable.monthlyLimit,
@@ -255,7 +280,6 @@ export async function handler(
     return {
       apiKeyId: data.apiKey,
       workspaceID: data.workspaceID,
-      dataShare: data.dataShare,
       billing: {
         paymentMethodID: data.paymentMethodID,
         balance: data.balance,
@@ -267,8 +291,10 @@ export async function handler(
     }
   }
 
-  function validateBilling(authInfo: Awaited<ReturnType<typeof authenticate>>) {
+  function validateBilling(model: Model, authInfo: Awaited<ReturnType<typeof authenticate>>) {
     if (!authInfo || authInfo.isFree) return
+    if (model.allowAnonymous) return
+
     const billing = authInfo.billing
     if (!billing.paymentMethodID) throw new CreditsError("No payment method")
     if (billing.balance <= 0) throw new CreditsError("Insufficient balance")
@@ -288,42 +314,6 @@ export async function handler(
     }
   }
 
-  function validateModel(reqModel: string, authInfo: Awaited<ReturnType<typeof authenticate>>) {
-    const json = JSON.parse(Resource.ZEN_MODELS.value)
-
-    const allModels = z
-      .record(
-        z.string(),
-        z.object({
-          standard: ModelSchema,
-          dataShare: ModelSchema.optional(),
-        }),
-      )
-      .parse(json)
-
-    if (!(reqModel in allModels)) {
-      throw new ModelError(`Model ${reqModel} not supported`)
-    }
-    const modelId = reqModel as keyof typeof allModels
-    const modelData = authInfo?.dataShare
-      ? (allModels[modelId].dataShare ?? allModels[modelId].standard)
-      : allModels[modelId].standard
-    logger.metric({ model: modelId })
-    return { id: modelId, ...modelData }
-  }
-
-  function selectProvider(model: Model, authInfo: Awaited<ReturnType<typeof authenticate>>) {
-    let providers = model.providers.filter((provider) => !provider.disabled)
-
-    if (!authInfo) {
-      providers = providers.filter((provider) => provider.allowAnonymous)
-      if (providers.length === 0) throw new AuthError("Missing API key.")
-    }
-
-    const picks = providers.flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
-    return picks[Math.floor(Math.random() * picks.length)]
-  }
-
   async function trackUsage(
     authInfo: Awaited<ReturnType<typeof authenticate>>,
     modelInfo: ReturnType<typeof validateModel>,