Frank 4 месяцев назад
Родитель
Сommit
c93c0d402d
1 измененных файлов с 27 добавлено и 3 удалено
  1. 27 3
      packages/console/app/src/routes/zen/handler.ts

+ 27 - 3
packages/console/app/src/routes/zen/handler.ts

@@ -13,6 +13,7 @@ import { WorkspaceTable } from "@opencode-ai/console-core/schema/workspace.sql.j
 import { ZenModel } from "@opencode-ai/console-core/model.js"
 import { UserTable } from "@opencode-ai/console-core/schema/user.sql.js"
 import { ModelTable } from "@opencode-ai/console-core/schema/model.sql.js"
+import { ProviderTable } from "@opencode-ai/console-core/schema/provider.sql.js"
 
 export async function handler(
   input: APIEvent,
@@ -67,9 +68,10 @@ export async function handler(
     })
     const modelInfo = validateModel(body.model)
     const providerInfo = selectProvider(modelInfo)
-    const authInfo = await authenticate(modelInfo)
+    const authInfo = await authenticate(modelInfo, providerInfo)
     validateBilling(modelInfo, authInfo)
     validateModelSettings(authInfo)
+    updateProviderKey(authInfo, providerInfo)
     logger.metric({ provider: providerInfo.id })
 
     // Request to model provider
@@ -232,7 +234,10 @@ export async function handler(
     return providers[Math.floor(Math.random() * providers.length)]
   }
 
-  async function authenticate(model: Awaited<ReturnType<typeof validateModel>>) {
+  async function authenticate(
+    model: Awaited<ReturnType<typeof validateModel>>,
+    providerInfo: Awaited<ReturnType<typeof selectProvider>>,
+  ) {
     const apiKey = opts.parseApiKey(input.request.headers)
     if (!apiKey) {
       if (model.allowAnonymous) return
@@ -257,6 +262,9 @@ export async function handler(
             monthlyUsage: UserTable.monthlyUsage,
             timeMonthlyUsageUpdated: UserTable.timeMonthlyUsageUpdated,
           },
+          provider: {
+            credentials: ProviderTable.credentials,
+          },
           timeDisabled: ModelTable.timeCreated,
         })
         .from(KeyTable)
@@ -264,6 +272,10 @@ export async function handler(
         .innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
         .innerJoin(UserTable, and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID)))
         .leftJoin(ModelTable, and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, model.id)))
+        .leftJoin(
+          ProviderTable,
+          and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, providerInfo.id)),
+        )
         .where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
         .then((rows) => rows[0]),
     )
@@ -279,6 +291,7 @@ export async function handler(
       workspaceID: data.workspaceID,
       billing: data.billing,
       user: data.user,
+      provider: data.provider,
       isFree: FREE_WORKSPACES.includes(data.workspaceID),
       isDisabled: !!data.timeDisabled,
     }
@@ -327,6 +340,15 @@ export async function handler(
     if (authInfo.isDisabled) throw new ModelError("Model is disabled")
   }
 
+  function updateProviderKey(
+    authInfo: Awaited<ReturnType<typeof authenticate>>,
+    providerInfo: Awaited<ReturnType<typeof selectProvider>>,
+  ) {
+    if (!authInfo) return
+    if (!authInfo.provider?.credentials) return
+    providerInfo.apiKey = authInfo.provider.credentials
+  }
+
   async function trackUsage(
     authInfo: Awaited<ReturnType<typeof authenticate>>,
     modelInfo: ReturnType<typeof validateModel>,
@@ -389,7 +411,7 @@ export async function handler(
 
     if (!authInfo) return
 
-    const cost = authInfo.isFree ? 0 : centsToMicroCents(totalCostInCent)
+    const cost = authInfo.isFree || authInfo.provider?.credentials ? 0 : centsToMicroCents(totalCostInCent)
     await Database.transaction(async (tx) => {
       await tx.insert(UsageTable).values({
         workspaceID: authInfo.workspaceID,
@@ -441,6 +463,8 @@ export async function handler(
 
   async function reload(authInfo: Awaited<ReturnType<typeof authenticate>>) {
     if (!authInfo) return
+    if (authInfo.isFree) return
+    if (authInfo.provider?.credentials) return
 
     const lock = await Database.use((tx) =>
       tx